1use std::sync::Arc;
7
8use arrow::{array::RecordBatch, datatypes::SchemaRef};
9#[cfg(feature = "tokio-runtime")]
10use tokio::sync::mpsc;
11
12use crate::{
13 error::{Error, Result},
14 streaming::DataSource,
15};
16
17#[cfg(feature = "tokio-runtime")]
41pub struct AsyncPrefetchDataset {
42 receiver: mpsc::Receiver<Result<RecordBatch>>,
43 schema: SchemaRef,
44 #[allow(dead_code)] handle: tokio::task::JoinHandle<()>,
46}
47
48#[cfg(feature = "tokio-runtime")]
49impl AsyncPrefetchDataset {
50 pub fn new(mut source: Box<dyn DataSource>, prefetch_size: usize) -> Self {
57 let schema = source.schema();
58 let (tx, rx) = mpsc::channel(prefetch_size.max(1));
59
60 let handle = tokio::spawn(async move {
61 loop {
62 match source.next_batch() {
63 Ok(Some(batch)) => {
64 if tx.send(Ok(batch)).await.is_err() {
65 break;
67 }
68 }
69 Ok(None) => break, Err(e) => {
71 let _ = tx.send(Err(e)).await;
72 break;
73 }
74 }
75 }
76 });
77
78 Self {
79 receiver: rx,
80 schema,
81 handle,
82 }
83 }
84
85 pub fn from_parquet(
97 path: impl AsRef<std::path::Path>,
98 batch_size: usize,
99 prefetch_size: usize,
100 ) -> Result<Self> {
101 let source = crate::streaming::ParquetSource::new(path, batch_size)?;
102 Ok(Self::new(Box::new(source), prefetch_size))
103 }
104
105 pub fn schema(&self) -> SchemaRef {
107 Arc::clone(&self.schema)
108 }
109
110 pub async fn next(&mut self) -> Option<Result<RecordBatch>> {
114 self.receiver.recv().await
115 }
116
117 pub fn try_next(&mut self) -> Option<Result<RecordBatch>> {
121 self.receiver.try_recv().ok()
122 }
123
124 pub fn buffered_count(&self) -> usize {
126 self.receiver.len()
127 }
128}
129
130#[cfg(feature = "tokio-runtime")]
131impl std::fmt::Debug for AsyncPrefetchDataset {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 f.debug_struct("AsyncPrefetchDataset")
134 .field("buffered", &self.receiver.len())
135 .finish_non_exhaustive()
136 }
137}
138
139#[cfg(feature = "tokio-runtime")]
141#[derive(Debug, Default)]
142pub struct AsyncPrefetchBuilder {
143 batch_size: Option<usize>,
144 prefetch_size: Option<usize>,
145}
146
147#[cfg(feature = "tokio-runtime")]
148impl AsyncPrefetchBuilder {
149 pub fn new() -> Self {
151 Self::default()
152 }
153
154 #[must_use]
156 pub fn batch_size(mut self, size: usize) -> Self {
157 self.batch_size = Some(size);
158 self
159 }
160
161 #[must_use]
163 pub fn prefetch_size(mut self, size: usize) -> Self {
164 self.prefetch_size = Some(size);
165 self
166 }
167
168 pub fn from_parquet(self, path: impl AsRef<std::path::Path>) -> Result<AsyncPrefetchDataset> {
174 let batch_size = self.batch_size.unwrap_or(1024);
175 let prefetch_size = self.prefetch_size.unwrap_or(4);
176
177 if batch_size == 0 {
178 return Err(Error::invalid_config("batch_size must be greater than 0"));
179 }
180
181 AsyncPrefetchDataset::from_parquet(path, batch_size, prefetch_size)
182 }
183
184 pub fn from_source(self, source: Box<dyn DataSource>) -> AsyncPrefetchDataset {
186 let prefetch_size = self.prefetch_size.unwrap_or(4);
187 AsyncPrefetchDataset::new(source, prefetch_size)
188 }
189}
190
191#[cfg(feature = "tokio-runtime")]
196pub struct SyncPrefetchDataset {
197 inner: AsyncPrefetchDataset,
198 runtime: tokio::runtime::Handle,
199}
200
201#[cfg(feature = "tokio-runtime")]
202impl SyncPrefetchDataset {
203 pub fn new(dataset: AsyncPrefetchDataset, runtime: tokio::runtime::Handle) -> Self {
210 Self {
211 inner: dataset,
212 runtime,
213 }
214 }
215
216 pub fn schema(&self) -> SchemaRef {
218 self.inner.schema()
219 }
220
221 pub fn next_blocking(&mut self) -> Option<Result<RecordBatch>> {
223 self.runtime.block_on(self.inner.next())
224 }
225}
226
227#[cfg(feature = "tokio-runtime")]
228impl std::fmt::Debug for SyncPrefetchDataset {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 f.debug_struct("SyncPrefetchDataset")
231 .field("inner", &self.inner)
232 .finish_non_exhaustive()
233 }
234}
235
236#[cfg(test)]
237#[cfg(feature = "tokio-runtime")]
238mod tests {
239 use std::sync::Arc;
240
241 use arrow::{
242 array::{Int32Array, StringArray},
243 datatypes::{DataType, Field, Schema},
244 };
245
246 use super::*;
247 use crate::streaming::MemorySource;
248
249 fn create_test_batches(count: usize, rows_per_batch: usize) -> Vec<RecordBatch> {
250 let schema = Arc::new(Schema::new(vec![
251 Field::new("id", DataType::Int32, false),
252 Field::new("name", DataType::Utf8, false),
253 ]));
254
255 (0..count)
256 .map(|batch_idx| {
257 let start = (batch_idx * rows_per_batch) as i32;
258 let ids: Vec<i32> = (start..start + rows_per_batch as i32).collect();
259 let names: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
260
261 RecordBatch::try_new(
262 Arc::clone(&schema),
263 vec![
264 Arc::new(Int32Array::from(ids)),
265 Arc::new(StringArray::from(names)),
266 ],
267 )
268 .ok()
269 .unwrap_or_else(|| panic!("Should create batch"))
270 })
271 .collect()
272 }
273
274 #[tokio::test]
275 async fn test_async_prefetch_creation() {
276 let batches = create_test_batches(5, 10);
277 let source = MemorySource::new(batches)
278 .ok()
279 .unwrap_or_else(|| panic!("Should create source"));
280
281 let dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
282 assert_eq!(dataset.schema().fields().len(), 2);
283 }
284
285 #[tokio::test]
286 async fn test_async_prefetch_iteration() {
287 let batches = create_test_batches(5, 10);
288 let source = MemorySource::new(batches)
289 .ok()
290 .unwrap_or_else(|| panic!("Should create source"));
291
292 let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
293
294 let mut count = 0;
295 let mut total_rows = 0;
296 while let Some(result) = dataset.next().await {
297 let batch = result.ok().unwrap_or_else(|| panic!("Should get batch"));
298 count += 1;
299 total_rows += batch.num_rows();
300 }
301
302 assert_eq!(count, 5);
303 assert_eq!(total_rows, 50);
304 }
305
306 #[tokio::test]
307 async fn test_async_prefetch_try_next() {
308 let batches = create_test_batches(3, 10);
309 let source = MemorySource::new(batches)
310 .ok()
311 .unwrap_or_else(|| panic!("Should create source"));
312
313 let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 10);
314
315 tokio::task::yield_now().await;
317 tokio::task::yield_now().await;
318
319 let mut count = 0;
321 while dataset.try_next().is_some() {
322 count += 1;
323 }
324
325 assert!(count > 0, "Should have prefetched some batches");
326 }
327
328 #[tokio::test]
329 async fn test_async_prefetch_buffered_count() {
330 let batches = create_test_batches(10, 5);
331 let source = MemorySource::new(batches)
332 .ok()
333 .unwrap_or_else(|| panic!("Should create source"));
334
335 let dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
336
337 for _ in 0..10 {
339 tokio::task::yield_now().await;
340 }
341
342 let buffered = dataset.buffered_count();
344 assert!(buffered <= 4, "Should not exceed prefetch size");
345 }
346
347 #[tokio::test]
348 async fn test_async_prefetch_builder() {
349 let batches = create_test_batches(3, 10);
350 let source = MemorySource::new(batches)
351 .ok()
352 .unwrap_or_else(|| panic!("Should create source"));
353
354 let mut dataset = AsyncPrefetchBuilder::new()
355 .batch_size(10)
356 .prefetch_size(2)
357 .from_source(Box::new(source));
358
359 let mut count = 0;
360 while let Some(result) = dataset.next().await {
361 assert!(result.is_ok());
362 count += 1;
363 }
364 assert_eq!(count, 3);
365 }
366
367 #[tokio::test]
368 async fn test_async_prefetch_debug() {
369 let batches = create_test_batches(2, 5);
370 let source = MemorySource::new(batches)
371 .ok()
372 .unwrap_or_else(|| panic!("Should create source"));
373
374 let dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
375 let debug_str = format!("{:?}", dataset);
376 assert!(debug_str.contains("AsyncPrefetchDataset"));
377 }
378
379 #[tokio::test]
380 async fn test_async_prefetch_parquet_roundtrip() {
381 let batch = create_test_batches(1, 100)[0].clone();
383 let dataset = crate::ArrowDataset::from_batch(batch)
384 .ok()
385 .unwrap_or_else(|| panic!("Should create dataset"));
386
387 let temp_dir = tempfile::tempdir()
389 .ok()
390 .unwrap_or_else(|| panic!("Should create temp dir"));
391 let path = temp_dir.path().join("async_test.parquet");
392 dataset
393 .to_parquet(&path)
394 .ok()
395 .unwrap_or_else(|| panic!("Should write parquet"));
396
397 let mut async_dataset = AsyncPrefetchDataset::from_parquet(&path, 25, 4)
399 .ok()
400 .unwrap_or_else(|| panic!("Should create async dataset"));
401
402 let mut total = 0;
403 while let Some(result) = async_dataset.next().await {
404 let batch = result.ok().unwrap_or_else(|| panic!("Should get batch"));
405 total += batch.num_rows();
406 }
407 assert_eq!(total, 100);
408 }
409
410 #[tokio::test]
411 async fn test_sync_prefetch_wrapper() {
412 let batches = create_test_batches(3, 10);
413 let source = MemorySource::new(batches)
414 .ok()
415 .unwrap_or_else(|| panic!("Should create source"));
416
417 let async_dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
418 let handle = tokio::runtime::Handle::current();
419 let sync_dataset = SyncPrefetchDataset::new(async_dataset, handle);
420
421 assert_eq!(sync_dataset.schema().fields().len(), 2);
422
423 let debug_str = format!("{:?}", sync_dataset);
424 assert!(debug_str.contains("SyncPrefetchDataset"));
425 }
426
427 #[tokio::test]
428 async fn test_builder_zero_batch_size_error() {
429 let result = AsyncPrefetchBuilder::new()
430 .batch_size(0)
431 .from_parquet("/nonexistent.parquet");
432
433 assert!(result.is_err());
434 }
435
436 #[tokio::test]
437 async fn test_builder_defaults() {
438 let batches = create_test_batches(2, 5);
439 let source = MemorySource::new(batches)
440 .ok()
441 .unwrap_or_else(|| panic!("Should create source"));
442
443 let dataset = AsyncPrefetchBuilder::new().from_source(Box::new(source));
445
446 assert_eq!(dataset.schema().fields().len(), 2);
447 }
448
449 #[tokio::test]
450 async fn test_async_prefetch_quick_exhaustion() {
451 struct QuickExhaustSource {
453 schema: SchemaRef,
454 exhausted: bool,
455 }
456
457 impl crate::streaming::DataSource for QuickExhaustSource {
458 fn schema(&self) -> SchemaRef {
459 Arc::clone(&self.schema)
460 }
461
462 fn next_batch(&mut self) -> crate::Result<Option<RecordBatch>> {
463 if self.exhausted {
464 Ok(None)
465 } else {
466 self.exhausted = true;
467 Ok(Some(create_test_batches(1, 1)[0].clone()))
468 }
469 }
470 }
471
472 let source = QuickExhaustSource {
473 schema: create_test_batches(1, 1)[0].schema(),
474 exhausted: false,
475 };
476
477 let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
478
479 let first = dataset.next().await;
481 assert!(first.is_some());
482 assert!(first.unwrap().is_ok());
483
484 let second = dataset.next().await;
486 assert!(second.is_none());
487 }
488
489 #[tokio::test]
490 async fn test_async_prefetch_single_batch() {
491 let batches = create_test_batches(1, 100);
492 let source = MemorySource::new(batches)
493 .ok()
494 .unwrap_or_else(|| panic!("Should create source"));
495
496 let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
497
498 let batch = dataset
499 .next()
500 .await
501 .unwrap_or_else(|| panic!("Should have batch"))
502 .ok()
503 .unwrap_or_else(|| panic!("Batch should be ok"));
504 assert_eq!(batch.num_rows(), 100);
505
506 assert!(dataset.next().await.is_none());
508 }
509
510 #[tokio::test]
511 async fn test_async_prefetch_large_prefetch_size() {
512 let batches = create_test_batches(3, 10);
514 let source = MemorySource::new(batches)
515 .ok()
516 .unwrap_or_else(|| panic!("Should create source"));
517
518 let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 100);
519
520 let mut count = 0;
521 while let Some(result) = dataset.next().await {
522 assert!(result.is_ok());
523 count += 1;
524 }
525 assert_eq!(count, 3);
526 }
527
528 #[tokio::test]
529 async fn test_async_prefetch_prefetch_size_one() {
530 let batches = create_test_batches(5, 10);
532 let source = MemorySource::new(batches)
533 .ok()
534 .unwrap_or_else(|| panic!("Should create source"));
535
536 let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 1);
537
538 let mut count = 0;
539 while let Some(result) = dataset.next().await {
540 assert!(result.is_ok());
541 count += 1;
542 }
543 assert_eq!(count, 5);
544 }
545
546 #[tokio::test]
547 async fn test_async_prefetch_error_source() {
548 struct ErrorSource {
550 schema: SchemaRef,
551 calls: usize,
552 }
553
554 impl crate::streaming::DataSource for ErrorSource {
555 fn schema(&self) -> SchemaRef {
556 Arc::clone(&self.schema)
557 }
558
559 fn next_batch(&mut self) -> crate::Result<Option<RecordBatch>> {
560 self.calls += 1;
561 if self.calls > 2 {
562 Err(crate::Error::storage("Simulated error"))
563 } else {
564 Ok(Some(create_test_batches(1, 5)[0].clone()))
565 }
566 }
567 }
568
569 let source = ErrorSource {
570 schema: create_test_batches(1, 1)[0].schema(),
571 calls: 0,
572 };
573
574 let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
575
576 let b1 = dataset.next().await;
578 assert!(b1.is_some());
579 assert!(b1.unwrap().is_ok());
580
581 let b2 = dataset.next().await;
582 assert!(b2.is_some());
583 assert!(b2.unwrap().is_ok());
584
585 let b3 = dataset.next().await;
587 assert!(b3.is_some());
588 assert!(b3.unwrap().is_err());
589 }
590
591 #[tokio::test]
592 async fn test_async_prefetch_try_next_after_exhaustion() {
593 let batches = create_test_batches(1, 5);
595 let source = MemorySource::new(batches)
596 .ok()
597 .unwrap_or_else(|| panic!("Should create source"));
598
599 let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
600
601 let _ = dataset.next().await;
603
604 tokio::task::yield_now().await;
606
607 let result = dataset.try_next();
609 assert!(result.is_none());
610 }
611
612 #[tokio::test]
613 async fn test_builder_with_prefetch_size() {
614 let batches = create_test_batches(5, 10);
615 let source = MemorySource::new(batches)
616 .ok()
617 .unwrap_or_else(|| panic!("Should create source"));
618
619 let mut dataset = AsyncPrefetchBuilder::new()
620 .prefetch_size(2)
621 .from_source(Box::new(source));
622
623 let mut count = 0;
624 while let Some(result) = dataset.next().await {
625 assert!(result.is_ok());
626 count += 1;
627 }
628 assert_eq!(count, 5);
629 }
630
631 #[tokio::test]
632 async fn test_builder_from_parquet_roundtrip() {
633 let batch = create_test_batches(1, 50)[0].clone();
635 let dataset = crate::ArrowDataset::from_batch(batch)
636 .ok()
637 .unwrap_or_else(|| panic!("Should create dataset"));
638
639 let temp_dir = tempfile::tempdir()
640 .ok()
641 .unwrap_or_else(|| panic!("Should create temp dir"));
642 let path = temp_dir.path().join("builder_test.parquet");
643 dataset
644 .to_parquet(&path)
645 .ok()
646 .unwrap_or_else(|| panic!("Should write parquet"));
647
648 let mut async_dataset = AsyncPrefetchBuilder::new()
650 .batch_size(10)
651 .prefetch_size(3)
652 .from_parquet(&path)
653 .ok()
654 .unwrap_or_else(|| panic!("Should create async dataset"));
655
656 let mut total = 0;
657 while let Some(result) = async_dataset.next().await {
658 total += result.ok().unwrap().num_rows();
659 }
660 assert_eq!(total, 50);
661 }
662
663 #[test]
664 fn test_builder_debug() {
665 let builder = AsyncPrefetchBuilder::new().batch_size(32).prefetch_size(8);
666
667 let debug_str = format!("{:?}", builder);
668 assert!(debug_str.contains("AsyncPrefetchBuilder"));
669 }
670}