1use std::ops::Range;
2
3use bytes::Bytes;
4use futures::{StreamExt, TryStreamExt};
5use hashbrown::hash_map::RawEntryMut;
6use object_store::path::Path;
7use object_store::{ObjectMeta, ObjectStore};
8use polars_core::prelude::{InitHashMaps, PlHashMap};
9use polars_error::{PolarsError, PolarsResult};
10use polars_utils::mmap::MemSlice;
11use tokio::io::{AsyncSeekExt, AsyncWriteExt};
12
13use crate::pl_async::{
14 self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
15 tune_with_concurrency_budget, with_concurrency_budget,
16};
17
18mod inner {
19 use std::future::Future;
20 use std::sync::Arc;
21 use std::sync::atomic::AtomicBool;
22
23 use object_store::ObjectStore;
24 use polars_core::config;
25 use polars_error::PolarsResult;
26
27 use crate::cloud::PolarsObjectStoreBuilder;
28
29 #[derive(Debug)]
30 struct Inner {
31 store: tokio::sync::Mutex<Arc<dyn ObjectStore>>,
32 builder: PolarsObjectStoreBuilder,
33 }
34
35 #[derive(Debug)]
37 pub struct PolarsObjectStore {
38 inner: Arc<Inner>,
39 initial_store: std::sync::Arc<dyn ObjectStore>,
41 rebuilt: AtomicBool,
44 }
45
46 impl Clone for PolarsObjectStore {
47 fn clone(&self) -> Self {
48 Self {
49 inner: self.inner.clone(),
50 initial_store: self.initial_store.clone(),
51 rebuilt: AtomicBool::new(self.rebuilt.load(std::sync::atomic::Ordering::Relaxed)),
52 }
53 }
54 }
55
56 impl PolarsObjectStore {
57 pub(crate) fn new_from_inner(
58 store: Arc<dyn ObjectStore>,
59 builder: PolarsObjectStoreBuilder,
60 ) -> Self {
61 let initial_store = store.clone();
62 Self {
63 inner: Arc::new(Inner {
64 store: tokio::sync::Mutex::new(store),
65 builder,
66 }),
67 initial_store,
68 rebuilt: AtomicBool::new(false),
69 }
70 }
71
72 pub async fn to_dyn_object_store(&self) -> Arc<dyn ObjectStore> {
74 if !self.rebuilt.load(std::sync::atomic::Ordering::Relaxed) {
75 self.initial_store.clone()
76 } else {
77 self.inner.store.lock().await.clone()
78 }
79 }
80
81 pub async fn rebuild_inner(
82 &self,
83 from_version: &Arc<dyn ObjectStore>,
84 ) -> PolarsResult<Arc<dyn ObjectStore>> {
85 let mut current_store = self.inner.store.lock().await;
86
87 self.rebuilt
88 .store(true, std::sync::atomic::Ordering::Relaxed);
89
90 if Arc::ptr_eq(&*current_store, from_version) {
92 *current_store = self.inner.builder.clone().build_impl().await.map_err(|e| {
93 e.wrap_msg(|e| format!("attempt to rebuild object store failed: {e}"))
94 })?;
95 }
96
97 Ok((*current_store).clone())
98 }
99
100 pub async fn try_exec_rebuild_on_err<Fn, Fut, O>(&self, mut func: Fn) -> PolarsResult<O>
101 where
102 Fn: FnMut(&Arc<dyn ObjectStore>) -> Fut,
103 Fut: Future<Output = PolarsResult<O>>,
104 {
105 let store = self.to_dyn_object_store().await;
106
107 let out = func(&store).await;
108
109 let orig_err = match out {
110 Ok(v) => return Ok(v),
111 Err(e) => e,
112 };
113
114 if config::verbose() {
115 eprintln!(
116 "[PolarsObjectStore]: got error: {}, will attempt re-build",
117 &orig_err
118 );
119 }
120
121 let store = self
122 .rebuild_inner(&store)
123 .await
124 .map_err(|e| e.wrap_msg(|e| format!("{e}; original error: {orig_err}")))?;
125
126 func(&store).await.map_err(|e| {
127 if self.inner.builder.is_azure()
128 && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
129 != Ok("1")
130 {
131 e.wrap_msg(|e| {
134 format!(
135 "{e}; note: if you are using Python, consider setting \
136POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
137and use the storage account keys from Azure CLI to authenticate"
138 )
139 })
140 } else {
141 e
142 }
143 })
144 }
145 }
146}
147
148pub use inner::PolarsObjectStore;
149
150pub type ObjectStorePath = object_store::path::Path;
151
152impl PolarsObjectStore {
153 fn get_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
155 store: &'a dyn ObjectStore,
156 path: &'a Path,
157 ranges: T,
158 ) -> impl StreamExt<Item = PolarsResult<Bytes>>
159 + TryStreamExt<Ok = Bytes, Error = PolarsError, Item = PolarsResult<Bytes>>
160 + use<'a, T> {
161 futures::stream::iter(ranges.map(move |range| async move {
162 let out = store
163 .get_range(path, range.start as u64..range.end as u64)
164 .await?;
165 Ok(out)
166 }))
167 .buffered(get_concurrency_limit() as usize)
169 }
170
171 pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Bytes> {
172 self.try_exec_rebuild_on_err(move |store| {
173 let range = range.clone();
174 let st = store.clone();
175
176 async move {
177 let store = st;
178 let parts = split_range(range.clone());
179
180 if parts.len() == 1 {
181 let out = tune_with_concurrency_budget(1, move || async move {
182 store
183 .get_range(path, range.start as u64..range.end as u64)
184 .await
185 })
186 .await?;
187
188 Ok(out)
189 } else {
190 let parts = tune_with_concurrency_budget(
191 parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
192 || {
193 Self::get_buffered_ranges_stream(&store, path, parts)
194 .try_collect::<Vec<Bytes>>()
195 },
196 )
197 .await?;
198
199 let mut combined = Vec::with_capacity(range.len());
200
201 for part in parts {
202 combined.extend_from_slice(&part)
203 }
204
205 assert_eq!(combined.len(), range.len());
206
207 PolarsResult::Ok(Bytes::from(combined))
208 }
209 }
210 })
211 .await
212 }
213
214 pub async fn get_ranges_sort(
220 &self,
221 path: &Path,
222 ranges: &mut [Range<usize>],
223 ) -> PolarsResult<PlHashMap<usize, MemSlice>> {
224 if ranges.is_empty() {
225 return Ok(Default::default());
226 }
227
228 ranges.sort_unstable_by_key(|x| x.start);
229
230 let ranges_len = ranges.len();
231 let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
232
233 self.try_exec_rebuild_on_err(|store| {
234 let st = store.clone();
235
236 async {
237 let store = st;
238 let mut out = PlHashMap::with_capacity(ranges_len);
239
240 let mut stream =
241 Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned());
242
243 tune_with_concurrency_budget(
244 merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
245 || async {
246 let mut len = 0;
247 let mut current_offset = 0;
248 let mut ends_iter = merged_ends.iter();
249
250 let mut splitted_parts = vec![];
251
252 while let Some(bytes) = stream.try_next().await? {
253 len += bytes.len();
254 let end = *ends_iter.next().unwrap();
255
256 if end == 0 {
257 splitted_parts.push(bytes);
258 continue;
259 }
260
261 let full_range = ranges[current_offset..end]
262 .iter()
263 .cloned()
264 .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
265 .unwrap();
266
267 let bytes = if splitted_parts.is_empty() {
268 bytes
269 } else {
270 let mut out = Vec::with_capacity(full_range.len());
271
272 for x in splitted_parts.drain(..) {
273 out.extend_from_slice(&x);
274 }
275
276 out.extend_from_slice(&bytes);
277 Bytes::from(out)
278 };
279
280 assert_eq!(bytes.len(), full_range.len());
281
282 let bytes = MemSlice::from_bytes(bytes);
283
284 for range in &ranges[current_offset..end] {
285 let mem_slice = bytes.slice(
286 range.start - full_range.start..range.end - full_range.start,
287 );
288
289 match out.raw_entry_mut().from_key(&range.start) {
290 RawEntryMut::Vacant(slot) => {
291 slot.insert(range.start, mem_slice);
292 },
293 RawEntryMut::Occupied(mut slot) => {
294 if slot.get_mut().len() < mem_slice.len() {
295 *slot.get_mut() = mem_slice;
296 }
297 },
298 }
299 }
300
301 current_offset = end;
302 }
303
304 assert!(splitted_parts.is_empty());
305
306 PolarsResult::Ok(pl_async::Size::from(len as u64))
307 },
308 )
309 .await?;
310
311 Ok(out)
312 }
313 })
314 .await
315 }
316
317 pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
318 let opt_size = self.head(path).await.ok().map(|x| x.size);
319
320 let initial_pos = file.stream_position().await?;
321
322 self.try_exec_rebuild_on_err(|store| {
323 let st = store.clone();
324
325 let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) };
327
328 async {
329 file.set_len(initial_pos).await?; let store = st;
332 let parts = opt_size
333 .map(|x| split_range(0..x as usize))
334 .filter(|x| x.len() > 1);
335
336 if let Some(parts) = parts {
337 tune_with_concurrency_budget(
338 parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
339 || async {
340 let mut stream = Self::get_buffered_ranges_stream(&store, path, parts);
341 let mut len = 0;
342 while let Some(bytes) = stream.try_next().await? {
343 len += bytes.len();
344 file.write_all(&bytes).await?;
345 }
346
347 assert_eq!(len, opt_size.unwrap() as usize);
348
349 PolarsResult::Ok(pl_async::Size::from(len as u64))
350 },
351 )
352 .await?
353 } else {
354 tune_with_concurrency_budget(1, || async {
355 let mut stream = store.get(path).await?.into_stream();
356
357 let mut len = 0;
358 while let Some(bytes) = stream.try_next().await? {
359 len += bytes.len();
360 file.write_all(&bytes).await?;
361 }
362
363 PolarsResult::Ok(pl_async::Size::from(len as u64))
364 })
365 .await?
366 };
367
368 file.sync_all().await.map_err(PolarsError::from)?;
371
372 Ok(())
373 }
374 })
375 .await
376 }
377
378 pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
380 self.try_exec_rebuild_on_err(|store| {
381 let st = store.clone();
382
383 async {
384 with_concurrency_budget(1, || async {
385 let store = st;
386 let head_result = store.head(path).await;
387
388 if head_result.is_err() {
389 let get_range_0_0_result = store
392 .get_opts(
393 path,
394 object_store::GetOptions {
395 range: Some((0..1).into()),
396 ..Default::default()
397 },
398 )
399 .await;
400
401 if let Ok(v) = get_range_0_0_result {
402 return Ok(v.meta);
403 }
404 }
405
406 let out = head_result?;
407
408 Ok(out)
409 })
410 .await
411 }
412 })
413 .await
414 }
415}
416
417fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
420 let chunk_size = get_download_chunk_size();
421
422 let n_parts = [
424 (range.len().div_ceil(chunk_size)).max(1),
425 (range.len() / chunk_size).max(1),
426 ]
427 .into_iter()
428 .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
429 .unwrap();
430
431 let chunk_size = (range.len() / n_parts).max(1);
432
433 assert_eq!(n_parts, (range.len() / chunk_size).max(1));
434 let bytes_rem = range.len() % chunk_size;
435
436 (0..n_parts).map(move |part_no| {
437 let (start, end) = if part_no == 0 {
438 let end = range.start + chunk_size + bytes_rem;
440 let end = if end > range.end { range.end } else { end };
441 (range.start, end)
442 } else {
443 let start = bytes_rem + range.start + part_no * chunk_size;
444 (start, start + chunk_size)
445 };
446
447 start..end
448 })
449}
450
451fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
462 let chunk_size = get_download_chunk_size();
463
464 let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
465 let mut current_n_bytes = current_merged_range.len();
467
468 (0..ranges.len())
469 .filter_map(move |current_idx| {
470 let current_idx = 1 + current_idx;
471
472 if current_idx == ranges.len() {
473 Some((current_merged_range.clone(), current_idx))
475 } else {
476 let range = ranges[current_idx].clone();
477
478 let new_merged = current_merged_range.start.min(range.start)
479 ..current_merged_range.end.max(range.end);
480
481 let (distance, is_overlapping) = {
492 let l = current_merged_range.end.min(range.end);
493 let r = current_merged_range.start.max(range.start);
494
495 (r.abs_diff(l), r < l)
496 };
497
498 let should_merge = is_overlapping || {
499 let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
500 <= current_merged_range.len().abs_diff(chunk_size);
501 let gap_tolerance =
502 (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
503
504 leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
505 };
506
507 if should_merge {
508 current_merged_range = new_merged;
510 current_n_bytes += if is_overlapping {
511 range.len() - distance
512 } else {
513 range.len()
514 };
515 None
516 } else {
517 let out = (current_merged_range.clone(), current_idx);
518 current_merged_range = range;
519 current_n_bytes = current_merged_range.len();
520 Some(out)
521 }
522 }
523 })
524 .flat_map(|x| {
525 let (range, end) = x;
527 let split = split_range(range.clone());
528 let len = split.len();
529
530 split
531 .enumerate()
532 .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
533 })
534}
535
536#[cfg(test)]
537mod tests {
538
539 #[test]
540 fn test_split_range() {
541 use super::{get_download_chunk_size, split_range};
542
543 let chunk_size = get_download_chunk_size();
544
545 assert_eq!(chunk_size, 64 * 1024 * 1024);
546
547 #[allow(clippy::single_range_in_vec_init)]
548 {
549 assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
551 assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
552 }
553
554 let n = 4 * chunk_size / 3;
561
562 #[allow(clippy::single_range_in_vec_init)]
563 {
564 assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
565 }
566
567 assert_eq!(
568 split_range(0..n + 1).collect::<Vec<_>>(),
569 [0..44739243, 44739243..89478486]
570 );
571
572 let n = 12 * chunk_size / 5;
579
580 assert_eq!(
581 split_range(0..n).collect::<Vec<_>>(),
582 [0..80530637, 80530637..161061273]
583 );
584
585 assert_eq!(
586 split_range(0..n + 1).collect::<Vec<_>>(),
587 [0..53687092, 53687092..107374183, 107374183..161061274]
588 );
589 }
590
591 #[test]
592 fn test_merge_ranges() {
593 use super::{get_download_chunk_size, merge_ranges};
594
595 let chunk_size = get_download_chunk_size();
596
597 assert_eq!(chunk_size, 64 * 1024 * 1024);
598
599 assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
601
602 assert_eq!(
606 merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
607 [(0..66584576, 0), (66584576..133169152, 2)]
608 );
609
610 assert_eq!(
612 merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
613 [(0..1048578, 2)]
614 );
615
616 assert_eq!(
618 merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
619 [(0..1, 1), (1048578..1048579, 2)]
620 );
621
622 assert_eq!(
624 merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
625 [(0..11, 2)]
626 );
627
628 assert_eq!(
630 merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
631 [(0..11, 2)]
632 );
633
634 assert_eq!(
636 merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
637 .collect::<Vec<_>>(),
638 [(0..80 * 1024 * 1024, 2)]
639 );
640 }
641}