Skip to main content

laminar_sql/datafusion/
channel_source.rs

1//! Channel-based streaming source implementation
2//!
3//! This module provides `ChannelStreamSource`, the primary integration point
4//! between `LaminarDB`'s Reactor and `DataFusion`'s query engine.
5
6use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use arrow_array::RecordBatch;
12use arrow_schema::SchemaRef;
13use async_trait::async_trait;
14use datafusion::physical_plan::RecordBatchStream;
15use datafusion_common::DataFusionError;
16use datafusion_expr::Expr;
17use futures::Stream;
18use parking_lot::Mutex;
19
20use super::bridge::{BridgeSender, StreamBridge};
21use super::source::{SortColumn, StreamSource};
22
23/// Default channel capacity for the stream source.
24const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
25
26/// A streaming source that receives data through a channel.
27///
28/// This is the primary integration point between `LaminarDB`'s push-based
29/// Reactor and `DataFusion`'s pull-based query execution. Data is pushed
30/// into the source via `BridgeSender`, and `DataFusion` pulls it through
31/// the stream.
32///
33/// # Important Usage Pattern
34///
35/// The sender must be taken (not cloned) to ensure proper channel closure:
36///
37/// ```rust,ignore
38/// // Create the source and take the sender
39/// let source = ChannelStreamSource::new(schema);
40/// let sender = source.take_sender().expect("sender available");
41///
42/// // Register with `DataFusion`
43/// let provider = StreamingTableProvider::new("events", Arc::new(source));
44/// ctx.register_table("events", Arc::new(provider))?;
45///
46/// // Push data from Reactor
47/// sender.send(batch).await?;
48///
49/// // IMPORTANT: Drop the sender to close the channel before querying
50/// drop(sender);
51///
52/// // Execute query
53/// let df = ctx.sql("SELECT * FROM events").await?;
54/// let results = df.collect().await?;
55/// ```
56///
57/// # Thread Safety
58///
59/// The source is thread-safe and can be shared across threads. The sender
60/// can be cloned after being taken to allow multiple producers.
61pub struct ChannelStreamSource {
62    /// Schema of the data
63    schema: SchemaRef,
64    /// The bridge connecting sender and receivers
65    bridge: Mutex<Option<StreamBridge>>,
66    /// Sender for pushing data - must be taken, not cloned
67    sender: Mutex<Option<BridgeSender>>,
68    /// Channel capacity
69    capacity: usize,
70    /// Declared output ordering (for ORDER BY elision)
71    ordering: Option<Vec<SortColumn>>,
72}
73
74impl ChannelStreamSource {
75    /// Creates a new channel stream source with default capacity.
76    #[must_use]
77    pub fn new(schema: SchemaRef) -> Self {
78        Self::with_capacity(schema, DEFAULT_CHANNEL_CAPACITY)
79    }
80
81    /// Creates a new channel stream source with the given channel capacity.
82    #[must_use]
83    pub fn with_capacity(schema: SchemaRef, capacity: usize) -> Self {
84        let bridge = StreamBridge::new(Arc::clone(&schema), capacity);
85        let sender = bridge.sender();
86        Self {
87            schema,
88            bridge: Mutex::new(Some(bridge)),
89            sender: Mutex::new(Some(sender)),
90            capacity,
91            ordering: None,
92        }
93    }
94
95    /// Declares that this source produces data in the given sort order.
96    /// When set, `DataFusion` can elide `SortExec` for ORDER BY queries
97    /// that match the declared ordering.
98    #[must_use]
99    pub fn with_ordering(mut self, ordering: Vec<SortColumn>) -> Self {
100        self.ordering = Some(ordering);
101        self
102    }
103
104    /// Takes the sender for pushing batches into this source.
105    ///
106    /// This method can only be called once. The sender is moved out of
107    /// the source to ensure the caller has full ownership and can close
108    /// the channel by dropping the sender.
109    ///
110    /// The returned sender can be cloned to allow multiple producers.
111    ///
112    /// Returns `None` if the sender was already taken.
113    #[must_use]
114    pub fn take_sender(&self) -> Option<BridgeSender> {
115        self.sender.lock().take()
116    }
117
118    /// Returns a clone of the sender if it hasn't been taken yet.
119    ///
120    /// **Warning**: Using this method can lead to channel leak issues if
121    /// the original sender is never dropped. Prefer `take_sender()` for
122    /// proper channel lifecycle management.
123    #[must_use]
124    pub fn sender(&self) -> Option<BridgeSender> {
125        self.sender.lock().as_ref().map(BridgeSender::clone)
126    }
127
128    /// Resets the source with a new bridge and sender.
129    ///
130    /// This is useful when you need to reuse the source after the previous
131    /// stream has been consumed. Any data sent before the reset but not
132    /// yet consumed will be lost.
133    ///
134    /// Returns the new sender.
135    pub fn reset(&self) -> BridgeSender {
136        let bridge = StreamBridge::new(Arc::clone(&self.schema), self.capacity);
137        let sender = bridge.sender();
138        *self.bridge.lock() = Some(bridge);
139        *self.sender.lock() = Some(sender.clone());
140        sender
141    }
142}
143
144impl Debug for ChannelStreamSource {
145    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
146        f.debug_struct("ChannelStreamSource")
147            .field("schema", &self.schema)
148            .field("capacity", &self.capacity)
149            .finish_non_exhaustive()
150    }
151}
152
153#[async_trait]
154impl StreamSource for ChannelStreamSource {
155    fn schema(&self) -> SchemaRef {
156        Arc::clone(&self.schema)
157    }
158
159    fn output_ordering(&self) -> Option<Vec<SortColumn>> {
160        self.ordering.clone()
161    }
162
163    fn stream(
164        &self,
165        projection: Option<Vec<usize>>,
166        _filters: Vec<Expr>,
167    ) -> Result<datafusion::physical_plan::SendableRecordBatchStream, DataFusionError> {
168        let mut bridge_guard = self.bridge.lock();
169        let bridge = bridge_guard.take().ok_or_else(|| {
170            DataFusionError::Execution(
171                "Stream already taken; call reset() to create a new bridge".to_string(),
172            )
173        })?;
174
175        let inner_stream = bridge.into_stream();
176
177        // Apply projection if specified
178        let stream: datafusion::physical_plan::SendableRecordBatchStream =
179            if let Some(indices) = projection {
180                let projected_schema = {
181                    let fields: Vec<_> = indices
182                        .iter()
183                        .map(|&i| self.schema.field(i).clone())
184                        .collect();
185                    Arc::new(arrow_schema::Schema::new(fields))
186                };
187                Box::pin(ProjectingStream::new(
188                    inner_stream,
189                    projected_schema,
190                    indices,
191                ))
192            } else {
193                Box::pin(inner_stream)
194            };
195
196        Ok(stream)
197    }
198}
199
200/// A stream that applies column projection to record batches.
201struct ProjectingStream<S> {
202    inner: S,
203    schema: SchemaRef,
204    indices: Vec<usize>,
205}
206
207impl<S> ProjectingStream<S> {
208    fn new(inner: S, schema: SchemaRef, indices: Vec<usize>) -> Self {
209        Self {
210            inner,
211            schema,
212            indices,
213        }
214    }
215}
216
217impl<S> Debug for ProjectingStream<S> {
218    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
219        f.debug_struct("ProjectingStream")
220            .field("schema", &self.schema)
221            .field("indices", &self.indices)
222            .finish_non_exhaustive()
223    }
224}
225
226impl<S> Stream for ProjectingStream<S>
227where
228    S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
229{
230    type Item = Result<RecordBatch, DataFusionError>;
231
232    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233        match Pin::new(&mut self.inner).poll_next(cx) {
234            Poll::Ready(Some(Ok(batch))) => {
235                // Project columns using built-in projection (avoids intermediate Vec alloc)
236                let projected = batch.project(&self.indices).map_err(|e| {
237                    DataFusionError::ArrowError(Box::new(e), Some("projection failed".to_string()))
238                });
239                Poll::Ready(Some(projected))
240            }
241            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
242            Poll::Ready(None) => Poll::Ready(None),
243            Poll::Pending => Poll::Pending,
244        }
245    }
246}
247
248impl<S> RecordBatchStream for ProjectingStream<S>
249where
250    S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
251{
252    fn schema(&self) -> SchemaRef {
253        Arc::clone(&self.schema)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use arrow_array::Int64Array;
261    use arrow_schema::{DataType, Field, Schema};
262    use futures::StreamExt;
263
264    fn test_schema() -> SchemaRef {
265        Arc::new(Schema::new(vec![
266            Field::new("id", DataType::Int64, false),
267            Field::new("value", DataType::Int64, false),
268        ]))
269    }
270
271    fn test_batch(schema: &SchemaRef, ids: Vec<i64>, values: Vec<i64>) -> RecordBatch {
272        RecordBatch::try_new(
273            Arc::clone(schema),
274            vec![
275                Arc::new(Int64Array::from(ids)),
276                Arc::new(Int64Array::from(values)),
277            ],
278        )
279        .unwrap()
280    }
281
282    #[test]
283    fn test_channel_source_schema() {
284        let schema = test_schema();
285        let source = ChannelStreamSource::new(Arc::clone(&schema));
286
287        assert_eq!(source.schema(), schema);
288    }
289
290    #[tokio::test]
291    async fn test_channel_source_stream() {
292        let schema = test_schema();
293        let source = ChannelStreamSource::new(Arc::clone(&schema));
294        let sender = source.take_sender().unwrap();
295
296        let mut stream = source.stream(None, vec![]).unwrap();
297
298        // Send data
299        sender
300            .send(test_batch(&schema, vec![1, 2], vec![10, 20]))
301            .await
302            .unwrap();
303        drop(sender);
304
305        // Receive data
306        let batch = stream.next().await.unwrap().unwrap();
307        assert_eq!(batch.num_rows(), 2);
308        assert_eq!(batch.num_columns(), 2);
309    }
310
311    #[tokio::test]
312    async fn test_channel_source_projection() {
313        let schema = test_schema();
314        let source = ChannelStreamSource::new(Arc::clone(&schema));
315        let sender = source.take_sender().unwrap();
316
317        // Project only the "value" column (index 1)
318        let mut stream = source.stream(Some(vec![1]), vec![]).unwrap();
319
320        sender
321            .send(test_batch(&schema, vec![1, 2], vec![100, 200]))
322            .await
323            .unwrap();
324        drop(sender);
325
326        let batch = stream.next().await.unwrap().unwrap();
327        assert_eq!(batch.num_columns(), 1);
328        assert_eq!(batch.schema().field(0).name(), "value");
329
330        let values = batch
331            .column(0)
332            .as_any()
333            .downcast_ref::<Int64Array>()
334            .unwrap();
335        assert_eq!(values.value(0), 100);
336        assert_eq!(values.value(1), 200);
337    }
338
339    #[tokio::test]
340    async fn test_channel_source_stream_already_taken() {
341        let schema = test_schema();
342        let source = ChannelStreamSource::new(Arc::clone(&schema));
343
344        // First stream takes ownership
345        let _stream = source.stream(None, vec![]).unwrap();
346
347        // Second stream should fail
348        let result = source.stream(None, vec![]);
349        assert!(result.is_err());
350    }
351
352    #[tokio::test]
353    async fn test_channel_source_multiple_batches() {
354        let schema = test_schema();
355        let source = ChannelStreamSource::new(Arc::clone(&schema));
356        let sender = source.take_sender().unwrap();
357        let mut stream = source.stream(None, vec![]).unwrap();
358
359        // Send multiple batches
360        for i in 0..5i64 {
361            sender
362                .send(test_batch(&schema, vec![i], vec![i * 10]))
363                .await
364                .unwrap();
365        }
366        drop(sender);
367
368        // Receive all batches
369        let mut count = 0;
370        while let Some(result) = stream.next().await {
371            result.unwrap();
372            count += 1;
373        }
374        assert_eq!(count, 5);
375    }
376
377    #[tokio::test]
378    async fn test_channel_source_take_sender_once() {
379        let schema = test_schema();
380        let source = ChannelStreamSource::new(Arc::clone(&schema));
381
382        // First take succeeds
383        let sender = source.take_sender();
384        assert!(sender.is_some());
385
386        // Second take returns None
387        let sender2 = source.take_sender();
388        assert!(sender2.is_none());
389    }
390
391    #[tokio::test]
392    async fn test_channel_source_reset() {
393        let schema = test_schema();
394        let source = ChannelStreamSource::new(Arc::clone(&schema));
395
396        // Take sender and stream
397        let _sender = source.take_sender().unwrap();
398        let _stream = source.stream(None, vec![]).unwrap();
399
400        // Reset creates new bridge and sender
401        let new_sender = source.reset();
402        let mut new_stream = source.stream(None, vec![]).unwrap();
403
404        // Can use the new sender and stream
405        new_sender
406            .send(test_batch(&schema, vec![1], vec![10]))
407            .await
408            .unwrap();
409        drop(new_sender);
410
411        let batch = new_stream.next().await.unwrap().unwrap();
412        assert_eq!(batch.num_rows(), 1);
413    }
414
415    #[test]
416    fn test_channel_source_debug() {
417        let schema = test_schema();
418        let source = ChannelStreamSource::new(Arc::clone(&schema));
419
420        let debug_str = format!("{source:?}");
421        assert!(debug_str.contains("ChannelStreamSource"));
422        assert!(debug_str.contains("capacity"));
423    }
424
425    #[test]
426    fn test_channel_source_default_no_ordering() {
427        let schema = test_schema();
428        let source = ChannelStreamSource::new(Arc::clone(&schema));
429
430        assert!(source.output_ordering().is_none());
431    }
432
433    #[test]
434    fn test_channel_source_with_ordering() {
435        let schema = test_schema();
436        let source = ChannelStreamSource::new(Arc::clone(&schema))
437            .with_ordering(vec![SortColumn::ascending("id")]);
438
439        let ordering = source.output_ordering();
440        assert!(ordering.is_some());
441        let cols = ordering.unwrap();
442        assert_eq!(cols.len(), 1);
443        assert_eq!(cols[0].name, "id");
444        assert!(!cols[0].descending);
445    }
446}