polars_io/cloud/
polars_object_store.rs

1use std::ops::Range;
2
3use bytes::Bytes;
4use futures::{StreamExt, TryStreamExt};
5use object_store::path::Path;
6use object_store::{ObjectMeta, ObjectStore};
7use polars_core::prelude::{InitHashMaps, PlHashMap};
8use polars_error::{PolarsError, PolarsResult};
9use tokio::io::{AsyncSeekExt, AsyncWriteExt};
10
11use crate::pl_async::{
12    self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
13    tune_with_concurrency_budget, with_concurrency_budget,
14};
15
16mod inner {
17    use std::future::Future;
18    use std::sync::Arc;
19    use std::sync::atomic::AtomicBool;
20
21    use object_store::ObjectStore;
22    use polars_core::config;
23    use polars_error::PolarsResult;
24
25    use crate::cloud::PolarsObjectStoreBuilder;
26
27    #[derive(Debug)]
28    struct Inner {
29        store: tokio::sync::Mutex<Arc<dyn ObjectStore>>,
30        builder: PolarsObjectStoreBuilder,
31    }
32
33    /// Polars wrapper around [`ObjectStore`] functionality. This struct is cheaply cloneable.
34    #[derive(Debug)]
35    pub struct PolarsObjectStore {
36        inner: Arc<Inner>,
37        /// Avoid contending the Mutex `lock()` until the first re-build.
38        initial_store: std::sync::Arc<dyn ObjectStore>,
39        /// Used for interior mutability. Doesn't need to be shared with other threads so it's not
40        /// inside `Arc<>`.
41        rebuilt: AtomicBool,
42    }
43
44    impl Clone for PolarsObjectStore {
45        fn clone(&self) -> Self {
46            Self {
47                inner: self.inner.clone(),
48                initial_store: self.initial_store.clone(),
49                rebuilt: AtomicBool::new(self.rebuilt.load(std::sync::atomic::Ordering::Relaxed)),
50            }
51        }
52    }
53
54    impl PolarsObjectStore {
55        pub(crate) fn new_from_inner(
56            store: Arc<dyn ObjectStore>,
57            builder: PolarsObjectStoreBuilder,
58        ) -> Self {
59            let initial_store = store.clone();
60            Self {
61                inner: Arc::new(Inner {
62                    store: tokio::sync::Mutex::new(store),
63                    builder,
64                }),
65                initial_store,
66                rebuilt: AtomicBool::new(false),
67            }
68        }
69
70        /// Gets the underlying [`ObjectStore`] implementation.
71        pub async fn to_dyn_object_store(&self) -> Arc<dyn ObjectStore> {
72            if !self.rebuilt.load(std::sync::atomic::Ordering::Relaxed) {
73                self.initial_store.clone()
74            } else {
75                self.inner.store.lock().await.clone()
76            }
77        }
78
79        pub async fn rebuild_inner(
80            &self,
81            from_version: &Arc<dyn ObjectStore>,
82        ) -> PolarsResult<Arc<dyn ObjectStore>> {
83            let mut current_store = self.inner.store.lock().await;
84
85            self.rebuilt
86                .store(true, std::sync::atomic::Ordering::Relaxed);
87
88            // If this does not eq, then `inner` was already re-built by another thread.
89            if Arc::ptr_eq(&*current_store, from_version) {
90                *current_store = self.inner.builder.clone().build_impl().await.map_err(|e| {
91                    e.wrap_msg(|e| format!("attempt to rebuild object store failed: {}", e))
92                })?;
93            }
94
95            Ok((*current_store).clone())
96        }
97
98        pub async fn try_exec_rebuild_on_err<Fn, Fut, O>(&self, mut func: Fn) -> PolarsResult<O>
99        where
100            Fn: FnMut(&Arc<dyn ObjectStore>) -> Fut,
101            Fut: Future<Output = PolarsResult<O>>,
102        {
103            let store = self.to_dyn_object_store().await;
104
105            let out = func(&store).await;
106
107            let orig_err = match out {
108                Ok(v) => return Ok(v),
109                Err(e) => e,
110            };
111
112            if config::verbose() {
113                eprintln!(
114                    "[PolarsObjectStore]: got error: {}, will attempt re-build",
115                    &orig_err
116                );
117            }
118
119            let store = self
120                .rebuild_inner(&store)
121                .await
122                .map_err(|e| e.wrap_msg(|e| format!("{}; original error: {}", e, orig_err)))?;
123
124            func(&store).await.map_err(|e| {
125                if self.inner.builder.is_azure()
126                    && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
127                        != Ok("1")
128                {
129                    // Note: This error is intended for Python audiences. The logic for retrieving
130                    // these keys exist only on the Python side.
131                    e.wrap_msg(|e| {
132                        format!(
133                            "{}; note: if you are using Python, consider setting \
134POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
135and use the storage account keys from Azure CLI to authenticate",
136                            e
137                        )
138                    })
139                } else {
140                    e
141                }
142            })
143        }
144    }
145}
146
147pub use inner::PolarsObjectStore;
148
149pub type ObjectStorePath = object_store::path::Path;
150
151impl PolarsObjectStore {
152    /// Returns a buffered stream that downloads concurrently up to the concurrency limit.
153    fn get_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
154        store: &'a dyn ObjectStore,
155        path: &'a Path,
156        ranges: T,
157    ) -> impl StreamExt<Item = PolarsResult<Bytes>>
158    + TryStreamExt<Ok = Bytes, Error = PolarsError, Item = PolarsResult<Bytes>>
159    + use<'a, T> {
160        futures::stream::iter(ranges.map(move |range| async move {
161            let out = store
162                .get_range(path, range.start as u64..range.end as u64)
163                .await?;
164            Ok(out)
165        }))
166        // Add a limit locally as this gets run inside a single `tune_with_concurrency_budget`.
167        .buffered(get_concurrency_limit() as usize)
168    }
169
170    pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Bytes> {
171        self.try_exec_rebuild_on_err(move |store| {
172            let range = range.clone();
173            let st = store.clone();
174
175            async move {
176                let store = st;
177                let parts = split_range(range.clone());
178
179                if parts.len() == 1 {
180                    let out = tune_with_concurrency_budget(1, move || async move {
181                        store
182                            .get_range(path, range.start as u64..range.end as u64)
183                            .await
184                    })
185                    .await?;
186
187                    Ok(out)
188                } else {
189                    let parts = tune_with_concurrency_budget(
190                        parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
191                        || {
192                            Self::get_buffered_ranges_stream(&store, path, parts)
193                                .try_collect::<Vec<Bytes>>()
194                        },
195                    )
196                    .await?;
197
198                    let mut combined = Vec::with_capacity(range.len());
199
200                    for part in parts {
201                        combined.extend_from_slice(&part)
202                    }
203
204                    assert_eq!(combined.len(), range.len());
205
206                    PolarsResult::Ok(Bytes::from(combined))
207                }
208            }
209        })
210        .await
211    }
212
213    /// Fetch byte ranges into a HashMap keyed by the range start. This will mutably sort the
214    /// `ranges` slice for coalescing.
215    ///
216    /// # Panics
217    /// Panics if the same range start is used by more than 1 range.
218    pub async fn get_ranges_sort<
219        K: TryFrom<usize, Error = impl std::fmt::Debug> + std::hash::Hash + Eq,
220        T: From<Bytes>,
221    >(
222        &self,
223        path: &Path,
224        ranges: &mut [Range<usize>],
225    ) -> PolarsResult<PlHashMap<K, T>> {
226        if ranges.is_empty() {
227            return Ok(Default::default());
228        }
229
230        ranges.sort_unstable_by_key(|x| x.start);
231
232        let ranges_len = ranges.len();
233        let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
234
235        self.try_exec_rebuild_on_err(|store| {
236            let st = store.clone();
237
238            async {
239                let store = st;
240                let mut out = PlHashMap::with_capacity(ranges_len);
241
242                let mut stream =
243                    Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned());
244
245                tune_with_concurrency_budget(
246                    merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
247                    || async {
248                        let mut len = 0;
249                        let mut current_offset = 0;
250                        let mut ends_iter = merged_ends.iter();
251
252                        let mut splitted_parts = vec![];
253
254                        while let Some(bytes) = stream.try_next().await? {
255                            len += bytes.len();
256                            let end = *ends_iter.next().unwrap();
257
258                            if end == 0 {
259                                splitted_parts.push(bytes);
260                                continue;
261                            }
262
263                            let full_range = ranges[current_offset..end]
264                                .iter()
265                                .cloned()
266                                .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
267                                .unwrap();
268
269                            let bytes = if splitted_parts.is_empty() {
270                                bytes
271                            } else {
272                                let mut out = Vec::with_capacity(full_range.len());
273
274                                for x in splitted_parts.drain(..) {
275                                    out.extend_from_slice(&x);
276                                }
277
278                                out.extend_from_slice(&bytes);
279                                Bytes::from(out)
280                            };
281
282                            assert_eq!(bytes.len(), full_range.len());
283
284                            for range in &ranges[current_offset..end] {
285                                let v = out.insert(
286                                    K::try_from(range.start).unwrap(),
287                                    T::from(bytes.slice(
288                                        range.start - full_range.start
289                                            ..range.end - full_range.start,
290                                    )),
291                                );
292
293                                assert!(v.is_none()); // duplicate range start
294                            }
295
296                            current_offset = end;
297                        }
298
299                        assert!(splitted_parts.is_empty());
300
301                        PolarsResult::Ok(pl_async::Size::from(len as u64))
302                    },
303                )
304                .await?;
305
306                Ok(out)
307            }
308        })
309        .await
310    }
311
312    pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
313        let opt_size = self.head(path).await.ok().map(|x| x.size);
314
315        let initial_pos = file.stream_position().await?;
316
317        self.try_exec_rebuild_on_err(|store| {
318            let st = store.clone();
319
320            // Workaround for "can't move captured variable".
321            let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) };
322
323            async {
324                file.set_len(initial_pos).await?; // Reset if this function was called again.
325
326                let store = st;
327                let parts = opt_size
328                    .map(|x| split_range(0..x as usize))
329                    .filter(|x| x.len() > 1);
330
331                if let Some(parts) = parts {
332                    tune_with_concurrency_budget(
333                        parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
334                        || async {
335                            let mut stream = Self::get_buffered_ranges_stream(&store, path, parts);
336                            let mut len = 0;
337                            while let Some(bytes) = stream.try_next().await? {
338                                len += bytes.len();
339                                file.write_all(&bytes).await?;
340                            }
341
342                            assert_eq!(len, opt_size.unwrap() as usize);
343
344                            PolarsResult::Ok(pl_async::Size::from(len as u64))
345                        },
346                    )
347                    .await?
348                } else {
349                    tune_with_concurrency_budget(1, || async {
350                        let mut stream = store.get(path).await?.into_stream();
351
352                        let mut len = 0;
353                        while let Some(bytes) = stream.try_next().await? {
354                            len += bytes.len();
355                            file.write_all(&bytes).await?;
356                        }
357
358                        PolarsResult::Ok(pl_async::Size::from(len as u64))
359                    })
360                    .await?
361                };
362
363                // Dropping is delayed for tokio async files so we need to explicitly
364                // flush here (https://github.com/tokio-rs/tokio/issues/2307#issuecomment-596336451).
365                file.sync_all().await.map_err(PolarsError::from)?;
366
367                Ok(())
368            }
369        })
370        .await
371    }
372
373    /// Fetch the metadata of the parquet file, do not memoize it.
374    pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
375        self.try_exec_rebuild_on_err(|store| {
376            let st = store.clone();
377
378            async {
379                with_concurrency_budget(1, || async {
380                    let store = st;
381                    let head_result = store.head(path).await;
382
383                    if head_result.is_err() {
384                        // Pre-signed URLs forbid the HEAD method, but we can still retrieve the header
385                        // information with a range 0-0 request.
386                        let get_range_0_0_result = store
387                            .get_opts(
388                                path,
389                                object_store::GetOptions {
390                                    range: Some((0..1).into()),
391                                    ..Default::default()
392                                },
393                            )
394                            .await;
395
396                        if let Ok(v) = get_range_0_0_result {
397                            return Ok(v.meta);
398                        }
399                    }
400
401                    let out = head_result?;
402
403                    Ok(out)
404                })
405                .await
406            }
407        })
408        .await
409    }
410}
411
412/// Splits a single range into multiple smaller ranges, which can be downloaded concurrently for
413/// much higher throughput.
414fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
415    let chunk_size = get_download_chunk_size();
416
417    // Calculate n_parts such that we are as close as possible to the `chunk_size`.
418    let n_parts = [
419        (range.len().div_ceil(chunk_size)).max(1),
420        (range.len() / chunk_size).max(1),
421    ]
422    .into_iter()
423    .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
424    .unwrap();
425
426    let chunk_size = (range.len() / n_parts).max(1);
427
428    assert_eq!(n_parts, (range.len() / chunk_size).max(1));
429    let bytes_rem = range.len() % chunk_size;
430
431    (0..n_parts).map(move |part_no| {
432        let (start, end) = if part_no == 0 {
433            // Download remainder length in the first chunk since it starts downloading first.
434            let end = range.start + chunk_size + bytes_rem;
435            let end = if end > range.end { range.end } else { end };
436            (range.start, end)
437        } else {
438            let start = bytes_rem + range.start + part_no * chunk_size;
439            (start, start + chunk_size)
440        };
441
442        start..end
443    })
444}
445
446/// Note: For optimal performance, `ranges` should be sorted. More generally,
447/// ranges placed next to each other should also be close in range value.
448///
449/// # Returns
450/// `[(range1, end1), (range2, end2)]`, where:
451/// * `range1` contains bytes for the ranges from `ranges[0..end1]`
452/// * `range2` contains bytes for the ranges from `ranges[end1..end2]`
453/// * etc..
454///
455/// Note that if an end value is 0, it means the range is a splitted part and should be combined.
456fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
457    let chunk_size = get_download_chunk_size();
458
459    let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
460    // Number of fetched bytes excluding excess.
461    let mut current_n_bytes = current_merged_range.len();
462
463    (0..ranges.len())
464        .filter_map(move |current_idx| {
465            let current_idx = 1 + current_idx;
466
467            if current_idx == ranges.len() {
468                // No more items - flush current state.
469                Some((current_merged_range.clone(), current_idx))
470            } else {
471                let range = ranges[current_idx].clone();
472
473                let new_merged = current_merged_range.start.min(range.start)
474                    ..current_merged_range.end.max(range.end);
475
476                // E.g.:
477                // |--------|
478                //  oo        // range1
479                //       oo   // range2
480                //    ^^^     // distance = 3, is_overlapping = false
481                // E.g.:
482                // |--------|
483                //  ooooo     // range1
484                //     ooooo  // range2
485                //     ^^     // distance = 2, is_overlapping = true
486                let (distance, is_overlapping) = {
487                    let l = current_merged_range.end.min(range.end);
488                    let r = current_merged_range.start.max(range.start);
489
490                    (r.abs_diff(l), r < l)
491                };
492
493                let should_merge = is_overlapping || {
494                    let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
495                        <= current_merged_range.len().abs_diff(chunk_size);
496                    let gap_tolerance =
497                        (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
498
499                    leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
500                };
501
502                if should_merge {
503                    // Merge to existing range
504                    current_merged_range = new_merged;
505                    current_n_bytes += if is_overlapping {
506                        range.len() - distance
507                    } else {
508                        range.len()
509                    };
510                    None
511                } else {
512                    let out = (current_merged_range.clone(), current_idx);
513                    current_merged_range = range;
514                    current_n_bytes = current_merged_range.len();
515                    Some(out)
516                }
517            }
518        })
519        .flat_map(|x| {
520            // Split large individual ranges within the list of ranges.
521            let (range, end) = x;
522            let split = split_range(range.clone());
523            let len = split.len();
524
525            split
526                .enumerate()
527                .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
528        })
529}
530
531#[cfg(test)]
532mod tests {
533
534    #[test]
535    fn test_split_range() {
536        use super::{get_download_chunk_size, split_range};
537
538        let chunk_size = get_download_chunk_size();
539
540        assert_eq!(chunk_size, 64 * 1024 * 1024);
541
542        #[allow(clippy::single_range_in_vec_init)]
543        {
544            // Round-trip empty ranges.
545            assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
546            assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
547        }
548
549        // Threshold to start splitting to 2 ranges
550        //
551        // n - chunk_size == chunk_size - n / 2
552        // n + n / 2 == 2 * chunk_size
553        // 3 * n == 4 * chunk_size
554        // n = 4 * chunk_size / 3
555        let n = 4 * chunk_size / 3;
556
557        #[allow(clippy::single_range_in_vec_init)]
558        {
559            assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
560        }
561
562        assert_eq!(
563            split_range(0..n + 1).collect::<Vec<_>>(),
564            [0..44739243, 44739243..89478486]
565        );
566
567        // Threshold to start splitting to 3 ranges
568        //
569        // n / 2 - chunk_size == chunk_size - n / 3
570        // n / 2 + n / 3 == 2 * chunk_size
571        // 5 * n == 12 * chunk_size
572        // n == 12 * chunk_size / 5
573        let n = 12 * chunk_size / 5;
574
575        assert_eq!(
576            split_range(0..n).collect::<Vec<_>>(),
577            [0..80530637, 80530637..161061273]
578        );
579
580        assert_eq!(
581            split_range(0..n + 1).collect::<Vec<_>>(),
582            [0..53687092, 53687092..107374183, 107374183..161061274]
583        );
584    }
585
586    #[test]
587    fn test_merge_ranges() {
588        use super::{get_download_chunk_size, merge_ranges};
589
590        let chunk_size = get_download_chunk_size();
591
592        assert_eq!(chunk_size, 64 * 1024 * 1024);
593
594        // Round-trip empty slice
595        assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
596
597        // We have 1 tiny request followed by 1 huge request. They are combined as it reduces the
598        // `abs_diff()` to the `chunk_size`, but afterwards they are split to 2 evenly sized
599        // requests.
600        assert_eq!(
601            merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
602            [(0..66584576, 0), (66584576..133169152, 2)]
603        );
604
605        // <= 1MiB gap, merge
606        assert_eq!(
607            merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
608            [(0..1048578, 2)]
609        );
610
611        // > 1MiB gap, do not merge
612        assert_eq!(
613            merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
614            [(0..1, 1), (1048578..1048579, 2)]
615        );
616
617        // <= 12.5% gap, merge
618        assert_eq!(
619            merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
620            [(0..11, 2)]
621        );
622
623        // <= 12.5% gap relative to RHS, merge
624        assert_eq!(
625            merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
626            [(0..11, 2)]
627        );
628
629        // Overlapping range, merge
630        assert_eq!(
631            merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
632                .collect::<Vec<_>>(),
633            [(0..80 * 1024 * 1024, 2)]
634        );
635    }
636}