Skip to main content

alimentar/
async_prefetch.rs

1//! Async prefetch for parallel I/O in streaming datasets.
2//!
3//! Provides [`AsyncPrefetchDataset`] which spawns a background task to read
4//! batches ahead of time, reducing I/O latency in the training loop.
5
6use std::sync::Arc;
7
8use arrow::{array::RecordBatch, datatypes::SchemaRef};
9#[cfg(feature = "tokio-runtime")]
10use tokio::sync::mpsc;
11
12use crate::{
13    error::{Error, Result},
14    streaming::DataSource,
15};
16
17/// A streaming dataset with async prefetch for parallel I/O.
18///
19/// Unlike [`StreamingDataset`](crate::streaming::StreamingDataset) which reads
20/// synchronously, `AsyncPrefetchDataset` spawns a background task that reads
21/// batches into a channel, allowing the main thread to process while I/O
22/// happens.
23///
24/// # Example
25///
26/// ```ignore
27/// use alimentar::async_prefetch::AsyncPrefetchDataset;
28///
29/// #[tokio::main]
30/// async fn main() {
31///     let dataset = AsyncPrefetchDataset::from_parquet("data.parquet", 1024, 4)
32///         .await
33///         .unwrap();
34///
35///     while let Some(batch) = dataset.next().await {
36///         println!("Processing batch with {} rows", batch.num_rows());
37///     }
38/// }
39/// ```
40#[cfg(feature = "tokio-runtime")]
41pub struct AsyncPrefetchDataset {
42    receiver: mpsc::Receiver<Result<RecordBatch>>,
43    schema: SchemaRef,
44    #[allow(dead_code)] // Kept alive to prevent task cancellation
45    handle: tokio::task::JoinHandle<()>,
46}
47
48#[cfg(feature = "tokio-runtime")]
49impl AsyncPrefetchDataset {
50    /// Creates a new async prefetch dataset from a data source.
51    ///
52    /// # Arguments
53    ///
54    /// * `source` - The data source to read from
55    /// * `prefetch_size` - Number of batches to buffer ahead
56    pub fn new(mut source: Box<dyn DataSource>, prefetch_size: usize) -> Self {
57        let schema = source.schema();
58        let (tx, rx) = mpsc::channel(prefetch_size.max(1));
59
60        let handle = tokio::spawn(async move {
61            loop {
62                match source.next_batch() {
63                    Ok(Some(batch)) => {
64                        if tx.send(Ok(batch)).await.is_err() {
65                            // Receiver dropped, stop reading
66                            break;
67                        }
68                    }
69                    Ok(None) => break, // End of source
70                    Err(e) => {
71                        let _ = tx.send(Err(e)).await;
72                        break;
73                    }
74                }
75            }
76        });
77
78        Self {
79            receiver: rx,
80            schema,
81            handle,
82        }
83    }
84
85    /// Creates an async prefetch dataset from a Parquet file.
86    ///
87    /// # Arguments
88    ///
89    /// * `path` - Path to the Parquet file
90    /// * `batch_size` - Number of rows per batch
91    /// * `prefetch_size` - Number of batches to buffer ahead
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if the file cannot be opened.
96    pub fn from_parquet(
97        path: impl AsRef<std::path::Path>,
98        batch_size: usize,
99        prefetch_size: usize,
100    ) -> Result<Self> {
101        let source = crate::streaming::ParquetSource::new(path, batch_size)?;
102        Ok(Self::new(Box::new(source), prefetch_size))
103    }
104
105    /// Returns the schema of the dataset.
106    pub fn schema(&self) -> SchemaRef {
107        Arc::clone(&self.schema)
108    }
109
110    /// Receives the next batch asynchronously.
111    ///
112    /// Returns `None` when the source is exhausted.
113    pub async fn next(&mut self) -> Option<Result<RecordBatch>> {
114        self.receiver.recv().await
115    }
116
117    /// Tries to receive a batch without waiting.
118    ///
119    /// Returns `None` if no batch is available or the source is exhausted.
120    pub fn try_next(&mut self) -> Option<Result<RecordBatch>> {
121        self.receiver.try_recv().ok()
122    }
123
124    /// Returns the number of batches currently buffered.
125    pub fn buffered_count(&self) -> usize {
126        self.receiver.len()
127    }
128}
129
130#[cfg(feature = "tokio-runtime")]
131impl std::fmt::Debug for AsyncPrefetchDataset {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        f.debug_struct("AsyncPrefetchDataset")
134            .field("buffered", &self.receiver.len())
135            .finish_non_exhaustive()
136    }
137}
138
139/// Builder for creating async prefetch datasets.
140#[cfg(feature = "tokio-runtime")]
141#[derive(Debug, Default)]
142pub struct AsyncPrefetchBuilder {
143    batch_size: Option<usize>,
144    prefetch_size: Option<usize>,
145}
146
147#[cfg(feature = "tokio-runtime")]
148impl AsyncPrefetchBuilder {
149    /// Creates a new builder.
150    pub fn new() -> Self {
151        Self::default()
152    }
153
154    /// Sets the batch size (rows per batch).
155    #[must_use]
156    pub fn batch_size(mut self, size: usize) -> Self {
157        self.batch_size = Some(size);
158        self
159    }
160
161    /// Sets the prefetch buffer size (number of batches).
162    #[must_use]
163    pub fn prefetch_size(mut self, size: usize) -> Self {
164        self.prefetch_size = Some(size);
165        self
166    }
167
168    /// Builds an async prefetch dataset from a Parquet file.
169    ///
170    /// # Errors
171    ///
172    /// Returns an error if the file cannot be opened.
173    pub fn from_parquet(self, path: impl AsRef<std::path::Path>) -> Result<AsyncPrefetchDataset> {
174        let batch_size = self.batch_size.unwrap_or(1024);
175        let prefetch_size = self.prefetch_size.unwrap_or(4);
176
177        if batch_size == 0 {
178            return Err(Error::invalid_config("batch_size must be greater than 0"));
179        }
180
181        AsyncPrefetchDataset::from_parquet(path, batch_size, prefetch_size)
182    }
183
184    /// Builds an async prefetch dataset from a data source.
185    pub fn from_source(self, source: Box<dyn DataSource>) -> AsyncPrefetchDataset {
186        let prefetch_size = self.prefetch_size.unwrap_or(4);
187        AsyncPrefetchDataset::new(source, prefetch_size)
188    }
189}
190
191/// Synchronous wrapper for async prefetch that works with DataLoader.
192///
193/// This allows using async prefetch with the existing synchronous DataLoader
194/// API by blocking on the async operations internally.
195#[cfg(feature = "tokio-runtime")]
196pub struct SyncPrefetchDataset {
197    inner: AsyncPrefetchDataset,
198    runtime: tokio::runtime::Handle,
199}
200
201#[cfg(feature = "tokio-runtime")]
202impl SyncPrefetchDataset {
203    /// Creates a new sync wrapper around an async prefetch dataset.
204    ///
205    /// # Arguments
206    ///
207    /// * `dataset` - The async dataset to wrap
208    /// * `runtime` - Handle to the tokio runtime
209    pub fn new(dataset: AsyncPrefetchDataset, runtime: tokio::runtime::Handle) -> Self {
210        Self {
211            inner: dataset,
212            runtime,
213        }
214    }
215
216    /// Returns the schema.
217    pub fn schema(&self) -> SchemaRef {
218        self.inner.schema()
219    }
220
221    /// Gets the next batch, blocking if necessary.
222    pub fn next_blocking(&mut self) -> Option<Result<RecordBatch>> {
223        self.runtime.block_on(self.inner.next())
224    }
225}
226
227#[cfg(feature = "tokio-runtime")]
228impl std::fmt::Debug for SyncPrefetchDataset {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        f.debug_struct("SyncPrefetchDataset")
231            .field("inner", &self.inner)
232            .finish_non_exhaustive()
233    }
234}
235
236#[cfg(test)]
237#[cfg(feature = "tokio-runtime")]
238mod tests {
239    use std::sync::Arc;
240
241    use arrow::{
242        array::{Int32Array, StringArray},
243        datatypes::{DataType, Field, Schema},
244    };
245
246    use super::*;
247    use crate::streaming::MemorySource;
248
249    fn create_test_batches(count: usize, rows_per_batch: usize) -> Vec<RecordBatch> {
250        let schema = Arc::new(Schema::new(vec![
251            Field::new("id", DataType::Int32, false),
252            Field::new("name", DataType::Utf8, false),
253        ]));
254
255        (0..count)
256            .map(|batch_idx| {
257                let start = (batch_idx * rows_per_batch) as i32;
258                let ids: Vec<i32> = (start..start + rows_per_batch as i32).collect();
259                let names: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
260
261                RecordBatch::try_new(
262                    Arc::clone(&schema),
263                    vec![
264                        Arc::new(Int32Array::from(ids)),
265                        Arc::new(StringArray::from(names)),
266                    ],
267                )
268                .ok()
269                .unwrap_or_else(|| panic!("Should create batch"))
270            })
271            .collect()
272    }
273
274    #[tokio::test]
275    async fn test_async_prefetch_creation() {
276        let batches = create_test_batches(5, 10);
277        let source = MemorySource::new(batches)
278            .ok()
279            .unwrap_or_else(|| panic!("Should create source"));
280
281        let dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
282        assert_eq!(dataset.schema().fields().len(), 2);
283    }
284
285    #[tokio::test]
286    async fn test_async_prefetch_iteration() {
287        let batches = create_test_batches(5, 10);
288        let source = MemorySource::new(batches)
289            .ok()
290            .unwrap_or_else(|| panic!("Should create source"));
291
292        let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
293
294        let mut count = 0;
295        let mut total_rows = 0;
296        while let Some(result) = dataset.next().await {
297            let batch = result.ok().unwrap_or_else(|| panic!("Should get batch"));
298            count += 1;
299            total_rows += batch.num_rows();
300        }
301
302        assert_eq!(count, 5);
303        assert_eq!(total_rows, 50);
304    }
305
306    #[tokio::test]
307    async fn test_async_prefetch_try_next() {
308        let batches = create_test_batches(3, 10);
309        let source = MemorySource::new(batches)
310            .ok()
311            .unwrap_or_else(|| panic!("Should create source"));
312
313        let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 10);
314
315        // Yield to let background task run
316        tokio::task::yield_now().await;
317        tokio::task::yield_now().await;
318
319        // Should have some batches ready
320        let mut count = 0;
321        while dataset.try_next().is_some() {
322            count += 1;
323        }
324
325        assert!(count > 0, "Should have prefetched some batches");
326    }
327
328    #[tokio::test]
329    async fn test_async_prefetch_buffered_count() {
330        let batches = create_test_batches(10, 5);
331        let source = MemorySource::new(batches)
332            .ok()
333            .unwrap_or_else(|| panic!("Should create source"));
334
335        let dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
336
337        // Yield to let background task fill buffer
338        for _ in 0..10 {
339            tokio::task::yield_now().await;
340        }
341
342        // Buffer should have some items (up to prefetch_size)
343        let buffered = dataset.buffered_count();
344        assert!(buffered <= 4, "Should not exceed prefetch size");
345    }
346
347    #[tokio::test]
348    async fn test_async_prefetch_builder() {
349        let batches = create_test_batches(3, 10);
350        let source = MemorySource::new(batches)
351            .ok()
352            .unwrap_or_else(|| panic!("Should create source"));
353
354        let mut dataset = AsyncPrefetchBuilder::new()
355            .batch_size(10)
356            .prefetch_size(2)
357            .from_source(Box::new(source));
358
359        let mut count = 0;
360        while let Some(result) = dataset.next().await {
361            assert!(result.is_ok());
362            count += 1;
363        }
364        assert_eq!(count, 3);
365    }
366
367    #[tokio::test]
368    async fn test_async_prefetch_debug() {
369        let batches = create_test_batches(2, 5);
370        let source = MemorySource::new(batches)
371            .ok()
372            .unwrap_or_else(|| panic!("Should create source"));
373
374        let dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
375        let debug_str = format!("{:?}", dataset);
376        assert!(debug_str.contains("AsyncPrefetchDataset"));
377    }
378
379    #[tokio::test]
380    async fn test_async_prefetch_parquet_roundtrip() {
381        // Create test data
382        let batch = create_test_batches(1, 100)[0].clone();
383        let dataset = crate::ArrowDataset::from_batch(batch)
384            .ok()
385            .unwrap_or_else(|| panic!("Should create dataset"));
386
387        // Write to temp file
388        let temp_dir = tempfile::tempdir()
389            .ok()
390            .unwrap_or_else(|| panic!("Should create temp dir"));
391        let path = temp_dir.path().join("async_test.parquet");
392        dataset
393            .to_parquet(&path)
394            .ok()
395            .unwrap_or_else(|| panic!("Should write parquet"));
396
397        // Read back via async prefetch
398        let mut async_dataset = AsyncPrefetchDataset::from_parquet(&path, 25, 4)
399            .ok()
400            .unwrap_or_else(|| panic!("Should create async dataset"));
401
402        let mut total = 0;
403        while let Some(result) = async_dataset.next().await {
404            let batch = result.ok().unwrap_or_else(|| panic!("Should get batch"));
405            total += batch.num_rows();
406        }
407        assert_eq!(total, 100);
408    }
409
410    #[tokio::test]
411    async fn test_sync_prefetch_wrapper() {
412        let batches = create_test_batches(3, 10);
413        let source = MemorySource::new(batches)
414            .ok()
415            .unwrap_or_else(|| panic!("Should create source"));
416
417        let async_dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
418        let handle = tokio::runtime::Handle::current();
419        let sync_dataset = SyncPrefetchDataset::new(async_dataset, handle);
420
421        assert_eq!(sync_dataset.schema().fields().len(), 2);
422
423        let debug_str = format!("{:?}", sync_dataset);
424        assert!(debug_str.contains("SyncPrefetchDataset"));
425    }
426
427    #[tokio::test]
428    async fn test_builder_zero_batch_size_error() {
429        let result = AsyncPrefetchBuilder::new()
430            .batch_size(0)
431            .from_parquet("/nonexistent.parquet");
432
433        assert!(result.is_err());
434    }
435
436    #[tokio::test]
437    async fn test_builder_defaults() {
438        let batches = create_test_batches(2, 5);
439        let source = MemorySource::new(batches)
440            .ok()
441            .unwrap_or_else(|| panic!("Should create source"));
442
443        // Use default values
444        let dataset = AsyncPrefetchBuilder::new().from_source(Box::new(source));
445
446        assert_eq!(dataset.schema().fields().len(), 2);
447    }
448
449    #[tokio::test]
450    async fn test_async_prefetch_quick_exhaustion() {
451        // Test with source that quickly exhausts
452        struct QuickExhaustSource {
453            schema: SchemaRef,
454            exhausted: bool,
455        }
456
457        impl crate::streaming::DataSource for QuickExhaustSource {
458            fn schema(&self) -> SchemaRef {
459                Arc::clone(&self.schema)
460            }
461
462            fn next_batch(&mut self) -> crate::Result<Option<RecordBatch>> {
463                if self.exhausted {
464                    Ok(None)
465                } else {
466                    self.exhausted = true;
467                    Ok(Some(create_test_batches(1, 1)[0].clone()))
468                }
469            }
470        }
471
472        let source = QuickExhaustSource {
473            schema: create_test_batches(1, 1)[0].schema(),
474            exhausted: false,
475        };
476
477        let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
478
479        // First should succeed
480        let first = dataset.next().await;
481        assert!(first.is_some());
482        assert!(first.unwrap().is_ok());
483
484        // Second should be None (exhausted)
485        let second = dataset.next().await;
486        assert!(second.is_none());
487    }
488
489    #[tokio::test]
490    async fn test_async_prefetch_single_batch() {
491        let batches = create_test_batches(1, 100);
492        let source = MemorySource::new(batches)
493            .ok()
494            .unwrap_or_else(|| panic!("Should create source"));
495
496        let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
497
498        let batch = dataset
499            .next()
500            .await
501            .unwrap_or_else(|| panic!("Should have batch"))
502            .ok()
503            .unwrap_or_else(|| panic!("Batch should be ok"));
504        assert_eq!(batch.num_rows(), 100);
505
506        // No more batches
507        assert!(dataset.next().await.is_none());
508    }
509
510    #[tokio::test]
511    async fn test_async_prefetch_large_prefetch_size() {
512        // Prefetch size larger than available batches
513        let batches = create_test_batches(3, 10);
514        let source = MemorySource::new(batches)
515            .ok()
516            .unwrap_or_else(|| panic!("Should create source"));
517
518        let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 100);
519
520        let mut count = 0;
521        while let Some(result) = dataset.next().await {
522            assert!(result.is_ok());
523            count += 1;
524        }
525        assert_eq!(count, 3);
526    }
527
528    #[tokio::test]
529    async fn test_async_prefetch_prefetch_size_one() {
530        // Minimal prefetch
531        let batches = create_test_batches(5, 10);
532        let source = MemorySource::new(batches)
533            .ok()
534            .unwrap_or_else(|| panic!("Should create source"));
535
536        let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 1);
537
538        let mut count = 0;
539        while let Some(result) = dataset.next().await {
540            assert!(result.is_ok());
541            count += 1;
542        }
543        assert_eq!(count, 5);
544    }
545
546    #[tokio::test]
547    async fn test_async_prefetch_error_source() {
548        // Test with source that errors
549        struct ErrorSource {
550            schema: SchemaRef,
551            calls: usize,
552        }
553
554        impl crate::streaming::DataSource for ErrorSource {
555            fn schema(&self) -> SchemaRef {
556                Arc::clone(&self.schema)
557            }
558
559            fn next_batch(&mut self) -> crate::Result<Option<RecordBatch>> {
560                self.calls += 1;
561                if self.calls > 2 {
562                    Err(crate::Error::storage("Simulated error"))
563                } else {
564                    Ok(Some(create_test_batches(1, 5)[0].clone()))
565                }
566            }
567        }
568
569        let source = ErrorSource {
570            schema: create_test_batches(1, 1)[0].schema(),
571            calls: 0,
572        };
573
574        let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
575
576        // First two should succeed
577        let b1 = dataset.next().await;
578        assert!(b1.is_some());
579        assert!(b1.unwrap().is_ok());
580
581        let b2 = dataset.next().await;
582        assert!(b2.is_some());
583        assert!(b2.unwrap().is_ok());
584
585        // Third should be an error
586        let b3 = dataset.next().await;
587        assert!(b3.is_some());
588        assert!(b3.unwrap().is_err());
589    }
590
591    #[tokio::test]
592    async fn test_async_prefetch_try_next_after_exhaustion() {
593        // Create source with one batch
594        let batches = create_test_batches(1, 5);
595        let source = MemorySource::new(batches)
596            .ok()
597            .unwrap_or_else(|| panic!("Should create source"));
598
599        let mut dataset = AsyncPrefetchDataset::new(Box::new(source), 4);
600
601        // Consume the single batch
602        let _ = dataset.next().await;
603
604        // Allow background task to complete
605        tokio::task::yield_now().await;
606
607        // try_next should return None (exhausted)
608        let result = dataset.try_next();
609        assert!(result.is_none());
610    }
611
612    #[tokio::test]
613    async fn test_builder_with_prefetch_size() {
614        let batches = create_test_batches(5, 10);
615        let source = MemorySource::new(batches)
616            .ok()
617            .unwrap_or_else(|| panic!("Should create source"));
618
619        let mut dataset = AsyncPrefetchBuilder::new()
620            .prefetch_size(2)
621            .from_source(Box::new(source));
622
623        let mut count = 0;
624        while let Some(result) = dataset.next().await {
625            assert!(result.is_ok());
626            count += 1;
627        }
628        assert_eq!(count, 5);
629    }
630
631    #[tokio::test]
632    async fn test_builder_from_parquet_roundtrip() {
633        // Create test data
634        let batch = create_test_batches(1, 50)[0].clone();
635        let dataset = crate::ArrowDataset::from_batch(batch)
636            .ok()
637            .unwrap_or_else(|| panic!("Should create dataset"));
638
639        let temp_dir = tempfile::tempdir()
640            .ok()
641            .unwrap_or_else(|| panic!("Should create temp dir"));
642        let path = temp_dir.path().join("builder_test.parquet");
643        dataset
644            .to_parquet(&path)
645            .ok()
646            .unwrap_or_else(|| panic!("Should write parquet"));
647
648        // Read with builder
649        let mut async_dataset = AsyncPrefetchBuilder::new()
650            .batch_size(10)
651            .prefetch_size(3)
652            .from_parquet(&path)
653            .ok()
654            .unwrap_or_else(|| panic!("Should create async dataset"));
655
656        let mut total = 0;
657        while let Some(result) = async_dataset.next().await {
658            total += result.ok().unwrap().num_rows();
659        }
660        assert_eq!(total, 50);
661    }
662
663    #[test]
664    fn test_builder_debug() {
665        let builder = AsyncPrefetchBuilder::new().batch_size(32).prefetch_size(8);
666
667        let debug_str = format!("{:?}", builder);
668        assert!(debug_str.contains("AsyncPrefetchBuilder"));
669    }
670}