Skip to main content

lance_datafusion/
spill.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{
5    io::{BufReader, BufWriter},
6    path::PathBuf,
7    sync::{Arc, Mutex},
8};
9
10use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
11use arrow_array::RecordBatch;
12use arrow_schema::{ArrowError, Schema};
13use datafusion::{
14    execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter,
15};
16use datafusion_common::DataFusionError;
17use lance_arrow::memory::MemoryAccumulator;
18use lance_core::error::LanceOptionExt;
19
20/// Start a spill of Arrow data to a file that can be read later multiple times.
21///
22/// Up to `memory_limit` bytes of data can be buffered in memory before a spill
23/// is created. If the memory limit is never reached before [`SpillSender::finish()`]
24/// is called, then the data will simply be kept in memory and no spill will be
25/// created.
26///
27/// `path` is the path to the file that may be created. It should not already
28/// exist. It is the responsibility of the caller to delete the file after it is
29/// no longer needed.
30///
31/// The [`SpillSender`] allows you to write batches to the spill.
32///
33/// The [`SpillReceiver`] can open a [`SendableRecordBatchStream`] that reads
34/// batches from the spill. This can be opened before, during, or after batches
35/// have been written to the spill.
36///
37/// Once [`SpillSender`] is dropped, the temporary file is deleted. This will
38/// cause the [`SpillReceiver`] to return an error if it is still open.
39pub fn create_replay_spill(
40    path: std::path::PathBuf,
41    schema: Arc<Schema>,
42    memory_limit: usize,
43) -> (SpillSender, SpillReceiver) {
44    let initial_status = WriteStatus::default();
45    let (status_sender, status_receiver) = tokio::sync::watch::channel(initial_status);
46    let sender = SpillSender {
47        memory_limit,
48        path: path.clone(),
49        schema: schema.clone(),
50        state: SpillState::default(),
51        status_sender,
52    };
53
54    let receiver = SpillReceiver {
55        status_receiver,
56        path,
57        schema,
58    };
59
60    (sender, receiver)
61}
62
63#[derive(Clone)]
64pub struct SpillReceiver {
65    status_receiver: tokio::sync::watch::Receiver<WriteStatus>,
66    path: PathBuf,
67    schema: Arc<Schema>,
68}
69
70impl SpillReceiver {
71    /// Returns a stream of batches from the spill. The stream will emit
72    /// batches as they are written to the spill. If the spill has already
73    /// been finished, the stream will emit all batches in the spill.
74    ///
75    /// The stream will not complete until [`SpillSender::finish()`] is called.
76    ///
77    /// If the spill has been dropped, an error will be returned.
78    pub fn read(&self) -> SendableRecordBatchStream {
79        let rx = self.status_receiver.clone();
80        let reader = SpillReader::new(rx, self.path.clone());
81
82        let stream = futures::stream::try_unfold(reader, move |mut reader| async move {
83            match reader.read().await {
84                Ok(None) => Ok(None),
85                Ok(Some(batch)) => Ok(Some((batch, reader))),
86                Err(err) => Err(err),
87            }
88        });
89
90        Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
91    }
92}
93
94struct SpillReader {
95    pub batches_read: usize,
96    receiver: tokio::sync::watch::Receiver<WriteStatus>,
97    state: SpillReaderState,
98}
99
100enum SpillReaderState {
101    Buffered { spill_path: PathBuf },
102    Reader { reader: AsyncStreamReader },
103}
104
105impl SpillReader {
106    fn new(receiver: tokio::sync::watch::Receiver<WriteStatus>, spill_path: PathBuf) -> Self {
107        Self {
108            batches_read: 0,
109            receiver,
110            state: SpillReaderState::Buffered { spill_path },
111        }
112    }
113
114    async fn wait_for_more_data(&mut self) -> Result<Option<Arc<[RecordBatch]>>, DataFusionError> {
115        let status = self
116            .receiver
117            .wait_for(|status| {
118                status.error.is_some()
119                    || status.finished
120                    || status.batches_written() > self.batches_read
121            })
122            .await
123            .map_err(|_| {
124                DataFusionError::Execution(
125                    "Spill has been dropped before reader has finish.".into(),
126                )
127            })?;
128
129        if let Some(error) = &status.error {
130            let mut guard = error.lock().ok().expect_ok()?;
131            return Err(DataFusionError::from(&mut (*guard)));
132        }
133
134        if let DataLocation::Buffered { batches } = &status.data_location {
135            Ok(Some(batches.clone()))
136        } else {
137            Ok(None)
138        }
139    }
140
141    async fn get_reader(&mut self) -> Result<&AsyncStreamReader, ArrowError> {
142        if let SpillReaderState::Buffered { spill_path } = &self.state {
143            let reader = AsyncStreamReader::open(spill_path.clone()).await?;
144            // Skip batches we've already read before the writer started spilling.
145            // The read batches were spilled to the file for the benefit of
146            // future readers, as the spill is replay-able.
147            for _ in 0..self.batches_read {
148                reader.read().await?;
149            }
150            self.state = SpillReaderState::Reader { reader };
151        }
152
153        if let SpillReaderState::Reader { reader } = &mut self.state {
154            Ok(reader)
155        } else {
156            unreachable!()
157        }
158    }
159
160    async fn read(&mut self) -> Result<Option<RecordBatch>, DataFusionError> {
161        let maybe_data = self.wait_for_more_data().await?;
162
163        if let Some(batches) = maybe_data {
164            if self.batches_read < batches.len() {
165                let batch = batches[self.batches_read].clone();
166                self.batches_read += 1;
167                Ok(Some(batch))
168            } else {
169                Ok(None)
170            }
171        } else {
172            let reader = self.get_reader().await?;
173            let batch = reader.read().await?;
174            if batch.is_some() {
175                self.batches_read += 1;
176            }
177            Ok(batch)
178        }
179    }
180}
181
182/// The sender side of the spill. This is used to write batches to the spill.
183///
184/// Note: this must be kept alive until after the readers are done reading the
185/// spill. Otherwise, they will return an error.
186pub struct SpillSender {
187    memory_limit: usize,
188    schema: Arc<Schema>,
189    path: PathBuf,
190    state: SpillState,
191    status_sender: tokio::sync::watch::Sender<WriteStatus>,
192}
193
194enum SpillState {
195    Buffering {
196        batches: Vec<RecordBatch>,
197        memory_accumulator: MemoryAccumulator,
198    },
199    Spilling {
200        writer: AsyncStreamWriter,
201        batches_written: usize,
202    },
203    Finished {
204        batches: Option<Arc<[RecordBatch]>>,
205        batches_written: usize,
206    },
207    Errored {
208        error: Arc<Mutex<SpillError>>,
209    },
210}
211
212impl Default for SpillState {
213    fn default() -> Self {
214        Self::Buffering {
215            batches: Vec::new(),
216            memory_accumulator: MemoryAccumulator::default(),
217        }
218    }
219}
220
221#[derive(Clone, Debug, Default)]
222struct WriteStatus {
223    error: Option<Arc<Mutex<SpillError>>>,
224    finished: bool,
225    data_location: DataLocation,
226}
227
228impl WriteStatus {
229    fn batches_written(&self) -> usize {
230        match &self.data_location {
231            DataLocation::Buffered { batches } => batches.len(),
232            DataLocation::Spilled {
233                batches_written, ..
234            } => *batches_written,
235        }
236    }
237}
238
239#[derive(Clone, Debug)]
240enum DataLocation {
241    Buffered { batches: Arc<[RecordBatch]> },
242    Spilled { batches_written: usize },
243}
244
245impl Default for DataLocation {
246    fn default() -> Self {
247        Self::Buffered {
248            batches: Arc::new([]),
249        }
250    }
251}
252
253/// A DataFusion error that be be emitted multiple times. We provide the
254/// Original error first, and subsequent conversions provide a copy with a
255/// string representation of the original error.
256#[derive(Debug)]
257enum SpillError {
258    Original(DataFusionError),
259    Copy(DataFusionError),
260}
261
262impl From<DataFusionError> for SpillError {
263    fn from(err: DataFusionError) -> Self {
264        Self::Original(err)
265    }
266}
267
268impl From<&mut SpillError> for DataFusionError {
269    fn from(err: &mut SpillError) -> Self {
270        match err {
271            SpillError::Original(inner) => {
272                let copy = Self::Execution(inner.to_string());
273                let original = std::mem::replace(err, SpillError::Copy(copy));
274                if let SpillError::Original(inner) = original {
275                    inner
276                } else {
277                    unreachable!()
278                }
279            }
280            SpillError::Copy(Self::Execution(message)) => Self::Execution(message.clone()),
281            _ => unreachable!(),
282        }
283    }
284}
285
286impl From<&SpillState> for WriteStatus {
287    fn from(state: &SpillState) -> Self {
288        match state {
289            SpillState::Buffering { batches, .. } => Self {
290                finished: false,
291                data_location: DataLocation::Buffered {
292                    batches: batches.clone().into(),
293                },
294                error: None,
295            },
296            SpillState::Spilling {
297                batches_written, ..
298            } => Self {
299                finished: false,
300                data_location: DataLocation::Spilled {
301                    batches_written: *batches_written,
302                },
303                error: None,
304            },
305            SpillState::Finished {
306                batches_written,
307                batches,
308            } => {
309                let data_location = if let Some(batches) = batches {
310                    DataLocation::Buffered {
311                        batches: batches.clone(),
312                    }
313                } else {
314                    DataLocation::Spilled {
315                        batches_written: *batches_written,
316                    }
317                };
318                Self {
319                    finished: true,
320                    data_location,
321                    error: None,
322                }
323            }
324            SpillState::Errored { error } => Self {
325                finished: true,
326                data_location: DataLocation::default(), // Doesn't matter.
327                error: Some(error.clone()),
328            },
329        }
330    }
331}
332
333impl SpillSender {
334    /// Write a batch to the spill.  
335    ///  
336    /// If there is room in the `memory_limit` then the batch is queued.  
337    /// If `memory_limit` is first encountered then all queued batches, and this one,  
338    /// will be written to disk as part of this call.  
339    /// If we are already spilling then the batch will be written to disk as part of this  
340    /// call.
341    pub async fn write(&mut self, batch: RecordBatch) -> Result<(), DataFusionError> {
342        if let SpillState::Finished { .. } = self.state {
343            return Err(DataFusionError::Execution(
344                "Spill has already been finished".to_string(),
345            ));
346        }
347
348        if let SpillState::Errored { .. } = &self.state {
349            return Err(DataFusionError::Execution(
350                "Spill has sent an error".to_string(),
351            ));
352        }
353
354        let (writer, batches_written) = match &mut self.state {
355            SpillState::Buffering {
356                batches,
357                ref mut memory_accumulator,
358            } => {
359                memory_accumulator.record_batch(&batch);
360
361                if memory_accumulator.total() > self.memory_limit {
362                    let writer =
363                        AsyncStreamWriter::open(self.path.clone(), self.schema.clone()).await?;
364                    let batches_written = batches.len();
365                    for batch in batches.drain(..) {
366                        writer.write(batch).await?;
367                    }
368                    self.state = SpillState::Spilling {
369                        writer,
370                        batches_written,
371                    };
372                    if let SpillState::Spilling {
373                        writer,
374                        batches_written,
375                    } = &mut self.state
376                    {
377                        (writer, batches_written)
378                    } else {
379                        unreachable!()
380                    }
381                } else {
382                    batches.push(batch);
383                    self.status_sender
384                        .send_replace(WriteStatus::from(&self.state));
385                    return Ok(());
386                }
387            }
388            SpillState::Spilling {
389                writer,
390                batches_written,
391            } => (writer, batches_written),
392            _ => unreachable!(),
393        };
394
395        writer.write(batch).await?;
396        *batches_written += 1;
397        self.status_sender
398            .send_replace(WriteStatus::from(&self.state));
399
400        Ok(())
401    }
402
403    /// Send an error to the spill. This will be sent to all readers of the
404    /// spill.
405    pub fn send_error(&mut self, err: DataFusionError) {
406        let error = Arc::new(Mutex::new(err.into()));
407        self.state = SpillState::Errored { error };
408        self.status_sender
409            .send_replace(WriteStatus::from(&self.state));
410    }
411
412    /// Complete the spill write. This will finalize the Arrow IPC stream file.
413    /// The file will remain available for reading until the spill is dropped.
414    pub async fn finish(&mut self) -> Result<(), DataFusionError> {
415        // We create a temporary state to get an owned copy of current state.
416        // Since we hold an exclusive reference to `self`, no one should be
417        // able to see this temporary state.
418        let tmp_state = SpillState::Finished {
419            batches_written: 0,
420            batches: None,
421        };
422        match std::mem::replace(&mut self.state, tmp_state) {
423            SpillState::Buffering { batches, .. } => {
424                let batches_written = batches.len();
425                self.state = SpillState::Finished {
426                    batches_written,
427                    batches: Some(batches.into()),
428                };
429                self.status_sender
430                    .send_replace(WriteStatus::from(&self.state));
431            }
432            SpillState::Spilling {
433                writer,
434                batches_written,
435            } => {
436                writer.finish().await?;
437                self.state = SpillState::Finished {
438                    batches_written,
439                    batches: None,
440                };
441                self.status_sender
442                    .send_replace(WriteStatus::from(&self.state));
443            }
444            SpillState::Finished { .. } => {
445                return Err(DataFusionError::Execution(
446                    "Spill has already been finished".to_string(),
447                ));
448            }
449            SpillState::Errored { .. } => {
450                return Err(DataFusionError::Execution(
451                    "Spill has sent an error".to_string(),
452                ));
453            }
454        };
455
456        Ok(())
457    }
458}
459
460/// An async wrapper around [`StreamWriter`]. Each call uses [`tokio::task::spawn_blocking`]
461/// to spawn a blocking task to write the batch.
462struct AsyncStreamWriter {
463    writer: Arc<Mutex<StreamWriter<BufWriter<std::fs::File>>>>,
464}
465
466impl AsyncStreamWriter {
467    pub async fn open(path: PathBuf, schema: Arc<Schema>) -> Result<Self, ArrowError> {
468        let writer = tokio::task::spawn_blocking(move || {
469            let file = std::fs::File::create(&path).map_err(ArrowError::from)?;
470            let writer = BufWriter::new(file);
471            StreamWriter::try_new(writer, &schema)
472        })
473        .await
474        .unwrap()?;
475        let writer = Arc::new(Mutex::new(writer));
476        Ok(Self { writer })
477    }
478
479    pub async fn write(&self, batch: RecordBatch) -> Result<(), ArrowError> {
480        let writer = self.writer.clone();
481        tokio::task::spawn_blocking(move || {
482            let mut writer = writer.lock().unwrap();
483            writer.write(&batch)?;
484            writer.flush()
485        })
486        .await
487        .unwrap()
488    }
489
490    pub async fn finish(self) -> Result<(), ArrowError> {
491        let writer = self.writer.clone();
492        tokio::task::spawn_blocking(move || {
493            let mut writer = writer.lock().unwrap();
494            writer.finish()
495        })
496        .await
497        .unwrap()
498    }
499}
500
501struct AsyncStreamReader {
502    reader: Arc<Mutex<StreamReader<BufReader<std::fs::File>>>>,
503}
504
505impl AsyncStreamReader {
506    pub async fn open(path: PathBuf) -> Result<Self, ArrowError> {
507        let reader = tokio::task::spawn_blocking(move || {
508            let file = std::fs::File::open(&path).map_err(ArrowError::from)?;
509            let reader = BufReader::new(file);
510            StreamReader::try_new(reader, None)
511        })
512        .await
513        .unwrap()?;
514        let reader = Arc::new(Mutex::new(reader));
515        Ok(Self { reader })
516    }
517
518    pub async fn read(&self) -> Result<Option<RecordBatch>, ArrowError> {
519        let reader = self.reader.clone();
520        tokio::task::spawn_blocking(move || {
521            let mut reader = reader.lock().unwrap();
522            reader.next()
523        })
524        .await
525        .unwrap()
526        .transpose()
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use arrow_array::Int32Array;
533    use arrow_schema::{DataType, Field};
534    use futures::{poll, StreamExt, TryStreamExt};
535    use lance_core::utils::tempfile::{TempStdFile, TempStdPath};
536
537    use super::*;
538
539    #[tokio::test]
540    async fn test_spill() {
541        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
542        let batches = [
543            RecordBatch::try_new(
544                schema.clone(),
545                vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
546            )
547            .unwrap(),
548            RecordBatch::try_new(
549                schema.clone(),
550                vec![Arc::new(Int32Array::from(vec![4, 5, 6]))],
551            )
552            .unwrap(),
553        ];
554
555        // Create a stream
556        let path = TempStdFile::default();
557        let (mut spill, receiver) = create_replay_spill(path.to_owned(), schema.clone(), 0);
558
559        // We can open a reader prior to writing any data. No batches will be ready.
560        let mut stream_before = receiver.read();
561        let mut stream_before_next = stream_before.next();
562        let poll_res = poll!(&mut stream_before_next);
563        assert!(poll_res.is_pending());
564
565        // If we write a batch, the existing reader can now receive it.
566        spill.write(batches[0].clone()).await.unwrap();
567        let stream_before_batch1 = stream_before_next
568            .await
569            .expect("Expected a batch")
570            .expect("Expected no error");
571        assert_eq!(&stream_before_batch1, &batches[0]);
572        let mut stream_before_next = stream_before.next();
573        let poll_res = poll!(&mut stream_before_next);
574        assert!(poll_res.is_pending());
575
576        // We can also open a ready while the spill is being written to. We can
577        // retrieve batches written so far immediately.
578        let mut stream_during = receiver.read();
579        let stream_during_batch1 = stream_during
580            .next()
581            .await
582            .expect("Expected a batch")
583            .expect("Expected no error");
584        assert_eq!(&stream_during_batch1, &batches[0]);
585        let mut stream_during_next = stream_during.next();
586        let poll_res = poll!(&mut stream_during_next);
587        assert!(poll_res.is_pending());
588
589        // Once we finish the spill, readers can get remaining batches and will
590        // reach the end of the stream.
591        spill.write(batches[1].clone()).await.unwrap();
592        spill.finish().await.unwrap();
593
594        let stream_before_batch2 = stream_before_next
595            .await
596            .expect("Expected a batch")
597            .expect("Expected no error");
598        assert_eq!(&stream_before_batch2, &batches[1]);
599        assert!(stream_before.next().await.is_none());
600
601        let stream_during_batch2 = stream_during_next
602            .await
603            .expect("Expected a batch")
604            .expect("Expected no error");
605        assert_eq!(&stream_during_batch2, &batches[1]);
606        assert!(stream_during.next().await.is_none());
607
608        // Can also start a reader after finishing.
609        let stream_after = receiver.read();
610        let stream_after_batches = stream_after.try_collect::<Vec<_>>().await.unwrap();
611        assert_eq!(&stream_after_batches, &batches);
612
613        std::fs::remove_file(path).unwrap();
614    }
615
616    #[tokio::test]
617    async fn test_spill_error() {
618        // Create a spill
619        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
620        let path = TempStdFile::default();
621        let (mut spill, receiver) =
622            create_replay_spill(path.as_ref().to_owned(), schema.clone(), 0);
623        let batch = RecordBatch::try_new(
624            schema.clone(),
625            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
626        )
627        .unwrap();
628
629        spill.write(batch.clone()).await.unwrap();
630
631        let mut stream = receiver.read();
632        let stream_batch = stream
633            .next()
634            .await
635            .expect("Expected a batch")
636            .expect("Expected no error");
637        assert_eq!(&stream_batch, &batch);
638
639        spill.send_error(DataFusionError::ResourcesExhausted("🥱".into()));
640        let stream_error = stream
641            .next()
642            .await
643            .expect("Expected an error")
644            .expect_err("Expected an error");
645        assert!(matches!(
646            stream_error,
647            DataFusionError::ResourcesExhausted(message) if message == "🥱"
648        ));
649
650        // If we try to write after sending an error, it should return an error.
651        let err = spill.write(batch).await;
652        assert!(matches!(
653            err,
654            Err(DataFusionError::Execution(message)) if message == "Spill has sent an error"
655        ));
656
657        // If we try to finish after sending an error, it should return an error.
658        let err = spill.finish().await;
659        assert!(matches!(
660            err,
661            Err(DataFusionError::Execution(message)) if message == "Spill has sent an error"
662        ));
663
664        // If we try to read after sending an error, it should return an error.
665        let mut stream = receiver.read();
666        let stream_error = stream
667            .next()
668            .await
669            .expect("Expected an error")
670            .expect_err("Expected an error");
671        assert!(matches!(
672            stream_error,
673            DataFusionError::Execution(message) if message.contains("🥱")
674        ));
675
676        std::fs::remove_file(path).unwrap();
677    }
678
679    #[tokio::test]
680    async fn test_spill_buffered() {
681        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
682        let path = TempStdPath::default();
683        let memory_limit = 1024 * 1024; // 1 MiB
684        let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit);
685
686        // 0.5 MB batch
687        let batch = RecordBatch::try_new(
688            schema.clone(),
689            vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))],
690        )
691        .unwrap();
692        spill.write(batch.clone()).await.unwrap();
693        assert!(!std::fs::exists(&path).unwrap());
694
695        spill.finish().await.unwrap();
696        assert!(!std::fs::exists(&path).unwrap());
697
698        let mut stream = receiver.read();
699        let stream_batch = stream
700            .next()
701            .await
702            .expect("Expected a batch")
703            .expect("Expected no error");
704        assert_eq!(&stream_batch, &batch);
705
706        assert!(!std::fs::exists(&path).unwrap());
707    }
708
709    #[tokio::test]
710    async fn test_spill_buffered_transition() {
711        // Starts as buffered, then spills, then finished.
712        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
713        let path = TempStdPath::default();
714        let memory_limit = 1024 * 1024; // 1 MiB
715        let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit);
716
717        // 0.7 MB batch
718        let batch = RecordBatch::try_new(
719            schema.clone(),
720            vec![Arc::new(Int32Array::from(vec![1; (768 * 1024) / 4]))],
721        )
722        .unwrap();
723        spill.write(batch.clone()).await.unwrap();
724        assert!(!std::fs::exists(&path).unwrap());
725
726        let mut stream = receiver.read();
727        let stream_batch = stream
728            .next()
729            .await
730            .expect("Expected a batch")
731            .expect("Expected no error");
732        assert_eq!(&stream_batch, &batch);
733        assert!(!std::fs::exists(&path).unwrap());
734
735        // 0.5 MB batch
736        let batch = RecordBatch::try_new(
737            schema.clone(),
738            vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))],
739        )
740        .unwrap();
741        spill.write(batch.clone()).await.unwrap();
742        assert!(std::fs::exists(&path).unwrap());
743
744        let stream_batch = stream
745            .next()
746            .await
747            .expect("Expected a batch")
748            .expect("Expected no error");
749        assert_eq!(&stream_batch, &batch);
750        assert!(std::fs::exists(&path).unwrap());
751
752        spill.finish().await.unwrap();
753
754        assert!(stream.next().await.is_none());
755
756        std::fs::remove_file(path).unwrap();
757    }
758}