1use std::ops::Range;
2
3use futures::{StreamExt, TryStreamExt};
4use hashbrown::hash_map::RawEntryMut;
5use object_store::path::Path;
6use object_store::{ObjectMeta, ObjectStore, ObjectStoreExt};
7use polars_buffer::Buffer;
8use polars_core::prelude::{InitHashMaps, PlHashMap};
9use polars_error::{PolarsError, PolarsResult};
10use tokio::io::{AsyncSeekExt, AsyncWriteExt};
11
12use crate::pl_async::{
13 self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
14 tune_with_concurrency_budget, with_concurrency_budget,
15};
16
17mod inner {
18 use std::future::Future;
19 use std::sync::Arc;
20
21 use object_store::ObjectStore;
22 use polars_core::config;
23 use polars_error::PolarsResult;
24 use polars_utils::relaxed_cell::RelaxedCell;
25
26 use crate::cloud::PolarsObjectStoreBuilder;
27
28 #[derive(Debug)]
29 struct Inner {
30 store: tokio::sync::Mutex<Arc<dyn ObjectStore>>,
31 builder: PolarsObjectStoreBuilder,
32 }
33
34 #[derive(Clone, Debug)]
36 pub struct PolarsObjectStore {
37 inner: Arc<Inner>,
38 initial_store: std::sync::Arc<dyn ObjectStore>,
40 rebuilt: RelaxedCell<bool>,
43 }
44
45 impl PolarsObjectStore {
46 pub(crate) fn new_from_inner(
47 store: Arc<dyn ObjectStore>,
48 builder: PolarsObjectStoreBuilder,
49 ) -> Self {
50 let initial_store = store.clone();
51 Self {
52 inner: Arc::new(Inner {
53 store: tokio::sync::Mutex::new(store),
54 builder,
55 }),
56 initial_store,
57 rebuilt: RelaxedCell::from(false),
58 }
59 }
60
61 pub async fn to_dyn_object_store(&self) -> Arc<dyn ObjectStore> {
63 if !self.rebuilt.load() {
64 self.initial_store.clone()
65 } else {
66 self.inner.store.lock().await.clone()
67 }
68 }
69
70 pub async fn rebuild_inner(
71 &self,
72 from_version: &Arc<dyn ObjectStore>,
73 ) -> PolarsResult<Arc<dyn ObjectStore>> {
74 let mut current_store = self.inner.store.lock().await;
75
76 if Arc::ptr_eq(&*current_store, from_version) {
78 *current_store =
79 self.inner
80 .builder
81 .clone()
82 .build_impl(true)
83 .await
84 .map_err(|e| {
85 e.wrap_msg(|e| format!("attempt to rebuild object store failed: {e}"))
86 })?;
87 }
88
89 self.rebuilt.store(true);
90
91 Ok((*current_store).clone())
92 }
93
94 pub async fn try_exec_rebuild_on_err<Fn, Fut, O>(&self, mut func: Fn) -> PolarsResult<O>
95 where
96 Fn: FnMut(&Arc<dyn ObjectStore>) -> Fut,
97 Fut: Future<Output = PolarsResult<O>>,
98 {
99 let store = self.to_dyn_object_store().await;
100
101 let out = func(&store).await;
102
103 let orig_err = match out {
104 Ok(v) => return Ok(v),
105 Err(e) => e,
106 };
107
108 if config::verbose() {
109 eprintln!(
110 "[PolarsObjectStore]: got error: {}, will attempt re-build",
111 &orig_err
112 );
113 }
114
115 let store = self
116 .rebuild_inner(&store)
117 .await
118 .map_err(|e| e.wrap_msg(|e| format!("{e}; original error: {orig_err}")))?;
119
120 func(&store).await.map_err(|e| {
121 if self.inner.builder.is_azure()
122 && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
123 != Ok("1")
124 {
125 e.wrap_msg(|e| {
128 format!(
129 "{e}; note: if you are using Python, consider setting \
130POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
131and use the storage account keys from Azure CLI to authenticate"
132 )
133 })
134 } else {
135 e
136 }
137 })
138 }
139 }
140}
141
142pub use inner::PolarsObjectStore;
143
144pub type ObjectStorePath = object_store::path::Path;
145
146impl PolarsObjectStore {
147 fn get_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
149 store: &'a dyn ObjectStore,
150 path: &'a Path,
151 ranges: T,
152 ) -> impl StreamExt<Item = PolarsResult<Buffer<u8>>>
153 + TryStreamExt<Ok = Buffer<u8>, Error = PolarsError, Item = PolarsResult<Buffer<u8>>>
154 + use<'a, T> {
155 futures::stream::iter(ranges.map(move |range| async move {
156 if range.is_empty() {
157 return Ok(Buffer::new());
158 }
159
160 let out = store
161 .get_range(path, range.start as u64..range.end as u64)
162 .await?;
163 Ok(Buffer::from_owner(out))
164 }))
165 .buffered(get_concurrency_limit() as usize)
167 }
168
169 pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Buffer<u8>> {
170 if range.is_empty() {
171 return Ok(Buffer::new());
172 }
173
174 self.try_exec_rebuild_on_err(move |store| {
175 let range = range.clone();
176 let st = store.clone();
177
178 async move {
179 let store = st;
180 let parts = split_range(range.clone());
181
182 if parts.len() == 1 {
183 let out = tune_with_concurrency_budget(1, move || async move {
184 let bytes = store
185 .get_range(path, range.start as u64..range.end as u64)
186 .await?;
187 PolarsResult::Ok(Buffer::from_owner(bytes))
188 })
189 .await?;
190
191 Ok(out)
192 } else {
193 let parts = tune_with_concurrency_budget(
194 parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
195 || {
196 Self::get_buffered_ranges_stream(&store, path, parts)
197 .try_collect::<Vec<Buffer<u8>>>()
198 },
199 )
200 .await?;
201
202 let mut combined = Vec::with_capacity(range.len());
203
204 for part in parts {
205 combined.extend_from_slice(&part)
206 }
207
208 assert_eq!(combined.len(), range.len());
209
210 PolarsResult::Ok(Buffer::from_vec(combined))
211 }
212 }
213 })
214 .await
215 }
216
217 pub async fn get_ranges_sort(
223 &self,
224 path: &Path,
225 ranges: &mut [Range<usize>],
226 ) -> PolarsResult<PlHashMap<usize, Buffer<u8>>> {
227 if ranges.is_empty() {
228 return Ok(Default::default());
229 }
230
231 ranges.sort_unstable_by_key(|x| x.start);
232
233 let ranges_len = ranges.len();
234 let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
235
236 self.try_exec_rebuild_on_err(|store| {
237 let st = store.clone();
238
239 async {
240 let store = st;
241 let mut out = PlHashMap::with_capacity(ranges_len);
242
243 let mut stream =
244 Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned());
245
246 tune_with_concurrency_budget(
247 merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
248 || async {
249 let mut len = 0;
250 let mut current_offset = 0;
251 let mut ends_iter = merged_ends.iter();
252
253 let mut splitted_parts = vec![];
254
255 while let Some(bytes) = stream.try_next().await? {
256 len += bytes.len();
257 let end = *ends_iter.next().unwrap();
258
259 if end == 0 {
260 splitted_parts.push(bytes);
261 continue;
262 }
263
264 let full_range = ranges[current_offset..end]
265 .iter()
266 .cloned()
267 .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
268 .unwrap();
269
270 let bytes = if splitted_parts.is_empty() {
271 bytes
272 } else {
273 let mut out = Vec::with_capacity(full_range.len());
274
275 for x in splitted_parts.drain(..) {
276 out.extend_from_slice(&x);
277 }
278
279 out.extend_from_slice(&bytes);
280 Buffer::from(out)
281 };
282
283 assert_eq!(bytes.len(), full_range.len());
284
285 for range in &ranges[current_offset..end] {
286 let slice = bytes.clone().sliced(
287 range.start - full_range.start..range.end - full_range.start,
288 );
289
290 match out.raw_entry_mut().from_key(&range.start) {
291 RawEntryMut::Vacant(slot) => {
292 slot.insert(range.start, slice);
293 },
294 RawEntryMut::Occupied(mut slot) => {
295 if slot.get_mut().len() < slice.len() {
296 *slot.get_mut() = slice;
297 }
298 },
299 }
300 }
301
302 current_offset = end;
303 }
304
305 assert!(splitted_parts.is_empty());
306
307 PolarsResult::Ok(pl_async::Size::from(len as u64))
308 },
309 )
310 .await?;
311
312 Ok(out)
313 }
314 })
315 .await
316 }
317
318 pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
319 let opt_size = self.head(path).await.ok().map(|x| x.size);
320
321 let initial_pos = file.stream_position().await?;
322
323 self.try_exec_rebuild_on_err(|store| {
324 let st = store.clone();
325
326 let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) };
328
329 async {
330 file.set_len(initial_pos).await?; let store = st;
333 let parts = opt_size
334 .map(|x| split_range(0..x as usize))
335 .filter(|x| x.len() > 1);
336
337 if let Some(parts) = parts {
338 tune_with_concurrency_budget(
339 parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
340 || async {
341 let mut stream = Self::get_buffered_ranges_stream(&store, path, parts);
342 let mut len = 0;
343 while let Some(bytes) = stream.try_next().await? {
344 len += bytes.len();
345 file.write_all(&bytes).await?;
346 }
347
348 assert_eq!(len, opt_size.unwrap() as usize);
349
350 PolarsResult::Ok(pl_async::Size::from(len as u64))
351 },
352 )
353 .await?
354 } else {
355 tune_with_concurrency_budget(1, || async {
356 let mut stream = store.get(path).await?.into_stream();
357
358 let mut len = 0;
359 while let Some(bytes) = stream.try_next().await? {
360 len += bytes.len();
361 file.write_all(&bytes).await?;
362 }
363
364 PolarsResult::Ok(pl_async::Size::from(len as u64))
365 })
366 .await?
367 };
368
369 file.sync_all().await.map_err(PolarsError::from)?;
372
373 Ok(())
374 }
375 })
376 .await
377 }
378
379 pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
381 self.try_exec_rebuild_on_err(|store| {
382 let st = store.clone();
383
384 async {
385 with_concurrency_budget(1, || async {
386 let store = st;
387 let head_result = store.head(path).await;
388
389 if head_result.is_err() {
390 let get_range_0_0_result = store
393 .get_opts(
394 path,
395 object_store::GetOptions {
396 range: Some((0..1).into()),
397 ..Default::default()
398 },
399 )
400 .await;
401
402 if let Ok(v) = get_range_0_0_result {
403 return Ok(v.meta);
404 }
405 }
406
407 let out = head_result?;
408
409 Ok(out)
410 })
411 .await
412 }
413 })
414 .await
415 }
416}
417
418fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
421 let chunk_size = get_download_chunk_size();
422
423 let n_parts = [
425 (range.len().div_ceil(chunk_size)).max(1),
426 (range.len() / chunk_size).max(1),
427 ]
428 .into_iter()
429 .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
430 .unwrap();
431
432 let chunk_size = (range.len() / n_parts).max(1);
433
434 assert_eq!(n_parts, (range.len() / chunk_size).max(1));
435 let bytes_rem = range.len() % chunk_size;
436
437 (0..n_parts).map(move |part_no| {
438 let (start, end) = if part_no == 0 {
439 let end = range.start + chunk_size + bytes_rem;
441 let end = if end > range.end { range.end } else { end };
442 (range.start, end)
443 } else {
444 let start = bytes_rem + range.start + part_no * chunk_size;
445 (start, start + chunk_size)
446 };
447
448 start..end
449 })
450}
451
452fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
463 let chunk_size = get_download_chunk_size();
464
465 let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
466 let mut current_n_bytes = current_merged_range.len();
468
469 (0..ranges.len())
470 .filter_map(move |current_idx| {
471 let current_idx = 1 + current_idx;
472
473 if current_idx == ranges.len() {
474 Some((current_merged_range.clone(), current_idx))
476 } else {
477 let range = ranges[current_idx].clone();
478
479 let new_merged = current_merged_range.start.min(range.start)
480 ..current_merged_range.end.max(range.end);
481
482 let (distance, is_overlapping) = {
493 let l = current_merged_range.end.min(range.end);
494 let r = current_merged_range.start.max(range.start);
495
496 (r.abs_diff(l), r < l)
497 };
498
499 let should_merge = is_overlapping || {
500 let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
501 <= current_merged_range.len().abs_diff(chunk_size);
502 let gap_tolerance =
503 (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
504
505 leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
506 };
507
508 if should_merge {
509 current_merged_range = new_merged;
511 current_n_bytes += if is_overlapping {
512 range.len() - distance
513 } else {
514 range.len()
515 };
516 None
517 } else {
518 let out = (current_merged_range.clone(), current_idx);
519 current_merged_range = range;
520 current_n_bytes = current_merged_range.len();
521 Some(out)
522 }
523 }
524 })
525 .flat_map(|x| {
526 let (range, end) = x;
528 let split = split_range(range);
529 let len = split.len();
530
531 split
532 .enumerate()
533 .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
534 })
535}
536
537#[cfg(test)]
538mod tests {
539
540 #[test]
541 fn test_split_range() {
542 use super::{get_download_chunk_size, split_range};
543
544 let chunk_size = get_download_chunk_size();
545
546 assert_eq!(chunk_size, 64 * 1024 * 1024);
547
548 #[allow(clippy::single_range_in_vec_init)]
549 {
550 assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
552 assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
553 }
554
555 let n = 4 * chunk_size / 3;
562
563 #[allow(clippy::single_range_in_vec_init)]
564 {
565 assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
566 }
567
568 assert_eq!(
569 split_range(0..n + 1).collect::<Vec<_>>(),
570 [0..44739243, 44739243..89478486]
571 );
572
573 let n = 12 * chunk_size / 5;
580
581 assert_eq!(
582 split_range(0..n).collect::<Vec<_>>(),
583 [0..80530637, 80530637..161061273]
584 );
585
586 assert_eq!(
587 split_range(0..n + 1).collect::<Vec<_>>(),
588 [0..53687092, 53687092..107374183, 107374183..161061274]
589 );
590 }
591
592 #[test]
593 fn test_merge_ranges() {
594 use super::{get_download_chunk_size, merge_ranges};
595
596 let chunk_size = get_download_chunk_size();
597
598 assert_eq!(chunk_size, 64 * 1024 * 1024);
599
600 assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
602
603 assert_eq!(
607 merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
608 [(0..66584576, 0), (66584576..133169152, 2)]
609 );
610
611 assert_eq!(
613 merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
614 [(0..1048578, 2)]
615 );
616
617 assert_eq!(
619 merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
620 [(0..1, 1), (1048578..1048579, 2)]
621 );
622
623 assert_eq!(
625 merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
626 [(0..11, 2)]
627 );
628
629 assert_eq!(
631 merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
632 [(0..11, 2)]
633 );
634
635 assert_eq!(
637 merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
638 .collect::<Vec<_>>(),
639 [(0..80 * 1024 * 1024, 2)]
640 );
641 }
642}