use std::ops::Range;
use std::sync::Arc;
use polars_buffer::Buffer;
use polars_core::prelude::PlHashMap;
use polars_core::series::IsSorted;
use polars_core::utils::arrow::bitmap::Bitmap;
use polars_error::PolarsResult;
use polars_io::predicates::ScanIOPredicate;
use polars_io::prelude::{FileMetadata, create_sorting_map};
use polars_io::utils::byte_source::{ByteSource, DynByteSource};
use polars_parquet::read::RowGroupMetadata;
use polars_utils::pl_str::PlSmallStr;
use crate::metrics::OptIOMetrics;
use crate::nodes::io_sources::parquet::projection::ArrowFieldProjection;
use crate::utils::tokio_handle_ext;
pub(super) struct RowGroupData {
pub(super) fetched_bytes: FetchedBytes,
pub(super) row_offset: usize,
pub(super) slice: Option<(usize, usize)>,
pub(super) row_group_metadata: RowGroupMetadata,
pub(super) sorting_map: Vec<(usize, IsSorted)>,
}
pub(super) struct RowGroupDataFetcher {
pub(super) projection: Arc<[ArrowFieldProjection]>,
pub(super) is_full_projection: bool,
#[allow(unused)] pub(super) predicate: Option<ScanIOPredicate>,
pub(super) slice_range: Option<Range<usize>>,
pub(super) memory_prefetch_func: fn(&[u8]) -> (),
pub(super) metadata: Arc<FileMetadata>,
pub(super) byte_source: Arc<DynByteSource>,
pub(super) io_metrics: OptIOMetrics,
pub(super) row_group_slice: Range<usize>,
pub(super) row_group_mask: Option<Bitmap>,
pub(super) row_offset: usize,
}
impl RowGroupDataFetcher {
pub(super) async fn next(
&mut self,
) -> Option<PolarsResult<tokio_handle_ext::AbortOnDropHandle<PolarsResult<RowGroupData>>>> {
while !self.row_group_slice.is_empty() {
let idx = self.row_group_slice.start;
self.row_group_slice.start += 1;
let row_group_metadata = &self.metadata.row_groups[idx];
let current_row_offset = self.row_offset;
let num_rows = row_group_metadata.num_rows();
let sorting_map = create_sorting_map(row_group_metadata);
self.row_offset = current_row_offset.saturating_add(num_rows);
let slice = if let Some(slice_range) = self.slice_range.as_mut() {
let rg_row_start = slice_range.start;
let rg_row_end = slice_range.end.min(num_rows);
*slice_range = slice_range.start.saturating_sub(num_rows)
..slice_range.end.saturating_sub(num_rows);
Some((rg_row_start, rg_row_end - rg_row_start))
} else {
None
};
if let Some(row_group_mask) = self.row_group_mask.as_mut() {
let do_skip = row_group_mask.get_bit(0);
row_group_mask.slice(1, self.row_group_slice.len());
if do_skip {
continue;
}
}
let metadata = self.metadata.clone();
let current_byte_source = self.byte_source.clone();
let io_metrics = self.io_metrics.clone();
let projection = self.projection.clone();
let is_full_projection = self.is_full_projection;
let memory_prefetch_func = self.memory_prefetch_func;
let io_runtime = polars_io::pl_async::get_runtime();
let handle = io_runtime.spawn(async move {
let row_group_metadata = &metadata.row_groups[idx];
let fetched_bytes = if let DynByteSource::Buffer(mem_slice) =
current_byte_source.as_ref()
{
if memory_prefetch_func as usize
!= polars_utils::mem::prefetch::no_prefetch as *const () as usize
{
let slice = mem_slice.0.as_ref();
if !is_full_projection {
for range in get_row_group_byte_ranges_for_projection(
row_group_metadata,
&mut projection.iter().map(|x| &x.arrow_field().name),
) {
memory_prefetch_func(unsafe { slice.get_unchecked(range) })
}
} else {
let range = row_group_metadata.full_byte_range();
let range = range.start as usize..range.end as usize;
memory_prefetch_func(unsafe { slice.get_unchecked(range) })
};
}
let mem_slice = mem_slice.0.clone();
FetchedBytes::Buffer {
offset: 0,
buffer: mem_slice,
}
} else if !is_full_projection {
let mut total_bytes: u64 = 0;
let mut ranges = get_row_group_byte_ranges_for_projection(
row_group_metadata,
&mut projection.iter().map(|x| &x.arrow_field().name),
)
.inspect(|range| total_bytes += range.len() as u64)
.collect::<Vec<_>>();
let n_ranges = ranges.len();
let bytes_map = io_metrics
.record_download(total_bytes, current_byte_source.get_ranges(&mut ranges))
.await?;
assert_eq!(bytes_map.len(), n_ranges);
FetchedBytes::BytesMap(bytes_map)
} else {
let mut total_bytes: u64 = 0;
let mut ranges = row_group_metadata
.byte_ranges_iter()
.map(|x| x.start as usize..x.end as usize)
.inspect(|range| total_bytes += range.len() as u64)
.collect::<Vec<_>>();
let n_ranges = ranges.len();
let bytes_map = io_metrics
.record_download(total_bytes, current_byte_source.get_ranges(&mut ranges))
.await?;
assert_eq!(bytes_map.len(), n_ranges);
FetchedBytes::BytesMap(bytes_map)
};
PolarsResult::Ok(RowGroupData {
fetched_bytes,
row_offset: current_row_offset,
slice,
row_group_metadata: row_group_metadata.clone(),
sorting_map,
})
});
let handle = tokio_handle_ext::AbortOnDropHandle(handle);
return Some(Ok(handle));
}
None
}
}
pub(super) enum FetchedBytes {
Buffer { buffer: Buffer<u8>, offset: usize },
BytesMap(PlHashMap<usize, Buffer<u8>>),
}
impl FetchedBytes {
pub(super) fn get_range(&self, range: std::ops::Range<usize>) -> Buffer<u8> {
match self {
Self::Buffer { buffer, offset } => {
let offset = *offset;
debug_assert!(range.start >= offset);
buffer
.clone()
.sliced(range.start - offset..range.end - offset)
},
Self::BytesMap(v) => {
let v = v.get(&range.start).unwrap();
debug_assert_eq!(v.len(), range.len());
v.clone()
},
}
}
}
fn get_row_group_byte_ranges_for_projection<'a>(
row_group_metadata: &'a RowGroupMetadata,
columns: &'a mut dyn Iterator<Item = &PlSmallStr>,
) -> impl Iterator<Item = std::ops::Range<usize>> + 'a {
columns.flat_map(|col_name| {
row_group_metadata
.columns_under_root_iter(col_name)
.into_iter()
.flatten()
.map(|col| {
let byte_range = col.byte_range();
byte_range.start as usize..byte_range.end as usize
})
})
}