use std::fmt::Display;
use std::ops::Range;
use std::sync::Arc;
use futures::{Stream, StreamExt as _, TryStreamExt as _};
use hashbrown::hash_map::RawEntryMut;
use object_store::path::Path;
use object_store::{ObjectMeta, ObjectStore, ObjectStoreExt};
use polars_buffer::Buffer;
use polars_core::prelude::{InitHashMaps, PlHashMap};
use polars_error::{PolarsError, PolarsResult};
use polars_utils::pl_path::PlRefPath;
use tokio::io::AsyncWriteExt;
use crate::pl_async::{
self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
tune_with_concurrency_budget, with_concurrency_budget,
};
#[derive(Debug)]
pub struct PolarsObjectStoreError {
pub base_url: PlRefPath,
pub source: object_store::Error,
}
impl PolarsObjectStoreError {
pub fn from_url(base_url: &PlRefPath) -> impl FnOnce(object_store::Error) -> Self {
|error| Self {
base_url: base_url.clone(),
source: error,
}
}
}
impl Display for PolarsObjectStoreError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"object-store error: {} (path: {})",
self.source, &self.base_url
)
}
}
impl std::error::Error for PolarsObjectStoreError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.source)
}
}
impl From<PolarsObjectStoreError> for std::io::Error {
fn from(value: PolarsObjectStoreError) -> Self {
std::io::Error::other(value)
}
}
impl From<PolarsObjectStoreError> for PolarsError {
fn from(value: PolarsObjectStoreError) -> Self {
PolarsError::IO {
error: Arc::new(value.into()),
msg: None,
}
}
}
mod inner {
use std::borrow::Cow;
use std::future::Future;
use std::sync::Arc;
use object_store::ObjectStore;
use polars_core::config;
use polars_error::{PolarsError, PolarsResult};
use polars_utils::relaxed_cell::RelaxedCell;
use crate::cloud::{ObjectStoreErrorContext, PolarsObjectStoreBuilder};
use crate::metrics::{IOMetrics, OptIOMetrics};
#[derive(Debug)]
struct Inner {
store: tokio::sync::RwLock<Arc<dyn ObjectStore>>,
builder: PolarsObjectStoreBuilder,
rebuilt: RelaxedCell<bool>,
}
#[derive(Clone, Debug)]
pub struct PolarsObjectStore {
inner: Arc<Inner>,
initial_store: std::sync::Arc<dyn ObjectStore>,
io_metrics: OptIOMetrics,
}
impl PolarsObjectStore {
pub(crate) fn new_from_inner(
store: Arc<dyn ObjectStore>,
builder: PolarsObjectStoreBuilder,
) -> Self {
let initial_store = store.clone();
Self {
inner: Arc::new(Inner {
store: tokio::sync::RwLock::new(store),
builder,
rebuilt: RelaxedCell::from(false),
}),
initial_store,
io_metrics: OptIOMetrics(None),
}
}
pub fn set_io_metrics(&mut self, io_metrics: Option<Arc<IOMetrics>>) -> &mut Self {
self.io_metrics = OptIOMetrics(io_metrics);
self
}
pub fn io_metrics(&self) -> &OptIOMetrics {
&self.io_metrics
}
pub async fn to_dyn_object_store(&self) -> Cow<'_, Arc<dyn ObjectStore>> {
if !self.inner.rebuilt.load() {
Cow::Borrowed(&self.initial_store)
} else {
Cow::Owned(self.inner.store.read().await.clone())
}
}
pub async fn rebuild_inner(
&self,
from_version: &Arc<dyn ObjectStore>,
) -> PolarsResult<Arc<dyn ObjectStore>> {
let mut current_store = self.inner.store.write().await;
if Arc::ptr_eq(&*current_store, from_version) {
*current_store =
self.inner
.builder
.clone()
.build_impl(true)
.await
.map_err(|e| {
e.wrap_msg(|e| format!("attempt to rebuild object store failed: {e}"))
})?;
}
self.inner.rebuilt.store(true);
Ok((*current_store).clone())
}
pub async fn exec_with_rebuild_retry_on_err<'s, 'f, Fn, Fut, O>(
&'s self,
mut func: Fn,
) -> PolarsResult<O>
where
Fn: FnMut(Cow<'s, Arc<dyn ObjectStore>>) -> Fut + 'f,
Fut: Future<Output = object_store::Result<O>>,
{
let store = self.to_dyn_object_store().await;
let out = func(store.clone()).await;
let orig_err = match out {
Ok(v) => return Ok(v),
Err(e) => e,
};
if config::verbose() {
eprintln!(
"[PolarsObjectStore]: got error: {}, will rebuild store and retry",
&orig_err
);
}
let store = self
.rebuild_inner(&store)
.await
.map_err(|e| e.wrap_msg(|e| format!("{e}; original error: {orig_err}")))?;
func(Cow::Owned(store)).await.map_err(|e| {
let e: PolarsError = self.error_context().attach_err_info(e).into();
if self.inner.builder.is_azure()
&& std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
!= Ok("1")
{
e.wrap_msg(|e| {
format!(
"{e}; note: if you are using Python, consider setting \
POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
and use the storage account keys from Azure CLI to authenticate"
)
})
} else {
e
}
})
}
pub fn error_context(&self) -> ObjectStoreErrorContext {
ObjectStoreErrorContext::new(self.inner.builder.path().clone())
}
}
}
#[derive(Clone)]
pub struct ObjectStoreErrorContext {
path: PlRefPath,
}
impl ObjectStoreErrorContext {
pub fn new(path: PlRefPath) -> Self {
Self { path }
}
pub fn attach_err_info(self, err: object_store::Error) -> PolarsObjectStoreError {
let ObjectStoreErrorContext { path } = self;
PolarsObjectStoreError {
base_url: path,
source: err,
}
}
}
pub use inner::PolarsObjectStore;
pub type ObjectStorePath = object_store::path::Path;
impl PolarsObjectStore {
pub fn build_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
&'a self,
path: &'a Path,
ranges: T,
) -> impl Stream<Item = PolarsResult<Buffer<u8>>> + use<'a, T> {
futures::stream::iter(ranges.map(move |range| async move {
if range.is_empty() {
return Ok(Buffer::new());
}
let out = self
.io_metrics()
.record_io_read(
range.len() as u64,
self.exec_with_rebuild_retry_on_err(|s| async move {
s.get_range(path, range.start as u64..range.end as u64)
.await
}),
)
.await?;
Ok(Buffer::from_owner(out))
}))
.buffered(get_concurrency_limit() as usize)
}
pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Buffer<u8>> {
if range.is_empty() {
return Ok(Buffer::new());
}
let parts = split_range(range.clone());
if parts.len() == 1 {
let out = tune_with_concurrency_budget(1, move || async move {
let bytes = self
.io_metrics()
.record_io_read(
range.len() as u64,
self.exec_with_rebuild_retry_on_err(|s| async move {
s.get_range(path, range.start as u64..range.end as u64)
.await
}),
)
.await?;
PolarsResult::Ok(Buffer::from_owner(bytes))
})
.await?;
Ok(out)
} else {
let parts = tune_with_concurrency_budget(
parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
|| {
self.build_buffered_ranges_stream(path, parts)
.try_collect::<Vec<Buffer<u8>>>()
},
)
.await?;
let mut combined = Vec::with_capacity(range.len());
for part in parts {
combined.extend_from_slice(&part)
}
assert_eq!(combined.len(), range.len());
PolarsResult::Ok(Buffer::from_vec(combined))
}
}
pub async fn get_ranges_sort(
&self,
path: &Path,
ranges: &mut [Range<usize>],
) -> PolarsResult<PlHashMap<usize, Buffer<u8>>> {
if ranges.is_empty() {
return Ok(Default::default());
}
ranges.sort_unstable_by_key(|x| x.start);
let ranges_len = ranges.len();
let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
let mut out = PlHashMap::with_capacity(ranges_len);
let mut stream = self.build_buffered_ranges_stream(path, merged_ranges.iter().cloned());
tune_with_concurrency_budget(
merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
|| async {
let mut len = 0;
let mut current_offset = 0;
let mut ends_iter = merged_ends.iter();
let mut splitted_parts = vec![];
while let Some(bytes) = stream.try_next().await? {
len += bytes.len();
let end = *ends_iter.next().unwrap();
if end == 0 {
splitted_parts.push(bytes);
continue;
}
let full_range = ranges[current_offset..end]
.iter()
.cloned()
.reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
.unwrap();
let bytes = if splitted_parts.is_empty() {
bytes
} else {
let mut out = Vec::with_capacity(full_range.len());
for x in splitted_parts.drain(..) {
out.extend_from_slice(&x);
}
out.extend_from_slice(&bytes);
Buffer::from(out)
};
assert_eq!(bytes.len(), full_range.len());
for range in &ranges[current_offset..end] {
let slice = bytes
.clone()
.sliced(range.start - full_range.start..range.end - full_range.start);
match out.raw_entry_mut().from_key(&range.start) {
RawEntryMut::Vacant(slot) => {
slot.insert(range.start, slice);
},
RawEntryMut::Occupied(mut slot) => {
if slot.get_mut().len() < slice.len() {
*slot.get_mut() = slice;
}
},
}
}
current_offset = end;
}
assert!(splitted_parts.is_empty());
PolarsResult::Ok(pl_async::Size::from(len as u64))
},
)
.await?;
Ok(out)
}
pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
let size = self.head(path).await?.size;
let parts = split_range(0..size as usize);
tune_with_concurrency_budget(
parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
|| async {
let mut stream = self.build_buffered_ranges_stream(path, parts);
let mut len = 0;
while let Some(bytes) = stream.try_next().await? {
len += bytes.len();
file.write_all(&bytes).await?;
}
assert_eq!(len, size as usize);
PolarsResult::Ok(pl_async::Size::from(len as u64))
},
)
.await?;
file.sync_all().await.map_err(PolarsError::from)?;
Ok(())
}
pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
with_concurrency_budget(1, || {
self.exec_with_rebuild_retry_on_err(|s| {
async move {
let head_result = self.io_metrics().record_io_read(0, s.head(path)).await;
if head_result.is_err() {
let get_range_0_1_result = self
.io_metrics()
.record_io_read(
0,
s.get_opts(
path,
object_store::GetOptions {
range: Some((0..1).into()),
..Default::default()
},
),
)
.await;
if let Ok(v) = get_range_0_1_result {
return Ok(v.meta);
}
}
let out = head_result?;
Ok(out)
}
})
})
.await
}
}
fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
let chunk_size = get_download_chunk_size();
let n_parts = [
(range.len().div_ceil(chunk_size)).max(1),
(range.len() / chunk_size).max(1),
]
.into_iter()
.min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
.unwrap();
let chunk_size = (range.len() / n_parts).max(1);
assert_eq!(n_parts, (range.len() / chunk_size).max(1));
let bytes_rem = range.len() % chunk_size;
(0..n_parts).map(move |part_no| {
let (start, end) = if part_no == 0 {
let end = range.start + chunk_size + bytes_rem;
let end = if end > range.end { range.end } else { end };
(range.start, end)
} else {
let start = bytes_rem + range.start + part_no * chunk_size;
(start, start + chunk_size)
};
start..end
})
}
fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
let chunk_size = get_download_chunk_size();
let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
let mut current_n_bytes = current_merged_range.len();
(0..ranges.len())
.filter_map(move |current_idx| {
let current_idx = 1 + current_idx;
if current_idx == ranges.len() {
Some((current_merged_range.clone(), current_idx))
} else {
let range = ranges[current_idx].clone();
let new_merged = current_merged_range.start.min(range.start)
..current_merged_range.end.max(range.end);
let (distance, is_overlapping) = {
let l = current_merged_range.end.min(range.end);
let r = current_merged_range.start.max(range.start);
(r.abs_diff(l), r < l)
};
let should_merge = is_overlapping || {
let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
<= current_merged_range.len().abs_diff(chunk_size);
let gap_tolerance =
(current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
};
if should_merge {
current_merged_range = new_merged;
current_n_bytes += if is_overlapping {
range.len() - distance
} else {
range.len()
};
None
} else {
let out = (current_merged_range.clone(), current_idx);
current_merged_range = range;
current_n_bytes = current_merged_range.len();
Some(out)
}
}
})
.flat_map(|x| {
let (range, end) = x;
let split = split_range(range);
let len = split.len();
split
.enumerate()
.map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
})
}
#[cfg(test)]
mod tests {
#[test]
fn test_split_range() {
use super::{get_download_chunk_size, split_range};
let chunk_size = get_download_chunk_size();
assert_eq!(chunk_size, 64 * 1024 * 1024);
#[allow(clippy::single_range_in_vec_init)]
{
assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
}
let n = 4 * chunk_size / 3;
#[allow(clippy::single_range_in_vec_init)]
{
assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
}
assert_eq!(
split_range(0..n + 1).collect::<Vec<_>>(),
[0..44739243, 44739243..89478486]
);
let n = 12 * chunk_size / 5;
assert_eq!(
split_range(0..n).collect::<Vec<_>>(),
[0..80530637, 80530637..161061273]
);
assert_eq!(
split_range(0..n + 1).collect::<Vec<_>>(),
[0..53687092, 53687092..107374183, 107374183..161061274]
);
}
#[test]
fn test_merge_ranges() {
use super::{get_download_chunk_size, merge_ranges};
let chunk_size = get_download_chunk_size();
assert_eq!(chunk_size, 64 * 1024 * 1024);
assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
assert_eq!(
merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
[(0..66584576, 0), (66584576..133169152, 2)]
);
assert_eq!(
merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
[(0..1048578, 2)]
);
assert_eq!(
merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
[(0..1, 1), (1048578..1048579, 2)]
);
assert_eq!(
merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
[(0..11, 2)]
);
assert_eq!(
merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
[(0..11, 2)]
);
assert_eq!(
merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
.collect::<Vec<_>>(),
[(0..80 * 1024 * 1024, 2)]
);
}
}