polars_io/cloud/
polars_object_store.rs

1use std::ops::Range;
2
3use bytes::Bytes;
4use futures::{StreamExt, TryStreamExt};
5use hashbrown::hash_map::RawEntryMut;
6use object_store::path::Path;
7use object_store::{ObjectMeta, ObjectStore};
8use polars_core::prelude::{InitHashMaps, PlHashMap};
9use polars_error::{PolarsError, PolarsResult};
10use polars_utils::mmap::MemSlice;
11use tokio::io::{AsyncSeekExt, AsyncWriteExt};
12
13use crate::pl_async::{
14    self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
15    tune_with_concurrency_budget, with_concurrency_budget,
16};
17
18mod inner {
19    use std::future::Future;
20    use std::sync::Arc;
21    use std::sync::atomic::AtomicBool;
22
23    use object_store::ObjectStore;
24    use polars_core::config;
25    use polars_error::PolarsResult;
26
27    use crate::cloud::PolarsObjectStoreBuilder;
28
29    #[derive(Debug)]
30    struct Inner {
31        store: tokio::sync::Mutex<Arc<dyn ObjectStore>>,
32        builder: PolarsObjectStoreBuilder,
33    }
34
35    /// Polars wrapper around [`ObjectStore`] functionality. This struct is cheaply cloneable.
36    #[derive(Debug)]
37    pub struct PolarsObjectStore {
38        inner: Arc<Inner>,
39        /// Avoid contending the Mutex `lock()` until the first re-build.
40        initial_store: std::sync::Arc<dyn ObjectStore>,
41        /// Used for interior mutability. Doesn't need to be shared with other threads so it's not
42        /// inside `Arc<>`.
43        rebuilt: AtomicBool,
44    }
45
46    impl Clone for PolarsObjectStore {
47        fn clone(&self) -> Self {
48            Self {
49                inner: self.inner.clone(),
50                initial_store: self.initial_store.clone(),
51                rebuilt: AtomicBool::new(self.rebuilt.load(std::sync::atomic::Ordering::Relaxed)),
52            }
53        }
54    }
55
56    impl PolarsObjectStore {
57        pub(crate) fn new_from_inner(
58            store: Arc<dyn ObjectStore>,
59            builder: PolarsObjectStoreBuilder,
60        ) -> Self {
61            let initial_store = store.clone();
62            Self {
63                inner: Arc::new(Inner {
64                    store: tokio::sync::Mutex::new(store),
65                    builder,
66                }),
67                initial_store,
68                rebuilt: AtomicBool::new(false),
69            }
70        }
71
72        /// Gets the underlying [`ObjectStore`] implementation.
73        pub async fn to_dyn_object_store(&self) -> Arc<dyn ObjectStore> {
74            if !self.rebuilt.load(std::sync::atomic::Ordering::Relaxed) {
75                self.initial_store.clone()
76            } else {
77                self.inner.store.lock().await.clone()
78            }
79        }
80
81        pub async fn rebuild_inner(
82            &self,
83            from_version: &Arc<dyn ObjectStore>,
84        ) -> PolarsResult<Arc<dyn ObjectStore>> {
85            let mut current_store = self.inner.store.lock().await;
86
87            self.rebuilt
88                .store(true, std::sync::atomic::Ordering::Relaxed);
89
90            // If this does not eq, then `inner` was already re-built by another thread.
91            if Arc::ptr_eq(&*current_store, from_version) {
92                *current_store = self.inner.builder.clone().build_impl().await.map_err(|e| {
93                    e.wrap_msg(|e| format!("attempt to rebuild object store failed: {e}"))
94                })?;
95            }
96
97            Ok((*current_store).clone())
98        }
99
100        pub async fn try_exec_rebuild_on_err<Fn, Fut, O>(&self, mut func: Fn) -> PolarsResult<O>
101        where
102            Fn: FnMut(&Arc<dyn ObjectStore>) -> Fut,
103            Fut: Future<Output = PolarsResult<O>>,
104        {
105            let store = self.to_dyn_object_store().await;
106
107            let out = func(&store).await;
108
109            let orig_err = match out {
110                Ok(v) => return Ok(v),
111                Err(e) => e,
112            };
113
114            if config::verbose() {
115                eprintln!(
116                    "[PolarsObjectStore]: got error: {}, will attempt re-build",
117                    &orig_err
118                );
119            }
120
121            let store = self
122                .rebuild_inner(&store)
123                .await
124                .map_err(|e| e.wrap_msg(|e| format!("{e}; original error: {orig_err}")))?;
125
126            func(&store).await.map_err(|e| {
127                if self.inner.builder.is_azure()
128                    && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
129                        != Ok("1")
130                {
131                    // Note: This error is intended for Python audiences. The logic for retrieving
132                    // these keys exist only on the Python side.
133                    e.wrap_msg(|e| {
134                        format!(
135                            "{e}; note: if you are using Python, consider setting \
136POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
137and use the storage account keys from Azure CLI to authenticate"
138                        )
139                    })
140                } else {
141                    e
142                }
143            })
144        }
145    }
146}
147
148pub use inner::PolarsObjectStore;
149
150pub type ObjectStorePath = object_store::path::Path;
151
152impl PolarsObjectStore {
153    /// Returns a buffered stream that downloads concurrently up to the concurrency limit.
154    fn get_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
155        store: &'a dyn ObjectStore,
156        path: &'a Path,
157        ranges: T,
158    ) -> impl StreamExt<Item = PolarsResult<Bytes>>
159    + TryStreamExt<Ok = Bytes, Error = PolarsError, Item = PolarsResult<Bytes>>
160    + use<'a, T> {
161        futures::stream::iter(ranges.map(move |range| async move {
162            let out = store
163                .get_range(path, range.start as u64..range.end as u64)
164                .await?;
165            Ok(out)
166        }))
167        // Add a limit locally as this gets run inside a single `tune_with_concurrency_budget`.
168        .buffered(get_concurrency_limit() as usize)
169    }
170
171    pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Bytes> {
172        self.try_exec_rebuild_on_err(move |store| {
173            let range = range.clone();
174            let st = store.clone();
175
176            async move {
177                let store = st;
178                let parts = split_range(range.clone());
179
180                if parts.len() == 1 {
181                    let out = tune_with_concurrency_budget(1, move || async move {
182                        store
183                            .get_range(path, range.start as u64..range.end as u64)
184                            .await
185                    })
186                    .await?;
187
188                    Ok(out)
189                } else {
190                    let parts = tune_with_concurrency_budget(
191                        parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
192                        || {
193                            Self::get_buffered_ranges_stream(&store, path, parts)
194                                .try_collect::<Vec<Bytes>>()
195                        },
196                    )
197                    .await?;
198
199                    let mut combined = Vec::with_capacity(range.len());
200
201                    for part in parts {
202                        combined.extend_from_slice(&part)
203                    }
204
205                    assert_eq!(combined.len(), range.len());
206
207                    PolarsResult::Ok(Bytes::from(combined))
208                }
209            }
210        })
211        .await
212    }
213
214    /// Fetch byte ranges into a HashMap keyed by the range start. This will mutably sort the
215    /// `ranges` slice for coalescing.
216    ///
217    /// # Panics
218    /// Panics if the same range start is used by more than 1 range.
219    pub async fn get_ranges_sort(
220        &self,
221        path: &Path,
222        ranges: &mut [Range<usize>],
223    ) -> PolarsResult<PlHashMap<usize, MemSlice>> {
224        if ranges.is_empty() {
225            return Ok(Default::default());
226        }
227
228        ranges.sort_unstable_by_key(|x| x.start);
229
230        let ranges_len = ranges.len();
231        let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
232
233        self.try_exec_rebuild_on_err(|store| {
234            let st = store.clone();
235
236            async {
237                let store = st;
238                let mut out = PlHashMap::with_capacity(ranges_len);
239
240                let mut stream =
241                    Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned());
242
243                tune_with_concurrency_budget(
244                    merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
245                    || async {
246                        let mut len = 0;
247                        let mut current_offset = 0;
248                        let mut ends_iter = merged_ends.iter();
249
250                        let mut splitted_parts = vec![];
251
252                        while let Some(bytes) = stream.try_next().await? {
253                            len += bytes.len();
254                            let end = *ends_iter.next().unwrap();
255
256                            if end == 0 {
257                                splitted_parts.push(bytes);
258                                continue;
259                            }
260
261                            let full_range = ranges[current_offset..end]
262                                .iter()
263                                .cloned()
264                                .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
265                                .unwrap();
266
267                            let bytes = if splitted_parts.is_empty() {
268                                bytes
269                            } else {
270                                let mut out = Vec::with_capacity(full_range.len());
271
272                                for x in splitted_parts.drain(..) {
273                                    out.extend_from_slice(&x);
274                                }
275
276                                out.extend_from_slice(&bytes);
277                                Bytes::from(out)
278                            };
279
280                            assert_eq!(bytes.len(), full_range.len());
281
282                            let bytes = MemSlice::from_bytes(bytes);
283
284                            for range in &ranges[current_offset..end] {
285                                let mem_slice = bytes.slice(
286                                    range.start - full_range.start..range.end - full_range.start,
287                                );
288
289                                match out.raw_entry_mut().from_key(&range.start) {
290                                    RawEntryMut::Vacant(slot) => {
291                                        slot.insert(range.start, mem_slice);
292                                    },
293                                    RawEntryMut::Occupied(mut slot) => {
294                                        if slot.get_mut().len() < mem_slice.len() {
295                                            *slot.get_mut() = mem_slice;
296                                        }
297                                    },
298                                }
299                            }
300
301                            current_offset = end;
302                        }
303
304                        assert!(splitted_parts.is_empty());
305
306                        PolarsResult::Ok(pl_async::Size::from(len as u64))
307                    },
308                )
309                .await?;
310
311                Ok(out)
312            }
313        })
314        .await
315    }
316
317    pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
318        let opt_size = self.head(path).await.ok().map(|x| x.size);
319
320        let initial_pos = file.stream_position().await?;
321
322        self.try_exec_rebuild_on_err(|store| {
323            let st = store.clone();
324
325            // Workaround for "can't move captured variable".
326            let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) };
327
328            async {
329                file.set_len(initial_pos).await?; // Reset if this function was called again.
330
331                let store = st;
332                let parts = opt_size
333                    .map(|x| split_range(0..x as usize))
334                    .filter(|x| x.len() > 1);
335
336                if let Some(parts) = parts {
337                    tune_with_concurrency_budget(
338                        parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
339                        || async {
340                            let mut stream = Self::get_buffered_ranges_stream(&store, path, parts);
341                            let mut len = 0;
342                            while let Some(bytes) = stream.try_next().await? {
343                                len += bytes.len();
344                                file.write_all(&bytes).await?;
345                            }
346
347                            assert_eq!(len, opt_size.unwrap() as usize);
348
349                            PolarsResult::Ok(pl_async::Size::from(len as u64))
350                        },
351                    )
352                    .await?
353                } else {
354                    tune_with_concurrency_budget(1, || async {
355                        let mut stream = store.get(path).await?.into_stream();
356
357                        let mut len = 0;
358                        while let Some(bytes) = stream.try_next().await? {
359                            len += bytes.len();
360                            file.write_all(&bytes).await?;
361                        }
362
363                        PolarsResult::Ok(pl_async::Size::from(len as u64))
364                    })
365                    .await?
366                };
367
368                // Dropping is delayed for tokio async files so we need to explicitly
369                // flush here (https://github.com/tokio-rs/tokio/issues/2307#issuecomment-596336451).
370                file.sync_all().await.map_err(PolarsError::from)?;
371
372                Ok(())
373            }
374        })
375        .await
376    }
377
378    /// Fetch the metadata of the parquet file, do not memoize it.
379    pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
380        self.try_exec_rebuild_on_err(|store| {
381            let st = store.clone();
382
383            async {
384                with_concurrency_budget(1, || async {
385                    let store = st;
386                    let head_result = store.head(path).await;
387
388                    if head_result.is_err() {
389                        // Pre-signed URLs forbid the HEAD method, but we can still retrieve the header
390                        // information with a range 0-0 request.
391                        let get_range_0_0_result = store
392                            .get_opts(
393                                path,
394                                object_store::GetOptions {
395                                    range: Some((0..1).into()),
396                                    ..Default::default()
397                                },
398                            )
399                            .await;
400
401                        if let Ok(v) = get_range_0_0_result {
402                            return Ok(v.meta);
403                        }
404                    }
405
406                    let out = head_result?;
407
408                    Ok(out)
409                })
410                .await
411            }
412        })
413        .await
414    }
415}
416
417/// Splits a single range into multiple smaller ranges, which can be downloaded concurrently for
418/// much higher throughput.
419fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
420    let chunk_size = get_download_chunk_size();
421
422    // Calculate n_parts such that we are as close as possible to the `chunk_size`.
423    let n_parts = [
424        (range.len().div_ceil(chunk_size)).max(1),
425        (range.len() / chunk_size).max(1),
426    ]
427    .into_iter()
428    .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
429    .unwrap();
430
431    let chunk_size = (range.len() / n_parts).max(1);
432
433    assert_eq!(n_parts, (range.len() / chunk_size).max(1));
434    let bytes_rem = range.len() % chunk_size;
435
436    (0..n_parts).map(move |part_no| {
437        let (start, end) = if part_no == 0 {
438            // Download remainder length in the first chunk since it starts downloading first.
439            let end = range.start + chunk_size + bytes_rem;
440            let end = if end > range.end { range.end } else { end };
441            (range.start, end)
442        } else {
443            let start = bytes_rem + range.start + part_no * chunk_size;
444            (start, start + chunk_size)
445        };
446
447        start..end
448    })
449}
450
451/// Note: For optimal performance, `ranges` should be sorted. More generally,
452/// ranges placed next to each other should also be close in range value.
453///
454/// # Returns
455/// `[(range1, end1), (range2, end2)]`, where:
456/// * `range1` contains bytes for the ranges from `ranges[0..end1]`
457/// * `range2` contains bytes for the ranges from `ranges[end1..end2]`
458/// * etc..
459///
460/// Note that if an end value is 0, it means the range is a splitted part and should be combined.
461fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
462    let chunk_size = get_download_chunk_size();
463
464    let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
465    // Number of fetched bytes excluding excess.
466    let mut current_n_bytes = current_merged_range.len();
467
468    (0..ranges.len())
469        .filter_map(move |current_idx| {
470            let current_idx = 1 + current_idx;
471
472            if current_idx == ranges.len() {
473                // No more items - flush current state.
474                Some((current_merged_range.clone(), current_idx))
475            } else {
476                let range = ranges[current_idx].clone();
477
478                let new_merged = current_merged_range.start.min(range.start)
479                    ..current_merged_range.end.max(range.end);
480
481                // E.g.:
482                // |--------|
483                //  oo        // range1
484                //       oo   // range2
485                //    ^^^     // distance = 3, is_overlapping = false
486                // E.g.:
487                // |--------|
488                //  ooooo     // range1
489                //     ooooo  // range2
490                //     ^^     // distance = 2, is_overlapping = true
491                let (distance, is_overlapping) = {
492                    let l = current_merged_range.end.min(range.end);
493                    let r = current_merged_range.start.max(range.start);
494
495                    (r.abs_diff(l), r < l)
496                };
497
498                let should_merge = is_overlapping || {
499                    let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
500                        <= current_merged_range.len().abs_diff(chunk_size);
501                    let gap_tolerance =
502                        (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
503
504                    leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
505                };
506
507                if should_merge {
508                    // Merge to existing range
509                    current_merged_range = new_merged;
510                    current_n_bytes += if is_overlapping {
511                        range.len() - distance
512                    } else {
513                        range.len()
514                    };
515                    None
516                } else {
517                    let out = (current_merged_range.clone(), current_idx);
518                    current_merged_range = range;
519                    current_n_bytes = current_merged_range.len();
520                    Some(out)
521                }
522            }
523        })
524        .flat_map(|x| {
525            // Split large individual ranges within the list of ranges.
526            let (range, end) = x;
527            let split = split_range(range.clone());
528            let len = split.len();
529
530            split
531                .enumerate()
532                .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
533        })
534}
535
536#[cfg(test)]
537mod tests {
538
539    #[test]
540    fn test_split_range() {
541        use super::{get_download_chunk_size, split_range};
542
543        let chunk_size = get_download_chunk_size();
544
545        assert_eq!(chunk_size, 64 * 1024 * 1024);
546
547        #[allow(clippy::single_range_in_vec_init)]
548        {
549            // Round-trip empty ranges.
550            assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
551            assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
552        }
553
554        // Threshold to start splitting to 2 ranges
555        //
556        // n - chunk_size == chunk_size - n / 2
557        // n + n / 2 == 2 * chunk_size
558        // 3 * n == 4 * chunk_size
559        // n = 4 * chunk_size / 3
560        let n = 4 * chunk_size / 3;
561
562        #[allow(clippy::single_range_in_vec_init)]
563        {
564            assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
565        }
566
567        assert_eq!(
568            split_range(0..n + 1).collect::<Vec<_>>(),
569            [0..44739243, 44739243..89478486]
570        );
571
572        // Threshold to start splitting to 3 ranges
573        //
574        // n / 2 - chunk_size == chunk_size - n / 3
575        // n / 2 + n / 3 == 2 * chunk_size
576        // 5 * n == 12 * chunk_size
577        // n == 12 * chunk_size / 5
578        let n = 12 * chunk_size / 5;
579
580        assert_eq!(
581            split_range(0..n).collect::<Vec<_>>(),
582            [0..80530637, 80530637..161061273]
583        );
584
585        assert_eq!(
586            split_range(0..n + 1).collect::<Vec<_>>(),
587            [0..53687092, 53687092..107374183, 107374183..161061274]
588        );
589    }
590
591    #[test]
592    fn test_merge_ranges() {
593        use super::{get_download_chunk_size, merge_ranges};
594
595        let chunk_size = get_download_chunk_size();
596
597        assert_eq!(chunk_size, 64 * 1024 * 1024);
598
599        // Round-trip empty slice
600        assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
601
602        // We have 1 tiny request followed by 1 huge request. They are combined as it reduces the
603        // `abs_diff()` to the `chunk_size`, but afterwards they are split to 2 evenly sized
604        // requests.
605        assert_eq!(
606            merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
607            [(0..66584576, 0), (66584576..133169152, 2)]
608        );
609
610        // <= 1MiB gap, merge
611        assert_eq!(
612            merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
613            [(0..1048578, 2)]
614        );
615
616        // > 1MiB gap, do not merge
617        assert_eq!(
618            merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
619            [(0..1, 1), (1048578..1048579, 2)]
620        );
621
622        // <= 12.5% gap, merge
623        assert_eq!(
624            merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
625            [(0..11, 2)]
626        );
627
628        // <= 12.5% gap relative to RHS, merge
629        assert_eq!(
630            merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
631            [(0..11, 2)]
632        );
633
634        // Overlapping range, merge
635        assert_eq!(
636            merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
637                .collect::<Vec<_>>(),
638            [(0..80 * 1024 * 1024, 2)]
639        );
640    }
641}