Skip to main content

laminar_sql/datafusion/
bridge.rs

1//! Bridges `LaminarDB`'s push-based reactor with `DataFusion`'s pull-based
2//! query execution via a crossfire mpsc channel wrapped as a `RecordBatchStream`.
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use arrow_array::RecordBatch;
9use arrow_schema::SchemaRef;
10use crossfire::stream::AsyncStream;
11use crossfire::{mpsc, AsyncRx, MAsyncTx, TrySendError};
12use datafusion::physical_plan::RecordBatchStream;
13use datafusion_common::DataFusionError;
14use futures::Stream;
15
16/// Default channel capacity for the bridge.
17const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
18
19/// Push-to-pull bridge carrying `RecordBatch` results from the reactor
20/// into a `DataFusion` query execution plan.
21#[derive(Debug)]
22pub struct StreamBridge {
23    schema: SchemaRef,
24    sender: BridgeSender,
25    receiver: Option<AsyncRx<mpsc::Array<Result<RecordBatch, DataFusionError>>>>,
26}
27
28impl StreamBridge {
29    /// Creates a new bridge with the given schema and channel capacity.
30    #[must_use]
31    pub fn new(schema: SchemaRef, capacity: usize) -> Self {
32        let (tx, rx) = mpsc::bounded_async::<Result<RecordBatch, DataFusionError>>(capacity);
33        Self {
34            schema,
35            sender: BridgeSender { tx },
36            receiver: Some(rx),
37        }
38    }
39
40    /// Creates a new bridge with default capacity.
41    #[must_use]
42    pub fn with_default_capacity(schema: SchemaRef) -> Self {
43        Self::new(schema, DEFAULT_CHANNEL_CAPACITY)
44    }
45
46    /// Returns the schema for this bridge.
47    #[must_use]
48    pub fn schema(&self) -> SchemaRef {
49        Arc::clone(&self.schema)
50    }
51
52    /// Returns a cloneable sender for pushing batches into the bridge.
53    ///
54    /// Multiple senders can be created by cloning the returned sender.
55    #[must_use]
56    pub fn sender(&self) -> BridgeSender {
57        self.sender.clone()
58    }
59
60    /// Converts this bridge into a `RecordBatchStream` for `DataFusion`.
61    ///
62    /// This consumes the bridge, taking ownership of the receiver.
63    /// After calling this, you can still use senders obtained from `sender()`.
64    ///
65    /// # Panics
66    ///
67    /// Panics if called more than once (the receiver can only be taken once).
68    #[must_use]
69    pub fn into_stream(mut self) -> BridgeStream {
70        BridgeStream {
71            schema: self.schema,
72            receiver: self
73                .receiver
74                .take()
75                .expect("receiver already taken")
76                .into_stream(),
77        }
78    }
79
80    /// Creates a stream without consuming the bridge.
81    ///
82    /// This takes ownership of the receiver, so subsequent calls will return `None`.
83    #[must_use]
84    pub fn take_stream(&mut self) -> Option<BridgeStream> {
85        self.receiver.take().map(|receiver| BridgeStream {
86            schema: Arc::clone(&self.schema),
87            receiver: receiver.into_stream(),
88        })
89    }
90}
91
92/// A cloneable sender for pushing `RecordBatch` instances into a bridge.
93///
94/// Multiple producers can share senders by cloning this type.
95#[derive(Debug, Clone)]
96pub struct BridgeSender {
97    tx: MAsyncTx<mpsc::Array<Result<RecordBatch, DataFusionError>>>,
98}
99
100impl BridgeSender {
101    /// Sends a batch to the bridge.
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if the receiver has been dropped.
106    pub async fn send(&self, batch: RecordBatch) -> Result<(), BridgeSendError> {
107        self.tx
108            .send(Ok(batch))
109            .await
110            .map_err(|_| BridgeSendError::ReceiverDropped)
111    }
112
113    /// Sends an error to the bridge.
114    ///
115    /// This allows the producer to signal errors to the consumer.
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if the receiver has been dropped.
120    pub async fn send_error(&self, error: DataFusionError) -> Result<(), BridgeSendError> {
121        self.tx
122            .send(Err(error))
123            .await
124            .map_err(|_| BridgeSendError::ReceiverDropped)
125    }
126
127    /// Attempts to send a batch without waiting.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if the channel is full or the receiver is dropped.
132    pub fn try_send(&self, batch: RecordBatch) -> Result<(), BridgeTrySendError> {
133        self.tx.try_send(Ok(batch)).map_err(|e| match e {
134            TrySendError::Full(_) => BridgeTrySendError::Full,
135            TrySendError::Disconnected(_) => BridgeTrySendError::ReceiverDropped,
136        })
137    }
138
139    /// Returns true if the receiver has been dropped.
140    #[must_use]
141    pub fn is_closed(&self) -> bool {
142        self.tx.is_disconnected()
143    }
144}
145
146/// A stream that pulls `RecordBatch` instances from the bridge.
147///
148/// This implements both `Stream` and `DataFusion`'s `RecordBatchStream`
149/// so it can be used directly in `DataFusion` query execution.
150pub struct BridgeStream {
151    schema: SchemaRef,
152    receiver: AsyncStream<mpsc::Array<Result<RecordBatch, DataFusionError>>>,
153}
154
155impl std::fmt::Debug for BridgeStream {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        f.debug_struct("BridgeStream")
158            .field("schema", &self.schema)
159            .finish_non_exhaustive()
160    }
161}
162
163impl Stream for BridgeStream {
164    type Item = Result<RecordBatch, DataFusionError>;
165
166    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
167        Pin::new(&mut self.receiver).poll_next(cx)
168    }
169}
170
171impl RecordBatchStream for BridgeStream {
172    fn schema(&self) -> SchemaRef {
173        Arc::clone(&self.schema)
174    }
175}
176
177/// Error when sending a batch to the bridge.
178#[derive(Debug, thiserror::Error)]
179pub enum BridgeSendError {
180    /// The receiver has been dropped.
181    #[error("bridge receiver has been dropped")]
182    ReceiverDropped,
183}
184
185/// Error when trying to send a batch without blocking.
186#[derive(Debug, thiserror::Error)]
187pub enum BridgeTrySendError {
188    /// The channel is full.
189    #[error("bridge channel is full")]
190    Full,
191    /// The receiver has been dropped.
192    #[error("bridge receiver has been dropped")]
193    ReceiverDropped,
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use arrow_array::Int64Array;
200    use arrow_schema::{DataType, Field, Schema};
201    use futures::StreamExt;
202
203    fn test_schema() -> SchemaRef {
204        Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]))
205    }
206
207    fn test_batch(schema: &SchemaRef, values: Vec<i64>) -> RecordBatch {
208        let array = Arc::new(Int64Array::from(values));
209        RecordBatch::try_new(Arc::clone(schema), vec![array]).unwrap()
210    }
211
212    #[tokio::test]
213    async fn test_bridge_send_receive() {
214        let schema = test_schema();
215        let bridge = StreamBridge::new(Arc::clone(&schema), 10);
216        let sender = bridge.sender();
217        let mut stream = bridge.into_stream();
218
219        // Send a batch
220        let batch = test_batch(&schema, vec![1, 2, 3]);
221        sender.send(batch.clone()).await.unwrap();
222        drop(sender); // Close the channel
223
224        // Receive the batch
225        let received = stream.next().await.unwrap().unwrap();
226        assert_eq!(received.num_rows(), 3);
227
228        // Stream should end
229        assert!(stream.next().await.is_none());
230    }
231
232    #[tokio::test]
233    async fn test_bridge_multiple_batches() {
234        let schema = test_schema();
235        let bridge = StreamBridge::new(Arc::clone(&schema), 10);
236        let sender = bridge.sender();
237        let mut stream = bridge.into_stream();
238
239        // Send multiple batches
240        for i in 0..5 {
241            let batch = test_batch(&schema, vec![i64::from(i)]);
242            sender.send(batch).await.unwrap();
243        }
244        drop(sender);
245
246        // Receive all batches
247        let mut count = 0;
248        while let Some(result) = stream.next().await {
249            result.unwrap();
250            count += 1;
251        }
252        assert_eq!(count, 5);
253    }
254
255    #[tokio::test]
256    async fn test_bridge_sender_clone() {
257        let schema = test_schema();
258        let bridge = StreamBridge::new(Arc::clone(&schema), 10);
259        let sender1 = bridge.sender();
260        let sender2 = sender1.clone();
261        let mut stream = bridge.into_stream();
262
263        // Send from both senders
264        sender1.send(test_batch(&schema, vec![1])).await.unwrap();
265        sender2.send(test_batch(&schema, vec![2])).await.unwrap();
266        drop(sender1);
267        drop(sender2);
268
269        let mut count = 0;
270        while let Some(result) = stream.next().await {
271            result.unwrap();
272            count += 1;
273        }
274        assert_eq!(count, 2);
275    }
276
277    #[tokio::test]
278    async fn test_bridge_send_error() {
279        let schema = test_schema();
280        let bridge = StreamBridge::new(Arc::clone(&schema), 10);
281        let sender = bridge.sender();
282        let mut stream = bridge.into_stream();
283
284        // Send an error
285        sender
286            .send_error(DataFusionError::Plan("test error".to_string()))
287            .await
288            .unwrap();
289        drop(sender);
290
291        let result = stream.next().await.unwrap();
292        assert!(result.is_err());
293    }
294
295    #[tokio::test]
296    async fn test_bridge_try_send() {
297        let schema = test_schema();
298        let bridge = StreamBridge::new(Arc::clone(&schema), 2);
299        let sender = bridge.sender();
300        // Keep the stream alive to prevent channel close
301        let _stream = bridge.into_stream();
302
303        // Fill the channel
304        sender.try_send(test_batch(&schema, vec![1])).unwrap();
305        sender.try_send(test_batch(&schema, vec![2])).unwrap();
306
307        // Should fail when full
308        let result = sender.try_send(test_batch(&schema, vec![3]));
309        assert!(matches!(result, Err(BridgeTrySendError::Full)));
310    }
311
312    #[tokio::test]
313    async fn test_bridge_receiver_dropped() {
314        let schema = test_schema();
315        let bridge = StreamBridge::new(Arc::clone(&schema), 10);
316        let sender = bridge.sender();
317        let stream = bridge.into_stream();
318        drop(stream);
319
320        // Should detect closed channel
321        assert!(sender.is_closed());
322
323        let result = sender.send(test_batch(&schema, vec![1])).await;
324        assert!(matches!(result, Err(BridgeSendError::ReceiverDropped)));
325    }
326
327    #[test]
328    fn test_bridge_stream_schema() {
329        let schema = test_schema();
330        let bridge = StreamBridge::new(Arc::clone(&schema), 10);
331        let stream = bridge.into_stream();
332
333        assert_eq!(RecordBatchStream::schema(&stream), schema);
334    }
335}