1use arrow::datatypes::SchemaRef;
2use arrow::record_batch::RecordBatch;
3use datafusion::catalog::TableProvider;
4use datafusion::catalog::streaming::StreamingTable;
5use datafusion::error::DataFusionError;
6use datafusion::execution::TaskContext;
7use datafusion::physical_plan::SendableRecordBatchStream;
8use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
9use datafusion::physical_plan::streaming::PartitionStream;
10use futures::StreamExt;
11use std::sync::{Arc, Mutex as StdMutex};
12use tokio::sync::{Mutex as AsyncMutex, mpsc};
13use tokio_stream::wrappers::ReceiverStream;
14
15use core::fmt;
16
17pub const CONTINUOUS_TABLE_CHANNEL_CAPACITY: usize = 64;
24
25#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
27pub enum ContinuousInputError {
28 #[error("continuous table batch schema mismatch: expected {expected}, got {actual}")]
30 SchemaMismatch { expected: String, actual: String },
31 #[error("continuous table input queue is full")]
33 QueueFull,
34 #[error("continuous table input is closed")]
36 Closed,
37 #[error("continuous table input lock is poisoned: {0}")]
39 LockPoisoned(String),
40}
41
42pub struct ChannelPartitionStream {
44 schema: SchemaRef,
45 receiver: AsyncMutex<Option<mpsc::Receiver<RecordBatch>>>,
46}
47
48impl fmt::Debug for ChannelPartitionStream {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_struct("ChannelPartitionStream")
51 .field("schema", &self.schema)
52 .finish()
53 }
54}
55
56impl ChannelPartitionStream {
57 pub fn new(schema: SchemaRef, receiver: mpsc::Receiver<RecordBatch>) -> Self {
58 Self {
59 schema,
60 receiver: AsyncMutex::new(Some(receiver)),
61 }
62 }
63
64 fn error_stream(&self, message: impl Into<String>) -> SendableRecordBatchStream {
65 let message = message.into();
66 let stream = futures::stream::once(async move { Err(DataFusionError::Execution(message)) });
67 Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
68 }
69}
70
71impl PartitionStream for ChannelPartitionStream {
72 fn schema(&self) -> &SchemaRef {
73 &self.schema
74 }
75
76 fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
77 let mut rx_guard = match self.receiver.try_lock() {
78 Ok(guard) => guard,
79 Err(_) => {
80 return self.error_stream(
81 "continuous table partition is already executing in another query",
82 );
83 }
84 };
85 let Some(rx) = rx_guard.take() else {
86 return self.error_stream(
87 "continuous table partition has already been consumed by another query",
88 );
89 };
90
91 let stream = ReceiverStream::new(rx).map(Ok::<RecordBatch, DataFusionError>);
92 Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
93 }
94}
95
96pub struct ContinuousTableInput {
98 schema: SchemaRef,
99 sender: StdMutex<Option<mpsc::Sender<RecordBatch>>>,
100}
101
102impl fmt::Debug for ContinuousTableInput {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 f.debug_struct("ContinuousTableInput")
105 .field("schema", &self.schema)
106 .field("closed", &self.is_closed().ok())
107 .finish()
108 }
109}
110
111impl ContinuousTableInput {
112 fn new(schema: SchemaRef, sender: mpsc::Sender<RecordBatch>) -> Self {
113 Self {
114 schema,
115 sender: StdMutex::new(Some(sender)),
116 }
117 }
118
119 pub fn schema(&self) -> &SchemaRef {
121 &self.schema
122 }
123
124 pub fn try_send(&self, batch: RecordBatch) -> Result<(), ContinuousInputError> {
126 self.validate_schema(&batch)?;
127 let sender = self.sender_clone()?;
128 sender.try_send(batch).map_err(|error| match error {
129 mpsc::error::TrySendError::Full(_) => ContinuousInputError::QueueFull,
130 mpsc::error::TrySendError::Closed(_) => ContinuousInputError::Closed,
131 })
132 }
133
134 pub async fn send(&self, batch: RecordBatch) -> Result<(), ContinuousInputError> {
136 self.validate_schema(&batch)?;
137 self.sender_clone()?
138 .send(batch)
139 .await
140 .map_err(|_| ContinuousInputError::Closed)
141 }
142
143 pub fn close(&self) -> Result<bool, ContinuousInputError> {
148 let mut sender = self
149 .sender
150 .lock()
151 .map_err(|error| ContinuousInputError::LockPoisoned(error.to_string()))?;
152 Ok(sender.take().is_some())
153 }
154
155 pub fn cancel(&self) {
159 if let Ok(mut sender) = self.sender.lock() {
164 sender.take();
165 }
166 }
167
168 pub fn is_closed(&self) -> Result<bool, ContinuousInputError> {
170 self.sender
171 .lock()
172 .map(|sender| sender.is_none())
173 .map_err(|error| ContinuousInputError::LockPoisoned(error.to_string()))
174 }
175
176 fn sender_clone(&self) -> Result<mpsc::Sender<RecordBatch>, ContinuousInputError> {
177 self.sender
178 .lock()
179 .map_err(|error| ContinuousInputError::LockPoisoned(error.to_string()))?
180 .clone()
181 .ok_or(ContinuousInputError::Closed)
182 }
183
184 fn validate_schema(&self, batch: &RecordBatch) -> Result<(), ContinuousInputError> {
185 if batch.schema().as_ref() != self.schema.as_ref() {
186 return Err(ContinuousInputError::SchemaMismatch {
187 expected: format!("{:?}", self.schema),
188 actual: format!("{:?}", batch.schema()),
189 });
190 }
191 Ok(())
192 }
193}
194
195pub fn create_continuous_table(
201 schema: SchemaRef,
202) -> datafusion::error::Result<(Arc<dyn TableProvider>, Arc<ContinuousTableInput>)> {
203 create_continuous_table_with_capacity(schema, CONTINUOUS_TABLE_CHANNEL_CAPACITY)
204}
205
206pub fn create_continuous_table_with_capacity(
210 schema: SchemaRef,
211 capacity: usize,
212) -> datafusion::error::Result<(Arc<dyn TableProvider>, Arc<ContinuousTableInput>)> {
213 let (tx, rx) = mpsc::channel(capacity.max(1));
214 let partition = Arc::new(ChannelPartitionStream::new(schema.clone(), rx));
215 let table = StreamingTable::try_new(schema.clone(), vec![partition])?;
216 Ok((
217 Arc::new(table),
218 Arc::new(ContinuousTableInput::new(schema, tx)),
219 ))
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use arrow::array::Int32Array;
226 use arrow::datatypes::{DataType, Field, Schema};
227 use std::sync::Arc;
228
229 fn make_schema() -> SchemaRef {
230 Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]))
231 }
232
233 fn make_batch(values: Vec<i32>) -> RecordBatch {
234 RecordBatch::try_new(make_schema(), vec![Arc::new(Int32Array::from(values))]).unwrap()
235 }
236
237 #[tokio::test]
238 async fn create_continuous_table_with_capacity_zero_is_clamped_to_one() {
239 let schema = make_schema();
240 let (table, tx) = create_continuous_table_with_capacity(schema, 0).unwrap();
241 tx.try_send(make_batch(vec![1]))
245 .expect("capacity should be >= 1");
246 assert!(tx.try_send(make_batch(vec![2])).is_err());
248 drop(table);
249 }
250
251 #[tokio::test]
252 async fn bounded_channel_rejects_oversized_queue_via_try_send() {
253 let schema = make_schema();
254 let (table, tx) = create_continuous_table_with_capacity(schema, 2).unwrap();
255 assert!(tx.try_send(make_batch(vec![1])).is_ok());
258 assert!(tx.try_send(make_batch(vec![2])).is_ok());
259 let third = tx.try_send(make_batch(vec![3]));
260 assert!(
261 matches!(third, Err(ContinuousInputError::QueueFull)),
262 "expected Full, got {third:?}"
263 );
264 drop(table);
265 }
266
267 #[tokio::test]
268 async fn continuous_input_rejects_schema_mismatch_and_close_is_idempotent() {
269 let (table, input) = create_continuous_table(make_schema()).unwrap();
270 let wrong_schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
271 let wrong_batch = RecordBatch::try_new(
272 wrong_schema,
273 vec![Arc::new(arrow::array::Int64Array::from(vec![1]))],
274 )
275 .unwrap();
276
277 let error = input
278 .try_send(wrong_batch)
279 .expect_err("schema mismatch must fail");
280 assert!(matches!(error, ContinuousInputError::SchemaMismatch { .. }));
281 assert!(input.close().unwrap());
282 assert!(!input.close().unwrap());
283 assert!(input.is_closed().unwrap());
284 assert!(matches!(
285 input.try_send(make_batch(vec![1])),
286 Err(ContinuousInputError::Closed)
287 ));
288 drop(table);
289 }
290}