Skip to main content

polars_io/cloud/
polars_object_store.rs

1use std::ops::Range;
2
3use futures::{StreamExt, TryStreamExt};
4use hashbrown::hash_map::RawEntryMut;
5use object_store::path::Path;
6use object_store::{ObjectMeta, ObjectStore, ObjectStoreExt};
7use polars_buffer::Buffer;
8use polars_core::prelude::{InitHashMaps, PlHashMap};
9use polars_error::{PolarsError, PolarsResult};
10use tokio::io::{AsyncSeekExt, AsyncWriteExt};
11
12use crate::pl_async::{
13    self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
14    tune_with_concurrency_budget, with_concurrency_budget,
15};
16
17mod inner {
18    use std::future::Future;
19    use std::sync::Arc;
20
21    use object_store::ObjectStore;
22    use polars_core::config;
23    use polars_error::PolarsResult;
24    use polars_utils::relaxed_cell::RelaxedCell;
25
26    use crate::cloud::PolarsObjectStoreBuilder;
27
28    #[derive(Debug)]
29    struct Inner {
30        store: tokio::sync::Mutex<Arc<dyn ObjectStore>>,
31        builder: PolarsObjectStoreBuilder,
32    }
33
34    /// Polars wrapper around [`ObjectStore`] functionality. This struct is cheaply cloneable.
35    #[derive(Clone, Debug)]
36    pub struct PolarsObjectStore {
37        inner: Arc<Inner>,
38        /// Avoid contending the Mutex `lock()` until the first re-build.
39        initial_store: std::sync::Arc<dyn ObjectStore>,
40        /// Used for interior mutability. Doesn't need to be shared with other threads so it's not
41        /// inside `Arc<>`.
42        rebuilt: RelaxedCell<bool>,
43    }
44
45    impl PolarsObjectStore {
46        pub(crate) fn new_from_inner(
47            store: Arc<dyn ObjectStore>,
48            builder: PolarsObjectStoreBuilder,
49        ) -> Self {
50            let initial_store = store.clone();
51            Self {
52                inner: Arc::new(Inner {
53                    store: tokio::sync::Mutex::new(store),
54                    builder,
55                }),
56                initial_store,
57                rebuilt: RelaxedCell::from(false),
58            }
59        }
60
61        /// Gets the underlying [`ObjectStore`] implementation.
62        pub async fn to_dyn_object_store(&self) -> Arc<dyn ObjectStore> {
63            if !self.rebuilt.load() {
64                self.initial_store.clone()
65            } else {
66                self.inner.store.lock().await.clone()
67            }
68        }
69
70        pub async fn rebuild_inner(
71            &self,
72            from_version: &Arc<dyn ObjectStore>,
73        ) -> PolarsResult<Arc<dyn ObjectStore>> {
74            let mut current_store = self.inner.store.lock().await;
75
76            // If this does not eq, then `inner` was already re-built by another thread.
77            if Arc::ptr_eq(&*current_store, from_version) {
78                *current_store =
79                    self.inner
80                        .builder
81                        .clone()
82                        .build_impl(true)
83                        .await
84                        .map_err(|e| {
85                            e.wrap_msg(|e| format!("attempt to rebuild object store failed: {e}"))
86                        })?;
87            }
88
89            self.rebuilt.store(true);
90
91            Ok((*current_store).clone())
92        }
93
94        pub async fn try_exec_rebuild_on_err<Fn, Fut, O>(&self, mut func: Fn) -> PolarsResult<O>
95        where
96            Fn: FnMut(&Arc<dyn ObjectStore>) -> Fut,
97            Fut: Future<Output = PolarsResult<O>>,
98        {
99            let store = self.to_dyn_object_store().await;
100
101            let out = func(&store).await;
102
103            let orig_err = match out {
104                Ok(v) => return Ok(v),
105                Err(e) => e,
106            };
107
108            if config::verbose() {
109                eprintln!(
110                    "[PolarsObjectStore]: got error: {}, will attempt re-build",
111                    &orig_err
112                );
113            }
114
115            let store = self
116                .rebuild_inner(&store)
117                .await
118                .map_err(|e| e.wrap_msg(|e| format!("{e}; original error: {orig_err}")))?;
119
120            func(&store).await.map_err(|e| {
121                if self.inner.builder.is_azure()
122                    && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
123                        != Ok("1")
124                {
125                    // Note: This error is intended for Python audiences. The logic for retrieving
126                    // these keys exist only on the Python side.
127                    e.wrap_msg(|e| {
128                        format!(
129                            "{e}; note: if you are using Python, consider setting \
130POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
131and use the storage account keys from Azure CLI to authenticate"
132                        )
133                    })
134                } else {
135                    e
136                }
137            })
138        }
139    }
140}
141
142pub use inner::PolarsObjectStore;
143
144pub type ObjectStorePath = object_store::path::Path;
145
146impl PolarsObjectStore {
147    /// Returns a buffered stream that downloads concurrently up to the concurrency limit.
148    fn get_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
149        store: &'a dyn ObjectStore,
150        path: &'a Path,
151        ranges: T,
152    ) -> impl StreamExt<Item = PolarsResult<Buffer<u8>>>
153    + TryStreamExt<Ok = Buffer<u8>, Error = PolarsError, Item = PolarsResult<Buffer<u8>>>
154    + use<'a, T> {
155        futures::stream::iter(ranges.map(move |range| async move {
156            if range.is_empty() {
157                return Ok(Buffer::new());
158            }
159
160            let out = store
161                .get_range(path, range.start as u64..range.end as u64)
162                .await?;
163            Ok(Buffer::from_owner(out))
164        }))
165        // Add a limit locally as this gets run inside a single `tune_with_concurrency_budget`.
166        .buffered(get_concurrency_limit() as usize)
167    }
168
169    pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Buffer<u8>> {
170        if range.is_empty() {
171            return Ok(Buffer::new());
172        }
173
174        self.try_exec_rebuild_on_err(move |store| {
175            let range = range.clone();
176            let st = store.clone();
177
178            async move {
179                let store = st;
180                let parts = split_range(range.clone());
181
182                if parts.len() == 1 {
183                    let out = tune_with_concurrency_budget(1, move || async move {
184                        let bytes = store
185                            .get_range(path, range.start as u64..range.end as u64)
186                            .await?;
187                        PolarsResult::Ok(Buffer::from_owner(bytes))
188                    })
189                    .await?;
190
191                    Ok(out)
192                } else {
193                    let parts = tune_with_concurrency_budget(
194                        parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
195                        || {
196                            Self::get_buffered_ranges_stream(&store, path, parts)
197                                .try_collect::<Vec<Buffer<u8>>>()
198                        },
199                    )
200                    .await?;
201
202                    let mut combined = Vec::with_capacity(range.len());
203
204                    for part in parts {
205                        combined.extend_from_slice(&part)
206                    }
207
208                    assert_eq!(combined.len(), range.len());
209
210                    PolarsResult::Ok(Buffer::from_vec(combined))
211                }
212            }
213        })
214        .await
215    }
216
217    /// Fetch byte ranges into a HashMap keyed by the range start. This will mutably sort the
218    /// `ranges` slice for coalescing.
219    ///
220    /// # Panics
221    /// Panics if the same range start is used by more than 1 range.
222    pub async fn get_ranges_sort(
223        &self,
224        path: &Path,
225        ranges: &mut [Range<usize>],
226    ) -> PolarsResult<PlHashMap<usize, Buffer<u8>>> {
227        if ranges.is_empty() {
228            return Ok(Default::default());
229        }
230
231        ranges.sort_unstable_by_key(|x| x.start);
232
233        let ranges_len = ranges.len();
234        let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
235
236        self.try_exec_rebuild_on_err(|store| {
237            let st = store.clone();
238
239            async {
240                let store = st;
241                let mut out = PlHashMap::with_capacity(ranges_len);
242
243                let mut stream =
244                    Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned());
245
246                tune_with_concurrency_budget(
247                    merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
248                    || async {
249                        let mut len = 0;
250                        let mut current_offset = 0;
251                        let mut ends_iter = merged_ends.iter();
252
253                        let mut splitted_parts = vec![];
254
255                        while let Some(bytes) = stream.try_next().await? {
256                            len += bytes.len();
257                            let end = *ends_iter.next().unwrap();
258
259                            if end == 0 {
260                                splitted_parts.push(bytes);
261                                continue;
262                            }
263
264                            let full_range = ranges[current_offset..end]
265                                .iter()
266                                .cloned()
267                                .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
268                                .unwrap();
269
270                            let bytes = if splitted_parts.is_empty() {
271                                bytes
272                            } else {
273                                let mut out = Vec::with_capacity(full_range.len());
274
275                                for x in splitted_parts.drain(..) {
276                                    out.extend_from_slice(&x);
277                                }
278
279                                out.extend_from_slice(&bytes);
280                                Buffer::from(out)
281                            };
282
283                            assert_eq!(bytes.len(), full_range.len());
284
285                            for range in &ranges[current_offset..end] {
286                                let slice = bytes.clone().sliced(
287                                    range.start - full_range.start..range.end - full_range.start,
288                                );
289
290                                match out.raw_entry_mut().from_key(&range.start) {
291                                    RawEntryMut::Vacant(slot) => {
292                                        slot.insert(range.start, slice);
293                                    },
294                                    RawEntryMut::Occupied(mut slot) => {
295                                        if slot.get_mut().len() < slice.len() {
296                                            *slot.get_mut() = slice;
297                                        }
298                                    },
299                                }
300                            }
301
302                            current_offset = end;
303                        }
304
305                        assert!(splitted_parts.is_empty());
306
307                        PolarsResult::Ok(pl_async::Size::from(len as u64))
308                    },
309                )
310                .await?;
311
312                Ok(out)
313            }
314        })
315        .await
316    }
317
318    pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
319        let opt_size = self.head(path).await.ok().map(|x| x.size);
320
321        let initial_pos = file.stream_position().await?;
322
323        self.try_exec_rebuild_on_err(|store| {
324            let st = store.clone();
325
326            // Workaround for "can't move captured variable".
327            let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) };
328
329            async {
330                file.set_len(initial_pos).await?; // Reset if this function was called again.
331
332                let store = st;
333                let parts = opt_size
334                    .map(|x| split_range(0..x as usize))
335                    .filter(|x| x.len() > 1);
336
337                if let Some(parts) = parts {
338                    tune_with_concurrency_budget(
339                        parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
340                        || async {
341                            let mut stream = Self::get_buffered_ranges_stream(&store, path, parts);
342                            let mut len = 0;
343                            while let Some(bytes) = stream.try_next().await? {
344                                len += bytes.len();
345                                file.write_all(&bytes).await?;
346                            }
347
348                            assert_eq!(len, opt_size.unwrap() as usize);
349
350                            PolarsResult::Ok(pl_async::Size::from(len as u64))
351                        },
352                    )
353                    .await?
354                } else {
355                    tune_with_concurrency_budget(1, || async {
356                        let mut stream = store.get(path).await?.into_stream();
357
358                        let mut len = 0;
359                        while let Some(bytes) = stream.try_next().await? {
360                            len += bytes.len();
361                            file.write_all(&bytes).await?;
362                        }
363
364                        PolarsResult::Ok(pl_async::Size::from(len as u64))
365                    })
366                    .await?
367                };
368
369                // Dropping is delayed for tokio async files so we need to explicitly
370                // flush here (https://github.com/tokio-rs/tokio/issues/2307#issuecomment-596336451).
371                file.sync_all().await.map_err(PolarsError::from)?;
372
373                Ok(())
374            }
375        })
376        .await
377    }
378
379    /// Fetch the metadata of the parquet file, do not memoize it.
380    pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
381        self.try_exec_rebuild_on_err(|store| {
382            let st = store.clone();
383
384            async {
385                with_concurrency_budget(1, || async {
386                    let store = st;
387                    let head_result = store.head(path).await;
388
389                    if head_result.is_err() {
390                        // Pre-signed URLs forbid the HEAD method, but we can still retrieve the header
391                        // information with a range 0-0 request.
392                        let get_range_0_0_result = store
393                            .get_opts(
394                                path,
395                                object_store::GetOptions {
396                                    range: Some((0..1).into()),
397                                    ..Default::default()
398                                },
399                            )
400                            .await;
401
402                        if let Ok(v) = get_range_0_0_result {
403                            return Ok(v.meta);
404                        }
405                    }
406
407                    let out = head_result?;
408
409                    Ok(out)
410                })
411                .await
412            }
413        })
414        .await
415    }
416}
417
418/// Splits a single range into multiple smaller ranges, which can be downloaded concurrently for
419/// much higher throughput.
420fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
421    let chunk_size = get_download_chunk_size();
422
423    // Calculate n_parts such that we are as close as possible to the `chunk_size`.
424    let n_parts = [
425        (range.len().div_ceil(chunk_size)).max(1),
426        (range.len() / chunk_size).max(1),
427    ]
428    .into_iter()
429    .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
430    .unwrap();
431
432    let chunk_size = (range.len() / n_parts).max(1);
433
434    assert_eq!(n_parts, (range.len() / chunk_size).max(1));
435    let bytes_rem = range.len() % chunk_size;
436
437    (0..n_parts).map(move |part_no| {
438        let (start, end) = if part_no == 0 {
439            // Download remainder length in the first chunk since it starts downloading first.
440            let end = range.start + chunk_size + bytes_rem;
441            let end = if end > range.end { range.end } else { end };
442            (range.start, end)
443        } else {
444            let start = bytes_rem + range.start + part_no * chunk_size;
445            (start, start + chunk_size)
446        };
447
448        start..end
449    })
450}
451
452/// Note: For optimal performance, `ranges` should be sorted. More generally,
453/// ranges placed next to each other should also be close in range value.
454///
455/// # Returns
456/// `[(range1, end1), (range2, end2)]`, where:
457/// * `range1` contains bytes for the ranges from `ranges[0..end1]`
458/// * `range2` contains bytes for the ranges from `ranges[end1..end2]`
459/// * etc..
460///
461/// Note that if an end value is 0, it means the range is a splitted part and should be combined.
462fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
463    let chunk_size = get_download_chunk_size();
464
465    let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
466    // Number of fetched bytes excluding excess.
467    let mut current_n_bytes = current_merged_range.len();
468
469    (0..ranges.len())
470        .filter_map(move |current_idx| {
471            let current_idx = 1 + current_idx;
472
473            if current_idx == ranges.len() {
474                // No more items - flush current state.
475                Some((current_merged_range.clone(), current_idx))
476            } else {
477                let range = ranges[current_idx].clone();
478
479                let new_merged = current_merged_range.start.min(range.start)
480                    ..current_merged_range.end.max(range.end);
481
482                // E.g.:
483                // |--------|
484                //  oo        // range1
485                //       oo   // range2
486                //    ^^^     // distance = 3, is_overlapping = false
487                // E.g.:
488                // |--------|
489                //  ooooo     // range1
490                //     ooooo  // range2
491                //     ^^     // distance = 2, is_overlapping = true
492                let (distance, is_overlapping) = {
493                    let l = current_merged_range.end.min(range.end);
494                    let r = current_merged_range.start.max(range.start);
495
496                    (r.abs_diff(l), r < l)
497                };
498
499                let should_merge = is_overlapping || {
500                    let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
501                        <= current_merged_range.len().abs_diff(chunk_size);
502                    let gap_tolerance =
503                        (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
504
505                    leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
506                };
507
508                if should_merge {
509                    // Merge to existing range
510                    current_merged_range = new_merged;
511                    current_n_bytes += if is_overlapping {
512                        range.len() - distance
513                    } else {
514                        range.len()
515                    };
516                    None
517                } else {
518                    let out = (current_merged_range.clone(), current_idx);
519                    current_merged_range = range;
520                    current_n_bytes = current_merged_range.len();
521                    Some(out)
522                }
523            }
524        })
525        .flat_map(|x| {
526            // Split large individual ranges within the list of ranges.
527            let (range, end) = x;
528            let split = split_range(range);
529            let len = split.len();
530
531            split
532                .enumerate()
533                .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
534        })
535}
536
537#[cfg(test)]
538mod tests {
539
540    #[test]
541    fn test_split_range() {
542        use super::{get_download_chunk_size, split_range};
543
544        let chunk_size = get_download_chunk_size();
545
546        assert_eq!(chunk_size, 64 * 1024 * 1024);
547
548        #[allow(clippy::single_range_in_vec_init)]
549        {
550            // Round-trip empty ranges.
551            assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
552            assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
553        }
554
555        // Threshold to start splitting to 2 ranges
556        //
557        // n - chunk_size == chunk_size - n / 2
558        // n + n / 2 == 2 * chunk_size
559        // 3 * n == 4 * chunk_size
560        // n = 4 * chunk_size / 3
561        let n = 4 * chunk_size / 3;
562
563        #[allow(clippy::single_range_in_vec_init)]
564        {
565            assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
566        }
567
568        assert_eq!(
569            split_range(0..n + 1).collect::<Vec<_>>(),
570            [0..44739243, 44739243..89478486]
571        );
572
573        // Threshold to start splitting to 3 ranges
574        //
575        // n / 2 - chunk_size == chunk_size - n / 3
576        // n / 2 + n / 3 == 2 * chunk_size
577        // 5 * n == 12 * chunk_size
578        // n == 12 * chunk_size / 5
579        let n = 12 * chunk_size / 5;
580
581        assert_eq!(
582            split_range(0..n).collect::<Vec<_>>(),
583            [0..80530637, 80530637..161061273]
584        );
585
586        assert_eq!(
587            split_range(0..n + 1).collect::<Vec<_>>(),
588            [0..53687092, 53687092..107374183, 107374183..161061274]
589        );
590    }
591
592    #[test]
593    fn test_merge_ranges() {
594        use super::{get_download_chunk_size, merge_ranges};
595
596        let chunk_size = get_download_chunk_size();
597
598        assert_eq!(chunk_size, 64 * 1024 * 1024);
599
600        // Round-trip empty slice
601        assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
602
603        // We have 1 tiny request followed by 1 huge request. They are combined as it reduces the
604        // `abs_diff()` to the `chunk_size`, but afterwards they are split to 2 evenly sized
605        // requests.
606        assert_eq!(
607            merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
608            [(0..66584576, 0), (66584576..133169152, 2)]
609        );
610
611        // <= 1MiB gap, merge
612        assert_eq!(
613            merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
614            [(0..1048578, 2)]
615        );
616
617        // > 1MiB gap, do not merge
618        assert_eq!(
619            merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
620            [(0..1, 1), (1048578..1048579, 2)]
621        );
622
623        // <= 12.5% gap, merge
624        assert_eq!(
625            merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
626            [(0..11, 2)]
627        );
628
629        // <= 12.5% gap relative to RHS, merge
630        assert_eq!(
631            merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
632            [(0..11, 2)]
633        );
634
635        // Overlapping range, merge
636        assert_eq!(
637            merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
638                .collect::<Vec<_>>(),
639            [(0..80 * 1024 * 1024, 2)]
640        );
641    }
642}