1use std::ops::Range;
2
3use bytes::Bytes;
4use futures::{StreamExt, TryStreamExt};
5use object_store::path::Path;
6use object_store::{ObjectMeta, ObjectStore};
7use polars_core::prelude::{InitHashMaps, PlHashMap};
8use polars_error::{PolarsError, PolarsResult};
9use tokio::io::{AsyncSeekExt, AsyncWriteExt};
10
11use crate::pl_async::{
12 self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
13 tune_with_concurrency_budget, with_concurrency_budget,
14};
15
16mod inner {
17 use std::future::Future;
18 use std::sync::Arc;
19 use std::sync::atomic::AtomicBool;
20
21 use object_store::ObjectStore;
22 use polars_core::config;
23 use polars_error::PolarsResult;
24
25 use crate::cloud::PolarsObjectStoreBuilder;
26
27 #[derive(Debug)]
28 struct Inner {
29 store: tokio::sync::Mutex<Arc<dyn ObjectStore>>,
30 builder: PolarsObjectStoreBuilder,
31 }
32
33 #[derive(Debug)]
35 pub struct PolarsObjectStore {
36 inner: Arc<Inner>,
37 initial_store: std::sync::Arc<dyn ObjectStore>,
39 rebuilt: AtomicBool,
42 }
43
44 impl Clone for PolarsObjectStore {
45 fn clone(&self) -> Self {
46 Self {
47 inner: self.inner.clone(),
48 initial_store: self.initial_store.clone(),
49 rebuilt: AtomicBool::new(self.rebuilt.load(std::sync::atomic::Ordering::Relaxed)),
50 }
51 }
52 }
53
54 impl PolarsObjectStore {
55 pub(crate) fn new_from_inner(
56 store: Arc<dyn ObjectStore>,
57 builder: PolarsObjectStoreBuilder,
58 ) -> Self {
59 let initial_store = store.clone();
60 Self {
61 inner: Arc::new(Inner {
62 store: tokio::sync::Mutex::new(store),
63 builder,
64 }),
65 initial_store,
66 rebuilt: AtomicBool::new(false),
67 }
68 }
69
70 pub async fn to_dyn_object_store(&self) -> Arc<dyn ObjectStore> {
72 if !self.rebuilt.load(std::sync::atomic::Ordering::Relaxed) {
73 self.initial_store.clone()
74 } else {
75 self.inner.store.lock().await.clone()
76 }
77 }
78
79 pub async fn rebuild_inner(
80 &self,
81 from_version: &Arc<dyn ObjectStore>,
82 ) -> PolarsResult<Arc<dyn ObjectStore>> {
83 let mut current_store = self.inner.store.lock().await;
84
85 self.rebuilt
86 .store(true, std::sync::atomic::Ordering::Relaxed);
87
88 if Arc::ptr_eq(&*current_store, from_version) {
90 *current_store = self.inner.builder.clone().build_impl().await.map_err(|e| {
91 e.wrap_msg(|e| format!("attempt to rebuild object store failed: {}", e))
92 })?;
93 }
94
95 Ok((*current_store).clone())
96 }
97
98 pub async fn try_exec_rebuild_on_err<Fn, Fut, O>(&self, mut func: Fn) -> PolarsResult<O>
99 where
100 Fn: FnMut(&Arc<dyn ObjectStore>) -> Fut,
101 Fut: Future<Output = PolarsResult<O>>,
102 {
103 let store = self.to_dyn_object_store().await;
104
105 let out = func(&store).await;
106
107 let orig_err = match out {
108 Ok(v) => return Ok(v),
109 Err(e) => e,
110 };
111
112 if config::verbose() {
113 eprintln!(
114 "[PolarsObjectStore]: got error: {}, will attempt re-build",
115 &orig_err
116 );
117 }
118
119 let store = self
120 .rebuild_inner(&store)
121 .await
122 .map_err(|e| e.wrap_msg(|e| format!("{}; original error: {}", e, orig_err)))?;
123
124 func(&store).await.map_err(|e| {
125 if self.inner.builder.is_azure()
126 && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
127 != Ok("1")
128 {
129 e.wrap_msg(|e| {
132 format!(
133 "{}; note: if you are using Python, consider setting \
134POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
135and use the storage account keys from Azure CLI to authenticate",
136 e
137 )
138 })
139 } else {
140 e
141 }
142 })
143 }
144 }
145}
146
147pub use inner::PolarsObjectStore;
148
149pub type ObjectStorePath = object_store::path::Path;
150
151impl PolarsObjectStore {
152 fn get_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
154 store: &'a dyn ObjectStore,
155 path: &'a Path,
156 ranges: T,
157 ) -> impl StreamExt<Item = PolarsResult<Bytes>>
158 + TryStreamExt<Ok = Bytes, Error = PolarsError, Item = PolarsResult<Bytes>>
159 + use<'a, T> {
160 futures::stream::iter(ranges.map(move |range| async move {
161 let out = store
162 .get_range(path, range.start as u64..range.end as u64)
163 .await?;
164 Ok(out)
165 }))
166 .buffered(get_concurrency_limit() as usize)
168 }
169
170 pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Bytes> {
171 self.try_exec_rebuild_on_err(move |store| {
172 let range = range.clone();
173 let st = store.clone();
174
175 async move {
176 let store = st;
177 let parts = split_range(range.clone());
178
179 if parts.len() == 1 {
180 let out = tune_with_concurrency_budget(1, move || async move {
181 store
182 .get_range(path, range.start as u64..range.end as u64)
183 .await
184 })
185 .await?;
186
187 Ok(out)
188 } else {
189 let parts = tune_with_concurrency_budget(
190 parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
191 || {
192 Self::get_buffered_ranges_stream(&store, path, parts)
193 .try_collect::<Vec<Bytes>>()
194 },
195 )
196 .await?;
197
198 let mut combined = Vec::with_capacity(range.len());
199
200 for part in parts {
201 combined.extend_from_slice(&part)
202 }
203
204 assert_eq!(combined.len(), range.len());
205
206 PolarsResult::Ok(Bytes::from(combined))
207 }
208 }
209 })
210 .await
211 }
212
213 pub async fn get_ranges_sort<
219 K: TryFrom<usize, Error = impl std::fmt::Debug> + std::hash::Hash + Eq,
220 T: From<Bytes>,
221 >(
222 &self,
223 path: &Path,
224 ranges: &mut [Range<usize>],
225 ) -> PolarsResult<PlHashMap<K, T>> {
226 if ranges.is_empty() {
227 return Ok(Default::default());
228 }
229
230 ranges.sort_unstable_by_key(|x| x.start);
231
232 let ranges_len = ranges.len();
233 let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
234
235 self.try_exec_rebuild_on_err(|store| {
236 let st = store.clone();
237
238 async {
239 let store = st;
240 let mut out = PlHashMap::with_capacity(ranges_len);
241
242 let mut stream =
243 Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned());
244
245 tune_with_concurrency_budget(
246 merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
247 || async {
248 let mut len = 0;
249 let mut current_offset = 0;
250 let mut ends_iter = merged_ends.iter();
251
252 let mut splitted_parts = vec![];
253
254 while let Some(bytes) = stream.try_next().await? {
255 len += bytes.len();
256 let end = *ends_iter.next().unwrap();
257
258 if end == 0 {
259 splitted_parts.push(bytes);
260 continue;
261 }
262
263 let full_range = ranges[current_offset..end]
264 .iter()
265 .cloned()
266 .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
267 .unwrap();
268
269 let bytes = if splitted_parts.is_empty() {
270 bytes
271 } else {
272 let mut out = Vec::with_capacity(full_range.len());
273
274 for x in splitted_parts.drain(..) {
275 out.extend_from_slice(&x);
276 }
277
278 out.extend_from_slice(&bytes);
279 Bytes::from(out)
280 };
281
282 assert_eq!(bytes.len(), full_range.len());
283
284 for range in &ranges[current_offset..end] {
285 let v = out.insert(
286 K::try_from(range.start).unwrap(),
287 T::from(bytes.slice(
288 range.start - full_range.start
289 ..range.end - full_range.start,
290 )),
291 );
292
293 assert!(v.is_none()); }
295
296 current_offset = end;
297 }
298
299 assert!(splitted_parts.is_empty());
300
301 PolarsResult::Ok(pl_async::Size::from(len as u64))
302 },
303 )
304 .await?;
305
306 Ok(out)
307 }
308 })
309 .await
310 }
311
312 pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
313 let opt_size = self.head(path).await.ok().map(|x| x.size);
314
315 let initial_pos = file.stream_position().await?;
316
317 self.try_exec_rebuild_on_err(|store| {
318 let st = store.clone();
319
320 let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) };
322
323 async {
324 file.set_len(initial_pos).await?; let store = st;
327 let parts = opt_size
328 .map(|x| split_range(0..x as usize))
329 .filter(|x| x.len() > 1);
330
331 if let Some(parts) = parts {
332 tune_with_concurrency_budget(
333 parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
334 || async {
335 let mut stream = Self::get_buffered_ranges_stream(&store, path, parts);
336 let mut len = 0;
337 while let Some(bytes) = stream.try_next().await? {
338 len += bytes.len();
339 file.write_all(&bytes).await?;
340 }
341
342 assert_eq!(len, opt_size.unwrap() as usize);
343
344 PolarsResult::Ok(pl_async::Size::from(len as u64))
345 },
346 )
347 .await?
348 } else {
349 tune_with_concurrency_budget(1, || async {
350 let mut stream = store.get(path).await?.into_stream();
351
352 let mut len = 0;
353 while let Some(bytes) = stream.try_next().await? {
354 len += bytes.len();
355 file.write_all(&bytes).await?;
356 }
357
358 PolarsResult::Ok(pl_async::Size::from(len as u64))
359 })
360 .await?
361 };
362
363 file.sync_all().await.map_err(PolarsError::from)?;
366
367 Ok(())
368 }
369 })
370 .await
371 }
372
373 pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
375 self.try_exec_rebuild_on_err(|store| {
376 let st = store.clone();
377
378 async {
379 with_concurrency_budget(1, || async {
380 let store = st;
381 let head_result = store.head(path).await;
382
383 if head_result.is_err() {
384 let get_range_0_0_result = store
387 .get_opts(
388 path,
389 object_store::GetOptions {
390 range: Some((0..1).into()),
391 ..Default::default()
392 },
393 )
394 .await;
395
396 if let Ok(v) = get_range_0_0_result {
397 return Ok(v.meta);
398 }
399 }
400
401 let out = head_result?;
402
403 Ok(out)
404 })
405 .await
406 }
407 })
408 .await
409 }
410}
411
412fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
415 let chunk_size = get_download_chunk_size();
416
417 let n_parts = [
419 (range.len().div_ceil(chunk_size)).max(1),
420 (range.len() / chunk_size).max(1),
421 ]
422 .into_iter()
423 .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
424 .unwrap();
425
426 let chunk_size = (range.len() / n_parts).max(1);
427
428 assert_eq!(n_parts, (range.len() / chunk_size).max(1));
429 let bytes_rem = range.len() % chunk_size;
430
431 (0..n_parts).map(move |part_no| {
432 let (start, end) = if part_no == 0 {
433 let end = range.start + chunk_size + bytes_rem;
435 let end = if end > range.end { range.end } else { end };
436 (range.start, end)
437 } else {
438 let start = bytes_rem + range.start + part_no * chunk_size;
439 (start, start + chunk_size)
440 };
441
442 start..end
443 })
444}
445
446fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
457 let chunk_size = get_download_chunk_size();
458
459 let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
460 let mut current_n_bytes = current_merged_range.len();
462
463 (0..ranges.len())
464 .filter_map(move |current_idx| {
465 let current_idx = 1 + current_idx;
466
467 if current_idx == ranges.len() {
468 Some((current_merged_range.clone(), current_idx))
470 } else {
471 let range = ranges[current_idx].clone();
472
473 let new_merged = current_merged_range.start.min(range.start)
474 ..current_merged_range.end.max(range.end);
475
476 let (distance, is_overlapping) = {
487 let l = current_merged_range.end.min(range.end);
488 let r = current_merged_range.start.max(range.start);
489
490 (r.abs_diff(l), r < l)
491 };
492
493 let should_merge = is_overlapping || {
494 let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
495 <= current_merged_range.len().abs_diff(chunk_size);
496 let gap_tolerance =
497 (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
498
499 leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
500 };
501
502 if should_merge {
503 current_merged_range = new_merged;
505 current_n_bytes += if is_overlapping {
506 range.len() - distance
507 } else {
508 range.len()
509 };
510 None
511 } else {
512 let out = (current_merged_range.clone(), current_idx);
513 current_merged_range = range;
514 current_n_bytes = current_merged_range.len();
515 Some(out)
516 }
517 }
518 })
519 .flat_map(|x| {
520 let (range, end) = x;
522 let split = split_range(range.clone());
523 let len = split.len();
524
525 split
526 .enumerate()
527 .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
528 })
529}
530
531#[cfg(test)]
532mod tests {
533
534 #[test]
535 fn test_split_range() {
536 use super::{get_download_chunk_size, split_range};
537
538 let chunk_size = get_download_chunk_size();
539
540 assert_eq!(chunk_size, 64 * 1024 * 1024);
541
542 #[allow(clippy::single_range_in_vec_init)]
543 {
544 assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
546 assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
547 }
548
549 let n = 4 * chunk_size / 3;
556
557 #[allow(clippy::single_range_in_vec_init)]
558 {
559 assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
560 }
561
562 assert_eq!(
563 split_range(0..n + 1).collect::<Vec<_>>(),
564 [0..44739243, 44739243..89478486]
565 );
566
567 let n = 12 * chunk_size / 5;
574
575 assert_eq!(
576 split_range(0..n).collect::<Vec<_>>(),
577 [0..80530637, 80530637..161061273]
578 );
579
580 assert_eq!(
581 split_range(0..n + 1).collect::<Vec<_>>(),
582 [0..53687092, 53687092..107374183, 107374183..161061274]
583 );
584 }
585
586 #[test]
587 fn test_merge_ranges() {
588 use super::{get_download_chunk_size, merge_ranges};
589
590 let chunk_size = get_download_chunk_size();
591
592 assert_eq!(chunk_size, 64 * 1024 * 1024);
593
594 assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
596
597 assert_eq!(
601 merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
602 [(0..66584576, 0), (66584576..133169152, 2)]
603 );
604
605 assert_eq!(
607 merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
608 [(0..1048578, 2)]
609 );
610
611 assert_eq!(
613 merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
614 [(0..1, 1), (1048578..1048579, 2)]
615 );
616
617 assert_eq!(
619 merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
620 [(0..11, 2)]
621 );
622
623 assert_eq!(
625 merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
626 [(0..11, 2)]
627 );
628
629 assert_eq!(
631 merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
632 .collect::<Vec<_>>(),
633 [(0..80 * 1024 * 1024, 2)]
634 );
635 }
636}