lance_datafusion/
spill.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{
5    io::{BufReader, BufWriter},
6    path::PathBuf,
7    sync::{Arc, Mutex},
8};
9
10use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
11use arrow_array::RecordBatch;
12use arrow_schema::{ArrowError, Schema};
13use datafusion::{
14    execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter,
15};
16use datafusion_common::DataFusionError;
17use lance_arrow::memory::MemoryAccumulator;
18use lance_core::error::LanceOptionExt;
19
20/// Start a spill of Arrow data to a file that can be read later multiple times.
21///
22/// Up to `memory_limit` bytes of data can be buffered in memory before a spill
23/// is created. If the memory limit is never reached before [`SpillSender::finish()`]
24/// is called, then the data will simply be kept in memory and no spill will be
25/// created.
26///
27/// `path` is the path to the file that may be created. It should not already
28/// exist. It is the responsibility of the caller to delete the file after it is
29/// no longer needed.
30///
31/// The [`SpillSender`] allows you to write batches to the spill.
32///
33/// The [`SpillReceiver`] can open a [`SendableRecordBatchStream`] that reads
34/// batches from the spill. This can be opened before, during, or after batches
35/// have been written to the spill.
36///
37/// Once [`SpillSender`] is dropped, the temporary file is deleted. This will
38/// cause the [`SpillReceiver`] to return an error if it is still open.
39pub fn create_replay_spill(
40    path: std::path::PathBuf,
41    schema: Arc<Schema>,
42    memory_limit: usize,
43) -> (SpillSender, SpillReceiver) {
44    let initial_status = WriteStatus::default();
45    let (status_sender, status_receiver) = tokio::sync::watch::channel(initial_status);
46    let sender = SpillSender {
47        memory_limit,
48        path: path.clone(),
49        schema: schema.clone(),
50        state: SpillState::default(),
51        status_sender,
52    };
53
54    let receiver = SpillReceiver {
55        status_receiver,
56        path,
57        schema,
58    };
59
60    (sender, receiver)
61}
62
63#[derive(Clone)]
64pub struct SpillReceiver {
65    status_receiver: tokio::sync::watch::Receiver<WriteStatus>,
66    path: PathBuf,
67    schema: Arc<Schema>,
68}
69
70impl SpillReceiver {
71    /// Returns a stream of batches from the spill. The stream will emit
72    /// batches as they are written to the spill. If the spill has already
73    /// been finished, the stream will emit all batches in the spill.
74    ///
75    /// The stream will not complete until [`Self::finish()`] is called.
76    ///
77    /// If the spill has been dropped, an error will be returned.
78    pub fn read(&self) -> SendableRecordBatchStream {
79        let rx = self.status_receiver.clone();
80        let reader = SpillReader::new(rx, self.path.clone());
81
82        let stream = futures::stream::try_unfold(reader, move |mut reader| async move {
83            match reader.read().await {
84                Ok(None) => Ok(None),
85                Ok(Some(batch)) => Ok(Some((batch, reader))),
86                Err(err) => Err(err),
87            }
88        });
89
90        Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
91    }
92}
93
94struct SpillReader {
95    pub batches_read: usize,
96    receiver: tokio::sync::watch::Receiver<WriteStatus>,
97    state: SpillReaderState,
98}
99
100enum SpillReaderState {
101    Buffered { spill_path: PathBuf },
102    Reader { reader: AsyncStreamReader },
103}
104
105impl SpillReader {
106    fn new(receiver: tokio::sync::watch::Receiver<WriteStatus>, spill_path: PathBuf) -> Self {
107        Self {
108            batches_read: 0,
109            receiver,
110            state: SpillReaderState::Buffered { spill_path },
111        }
112    }
113
114    async fn wait_for_more_data(&mut self) -> Result<Option<Arc<[RecordBatch]>>, DataFusionError> {
115        let status = self
116            .receiver
117            .wait_for(|status| {
118                status.error.is_some()
119                    || status.finished
120                    || status.batches_written() > self.batches_read
121            })
122            .await
123            .map_err(|_| {
124                DataFusionError::Execution(
125                    "Spill has been dropped before reader has finish.".into(),
126                )
127            })?;
128
129        if let Some(error) = &status.error {
130            let mut guard = error.lock().ok().expect_ok()?;
131            return Err(DataFusionError::from(&mut (*guard)));
132        }
133
134        if let DataLocation::Buffered { batches } = &status.data_location {
135            Ok(Some(batches.clone()))
136        } else {
137            Ok(None)
138        }
139    }
140
141    async fn get_reader(&mut self) -> Result<&AsyncStreamReader, ArrowError> {
142        if let SpillReaderState::Buffered { spill_path } = &self.state {
143            let reader = AsyncStreamReader::open(spill_path.clone()).await?;
144            // Skip batches we've already read before the writer started spilling.
145            // The read batches were spilled to the file for the benefit of
146            // future readers, as the spill is replay-able.
147            for _ in 0..self.batches_read {
148                reader.read().await?;
149            }
150            self.state = SpillReaderState::Reader { reader };
151        }
152
153        if let SpillReaderState::Reader { reader } = &mut self.state {
154            Ok(reader)
155        } else {
156            unreachable!()
157        }
158    }
159
160    async fn read(&mut self) -> Result<Option<RecordBatch>, DataFusionError> {
161        let maybe_data = self.wait_for_more_data().await?;
162
163        if let Some(batches) = maybe_data {
164            if self.batches_read < batches.len() {
165                let batch = batches[self.batches_read].clone();
166                self.batches_read += 1;
167                Ok(Some(batch))
168            } else {
169                Ok(None)
170            }
171        } else {
172            let reader = self.get_reader().await?;
173            let batch = reader.read().await?;
174            if batch.is_some() {
175                self.batches_read += 1;
176            }
177            Ok(batch)
178        }
179    }
180}
181
182/// The sender side of the spill. This is used to write batches to the spill.
183///
184/// Note: this must be kept alive until after the readers are done reading the
185/// spill. Otherwise, they will return an error.
186pub struct SpillSender {
187    memory_limit: usize,
188    schema: Arc<Schema>,
189    path: PathBuf,
190    state: SpillState,
191    status_sender: tokio::sync::watch::Sender<WriteStatus>,
192}
193
194enum SpillState {
195    Buffering {
196        batches: Vec<RecordBatch>,
197        memory_accumulator: MemoryAccumulator,
198    },
199    Spilling {
200        writer: AsyncStreamWriter,
201        batches_written: usize,
202    },
203    Finished {
204        batches: Option<Arc<[RecordBatch]>>,
205        batches_written: usize,
206    },
207    Errored {
208        error: Arc<Mutex<SpillError>>,
209    },
210}
211
212impl Default for SpillState {
213    fn default() -> Self {
214        Self::Buffering {
215            batches: Vec::new(),
216            memory_accumulator: MemoryAccumulator::default(),
217        }
218    }
219}
220
221#[derive(Clone, Debug, Default)]
222struct WriteStatus {
223    error: Option<Arc<Mutex<SpillError>>>,
224    finished: bool,
225    data_location: DataLocation,
226}
227
228impl WriteStatus {
229    fn batches_written(&self) -> usize {
230        match &self.data_location {
231            DataLocation::Buffered { batches } => batches.len(),
232            DataLocation::Spilled {
233                batches_written, ..
234            } => *batches_written,
235        }
236    }
237}
238
239#[derive(Clone, Debug)]
240enum DataLocation {
241    Buffered { batches: Arc<[RecordBatch]> },
242    Spilled { batches_written: usize },
243}
244
245impl Default for DataLocation {
246    fn default() -> Self {
247        Self::Buffered {
248            batches: Arc::new([]),
249        }
250    }
251}
252
253/// A DataFusion error that be be emitted multiple times. We provide the
254/// Original error first, and subsequent conversions provide a copy with a
255/// string representation of the original error.
256#[derive(Debug)]
257enum SpillError {
258    Original(DataFusionError),
259    Copy(DataFusionError),
260}
261
262impl From<DataFusionError> for SpillError {
263    fn from(err: DataFusionError) -> Self {
264        Self::Original(err)
265    }
266}
267
268impl From<&mut SpillError> for DataFusionError {
269    fn from(err: &mut SpillError) -> Self {
270        match err {
271            SpillError::Original(inner) => {
272                let copy = Self::Execution(inner.to_string());
273                let original = std::mem::replace(err, SpillError::Copy(copy));
274                if let SpillError::Original(inner) = original {
275                    inner
276                } else {
277                    unreachable!()
278                }
279            }
280            SpillError::Copy(Self::Execution(message)) => Self::Execution(message.clone()),
281            _ => unreachable!(),
282        }
283    }
284}
285
286impl From<&SpillState> for WriteStatus {
287    fn from(state: &SpillState) -> Self {
288        match state {
289            SpillState::Buffering { batches, .. } => Self {
290                finished: false,
291                data_location: DataLocation::Buffered {
292                    batches: batches.clone().into(),
293                },
294                error: None,
295            },
296            SpillState::Spilling {
297                batches_written, ..
298            } => Self {
299                finished: false,
300                data_location: DataLocation::Spilled {
301                    batches_written: *batches_written,
302                },
303                error: None,
304            },
305            SpillState::Finished {
306                batches_written,
307                batches,
308            } => {
309                let data_location = if let Some(batches) = batches {
310                    DataLocation::Buffered {
311                        batches: batches.clone(),
312                    }
313                } else {
314                    DataLocation::Spilled {
315                        batches_written: *batches_written,
316                    }
317                };
318                Self {
319                    finished: true,
320                    data_location,
321                    error: None,
322                }
323            }
324            SpillState::Errored { error } => Self {
325                finished: true,
326                data_location: DataLocation::default(), // Doesn't matter.
327                error: Some(error.clone()),
328            },
329        }
330    }
331}
332
333impl SpillSender {
334    /// Write a batch to the spill.  
335    ///  
336    /// If there is room in the `memory_limit` then the batch is queued.  
337    /// If `memory_limit` is first encountered then all queued batches, and this one,  
338    /// will be written to disk as part of this call.  
339    /// If we are already spilling then the batch will be written to disk as part of this  
340    /// call.
341    pub async fn write(&mut self, batch: RecordBatch) -> Result<(), DataFusionError> {
342        if let SpillState::Finished { .. } = self.state {
343            return Err(DataFusionError::Execution(
344                "Spill has already been finished".to_string(),
345            ));
346        }
347
348        if let SpillState::Errored { .. } = &self.state {
349            return Err(DataFusionError::Execution(
350                "Spill has sent an error".to_string(),
351            ));
352        }
353
354        let (writer, batches_written) = match &mut self.state {
355            SpillState::Buffering {
356                batches,
357                ref mut memory_accumulator,
358            } => {
359                memory_accumulator.record_batch(&batch);
360
361                if memory_accumulator.total() > self.memory_limit {
362                    let writer =
363                        AsyncStreamWriter::open(self.path.clone(), self.schema.clone()).await?;
364                    let batches_written = batches.len();
365                    for batch in batches.drain(..) {
366                        writer.write(batch).await?;
367                    }
368                    self.state = SpillState::Spilling {
369                        writer,
370                        batches_written,
371                    };
372                    if let SpillState::Spilling {
373                        writer,
374                        batches_written,
375                    } = &mut self.state
376                    {
377                        (writer, batches_written)
378                    } else {
379                        unreachable!()
380                    }
381                } else {
382                    batches.push(batch);
383                    self.status_sender
384                        .send_replace(WriteStatus::from(&self.state));
385                    return Ok(());
386                }
387            }
388            SpillState::Spilling {
389                writer,
390                batches_written,
391            } => (writer, batches_written),
392            _ => unreachable!(),
393        };
394
395        writer.write(batch).await?;
396        *batches_written += 1;
397        self.status_sender
398            .send_replace(WriteStatus::from(&self.state));
399
400        Ok(())
401    }
402
403    /// Send an error to the spill. This will be sent to all readers of the
404    /// spill.
405    pub fn send_error(&mut self, err: DataFusionError) {
406        let error = Arc::new(Mutex::new(err.into()));
407        self.state = SpillState::Errored { error };
408        self.status_sender
409            .send_replace(WriteStatus::from(&self.state));
410    }
411
412    /// Complete the spill write. This will finalize the Arrow IPC stream file.
413    /// The file will remain available for reading until [`Self::shutdown()`]
414    /// or until the spill is dropped.
415    pub async fn finish(&mut self) -> Result<(), DataFusionError> {
416        // We create a temporary state to get an owned copy of current state.
417        // Since we hold an exclusive reference to `self`, no one should be
418        // able to see this temporary state.
419        let tmp_state = SpillState::Finished {
420            batches_written: 0,
421            batches: None,
422        };
423        match std::mem::replace(&mut self.state, tmp_state) {
424            SpillState::Buffering { batches, .. } => {
425                let batches_written = batches.len();
426                self.state = SpillState::Finished {
427                    batches_written,
428                    batches: Some(batches.into()),
429                };
430                self.status_sender
431                    .send_replace(WriteStatus::from(&self.state));
432            }
433            SpillState::Spilling {
434                writer,
435                batches_written,
436            } => {
437                writer.finish().await?;
438                self.state = SpillState::Finished {
439                    batches_written,
440                    batches: None,
441                };
442                self.status_sender
443                    .send_replace(WriteStatus::from(&self.state));
444            }
445            SpillState::Finished { .. } => {
446                return Err(DataFusionError::Execution(
447                    "Spill has already been finished".to_string(),
448                ));
449            }
450            SpillState::Errored { .. } => {
451                return Err(DataFusionError::Execution(
452                    "Spill has sent an error".to_string(),
453                ));
454            }
455        };
456
457        Ok(())
458    }
459}
460
461/// An async wrapper around [`StreamWriter`]. Each call uses [`tokio::task::spawn_blocking`]
462/// to spawn a blocking task to write the batch.
463struct AsyncStreamWriter {
464    writer: Arc<Mutex<StreamWriter<BufWriter<std::fs::File>>>>,
465}
466
467impl AsyncStreamWriter {
468    pub async fn open(path: PathBuf, schema: Arc<Schema>) -> Result<Self, ArrowError> {
469        let writer = tokio::task::spawn_blocking(move || {
470            let file = std::fs::File::create(&path).map_err(ArrowError::from)?;
471            let writer = BufWriter::new(file);
472            StreamWriter::try_new(writer, &schema)
473        })
474        .await
475        .unwrap()?;
476        let writer = Arc::new(Mutex::new(writer));
477        Ok(Self { writer })
478    }
479
480    pub async fn write(&self, batch: RecordBatch) -> Result<(), ArrowError> {
481        let writer = self.writer.clone();
482        tokio::task::spawn_blocking(move || {
483            let mut writer = writer.lock().unwrap();
484            writer.write(&batch)?;
485            writer.flush()
486        })
487        .await
488        .unwrap()
489    }
490
491    pub async fn finish(self) -> Result<(), ArrowError> {
492        let writer = self.writer.clone();
493        tokio::task::spawn_blocking(move || {
494            let mut writer = writer.lock().unwrap();
495            writer.finish()
496        })
497        .await
498        .unwrap()
499    }
500}
501
502struct AsyncStreamReader {
503    reader: Arc<Mutex<StreamReader<BufReader<std::fs::File>>>>,
504}
505
506impl AsyncStreamReader {
507    pub async fn open(path: PathBuf) -> Result<Self, ArrowError> {
508        let reader = tokio::task::spawn_blocking(move || {
509            let file = std::fs::File::open(&path).map_err(ArrowError::from)?;
510            let reader = BufReader::new(file);
511            StreamReader::try_new(reader, None)
512        })
513        .await
514        .unwrap()?;
515        let reader = Arc::new(Mutex::new(reader));
516        Ok(Self { reader })
517    }
518
519    pub async fn read(&self) -> Result<Option<RecordBatch>, ArrowError> {
520        let reader = self.reader.clone();
521        tokio::task::spawn_blocking(move || {
522            let mut reader = reader.lock().unwrap();
523            reader.next()
524        })
525        .await
526        .unwrap()
527        .transpose()
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use arrow_array::Int32Array;
534    use arrow_schema::{DataType, Field};
535    use futures::{poll, StreamExt, TryStreamExt};
536    use lance_core::utils::tempfile::{TempStdFile, TempStdPath};
537
538    use super::*;
539
540    #[tokio::test]
541    async fn test_spill() {
542        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
543        let batches = [
544            RecordBatch::try_new(
545                schema.clone(),
546                vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
547            )
548            .unwrap(),
549            RecordBatch::try_new(
550                schema.clone(),
551                vec![Arc::new(Int32Array::from(vec![4, 5, 6]))],
552            )
553            .unwrap(),
554        ];
555
556        // Create a stream
557        let path = TempStdFile::default();
558        let (mut spill, receiver) = create_replay_spill(path.to_owned(), schema.clone(), 0);
559
560        // We can open a reader prior to writing any data. No batches will be ready.
561        let mut stream_before = receiver.read();
562        let mut stream_before_next = stream_before.next();
563        let poll_res = poll!(&mut stream_before_next);
564        assert!(poll_res.is_pending());
565
566        // If we write a batch, the existing reader can now receive it.
567        spill.write(batches[0].clone()).await.unwrap();
568        let stream_before_batch1 = stream_before_next
569            .await
570            .expect("Expected a batch")
571            .expect("Expected no error");
572        assert_eq!(&stream_before_batch1, &batches[0]);
573        let mut stream_before_next = stream_before.next();
574        let poll_res = poll!(&mut stream_before_next);
575        assert!(poll_res.is_pending());
576
577        // We can also open a ready while the spill is being written to. We can
578        // retrieve batches written so far immediately.
579        let mut stream_during = receiver.read();
580        let stream_during_batch1 = stream_during
581            .next()
582            .await
583            .expect("Expected a batch")
584            .expect("Expected no error");
585        assert_eq!(&stream_during_batch1, &batches[0]);
586        let mut stream_during_next = stream_during.next();
587        let poll_res = poll!(&mut stream_during_next);
588        assert!(poll_res.is_pending());
589
590        // Once we finish the spill, readers can get remaining batches and will
591        // reach the end of the stream.
592        spill.write(batches[1].clone()).await.unwrap();
593        spill.finish().await.unwrap();
594
595        let stream_before_batch2 = stream_before_next
596            .await
597            .expect("Expected a batch")
598            .expect("Expected no error");
599        assert_eq!(&stream_before_batch2, &batches[1]);
600        assert!(stream_before.next().await.is_none());
601
602        let stream_during_batch2 = stream_during_next
603            .await
604            .expect("Expected a batch")
605            .expect("Expected no error");
606        assert_eq!(&stream_during_batch2, &batches[1]);
607        assert!(stream_during.next().await.is_none());
608
609        // Can also start a reader after finishing.
610        let stream_after = receiver.read();
611        let stream_after_batches = stream_after.try_collect::<Vec<_>>().await.unwrap();
612        assert_eq!(&stream_after_batches, &batches);
613
614        std::fs::remove_file(path).unwrap();
615    }
616
617    #[tokio::test]
618    async fn test_spill_error() {
619        // Create a spill
620        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
621        let path = TempStdFile::default();
622        let (mut spill, receiver) =
623            create_replay_spill(path.as_ref().to_owned(), schema.clone(), 0);
624        let batch = RecordBatch::try_new(
625            schema.clone(),
626            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
627        )
628        .unwrap();
629
630        spill.write(batch.clone()).await.unwrap();
631
632        let mut stream = receiver.read();
633        let stream_batch = stream
634            .next()
635            .await
636            .expect("Expected a batch")
637            .expect("Expected no error");
638        assert_eq!(&stream_batch, &batch);
639
640        spill.send_error(DataFusionError::ResourcesExhausted("🥱".into()));
641        let stream_error = stream
642            .next()
643            .await
644            .expect("Expected an error")
645            .expect_err("Expected an error");
646        assert!(matches!(
647            stream_error,
648            DataFusionError::ResourcesExhausted(message) if message == "🥱"
649        ));
650
651        // If we try to write after sending an error, it should return an error.
652        let err = spill.write(batch).await;
653        assert!(matches!(
654            err,
655            Err(DataFusionError::Execution(message)) if message == "Spill has sent an error"
656        ));
657
658        // If we try to finish after sending an error, it should return an error.
659        let err = spill.finish().await;
660        assert!(matches!(
661            err,
662            Err(DataFusionError::Execution(message)) if message == "Spill has sent an error"
663        ));
664
665        // If we try to read after sending an error, it should return an error.
666        let mut stream = receiver.read();
667        let stream_error = stream
668            .next()
669            .await
670            .expect("Expected an error")
671            .expect_err("Expected an error");
672        assert!(matches!(
673            stream_error,
674            DataFusionError::Execution(message) if message.contains("🥱")
675        ));
676
677        std::fs::remove_file(path).unwrap();
678    }
679
680    #[tokio::test]
681    async fn test_spill_buffered() {
682        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
683        let path = TempStdPath::default();
684        let memory_limit = 1024 * 1024; // 1 MiB
685        let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit);
686
687        // 0.5 MB batch
688        let batch = RecordBatch::try_new(
689            schema.clone(),
690            vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))],
691        )
692        .unwrap();
693        spill.write(batch.clone()).await.unwrap();
694        assert!(!std::fs::exists(&path).unwrap());
695
696        spill.finish().await.unwrap();
697        assert!(!std::fs::exists(&path).unwrap());
698
699        let mut stream = receiver.read();
700        let stream_batch = stream
701            .next()
702            .await
703            .expect("Expected a batch")
704            .expect("Expected no error");
705        assert_eq!(&stream_batch, &batch);
706
707        assert!(!std::fs::exists(&path).unwrap());
708    }
709
710    #[tokio::test]
711    async fn test_spill_buffered_transition() {
712        // Starts as buffered, then spills, then finished.
713        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
714        let path = TempStdPath::default();
715        let memory_limit = 1024 * 1024; // 1 MiB
716        let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit);
717
718        // 0.7 MB batch
719        let batch = RecordBatch::try_new(
720            schema.clone(),
721            vec![Arc::new(Int32Array::from(vec![1; (768 * 1024) / 4]))],
722        )
723        .unwrap();
724        spill.write(batch.clone()).await.unwrap();
725        assert!(!std::fs::exists(&path).unwrap());
726
727        let mut stream = receiver.read();
728        let stream_batch = stream
729            .next()
730            .await
731            .expect("Expected a batch")
732            .expect("Expected no error");
733        assert_eq!(&stream_batch, &batch);
734        assert!(!std::fs::exists(&path).unwrap());
735
736        // 0.5 MB batch
737        let batch = RecordBatch::try_new(
738            schema.clone(),
739            vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))],
740        )
741        .unwrap();
742        spill.write(batch.clone()).await.unwrap();
743        assert!(std::fs::exists(&path).unwrap());
744
745        let stream_batch = stream
746            .next()
747            .await
748            .expect("Expected a batch")
749            .expect("Expected no error");
750        assert_eq!(&stream_batch, &batch);
751        assert!(std::fs::exists(&path).unwrap());
752
753        spill.finish().await.unwrap();
754
755        assert!(stream.next().await.is_none());
756
757        std::fs::remove_file(path).unwrap();
758    }
759}