1use std::{
5 io::{BufReader, BufWriter},
6 path::PathBuf,
7 sync::{Arc, Mutex},
8};
9
10use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
11use arrow_array::RecordBatch;
12use arrow_schema::{ArrowError, Schema};
13use datafusion::{
14 execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter,
15};
16use datafusion_common::DataFusionError;
17use lance_arrow::memory::MemoryAccumulator;
18use lance_core::error::LanceOptionExt;
19
20pub fn create_replay_spill(
40 path: std::path::PathBuf,
41 schema: Arc<Schema>,
42 memory_limit: usize,
43) -> (SpillSender, SpillReceiver) {
44 let initial_status = WriteStatus::default();
45 let (status_sender, status_receiver) = tokio::sync::watch::channel(initial_status);
46 let sender = SpillSender {
47 memory_limit,
48 path: path.clone(),
49 schema: schema.clone(),
50 state: SpillState::default(),
51 status_sender,
52 };
53
54 let receiver = SpillReceiver {
55 status_receiver,
56 path,
57 schema,
58 };
59
60 (sender, receiver)
61}
62
63#[derive(Clone)]
64pub struct SpillReceiver {
65 status_receiver: tokio::sync::watch::Receiver<WriteStatus>,
66 path: PathBuf,
67 schema: Arc<Schema>,
68}
69
70impl SpillReceiver {
71 pub fn read(&self) -> SendableRecordBatchStream {
79 let rx = self.status_receiver.clone();
80 let reader = SpillReader::new(rx, self.path.clone());
81
82 let stream = futures::stream::try_unfold(reader, move |mut reader| async move {
83 match reader.read().await {
84 Ok(None) => Ok(None),
85 Ok(Some(batch)) => Ok(Some((batch, reader))),
86 Err(err) => Err(err),
87 }
88 });
89
90 Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
91 }
92}
93
94struct SpillReader {
95 pub batches_read: usize,
96 receiver: tokio::sync::watch::Receiver<WriteStatus>,
97 state: SpillReaderState,
98}
99
100enum SpillReaderState {
101 Buffered { spill_path: PathBuf },
102 Reader { reader: AsyncStreamReader },
103}
104
105impl SpillReader {
106 fn new(receiver: tokio::sync::watch::Receiver<WriteStatus>, spill_path: PathBuf) -> Self {
107 Self {
108 batches_read: 0,
109 receiver,
110 state: SpillReaderState::Buffered { spill_path },
111 }
112 }
113
114 async fn wait_for_more_data(&mut self) -> Result<Option<Arc<[RecordBatch]>>, DataFusionError> {
115 let status = self
116 .receiver
117 .wait_for(|status| {
118 status.error.is_some()
119 || status.finished
120 || status.batches_written() > self.batches_read
121 })
122 .await
123 .map_err(|_| {
124 DataFusionError::Execution(
125 "Spill has been dropped before reader has finish.".into(),
126 )
127 })?;
128
129 if let Some(error) = &status.error {
130 let mut guard = error.lock().ok().expect_ok()?;
131 return Err(DataFusionError::from(&mut (*guard)));
132 }
133
134 if let DataLocation::Buffered { batches } = &status.data_location {
135 Ok(Some(batches.clone()))
136 } else {
137 Ok(None)
138 }
139 }
140
141 async fn get_reader(&mut self) -> Result<&AsyncStreamReader, ArrowError> {
142 if let SpillReaderState::Buffered { spill_path } = &self.state {
143 let reader = AsyncStreamReader::open(spill_path.clone()).await?;
144 for _ in 0..self.batches_read {
148 reader.read().await?;
149 }
150 self.state = SpillReaderState::Reader { reader };
151 }
152
153 if let SpillReaderState::Reader { reader } = &mut self.state {
154 Ok(reader)
155 } else {
156 unreachable!()
157 }
158 }
159
160 async fn read(&mut self) -> Result<Option<RecordBatch>, DataFusionError> {
161 let maybe_data = self.wait_for_more_data().await?;
162
163 if let Some(batches) = maybe_data {
164 if self.batches_read < batches.len() {
165 let batch = batches[self.batches_read].clone();
166 self.batches_read += 1;
167 Ok(Some(batch))
168 } else {
169 Ok(None)
170 }
171 } else {
172 let reader = self.get_reader().await?;
173 let batch = reader.read().await?;
174 if batch.is_some() {
175 self.batches_read += 1;
176 }
177 Ok(batch)
178 }
179 }
180}
181
182pub struct SpillSender {
187 memory_limit: usize,
188 schema: Arc<Schema>,
189 path: PathBuf,
190 state: SpillState,
191 status_sender: tokio::sync::watch::Sender<WriteStatus>,
192}
193
194enum SpillState {
195 Buffering {
196 batches: Vec<RecordBatch>,
197 memory_accumulator: MemoryAccumulator,
198 },
199 Spilling {
200 writer: AsyncStreamWriter,
201 batches_written: usize,
202 },
203 Finished {
204 batches: Option<Arc<[RecordBatch]>>,
205 batches_written: usize,
206 },
207 Errored {
208 error: Arc<Mutex<SpillError>>,
209 },
210}
211
212impl Default for SpillState {
213 fn default() -> Self {
214 Self::Buffering {
215 batches: Vec::new(),
216 memory_accumulator: MemoryAccumulator::default(),
217 }
218 }
219}
220
221#[derive(Clone, Debug, Default)]
222struct WriteStatus {
223 error: Option<Arc<Mutex<SpillError>>>,
224 finished: bool,
225 data_location: DataLocation,
226}
227
228impl WriteStatus {
229 fn batches_written(&self) -> usize {
230 match &self.data_location {
231 DataLocation::Buffered { batches } => batches.len(),
232 DataLocation::Spilled {
233 batches_written, ..
234 } => *batches_written,
235 }
236 }
237}
238
239#[derive(Clone, Debug)]
240enum DataLocation {
241 Buffered { batches: Arc<[RecordBatch]> },
242 Spilled { batches_written: usize },
243}
244
245impl Default for DataLocation {
246 fn default() -> Self {
247 Self::Buffered {
248 batches: Arc::new([]),
249 }
250 }
251}
252
253#[derive(Debug)]
257enum SpillError {
258 Original(DataFusionError),
259 Copy(DataFusionError),
260}
261
262impl From<DataFusionError> for SpillError {
263 fn from(err: DataFusionError) -> Self {
264 Self::Original(err)
265 }
266}
267
268impl From<&mut SpillError> for DataFusionError {
269 fn from(err: &mut SpillError) -> Self {
270 match err {
271 SpillError::Original(inner) => {
272 let copy = Self::Execution(inner.to_string());
273 let original = std::mem::replace(err, SpillError::Copy(copy));
274 if let SpillError::Original(inner) = original {
275 inner
276 } else {
277 unreachable!()
278 }
279 }
280 SpillError::Copy(Self::Execution(message)) => Self::Execution(message.clone()),
281 _ => unreachable!(),
282 }
283 }
284}
285
286impl From<&SpillState> for WriteStatus {
287 fn from(state: &SpillState) -> Self {
288 match state {
289 SpillState::Buffering { batches, .. } => Self {
290 finished: false,
291 data_location: DataLocation::Buffered {
292 batches: batches.clone().into(),
293 },
294 error: None,
295 },
296 SpillState::Spilling {
297 batches_written, ..
298 } => Self {
299 finished: false,
300 data_location: DataLocation::Spilled {
301 batches_written: *batches_written,
302 },
303 error: None,
304 },
305 SpillState::Finished {
306 batches_written,
307 batches,
308 } => {
309 let data_location = if let Some(batches) = batches {
310 DataLocation::Buffered {
311 batches: batches.clone(),
312 }
313 } else {
314 DataLocation::Spilled {
315 batches_written: *batches_written,
316 }
317 };
318 Self {
319 finished: true,
320 data_location,
321 error: None,
322 }
323 }
324 SpillState::Errored { error } => Self {
325 finished: true,
326 data_location: DataLocation::default(), error: Some(error.clone()),
328 },
329 }
330 }
331}
332
333impl SpillSender {
334 pub async fn write(&mut self, batch: RecordBatch) -> Result<(), DataFusionError> {
342 if let SpillState::Finished { .. } = self.state {
343 return Err(DataFusionError::Execution(
344 "Spill has already been finished".to_string(),
345 ));
346 }
347
348 if let SpillState::Errored { .. } = &self.state {
349 return Err(DataFusionError::Execution(
350 "Spill has sent an error".to_string(),
351 ));
352 }
353
354 let (writer, batches_written) = match &mut self.state {
355 SpillState::Buffering {
356 batches,
357 ref mut memory_accumulator,
358 } => {
359 memory_accumulator.record_batch(&batch);
360
361 if memory_accumulator.total() > self.memory_limit {
362 let writer =
363 AsyncStreamWriter::open(self.path.clone(), self.schema.clone()).await?;
364 let batches_written = batches.len();
365 for batch in batches.drain(..) {
366 writer.write(batch).await?;
367 }
368 self.state = SpillState::Spilling {
369 writer,
370 batches_written,
371 };
372 if let SpillState::Spilling {
373 writer,
374 batches_written,
375 } = &mut self.state
376 {
377 (writer, batches_written)
378 } else {
379 unreachable!()
380 }
381 } else {
382 batches.push(batch);
383 self.status_sender
384 .send_replace(WriteStatus::from(&self.state));
385 return Ok(());
386 }
387 }
388 SpillState::Spilling {
389 writer,
390 batches_written,
391 } => (writer, batches_written),
392 _ => unreachable!(),
393 };
394
395 writer.write(batch).await?;
396 *batches_written += 1;
397 self.status_sender
398 .send_replace(WriteStatus::from(&self.state));
399
400 Ok(())
401 }
402
403 pub fn send_error(&mut self, err: DataFusionError) {
406 let error = Arc::new(Mutex::new(err.into()));
407 self.state = SpillState::Errored { error };
408 self.status_sender
409 .send_replace(WriteStatus::from(&self.state));
410 }
411
412 pub async fn finish(&mut self) -> Result<(), DataFusionError> {
416 let tmp_state = SpillState::Finished {
420 batches_written: 0,
421 batches: None,
422 };
423 match std::mem::replace(&mut self.state, tmp_state) {
424 SpillState::Buffering { batches, .. } => {
425 let batches_written = batches.len();
426 self.state = SpillState::Finished {
427 batches_written,
428 batches: Some(batches.into()),
429 };
430 self.status_sender
431 .send_replace(WriteStatus::from(&self.state));
432 }
433 SpillState::Spilling {
434 writer,
435 batches_written,
436 } => {
437 writer.finish().await?;
438 self.state = SpillState::Finished {
439 batches_written,
440 batches: None,
441 };
442 self.status_sender
443 .send_replace(WriteStatus::from(&self.state));
444 }
445 SpillState::Finished { .. } => {
446 return Err(DataFusionError::Execution(
447 "Spill has already been finished".to_string(),
448 ));
449 }
450 SpillState::Errored { .. } => {
451 return Err(DataFusionError::Execution(
452 "Spill has sent an error".to_string(),
453 ));
454 }
455 };
456
457 Ok(())
458 }
459}
460
461struct AsyncStreamWriter {
464 writer: Arc<Mutex<StreamWriter<BufWriter<std::fs::File>>>>,
465}
466
467impl AsyncStreamWriter {
468 pub async fn open(path: PathBuf, schema: Arc<Schema>) -> Result<Self, ArrowError> {
469 let writer = tokio::task::spawn_blocking(move || {
470 let file = std::fs::File::create(&path).map_err(ArrowError::from)?;
471 let writer = BufWriter::new(file);
472 StreamWriter::try_new(writer, &schema)
473 })
474 .await
475 .unwrap()?;
476 let writer = Arc::new(Mutex::new(writer));
477 Ok(Self { writer })
478 }
479
480 pub async fn write(&self, batch: RecordBatch) -> Result<(), ArrowError> {
481 let writer = self.writer.clone();
482 tokio::task::spawn_blocking(move || {
483 let mut writer = writer.lock().unwrap();
484 writer.write(&batch)?;
485 writer.flush()
486 })
487 .await
488 .unwrap()
489 }
490
491 pub async fn finish(self) -> Result<(), ArrowError> {
492 let writer = self.writer.clone();
493 tokio::task::spawn_blocking(move || {
494 let mut writer = writer.lock().unwrap();
495 writer.finish()
496 })
497 .await
498 .unwrap()
499 }
500}
501
502struct AsyncStreamReader {
503 reader: Arc<Mutex<StreamReader<BufReader<std::fs::File>>>>,
504}
505
506impl AsyncStreamReader {
507 pub async fn open(path: PathBuf) -> Result<Self, ArrowError> {
508 let reader = tokio::task::spawn_blocking(move || {
509 let file = std::fs::File::open(&path).map_err(ArrowError::from)?;
510 let reader = BufReader::new(file);
511 StreamReader::try_new(reader, None)
512 })
513 .await
514 .unwrap()?;
515 let reader = Arc::new(Mutex::new(reader));
516 Ok(Self { reader })
517 }
518
519 pub async fn read(&self) -> Result<Option<RecordBatch>, ArrowError> {
520 let reader = self.reader.clone();
521 tokio::task::spawn_blocking(move || {
522 let mut reader = reader.lock().unwrap();
523 reader.next()
524 })
525 .await
526 .unwrap()
527 .transpose()
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use arrow_array::Int32Array;
534 use arrow_schema::{DataType, Field};
535 use futures::{poll, StreamExt, TryStreamExt};
536 use lance_core::utils::tempfile::{TempStdFile, TempStdPath};
537
538 use super::*;
539
540 #[tokio::test]
541 async fn test_spill() {
542 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
543 let batches = [
544 RecordBatch::try_new(
545 schema.clone(),
546 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
547 )
548 .unwrap(),
549 RecordBatch::try_new(
550 schema.clone(),
551 vec![Arc::new(Int32Array::from(vec![4, 5, 6]))],
552 )
553 .unwrap(),
554 ];
555
556 let path = TempStdFile::default();
558 let (mut spill, receiver) = create_replay_spill(path.to_owned(), schema.clone(), 0);
559
560 let mut stream_before = receiver.read();
562 let mut stream_before_next = stream_before.next();
563 let poll_res = poll!(&mut stream_before_next);
564 assert!(poll_res.is_pending());
565
566 spill.write(batches[0].clone()).await.unwrap();
568 let stream_before_batch1 = stream_before_next
569 .await
570 .expect("Expected a batch")
571 .expect("Expected no error");
572 assert_eq!(&stream_before_batch1, &batches[0]);
573 let mut stream_before_next = stream_before.next();
574 let poll_res = poll!(&mut stream_before_next);
575 assert!(poll_res.is_pending());
576
577 let mut stream_during = receiver.read();
580 let stream_during_batch1 = stream_during
581 .next()
582 .await
583 .expect("Expected a batch")
584 .expect("Expected no error");
585 assert_eq!(&stream_during_batch1, &batches[0]);
586 let mut stream_during_next = stream_during.next();
587 let poll_res = poll!(&mut stream_during_next);
588 assert!(poll_res.is_pending());
589
590 spill.write(batches[1].clone()).await.unwrap();
593 spill.finish().await.unwrap();
594
595 let stream_before_batch2 = stream_before_next
596 .await
597 .expect("Expected a batch")
598 .expect("Expected no error");
599 assert_eq!(&stream_before_batch2, &batches[1]);
600 assert!(stream_before.next().await.is_none());
601
602 let stream_during_batch2 = stream_during_next
603 .await
604 .expect("Expected a batch")
605 .expect("Expected no error");
606 assert_eq!(&stream_during_batch2, &batches[1]);
607 assert!(stream_during.next().await.is_none());
608
609 let stream_after = receiver.read();
611 let stream_after_batches = stream_after.try_collect::<Vec<_>>().await.unwrap();
612 assert_eq!(&stream_after_batches, &batches);
613
614 std::fs::remove_file(path).unwrap();
615 }
616
617 #[tokio::test]
618 async fn test_spill_error() {
619 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
621 let path = TempStdFile::default();
622 let (mut spill, receiver) =
623 create_replay_spill(path.as_ref().to_owned(), schema.clone(), 0);
624 let batch = RecordBatch::try_new(
625 schema.clone(),
626 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
627 )
628 .unwrap();
629
630 spill.write(batch.clone()).await.unwrap();
631
632 let mut stream = receiver.read();
633 let stream_batch = stream
634 .next()
635 .await
636 .expect("Expected a batch")
637 .expect("Expected no error");
638 assert_eq!(&stream_batch, &batch);
639
640 spill.send_error(DataFusionError::ResourcesExhausted("🥱".into()));
641 let stream_error = stream
642 .next()
643 .await
644 .expect("Expected an error")
645 .expect_err("Expected an error");
646 assert!(matches!(
647 stream_error,
648 DataFusionError::ResourcesExhausted(message) if message == "🥱"
649 ));
650
651 let err = spill.write(batch).await;
653 assert!(matches!(
654 err,
655 Err(DataFusionError::Execution(message)) if message == "Spill has sent an error"
656 ));
657
658 let err = spill.finish().await;
660 assert!(matches!(
661 err,
662 Err(DataFusionError::Execution(message)) if message == "Spill has sent an error"
663 ));
664
665 let mut stream = receiver.read();
667 let stream_error = stream
668 .next()
669 .await
670 .expect("Expected an error")
671 .expect_err("Expected an error");
672 assert!(matches!(
673 stream_error,
674 DataFusionError::Execution(message) if message.contains("🥱")
675 ));
676
677 std::fs::remove_file(path).unwrap();
678 }
679
680 #[tokio::test]
681 async fn test_spill_buffered() {
682 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
683 let path = TempStdPath::default();
684 let memory_limit = 1024 * 1024; let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit);
686
687 let batch = RecordBatch::try_new(
689 schema.clone(),
690 vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))],
691 )
692 .unwrap();
693 spill.write(batch.clone()).await.unwrap();
694 assert!(!std::fs::exists(&path).unwrap());
695
696 spill.finish().await.unwrap();
697 assert!(!std::fs::exists(&path).unwrap());
698
699 let mut stream = receiver.read();
700 let stream_batch = stream
701 .next()
702 .await
703 .expect("Expected a batch")
704 .expect("Expected no error");
705 assert_eq!(&stream_batch, &batch);
706
707 assert!(!std::fs::exists(&path).unwrap());
708 }
709
710 #[tokio::test]
711 async fn test_spill_buffered_transition() {
712 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
714 let path = TempStdPath::default();
715 let memory_limit = 1024 * 1024; let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit);
717
718 let batch = RecordBatch::try_new(
720 schema.clone(),
721 vec![Arc::new(Int32Array::from(vec![1; (768 * 1024) / 4]))],
722 )
723 .unwrap();
724 spill.write(batch.clone()).await.unwrap();
725 assert!(!std::fs::exists(&path).unwrap());
726
727 let mut stream = receiver.read();
728 let stream_batch = stream
729 .next()
730 .await
731 .expect("Expected a batch")
732 .expect("Expected no error");
733 assert_eq!(&stream_batch, &batch);
734 assert!(!std::fs::exists(&path).unwrap());
735
736 let batch = RecordBatch::try_new(
738 schema.clone(),
739 vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))],
740 )
741 .unwrap();
742 spill.write(batch.clone()).await.unwrap();
743 assert!(std::fs::exists(&path).unwrap());
744
745 let stream_batch = stream
746 .next()
747 .await
748 .expect("Expected a batch")
749 .expect("Expected no error");
750 assert_eq!(&stream_batch, &batch);
751 assert!(std::fs::exists(&path).unwrap());
752
753 spill.finish().await.unwrap();
754
755 assert!(stream.next().await.is_none());
756
757 std::fs::remove_file(path).unwrap();
758 }
759}