1use std::{
5 io::{BufReader, BufWriter},
6 path::PathBuf,
7 sync::{Arc, Mutex},
8};
9
10use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
11use arrow_array::RecordBatch;
12use arrow_schema::{ArrowError, Schema};
13use datafusion::{
14 execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter,
15};
16use datafusion_common::DataFusionError;
17use lance_arrow::memory::MemoryAccumulator;
18use lance_core::error::LanceOptionExt;
19
20pub fn create_replay_spill(
40 path: std::path::PathBuf,
41 schema: Arc<Schema>,
42 memory_limit: usize,
43) -> (SpillSender, SpillReceiver) {
44 let initial_status = WriteStatus::default();
45 let (status_sender, status_receiver) = tokio::sync::watch::channel(initial_status);
46 let sender = SpillSender {
47 memory_limit,
48 path: path.clone(),
49 schema: schema.clone(),
50 state: SpillState::default(),
51 status_sender,
52 };
53
54 let receiver = SpillReceiver {
55 status_receiver,
56 path,
57 schema,
58 };
59
60 (sender, receiver)
61}
62
63#[derive(Clone)]
64pub struct SpillReceiver {
65 status_receiver: tokio::sync::watch::Receiver<WriteStatus>,
66 path: PathBuf,
67 schema: Arc<Schema>,
68}
69
70impl SpillReceiver {
71 pub fn read(&self) -> SendableRecordBatchStream {
79 let rx = self.status_receiver.clone();
80 let reader = SpillReader::new(rx, self.path.clone());
81
82 let stream = futures::stream::try_unfold(reader, move |mut reader| async move {
83 match reader.read().await {
84 Ok(None) => Ok(None),
85 Ok(Some(batch)) => Ok(Some((batch, reader))),
86 Err(err) => Err(err),
87 }
88 });
89
90 Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
91 }
92}
93
94struct SpillReader {
95 pub batches_read: usize,
96 receiver: tokio::sync::watch::Receiver<WriteStatus>,
97 state: SpillReaderState,
98}
99
100enum SpillReaderState {
101 Buffered { spill_path: PathBuf },
102 Reader { reader: AsyncStreamReader },
103}
104
105impl SpillReader {
106 fn new(receiver: tokio::sync::watch::Receiver<WriteStatus>, spill_path: PathBuf) -> Self {
107 Self {
108 batches_read: 0,
109 receiver,
110 state: SpillReaderState::Buffered { spill_path },
111 }
112 }
113
114 async fn wait_for_more_data(&mut self) -> Result<Option<Arc<[RecordBatch]>>, DataFusionError> {
115 let status = self
116 .receiver
117 .wait_for(|status| {
118 status.error.is_some()
119 || status.finished
120 || status.batches_written() > self.batches_read
121 })
122 .await
123 .map_err(|_| {
124 DataFusionError::Execution(
125 "Spill has been dropped before reader has finish.".into(),
126 )
127 })?;
128
129 if let Some(error) = &status.error {
130 let mut guard = error.lock().ok().expect_ok()?;
131 return Err(DataFusionError::from(&mut (*guard)));
132 }
133
134 if let DataLocation::Buffered { batches } = &status.data_location {
135 Ok(Some(batches.clone()))
136 } else {
137 Ok(None)
138 }
139 }
140
141 async fn get_reader(&mut self) -> Result<&AsyncStreamReader, ArrowError> {
142 if let SpillReaderState::Buffered { spill_path } = &self.state {
143 let reader = AsyncStreamReader::open(spill_path.clone()).await?;
144 for _ in 0..self.batches_read {
148 reader.read().await?;
149 }
150 self.state = SpillReaderState::Reader { reader };
151 }
152
153 if let SpillReaderState::Reader { reader } = &mut self.state {
154 Ok(reader)
155 } else {
156 unreachable!()
157 }
158 }
159
160 async fn read(&mut self) -> Result<Option<RecordBatch>, DataFusionError> {
161 let maybe_data = self.wait_for_more_data().await?;
162
163 if let Some(batches) = maybe_data {
164 if self.batches_read < batches.len() {
165 let batch = batches[self.batches_read].clone();
166 self.batches_read += 1;
167 Ok(Some(batch))
168 } else {
169 Ok(None)
170 }
171 } else {
172 let reader = self.get_reader().await?;
173 let batch = reader.read().await?;
174 if batch.is_some() {
175 self.batches_read += 1;
176 }
177 Ok(batch)
178 }
179 }
180}
181
182pub struct SpillSender {
187 memory_limit: usize,
188 schema: Arc<Schema>,
189 path: PathBuf,
190 state: SpillState,
191 status_sender: tokio::sync::watch::Sender<WriteStatus>,
192}
193
194enum SpillState {
195 Buffering {
196 batches: Vec<RecordBatch>,
197 memory_accumulator: MemoryAccumulator,
198 },
199 Spilling {
200 writer: AsyncStreamWriter,
201 batches_written: usize,
202 },
203 Finished {
204 batches: Option<Arc<[RecordBatch]>>,
205 batches_written: usize,
206 },
207 Errored {
208 error: Arc<Mutex<SpillError>>,
209 },
210}
211
212impl Default for SpillState {
213 fn default() -> Self {
214 Self::Buffering {
215 batches: Vec::new(),
216 memory_accumulator: MemoryAccumulator::default(),
217 }
218 }
219}
220
221#[derive(Clone, Debug, Default)]
222struct WriteStatus {
223 error: Option<Arc<Mutex<SpillError>>>,
224 finished: bool,
225 data_location: DataLocation,
226}
227
228impl WriteStatus {
229 fn batches_written(&self) -> usize {
230 match &self.data_location {
231 DataLocation::Buffered { batches } => batches.len(),
232 DataLocation::Spilled {
233 batches_written, ..
234 } => *batches_written,
235 }
236 }
237}
238
239#[derive(Clone, Debug)]
240enum DataLocation {
241 Buffered { batches: Arc<[RecordBatch]> },
242 Spilled { batches_written: usize },
243}
244
245impl Default for DataLocation {
246 fn default() -> Self {
247 Self::Buffered {
248 batches: Arc::new([]),
249 }
250 }
251}
252
253#[derive(Debug)]
257enum SpillError {
258 Original(DataFusionError),
259 Copy(DataFusionError),
260}
261
262impl From<DataFusionError> for SpillError {
263 fn from(err: DataFusionError) -> Self {
264 Self::Original(err)
265 }
266}
267
268impl From<&mut SpillError> for DataFusionError {
269 fn from(err: &mut SpillError) -> Self {
270 match err {
271 SpillError::Original(inner) => {
272 let copy = Self::Execution(inner.to_string());
273 let original = std::mem::replace(err, SpillError::Copy(copy));
274 if let SpillError::Original(inner) = original {
275 inner
276 } else {
277 unreachable!()
278 }
279 }
280 SpillError::Copy(Self::Execution(message)) => Self::Execution(message.clone()),
281 _ => unreachable!(),
282 }
283 }
284}
285
286impl From<&SpillState> for WriteStatus {
287 fn from(state: &SpillState) -> Self {
288 match state {
289 SpillState::Buffering { batches, .. } => Self {
290 finished: false,
291 data_location: DataLocation::Buffered {
292 batches: batches.clone().into(),
293 },
294 error: None,
295 },
296 SpillState::Spilling {
297 batches_written, ..
298 } => Self {
299 finished: false,
300 data_location: DataLocation::Spilled {
301 batches_written: *batches_written,
302 },
303 error: None,
304 },
305 SpillState::Finished {
306 batches_written,
307 batches,
308 } => {
309 let data_location = if let Some(batches) = batches {
310 DataLocation::Buffered {
311 batches: batches.clone(),
312 }
313 } else {
314 DataLocation::Spilled {
315 batches_written: *batches_written,
316 }
317 };
318 Self {
319 finished: true,
320 data_location,
321 error: None,
322 }
323 }
324 SpillState::Errored { error } => Self {
325 finished: true,
326 data_location: DataLocation::default(), error: Some(error.clone()),
328 },
329 }
330 }
331}
332
333impl SpillSender {
334 pub async fn write(&mut self, batch: RecordBatch) -> Result<(), DataFusionError> {
342 if let SpillState::Finished { .. } = self.state {
343 return Err(DataFusionError::Execution(
344 "Spill has already been finished".to_string(),
345 ));
346 }
347
348 if let SpillState::Errored { .. } = &self.state {
349 return Err(DataFusionError::Execution(
350 "Spill has sent an error".to_string(),
351 ));
352 }
353
354 let (writer, batches_written) = match &mut self.state {
355 SpillState::Buffering {
356 batches,
357 ref mut memory_accumulator,
358 } => {
359 memory_accumulator.record_batch(&batch);
360
361 if memory_accumulator.total() > self.memory_limit {
362 let writer =
363 AsyncStreamWriter::open(self.path.clone(), self.schema.clone()).await?;
364 let batches_written = batches.len();
365 for batch in batches.drain(..) {
366 writer.write(batch).await?;
367 }
368 self.state = SpillState::Spilling {
369 writer,
370 batches_written,
371 };
372 if let SpillState::Spilling {
373 writer,
374 batches_written,
375 } = &mut self.state
376 {
377 (writer, batches_written)
378 } else {
379 unreachable!()
380 }
381 } else {
382 batches.push(batch);
383 self.status_sender
384 .send_replace(WriteStatus::from(&self.state));
385 return Ok(());
386 }
387 }
388 SpillState::Spilling {
389 writer,
390 batches_written,
391 } => (writer, batches_written),
392 _ => unreachable!(),
393 };
394
395 writer.write(batch).await?;
396 *batches_written += 1;
397 self.status_sender
398 .send_replace(WriteStatus::from(&self.state));
399
400 Ok(())
401 }
402
403 pub fn send_error(&mut self, err: DataFusionError) {
406 let error = Arc::new(Mutex::new(err.into()));
407 self.state = SpillState::Errored { error };
408 self.status_sender
409 .send_replace(WriteStatus::from(&self.state));
410 }
411
412 pub async fn finish(&mut self) -> Result<(), DataFusionError> {
415 let tmp_state = SpillState::Finished {
419 batches_written: 0,
420 batches: None,
421 };
422 match std::mem::replace(&mut self.state, tmp_state) {
423 SpillState::Buffering { batches, .. } => {
424 let batches_written = batches.len();
425 self.state = SpillState::Finished {
426 batches_written,
427 batches: Some(batches.into()),
428 };
429 self.status_sender
430 .send_replace(WriteStatus::from(&self.state));
431 }
432 SpillState::Spilling {
433 writer,
434 batches_written,
435 } => {
436 writer.finish().await?;
437 self.state = SpillState::Finished {
438 batches_written,
439 batches: None,
440 };
441 self.status_sender
442 .send_replace(WriteStatus::from(&self.state));
443 }
444 SpillState::Finished { .. } => {
445 return Err(DataFusionError::Execution(
446 "Spill has already been finished".to_string(),
447 ));
448 }
449 SpillState::Errored { .. } => {
450 return Err(DataFusionError::Execution(
451 "Spill has sent an error".to_string(),
452 ));
453 }
454 };
455
456 Ok(())
457 }
458}
459
460struct AsyncStreamWriter {
463 writer: Arc<Mutex<StreamWriter<BufWriter<std::fs::File>>>>,
464}
465
466impl AsyncStreamWriter {
467 pub async fn open(path: PathBuf, schema: Arc<Schema>) -> Result<Self, ArrowError> {
468 let writer = tokio::task::spawn_blocking(move || {
469 let file = std::fs::File::create(&path).map_err(ArrowError::from)?;
470 let writer = BufWriter::new(file);
471 StreamWriter::try_new(writer, &schema)
472 })
473 .await
474 .unwrap()?;
475 let writer = Arc::new(Mutex::new(writer));
476 Ok(Self { writer })
477 }
478
479 pub async fn write(&self, batch: RecordBatch) -> Result<(), ArrowError> {
480 let writer = self.writer.clone();
481 tokio::task::spawn_blocking(move || {
482 let mut writer = writer.lock().unwrap();
483 writer.write(&batch)?;
484 writer.flush()
485 })
486 .await
487 .unwrap()
488 }
489
490 pub async fn finish(self) -> Result<(), ArrowError> {
491 let writer = self.writer.clone();
492 tokio::task::spawn_blocking(move || {
493 let mut writer = writer.lock().unwrap();
494 writer.finish()
495 })
496 .await
497 .unwrap()
498 }
499}
500
501struct AsyncStreamReader {
502 reader: Arc<Mutex<StreamReader<BufReader<std::fs::File>>>>,
503}
504
505impl AsyncStreamReader {
506 pub async fn open(path: PathBuf) -> Result<Self, ArrowError> {
507 let reader = tokio::task::spawn_blocking(move || {
508 let file = std::fs::File::open(&path).map_err(ArrowError::from)?;
509 let reader = BufReader::new(file);
510 StreamReader::try_new(reader, None)
511 })
512 .await
513 .unwrap()?;
514 let reader = Arc::new(Mutex::new(reader));
515 Ok(Self { reader })
516 }
517
518 pub async fn read(&self) -> Result<Option<RecordBatch>, ArrowError> {
519 let reader = self.reader.clone();
520 tokio::task::spawn_blocking(move || {
521 let mut reader = reader.lock().unwrap();
522 reader.next()
523 })
524 .await
525 .unwrap()
526 .transpose()
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use arrow_array::Int32Array;
533 use arrow_schema::{DataType, Field};
534 use futures::{poll, StreamExt, TryStreamExt};
535 use lance_core::utils::tempfile::{TempStdFile, TempStdPath};
536
537 use super::*;
538
539 #[tokio::test]
540 async fn test_spill() {
541 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
542 let batches = [
543 RecordBatch::try_new(
544 schema.clone(),
545 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
546 )
547 .unwrap(),
548 RecordBatch::try_new(
549 schema.clone(),
550 vec![Arc::new(Int32Array::from(vec![4, 5, 6]))],
551 )
552 .unwrap(),
553 ];
554
555 let path = TempStdFile::default();
557 let (mut spill, receiver) = create_replay_spill(path.to_owned(), schema.clone(), 0);
558
559 let mut stream_before = receiver.read();
561 let mut stream_before_next = stream_before.next();
562 let poll_res = poll!(&mut stream_before_next);
563 assert!(poll_res.is_pending());
564
565 spill.write(batches[0].clone()).await.unwrap();
567 let stream_before_batch1 = stream_before_next
568 .await
569 .expect("Expected a batch")
570 .expect("Expected no error");
571 assert_eq!(&stream_before_batch1, &batches[0]);
572 let mut stream_before_next = stream_before.next();
573 let poll_res = poll!(&mut stream_before_next);
574 assert!(poll_res.is_pending());
575
576 let mut stream_during = receiver.read();
579 let stream_during_batch1 = stream_during
580 .next()
581 .await
582 .expect("Expected a batch")
583 .expect("Expected no error");
584 assert_eq!(&stream_during_batch1, &batches[0]);
585 let mut stream_during_next = stream_during.next();
586 let poll_res = poll!(&mut stream_during_next);
587 assert!(poll_res.is_pending());
588
589 spill.write(batches[1].clone()).await.unwrap();
592 spill.finish().await.unwrap();
593
594 let stream_before_batch2 = stream_before_next
595 .await
596 .expect("Expected a batch")
597 .expect("Expected no error");
598 assert_eq!(&stream_before_batch2, &batches[1]);
599 assert!(stream_before.next().await.is_none());
600
601 let stream_during_batch2 = stream_during_next
602 .await
603 .expect("Expected a batch")
604 .expect("Expected no error");
605 assert_eq!(&stream_during_batch2, &batches[1]);
606 assert!(stream_during.next().await.is_none());
607
608 let stream_after = receiver.read();
610 let stream_after_batches = stream_after.try_collect::<Vec<_>>().await.unwrap();
611 assert_eq!(&stream_after_batches, &batches);
612
613 std::fs::remove_file(path).unwrap();
614 }
615
616 #[tokio::test]
617 async fn test_spill_error() {
618 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
620 let path = TempStdFile::default();
621 let (mut spill, receiver) =
622 create_replay_spill(path.as_ref().to_owned(), schema.clone(), 0);
623 let batch = RecordBatch::try_new(
624 schema.clone(),
625 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
626 )
627 .unwrap();
628
629 spill.write(batch.clone()).await.unwrap();
630
631 let mut stream = receiver.read();
632 let stream_batch = stream
633 .next()
634 .await
635 .expect("Expected a batch")
636 .expect("Expected no error");
637 assert_eq!(&stream_batch, &batch);
638
639 spill.send_error(DataFusionError::ResourcesExhausted("🥱".into()));
640 let stream_error = stream
641 .next()
642 .await
643 .expect("Expected an error")
644 .expect_err("Expected an error");
645 assert!(matches!(
646 stream_error,
647 DataFusionError::ResourcesExhausted(message) if message == "🥱"
648 ));
649
650 let err = spill.write(batch).await;
652 assert!(matches!(
653 err,
654 Err(DataFusionError::Execution(message)) if message == "Spill has sent an error"
655 ));
656
657 let err = spill.finish().await;
659 assert!(matches!(
660 err,
661 Err(DataFusionError::Execution(message)) if message == "Spill has sent an error"
662 ));
663
664 let mut stream = receiver.read();
666 let stream_error = stream
667 .next()
668 .await
669 .expect("Expected an error")
670 .expect_err("Expected an error");
671 assert!(matches!(
672 stream_error,
673 DataFusionError::Execution(message) if message.contains("🥱")
674 ));
675
676 std::fs::remove_file(path).unwrap();
677 }
678
679 #[tokio::test]
680 async fn test_spill_buffered() {
681 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
682 let path = TempStdPath::default();
683 let memory_limit = 1024 * 1024; let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit);
685
686 let batch = RecordBatch::try_new(
688 schema.clone(),
689 vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))],
690 )
691 .unwrap();
692 spill.write(batch.clone()).await.unwrap();
693 assert!(!std::fs::exists(&path).unwrap());
694
695 spill.finish().await.unwrap();
696 assert!(!std::fs::exists(&path).unwrap());
697
698 let mut stream = receiver.read();
699 let stream_batch = stream
700 .next()
701 .await
702 .expect("Expected a batch")
703 .expect("Expected no error");
704 assert_eq!(&stream_batch, &batch);
705
706 assert!(!std::fs::exists(&path).unwrap());
707 }
708
709 #[tokio::test]
710 async fn test_spill_buffered_transition() {
711 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
713 let path = TempStdPath::default();
714 let memory_limit = 1024 * 1024; let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit);
716
717 let batch = RecordBatch::try_new(
719 schema.clone(),
720 vec![Arc::new(Int32Array::from(vec![1; (768 * 1024) / 4]))],
721 )
722 .unwrap();
723 spill.write(batch.clone()).await.unwrap();
724 assert!(!std::fs::exists(&path).unwrap());
725
726 let mut stream = receiver.read();
727 let stream_batch = stream
728 .next()
729 .await
730 .expect("Expected a batch")
731 .expect("Expected no error");
732 assert_eq!(&stream_batch, &batch);
733 assert!(!std::fs::exists(&path).unwrap());
734
735 let batch = RecordBatch::try_new(
737 schema.clone(),
738 vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))],
739 )
740 .unwrap();
741 spill.write(batch.clone()).await.unwrap();
742 assert!(std::fs::exists(&path).unwrap());
743
744 let stream_batch = stream
745 .next()
746 .await
747 .expect("Expected a batch")
748 .expect("Expected no error");
749 assert_eq!(&stream_batch, &batch);
750 assert!(std::fs::exists(&path).unwrap());
751
752 spill.finish().await.unwrap();
753
754 assert!(stream.next().await.is_none());
755
756 std::fs::remove_file(path).unwrap();
757 }
758}