use anyhow::Result;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::error::DataFusionError;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::physical_plan::RecordBatchStream;
use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
use datafusion::prelude::SessionContext;
use super::definition::DestinationMode;
mod lance;
mod sql_dml;
pub use lance::LanceDestination;
pub use sql_dml::SqlDmlDestination;
#[derive(Debug, Clone)]
pub struct WriteOutcome {
pub rows_written: u64,
pub snapshot_id: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JobDestinationKind {
Lake,
Db,
}
#[async_trait]
pub trait JobDestination: Send + Sync {
fn kind(&self) -> JobDestinationKind;
async fn exists(&self) -> Result<bool>;
async fn schema(&self) -> Result<Option<Arc<Schema>>>;
async fn write(
&self,
stream: SendableRecordBatchStream,
mode: DestinationMode,
) -> Result<WriteOutcome>;
}
pub struct CancellableStream {
inner: SendableRecordBatchStream,
cancelled: Arc<AtomicBool>,
}
impl CancellableStream {
pub fn new(inner: SendableRecordBatchStream, cancelled: Arc<AtomicBool>) -> Self {
Self { inner, cancelled }
}
pub fn boxed(self) -> SendableRecordBatchStream {
Box::pin(self)
}
}
impl Stream for CancellableStream {
type Item = std::result::Result<RecordBatch, DataFusionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.cancelled.load(Ordering::SeqCst) {
return Poll::Ready(Some(Err(DataFusionError::Execution(
"job cancelled before commit".to_string(),
))));
}
Pin::new(&mut this.inner).poll_next(cx)
}
}
impl RecordBatchStream for CancellableStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
pub(crate) fn quote_table_ref(name: &str) -> String {
name.split('.')
.map(|seg| format!("\"{}\"", seg.replace('"', "\"\"")))
.collect::<Vec<_>>()
.join(".")
}
pub(crate) async fn lookup_table_schema(
ctx: &SessionContext,
table: &str,
) -> Result<Option<Arc<Schema>>> {
let sql = format!("SELECT * FROM {} LIMIT 0", quote_table_ref(table));
match ctx.sql(&sql).await {
Ok(df) => Ok(Some(Arc::new(df.schema().as_arrow().clone()))),
Err(e) => {
let msg = e.to_string();
let missing = msg.contains("not found")
|| msg.contains("does not exist")
|| msg.contains("Unsupported compound identifier");
if missing {
Ok(None)
} else {
Err(anyhow::anyhow!(
"Failed to resolve destination table '{}': {}",
table,
e
))
}
}
}
}
#[cfg(test)]
pub(crate) mod test_util {
use super::*;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use futures::stream;
pub fn vec_to_stream(
batches: Vec<RecordBatch>,
schema: SchemaRef,
) -> SendableRecordBatchStream {
Box::pin(RecordBatchStreamAdapter::new(
schema,
stream::iter(batches.into_iter().map(Ok)),
))
}
}
#[cfg(test)]
mod tests {
use super::test_util::vec_to_stream;
use super::*;
use arrow::array::Int64Array;
use arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
use futures::StreamExt;
#[test]
fn quote_table_ref_handles_dotted_idents_and_quotes() {
assert_eq!(quote_table_ref("plain"), "\"plain\"");
assert_eq!(
quote_table_ref("cat.schema.tbl"),
"\"cat\".\"schema\".\"tbl\""
);
assert_eq!(quote_table_ref("has\"quote"), "\"has\"\"quote\"");
}
#[tokio::test]
async fn cancellable_stream_errors_when_flag_is_set_before_first_poll() {
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int64,
false,
)]));
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(vec![1]))])
.unwrap();
let inner = vec_to_stream(vec![batch], schema);
let flag = Arc::new(AtomicBool::new(true));
let mut s = CancellableStream::new(inner, flag).boxed();
let item = s.next().await.expect("one item");
assert!(item.is_err(), "expected Err, got {:?}", item);
assert!(s.next().await.is_none() || true); }
#[tokio::test]
async fn cancellable_stream_passes_through_when_not_cancelled() {
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int64,
false,
)]));
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(vec![1, 2]))])
.unwrap();
let inner = vec_to_stream(vec![batch], schema);
let flag = Arc::new(AtomicBool::new(false));
let mut s = CancellableStream::new(inner, flag).boxed();
let b = s.next().await.unwrap().unwrap();
assert_eq!(b.num_rows(), 2);
assert!(s.next().await.is_none());
}
}