Skip to main content

krishiv_sql/
streaming.rs

1use arrow::datatypes::SchemaRef;
2use arrow::record_batch::RecordBatch;
3use datafusion::catalog::TableProvider;
4use datafusion::catalog::streaming::StreamingTable;
5use datafusion::error::DataFusionError;
6use datafusion::execution::TaskContext;
7use datafusion::physical_plan::SendableRecordBatchStream;
8use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
9use datafusion::physical_plan::streaming::PartitionStream;
10use futures::StreamExt;
11use std::sync::{Arc, Mutex as StdMutex};
12use tokio::sync::{Mutex as AsyncMutex, mpsc};
13use tokio_stream::wrappers::ReceiverStream;
14
15use core::fmt;
16
17/// Default per-continuous-table channel capacity. Bounds the in-memory
18/// queue between a producer and the DataFusion consumer: a slow consumer
19/// (e.g. an expensive join downstream) cannot cause an unbounded producer
20/// to grow memory without limit. 64 batches × ~1k rows/batch ≈ 64k rows
21/// of inflight buffering, which is enough to absorb short stalls without
22/// imposing visible backpressure on typical CDC / streaming-SQL workloads.
23pub const CONTINUOUS_TABLE_CHANNEL_CAPACITY: usize = 64;
24
25/// Errors returned by a continuous table producer.
26#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
27pub enum ContinuousInputError {
28    /// Submitted batch schema does not match the registered table schema.
29    #[error("continuous table batch schema mismatch: expected {expected}, got {actual}")]
30    SchemaMismatch { expected: String, actual: String },
31    /// The bounded producer queue has no remaining capacity.
32    #[error("continuous table input queue is full")]
33    QueueFull,
34    /// The producer was explicitly closed or its consumer was dropped.
35    #[error("continuous table input is closed")]
36    Closed,
37    /// Internal producer state was poisoned by a panic while locked.
38    #[error("continuous table input lock is poisoned: {0}")]
39    LockPoisoned(String),
40}
41
42/// A partition stream that reads from an MPSC channel.
43pub struct ChannelPartitionStream {
44    schema: SchemaRef,
45    receiver: AsyncMutex<Option<mpsc::Receiver<RecordBatch>>>,
46}
47
48impl fmt::Debug for ChannelPartitionStream {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("ChannelPartitionStream")
51            .field("schema", &self.schema)
52            .finish()
53    }
54}
55
56impl ChannelPartitionStream {
57    pub fn new(schema: SchemaRef, receiver: mpsc::Receiver<RecordBatch>) -> Self {
58        Self {
59            schema,
60            receiver: AsyncMutex::new(Some(receiver)),
61        }
62    }
63
64    fn error_stream(&self, message: impl Into<String>) -> SendableRecordBatchStream {
65        let message = message.into();
66        let stream = futures::stream::once(async move { Err(DataFusionError::Execution(message)) });
67        Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
68    }
69}
70
71impl PartitionStream for ChannelPartitionStream {
72    fn schema(&self) -> &SchemaRef {
73        &self.schema
74    }
75
76    fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
77        let mut rx_guard = match self.receiver.try_lock() {
78            Ok(guard) => guard,
79            Err(_) => {
80                return self.error_stream(
81                    "continuous table partition is already executing in another query",
82                );
83            }
84        };
85        let Some(rx) = rx_guard.take() else {
86            return self.error_stream(
87                "continuous table partition has already been consumed by another query",
88            );
89        };
90
91        let stream = ReceiverStream::new(rx).map(Ok::<RecordBatch, DataFusionError>);
92        Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
93    }
94}
95
96/// Schema-bound producer handle for one continuous SQL table.
97pub struct ContinuousTableInput {
98    schema: SchemaRef,
99    sender: StdMutex<Option<mpsc::Sender<RecordBatch>>>,
100}
101
102impl fmt::Debug for ContinuousTableInput {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        f.debug_struct("ContinuousTableInput")
105            .field("schema", &self.schema)
106            .field("closed", &self.is_closed().ok())
107            .finish()
108    }
109}
110
111impl ContinuousTableInput {
112    fn new(schema: SchemaRef, sender: mpsc::Sender<RecordBatch>) -> Self {
113        Self {
114            schema,
115            sender: StdMutex::new(Some(sender)),
116        }
117    }
118
119    /// Expected Arrow schema for every submitted batch.
120    pub fn schema(&self) -> &SchemaRef {
121        &self.schema
122    }
123
124    /// Submit a batch without waiting for queue capacity.
125    pub fn try_send(&self, batch: RecordBatch) -> Result<(), ContinuousInputError> {
126        self.validate_schema(&batch)?;
127        let sender = self.sender_clone()?;
128        sender.try_send(batch).map_err(|error| match error {
129            mpsc::error::TrySendError::Full(_) => ContinuousInputError::QueueFull,
130            mpsc::error::TrySendError::Closed(_) => ContinuousInputError::Closed,
131        })
132    }
133
134    /// Submit a batch, asynchronously waiting for queue capacity.
135    pub async fn send(&self, batch: RecordBatch) -> Result<(), ContinuousInputError> {
136        self.validate_schema(&batch)?;
137        self.sender_clone()?
138            .send(batch)
139            .await
140            .map_err(|_| ContinuousInputError::Closed)
141    }
142
143    /// Close the input. The consumer observes end-of-stream after queued data.
144    ///
145    /// Returns `true` when this call closed an open input and `false` when it
146    /// was already closed.
147    pub fn close(&self) -> Result<bool, ContinuousInputError> {
148        let mut sender = self
149            .sender
150            .lock()
151            .map_err(|error| ContinuousInputError::LockPoisoned(error.to_string()))?;
152        Ok(sender.take().is_some())
153    }
154
155    /// A-8: hard-cancel the stream. Drops any queued batches in the channel
156    /// before closing it, so the consumer sees an immediate end-of-stream
157    /// without flushing. Idempotent.
158    pub fn cancel(&self) {
159        // Take the sender and drop it; the receiver end of an mpsc channel
160        // returns `None` once all senders are gone, which our SQL consumer
161        // surfaces as end-of-stream. Dropping the sender also drops any
162        // batches still buffered in the channel — they are lost.
163        if let Ok(mut sender) = self.sender.lock() {
164            sender.take();
165        }
166    }
167
168    /// Whether the producer side has been closed.
169    pub fn is_closed(&self) -> Result<bool, ContinuousInputError> {
170        self.sender
171            .lock()
172            .map(|sender| sender.is_none())
173            .map_err(|error| ContinuousInputError::LockPoisoned(error.to_string()))
174    }
175
176    fn sender_clone(&self) -> Result<mpsc::Sender<RecordBatch>, ContinuousInputError> {
177        self.sender
178            .lock()
179            .map_err(|error| ContinuousInputError::LockPoisoned(error.to_string()))?
180            .clone()
181            .ok_or(ContinuousInputError::Closed)
182    }
183
184    fn validate_schema(&self, batch: &RecordBatch) -> Result<(), ContinuousInputError> {
185        if batch.schema().as_ref() != self.schema.as_ref() {
186            return Err(ContinuousInputError::SchemaMismatch {
187                expected: format!("{:?}", self.schema),
188                actual: format!("{:?}", batch.schema()),
189            });
190        }
191        Ok(())
192    }
193}
194
195/// Creates a new continuous-table provider and its schema-bound producer.
196/// The channel is bounded (capacity
197/// `CONTINUOUS_TABLE_CHANNEL_CAPACITY`) so a slow DataFusion consumer
198/// applies backpressure via [`ContinuousTableInput::send`], or
199/// [`ContinuousTableInput::try_send`] returns a resource-exhausted error.
200pub fn create_continuous_table(
201    schema: SchemaRef,
202) -> datafusion::error::Result<(Arc<dyn TableProvider>, Arc<ContinuousTableInput>)> {
203    create_continuous_table_with_capacity(schema, CONTINUOUS_TABLE_CHANNEL_CAPACITY)
204}
205
206/// Same as [`create_continuous_table`] but with a caller-supplied
207/// capacity. Useful for tests that want to exercise the full/empty
208/// channel boundary without needing to push 64 batches.
209pub fn create_continuous_table_with_capacity(
210    schema: SchemaRef,
211    capacity: usize,
212) -> datafusion::error::Result<(Arc<dyn TableProvider>, Arc<ContinuousTableInput>)> {
213    let (tx, rx) = mpsc::channel(capacity.max(1));
214    let partition = Arc::new(ChannelPartitionStream::new(schema.clone(), rx));
215    let table = StreamingTable::try_new(schema.clone(), vec![partition])?;
216    Ok((
217        Arc::new(table),
218        Arc::new(ContinuousTableInput::new(schema, tx)),
219    ))
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use arrow::array::Int32Array;
226    use arrow::datatypes::{DataType, Field, Schema};
227    use std::sync::Arc;
228
229    fn make_schema() -> SchemaRef {
230        Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]))
231    }
232
233    fn make_batch(values: Vec<i32>) -> RecordBatch {
234        RecordBatch::try_new(make_schema(), vec![Arc::new(Int32Array::from(values))]).unwrap()
235    }
236
237    #[tokio::test]
238    async fn create_continuous_table_with_capacity_zero_is_clamped_to_one() {
239        let schema = make_schema();
240        let (table, tx) = create_continuous_table_with_capacity(schema, 0).unwrap();
241        // Capacity 0 is clamped to 1: a `mpsc::channel(0)` would deadlock
242        // the sender before the receiver is even polled. The clamp is
243        // documented in `create_continuous_table_with_capacity`.
244        tx.try_send(make_batch(vec![1]))
245            .expect("capacity should be >= 1");
246        // The second try_send should fail with Full, not deadlock.
247        assert!(tx.try_send(make_batch(vec![2])).is_err());
248        drop(table);
249    }
250
251    #[tokio::test]
252    async fn bounded_channel_rejects_oversized_queue_via_try_send() {
253        let schema = make_schema();
254        let (table, tx) = create_continuous_table_with_capacity(schema, 2).unwrap();
255        // Fill to capacity (DataFusion does not pull until execute is
256        // called by the query plan). try_send must return Full once full.
257        assert!(tx.try_send(make_batch(vec![1])).is_ok());
258        assert!(tx.try_send(make_batch(vec![2])).is_ok());
259        let third = tx.try_send(make_batch(vec![3]));
260        assert!(
261            matches!(third, Err(ContinuousInputError::QueueFull)),
262            "expected Full, got {third:?}"
263        );
264        drop(table);
265    }
266
267    #[tokio::test]
268    async fn continuous_input_rejects_schema_mismatch_and_close_is_idempotent() {
269        let (table, input) = create_continuous_table(make_schema()).unwrap();
270        let wrong_schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
271        let wrong_batch = RecordBatch::try_new(
272            wrong_schema,
273            vec![Arc::new(arrow::array::Int64Array::from(vec![1]))],
274        )
275        .unwrap();
276
277        let error = input
278            .try_send(wrong_batch)
279            .expect_err("schema mismatch must fail");
280        assert!(matches!(error, ContinuousInputError::SchemaMismatch { .. }));
281        assert!(input.close().unwrap());
282        assert!(!input.close().unwrap());
283        assert!(input.is_closed().unwrap());
284        assert!(matches!(
285            input.try_send(make_batch(vec![1])),
286            Err(ContinuousInputError::Closed)
287        ));
288        drop(table);
289    }
290}