use std::sync::Arc;
use arrow::{array::RecordBatch, datatypes::SchemaRef};
#[cfg(feature = "tokio-runtime")]
use tokio::sync::mpsc;
use crate::{
error::{Error, Result},
streaming::DataSource,
};
#[cfg(feature = "tokio-runtime")]
pub struct AsyncPrefetchDataset {
receiver: mpsc::Receiver<Result<RecordBatch>>,
schema: SchemaRef,
#[allow(dead_code)] handle: tokio::task::JoinHandle<()>,
}
#[cfg(feature = "tokio-runtime")]
impl AsyncPrefetchDataset {
pub fn new(mut source: Box<dyn DataSource>, prefetch_size: usize) -> Self {
let schema = source.schema();
let (tx, rx) = mpsc::channel(prefetch_size.max(1));
let handle = tokio::spawn(async move {
loop {
match source.next_batch() {
Ok(Some(batch)) => {
if tx.send(Ok(batch)).await.is_err() {
break;
}
}
Ok(None) => break, Err(e) => {
let _ = tx.send(Err(e)).await;
break;
}
}
}
});
Self {
receiver: rx,
schema,
handle,
}
}
pub fn from_parquet(
path: impl AsRef<std::path::Path>,
batch_size: usize,
prefetch_size: usize,
) -> Result<Self> {
let source = crate::streaming::ParquetSource::new(path, batch_size)?;
Ok(Self::new(Box::new(source), prefetch_size))
}
pub fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
pub async fn next(&mut self) -> Option<Result<RecordBatch>> {
self.receiver.recv().await
}
pub fn try_next(&mut self) -> Option<Result<RecordBatch>> {
self.receiver.try_recv().ok()
}
pub fn buffered_count(&self) -> usize {
self.receiver.len()
}
}
#[cfg(feature = "tokio-runtime")]
impl std::fmt::Debug for AsyncPrefetchDataset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncPrefetchDataset")
.field("buffered", &self.receiver.len())
.finish_non_exhaustive()
}
}
#[cfg(feature = "tokio-runtime")]
#[derive(Debug, Default)]
pub struct AsyncPrefetchBuilder {
batch_size: Option<usize>,
prefetch_size: Option<usize>,
}
#[cfg(feature = "tokio-runtime")]
impl AsyncPrefetchBuilder {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn batch_size(mut self, size: usize) -> Self {
self.batch_size = Some(size);
self
}
#[must_use]
pub fn prefetch_size(mut self, size: usize) -> Self {
self.prefetch_size = Some(size);
self
}
pub fn from_parquet(self, path: impl AsRef<std::path::Path>) -> Result<AsyncPrefetchDataset> {
let batch_size = self.batch_size.unwrap_or(1024);
let prefetch_size = self.prefetch_size.unwrap_or(4);
if batch_size == 0 {
return Err(Error::invalid_config("batch_size must be greater than 0"));
}
AsyncPrefetchDataset::from_parquet(path, batch_size, prefetch_size)
}
pub fn from_source(self, source: Box<dyn DataSource>) -> AsyncPrefetchDataset {
let prefetch_size = self.prefetch_size.unwrap_or(4);
AsyncPrefetchDataset::new(source, prefetch_size)
}
}
#[cfg(feature = "tokio-runtime")]
pub struct SyncPrefetchDataset {
inner: AsyncPrefetchDataset,
runtime: tokio::runtime::Handle,
}
#[cfg(feature = "tokio-runtime")]
impl SyncPrefetchDataset {
pub fn new(dataset: AsyncPrefetchDataset, runtime: tokio::runtime::Handle) -> Self {
Self {
inner: dataset,
runtime,
}
}
pub fn schema(&self) -> SchemaRef {
self.inner.schema()
}
pub fn next_blocking(&mut self) -> Option<Result<RecordBatch>> {
self.runtime.block_on(self.inner.next())
}
}
#[cfg(feature = "tokio-runtime")]
impl std::fmt::Debug for SyncPrefetchDataset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SyncPrefetchDataset")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
mod tests {
use std::sync::Arc;
use arrow::{
array::{Int32Array, StringArray},
datatypes::{DataType, Field, Schema},
};
use super::*;
use crate::streaming::MemorySource;
fn create_test_batches(count: usize, rows_per_batch: usize) -> Vec<RecordBatch> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
(0..count)
.map(|batch_idx| {
let start = (batch_idx * rows_per_batch) as i32;
let ids: Vec<i32> = (start..start + rows_per_batch as i32).collect();
let names: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(ids)),
Arc::new(StringArray::from(names)),
],
)
.ok()
.unwrap_or_else(|| panic!("Should create batch"))
})
.collect()
}
#[tokio::test]
async fn test_async_prefetch_creation() {
let batches = create_test_batches(5, 10);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
assert_eq!(dataset.schema().fields().len(), 2);
}
#[tokio::test]
async fn test_async_prefetch_iteration() {
let batches = create_test_batches(5, 10);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
let mut count = 0;
let mut total_rows = 0;
while let Some(result) = dataset.next().await {
let batch = result.ok().unwrap_or_else(|| panic!("Should get batch"));
count += 1;
total_rows += batch.num_rows();
}
assert_eq!(count, 5);
assert_eq!(total_rows, 50);
}
#[tokio::test]
async fn test_async_prefetch_try_next() {
let batches = create_test_batches(3, 10);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 10);
tokio::task::yield_now().await;
tokio::task::yield_now().await;
let mut count = 0;
while dataset.try_next().is_some() {
count += 1;
}
assert!(count > 0, "Should have prefetched some batches");
}
#[tokio::test]
async fn test_async_prefetch_buffered_count() {
let batches = create_test_batches(10, 5);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
for _ in 0..10 {
tokio::task::yield_now().await;
}
let buffered = dataset.buffered_count();
assert!(buffered <= 4, "Should not exceed prefetch size");
}
#[tokio::test]
async fn test_async_prefetch_builder() {
let batches = create_test_batches(3, 10);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let mut dataset = AsyncPrefetchBuilder::new()
.batch_size(10)
.prefetch_size(2)
.from_source(Box::new(source));
let mut count = 0;
while let Some(result) = dataset.next().await {
assert!(result.is_ok());
count += 1;
}
assert_eq!(count, 3);
}
#[tokio::test]
async fn test_async_prefetch_debug() {
let batches = create_test_batches(2, 5);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
let debug_str = format!("{:?}", dataset);
assert!(debug_str.contains("AsyncPrefetchDataset"));
}
#[tokio::test]
async fn test_async_prefetch_parquet_roundtrip() {
let batch = create_test_batches(1, 100)[0].clone();
let dataset = crate::ArrowDataset::from_batch(batch)
.ok()
.unwrap_or_else(|| panic!("Should create dataset"));
let temp_dir = tempfile::tempdir()
.ok()
.unwrap_or_else(|| panic!("Should create temp dir"));
let path = temp_dir.path().join("async_test.parquet");
dataset
.to_parquet(&path)
.ok()
.unwrap_or_else(|| panic!("Should write parquet"));
let mut async_dataset = AsyncPrefetchDataset::from_parquet(&path, 25, 4)
.ok()
.unwrap_or_else(|| panic!("Should create async dataset"));
let mut total = 0;
while let Some(result) = async_dataset.next().await {
let batch = result.ok().unwrap_or_else(|| panic!("Should get batch"));
total += batch.num_rows();
}
assert_eq!(total, 100);
}
#[tokio::test]
async fn test_sync_prefetch_wrapper() {
let batches = create_test_batches(3, 10);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let async_dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
let handle = tokio::runtime::Handle::current();
let sync_dataset = SyncPrefetchDataset::new(async_dataset, handle);
assert_eq!(sync_dataset.schema().fields().len(), 2);
let debug_str = format!("{:?}", sync_dataset);
assert!(debug_str.contains("SyncPrefetchDataset"));
}
#[tokio::test]
async fn test_builder_zero_batch_size_error() {
let result = AsyncPrefetchBuilder::new()
.batch_size(0)
.from_parquet("/nonexistent.parquet");
assert!(result.is_err());
}
#[tokio::test]
async fn test_builder_defaults() {
let batches = create_test_batches(2, 5);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let dataset = AsyncPrefetchBuilder::new().from_source(Box::new(source));
assert_eq!(dataset.schema().fields().len(), 2);
}
#[tokio::test]
async fn test_async_prefetch_quick_exhaustion() {
struct QuickExhaustSource {
schema: SchemaRef,
exhausted: bool,
}
impl crate::streaming::DataSource for QuickExhaustSource {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn next_batch(&mut self) -> crate::Result<Option<RecordBatch>> {
if self.exhausted {
Ok(None)
} else {
self.exhausted = true;
Ok(Some(create_test_batches(1, 1)[0].clone()))
}
}
}
let source = QuickExhaustSource {
schema: create_test_batches(1, 1)[0].schema(),
exhausted: false,
};
let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
let first = dataset.next().await;
assert!(first.is_some());
assert!(first.unwrap().is_ok());
let second = dataset.next().await;
assert!(second.is_none());
}
#[tokio::test]
async fn test_async_prefetch_single_batch() {
let batches = create_test_batches(1, 100);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
let batch = dataset
.next()
.await
.unwrap_or_else(|| panic!("Should have batch"))
.ok()
.unwrap_or_else(|| panic!("Batch should be ok"));
assert_eq!(batch.num_rows(), 100);
assert!(dataset.next().await.is_none());
}
#[tokio::test]
async fn test_async_prefetch_large_prefetch_size() {
let batches = create_test_batches(3, 10);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 100);
let mut count = 0;
while let Some(result) = dataset.next().await {
assert!(result.is_ok());
count += 1;
}
assert_eq!(count, 3);
}
#[tokio::test]
async fn test_async_prefetch_prefetch_size_one() {
let batches = create_test_batches(5, 10);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 1);
let mut count = 0;
while let Some(result) = dataset.next().await {
assert!(result.is_ok());
count += 1;
}
assert_eq!(count, 5);
}
#[tokio::test]
async fn test_async_prefetch_error_source() {
struct ErrorSource {
schema: SchemaRef,
calls: usize,
}
impl crate::streaming::DataSource for ErrorSource {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn next_batch(&mut self) -> crate::Result<Option<RecordBatch>> {
self.calls += 1;
if self.calls > 2 {
Err(crate::Error::storage("Simulated error"))
} else {
Ok(Some(create_test_batches(1, 5)[0].clone()))
}
}
}
let source = ErrorSource {
schema: create_test_batches(1, 1)[0].schema(),
calls: 0,
};
let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
let b1 = dataset.next().await;
assert!(b1.is_some());
assert!(b1.unwrap().is_ok());
let b2 = dataset.next().await;
assert!(b2.is_some());
assert!(b2.unwrap().is_ok());
let b3 = dataset.next().await;
assert!(b3.is_some());
assert!(b3.unwrap().is_err());
}
#[tokio::test]
async fn test_async_prefetch_try_next_after_exhaustion() {
let batches = create_test_batches(1, 5);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
let _ = dataset.next().await;
tokio::task::yield_now().await;
let result = dataset.try_next();
assert!(result.is_none());
}
#[tokio::test]
async fn test_builder_with_prefetch_size() {
let batches = create_test_batches(5, 10);
let source = MemorySource::new(batches)
.ok()
.unwrap_or_else(|| panic!("Should create source"));
let mut dataset = AsyncPrefetchBuilder::new()
.prefetch_size(2)
.from_source(Box::new(source));
let mut count = 0;
while let Some(result) = dataset.next().await {
assert!(result.is_ok());
count += 1;
}
assert_eq!(count, 5);
}
#[tokio::test]
async fn test_builder_from_parquet_roundtrip() {
let batch = create_test_batches(1, 50)[0].clone();
let dataset = crate::ArrowDataset::from_batch(batch)
.ok()
.unwrap_or_else(|| panic!("Should create dataset"));
let temp_dir = tempfile::tempdir()
.ok()
.unwrap_or_else(|| panic!("Should create temp dir"));
let path = temp_dir.path().join("builder_test.parquet");
dataset
.to_parquet(&path)
.ok()
.unwrap_or_else(|| panic!("Should write parquet"));
let mut async_dataset = AsyncPrefetchBuilder::new()
.batch_size(10)
.prefetch_size(3)
.from_parquet(&path)
.ok()
.unwrap_or_else(|| panic!("Should create async dataset"));
let mut total = 0;
while let Some(result) = async_dataset.next().await {
total += result.ok().unwrap().num_rows();
}
assert_eq!(total, 50);
}
#[test]
fn test_builder_debug() {
let builder = AsyncPrefetchBuilder::new().batch_size(32).prefetch_size(8);
let debug_str = format!("{:?}", builder);
assert!(debug_str.contains("AsyncPrefetchBuilder"));
}
}