Skip to main content

laminar_sql/datafusion/
channel_source.rs

1//! Channel-based streaming source implementation
2//!
3//! This module provides `ChannelStreamSource`, the primary integration point
4//! between `LaminarDB`'s Reactor and `DataFusion`'s query engine.
5
6use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use arrow_array::RecordBatch;
12use arrow_schema::SchemaRef;
13use async_trait::async_trait;
14use datafusion::physical_plan::RecordBatchStream;
15use datafusion_common::DataFusionError;
16use datafusion_expr::Expr;
17use futures::Stream;
18use parking_lot::Mutex;
19
20use super::bridge::{BridgeSender, StreamBridge};
21use super::source::{SortColumn, StreamSource};
22
23/// Default channel capacity for the stream source.
24const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
25
26/// A streaming source that receives data through a channel.
27///
28/// This is the primary integration point between `LaminarDB`'s push-based
29/// Reactor and `DataFusion`'s pull-based query execution. Data is pushed
30/// into the source via `BridgeSender`, and `DataFusion` pulls it through
31/// the stream.
32///
33/// # Important Usage Pattern
34///
35/// The sender must be taken (not cloned) to ensure proper channel closure:
36///
37/// ```rust,ignore
38/// // Create the source and take the sender
39/// let source = ChannelStreamSource::new(schema);
40/// let sender = source.take_sender().expect("sender available");
41///
42/// // Register with `DataFusion`
43/// let provider = StreamingTableProvider::new("events", Arc::new(source));
44/// ctx.register_table("events", Arc::new(provider))?;
45///
46/// // Push data from Reactor
47/// sender.send(batch).await?;
48///
49/// // IMPORTANT: Drop the sender to close the channel before querying
50/// drop(sender);
51///
52/// // Execute query
53/// let df = ctx.sql("SELECT * FROM events").await?;
54/// let results = df.collect().await?;
55/// ```
56///
57/// # Thread Safety
58///
59/// The source is thread-safe and can be shared across threads. The sender
60/// can be cloned after being taken to allow multiple producers.
61pub struct ChannelStreamSource {
62    /// Schema of the data
63    schema: SchemaRef,
64    /// The bridge connecting sender and receivers
65    bridge: Mutex<Option<StreamBridge>>,
66    /// Sender for pushing data - must be taken, not cloned
67    sender: Mutex<Option<BridgeSender>>,
68    /// Channel capacity
69    capacity: usize,
70    /// Declared output ordering (for ORDER BY elision)
71    ordering: Option<Vec<SortColumn>>,
72}
73
74impl ChannelStreamSource {
75    /// Creates a new channel stream source with default capacity.
76    ///
77    /// # Arguments
78    ///
79    /// * `schema` - Schema of the `RecordBatch` instances that will be pushed
80    #[must_use]
81    pub fn new(schema: SchemaRef) -> Self {
82        Self::with_capacity(schema, DEFAULT_CHANNEL_CAPACITY)
83    }
84
85    /// Creates a new channel stream source with specified capacity.
86    ///
87    /// # Arguments
88    ///
89    /// * `schema` - Schema of the `RecordBatch` instances that will be pushed
90    /// * `capacity` - Maximum number of batches that can be buffered
91    #[must_use]
92    pub fn with_capacity(schema: SchemaRef, capacity: usize) -> Self {
93        let bridge = StreamBridge::new(Arc::clone(&schema), capacity);
94        let sender = bridge.sender();
95        Self {
96            schema,
97            bridge: Mutex::new(Some(bridge)),
98            sender: Mutex::new(Some(sender)),
99            capacity,
100            ordering: None,
101        }
102    }
103
104    /// Declares that this source produces data in the given sort order.
105    ///
106    /// When set, `DataFusion` can elide `SortExec` for ORDER BY queries
107    /// that match the declared ordering.
108    ///
109    /// # Arguments
110    ///
111    /// * `ordering` - The columns that the output is sorted by
112    #[must_use]
113    pub fn with_ordering(mut self, ordering: Vec<SortColumn>) -> Self {
114        self.ordering = Some(ordering);
115        self
116    }
117
118    /// Takes the sender for pushing batches into this source.
119    ///
120    /// This method can only be called once. The sender is moved out of
121    /// the source to ensure the caller has full ownership and can close
122    /// the channel by dropping the sender.
123    ///
124    /// The returned sender can be cloned to allow multiple producers.
125    ///
126    /// Returns `None` if the sender was already taken.
127    #[must_use]
128    pub fn take_sender(&self) -> Option<BridgeSender> {
129        self.sender.lock().take()
130    }
131
132    /// Returns a clone of the sender if it hasn't been taken yet.
133    ///
134    /// **Warning**: Using this method can lead to channel leak issues if
135    /// the original sender is never dropped. Prefer `take_sender()` for
136    /// proper channel lifecycle management.
137    #[must_use]
138    pub fn sender(&self) -> Option<BridgeSender> {
139        self.sender.lock().as_ref().map(BridgeSender::clone)
140    }
141
142    /// Resets the source with a new bridge and sender.
143    ///
144    /// This is useful when you need to reuse the source after the previous
145    /// stream has been consumed. Any data sent before the reset but not
146    /// yet consumed will be lost.
147    ///
148    /// Returns the new sender.
149    pub fn reset(&self) -> BridgeSender {
150        let bridge = StreamBridge::new(Arc::clone(&self.schema), self.capacity);
151        let sender = bridge.sender();
152        *self.bridge.lock() = Some(bridge);
153        *self.sender.lock() = Some(sender.clone());
154        sender
155    }
156}
157
158impl Debug for ChannelStreamSource {
159    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct("ChannelStreamSource")
161            .field("schema", &self.schema)
162            .field("capacity", &self.capacity)
163            .finish_non_exhaustive()
164    }
165}
166
167#[async_trait]
168impl StreamSource for ChannelStreamSource {
169    fn schema(&self) -> SchemaRef {
170        Arc::clone(&self.schema)
171    }
172
173    fn output_ordering(&self) -> Option<Vec<SortColumn>> {
174        self.ordering.clone()
175    }
176
177    fn stream(
178        &self,
179        projection: Option<Vec<usize>>,
180        _filters: Vec<Expr>,
181    ) -> Result<datafusion::physical_plan::SendableRecordBatchStream, DataFusionError> {
182        let mut bridge_guard = self.bridge.lock();
183        let bridge = bridge_guard.take().ok_or_else(|| {
184            DataFusionError::Execution(
185                "Stream already taken; call reset() to create a new bridge".to_string(),
186            )
187        })?;
188
189        let inner_stream = bridge.into_stream();
190
191        // Apply projection if specified
192        let stream: datafusion::physical_plan::SendableRecordBatchStream =
193            if let Some(indices) = projection {
194                let projected_schema = {
195                    let fields: Vec<_> = indices
196                        .iter()
197                        .map(|&i| self.schema.field(i).clone())
198                        .collect();
199                    Arc::new(arrow_schema::Schema::new(fields))
200                };
201                Box::pin(ProjectingStream::new(
202                    inner_stream,
203                    projected_schema,
204                    indices,
205                ))
206            } else {
207                Box::pin(inner_stream)
208            };
209
210        Ok(stream)
211    }
212}
213
214/// A stream that applies column projection to record batches.
215struct ProjectingStream<S> {
216    inner: S,
217    schema: SchemaRef,
218    indices: Vec<usize>,
219}
220
221impl<S> ProjectingStream<S> {
222    fn new(inner: S, schema: SchemaRef, indices: Vec<usize>) -> Self {
223        Self {
224            inner,
225            schema,
226            indices,
227        }
228    }
229}
230
231impl<S> Debug for ProjectingStream<S> {
232    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233        f.debug_struct("ProjectingStream")
234            .field("schema", &self.schema)
235            .field("indices", &self.indices)
236            .finish_non_exhaustive()
237    }
238}
239
240impl<S> Stream for ProjectingStream<S>
241where
242    S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
243{
244    type Item = Result<RecordBatch, DataFusionError>;
245
246    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
247        match Pin::new(&mut self.inner).poll_next(cx) {
248            Poll::Ready(Some(Ok(batch))) => {
249                // Project columns
250                let columns: Vec<_> = self
251                    .indices
252                    .iter()
253                    .map(|&i| Arc::clone(batch.column(i)))
254                    .collect();
255                let projected =
256                    RecordBatch::try_new(Arc::clone(&self.schema), columns).map_err(|e| {
257                        DataFusionError::ArrowError(
258                            Box::new(e),
259                            Some("projection failed".to_string()),
260                        )
261                    });
262                Poll::Ready(Some(projected))
263            }
264            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
265            Poll::Ready(None) => Poll::Ready(None),
266            Poll::Pending => Poll::Pending,
267        }
268    }
269}
270
271impl<S> RecordBatchStream for ProjectingStream<S>
272where
273    S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
274{
275    fn schema(&self) -> SchemaRef {
276        Arc::clone(&self.schema)
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use arrow_array::Int64Array;
284    use arrow_schema::{DataType, Field, Schema};
285    use futures::StreamExt;
286
287    fn test_schema() -> SchemaRef {
288        Arc::new(Schema::new(vec![
289            Field::new("id", DataType::Int64, false),
290            Field::new("value", DataType::Int64, false),
291        ]))
292    }
293
294    fn test_batch(schema: &SchemaRef, ids: Vec<i64>, values: Vec<i64>) -> RecordBatch {
295        RecordBatch::try_new(
296            Arc::clone(schema),
297            vec![
298                Arc::new(Int64Array::from(ids)),
299                Arc::new(Int64Array::from(values)),
300            ],
301        )
302        .unwrap()
303    }
304
305    #[test]
306    fn test_channel_source_schema() {
307        let schema = test_schema();
308        let source = ChannelStreamSource::new(Arc::clone(&schema));
309
310        assert_eq!(source.schema(), schema);
311    }
312
313    #[tokio::test]
314    async fn test_channel_source_stream() {
315        let schema = test_schema();
316        let source = ChannelStreamSource::new(Arc::clone(&schema));
317        let sender = source.take_sender().unwrap();
318
319        let mut stream = source.stream(None, vec![]).unwrap();
320
321        // Send data
322        sender
323            .send(test_batch(&schema, vec![1, 2], vec![10, 20]))
324            .await
325            .unwrap();
326        drop(sender);
327
328        // Receive data
329        let batch = stream.next().await.unwrap().unwrap();
330        assert_eq!(batch.num_rows(), 2);
331        assert_eq!(batch.num_columns(), 2);
332    }
333
334    #[tokio::test]
335    async fn test_channel_source_projection() {
336        let schema = test_schema();
337        let source = ChannelStreamSource::new(Arc::clone(&schema));
338        let sender = source.take_sender().unwrap();
339
340        // Project only the "value" column (index 1)
341        let mut stream = source.stream(Some(vec![1]), vec![]).unwrap();
342
343        sender
344            .send(test_batch(&schema, vec![1, 2], vec![100, 200]))
345            .await
346            .unwrap();
347        drop(sender);
348
349        let batch = stream.next().await.unwrap().unwrap();
350        assert_eq!(batch.num_columns(), 1);
351        assert_eq!(batch.schema().field(0).name(), "value");
352
353        let values = batch
354            .column(0)
355            .as_any()
356            .downcast_ref::<Int64Array>()
357            .unwrap();
358        assert_eq!(values.value(0), 100);
359        assert_eq!(values.value(1), 200);
360    }
361
362    #[tokio::test]
363    async fn test_channel_source_stream_already_taken() {
364        let schema = test_schema();
365        let source = ChannelStreamSource::new(Arc::clone(&schema));
366
367        // First stream takes ownership
368        let _stream = source.stream(None, vec![]).unwrap();
369
370        // Second stream should fail
371        let result = source.stream(None, vec![]);
372        assert!(result.is_err());
373    }
374
375    #[tokio::test]
376    async fn test_channel_source_multiple_batches() {
377        let schema = test_schema();
378        let source = ChannelStreamSource::new(Arc::clone(&schema));
379        let sender = source.take_sender().unwrap();
380        let mut stream = source.stream(None, vec![]).unwrap();
381
382        // Send multiple batches
383        for i in 0..5i64 {
384            sender
385                .send(test_batch(&schema, vec![i], vec![i * 10]))
386                .await
387                .unwrap();
388        }
389        drop(sender);
390
391        // Receive all batches
392        let mut count = 0;
393        while let Some(result) = stream.next().await {
394            result.unwrap();
395            count += 1;
396        }
397        assert_eq!(count, 5);
398    }
399
400    #[tokio::test]
401    async fn test_channel_source_take_sender_once() {
402        let schema = test_schema();
403        let source = ChannelStreamSource::new(Arc::clone(&schema));
404
405        // First take succeeds
406        let sender = source.take_sender();
407        assert!(sender.is_some());
408
409        // Second take returns None
410        let sender2 = source.take_sender();
411        assert!(sender2.is_none());
412    }
413
414    #[tokio::test]
415    async fn test_channel_source_reset() {
416        let schema = test_schema();
417        let source = ChannelStreamSource::new(Arc::clone(&schema));
418
419        // Take sender and stream
420        let _sender = source.take_sender().unwrap();
421        let _stream = source.stream(None, vec![]).unwrap();
422
423        // Reset creates new bridge and sender
424        let new_sender = source.reset();
425        let mut new_stream = source.stream(None, vec![]).unwrap();
426
427        // Can use the new sender and stream
428        new_sender
429            .send(test_batch(&schema, vec![1], vec![10]))
430            .await
431            .unwrap();
432        drop(new_sender);
433
434        let batch = new_stream.next().await.unwrap().unwrap();
435        assert_eq!(batch.num_rows(), 1);
436    }
437
438    #[test]
439    fn test_channel_source_debug() {
440        let schema = test_schema();
441        let source = ChannelStreamSource::new(Arc::clone(&schema));
442
443        let debug_str = format!("{source:?}");
444        assert!(debug_str.contains("ChannelStreamSource"));
445        assert!(debug_str.contains("capacity"));
446    }
447
448    #[test]
449    fn test_channel_source_default_no_ordering() {
450        let schema = test_schema();
451        let source = ChannelStreamSource::new(Arc::clone(&schema));
452
453        assert!(source.output_ordering().is_none());
454    }
455
456    #[test]
457    fn test_channel_source_with_ordering() {
458        let schema = test_schema();
459        let source = ChannelStreamSource::new(Arc::clone(&schema))
460            .with_ordering(vec![SortColumn::ascending("id")]);
461
462        let ordering = source.output_ordering();
463        assert!(ordering.is_some());
464        let cols = ordering.unwrap();
465        assert_eq!(cols.len(), 1);
466        assert_eq!(cols[0].name, "id");
467        assert!(!cols[0].descending);
468    }
469}