laminar_sql/datafusion/
bridge.rs1use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use arrow_array::RecordBatch;
26use arrow_schema::SchemaRef;
27use datafusion::physical_plan::RecordBatchStream;
28use datafusion_common::DataFusionError;
29use futures::Stream;
30use tokio::sync::mpsc;
31
32const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
34
35#[derive(Debug)]
56pub struct StreamBridge {
57 schema: SchemaRef,
59 sender: BridgeSender,
61 receiver: Option<mpsc::Receiver<Result<RecordBatch, DataFusionError>>>,
63}
64
65impl StreamBridge {
66 #[must_use]
73 pub fn new(schema: SchemaRef, capacity: usize) -> Self {
74 let (tx, rx) = mpsc::channel(capacity);
75 Self {
76 schema,
77 sender: BridgeSender { tx },
78 receiver: Some(rx),
79 }
80 }
81
82 #[must_use]
84 pub fn with_default_capacity(schema: SchemaRef) -> Self {
85 Self::new(schema, DEFAULT_CHANNEL_CAPACITY)
86 }
87
88 #[must_use]
90 pub fn schema(&self) -> SchemaRef {
91 Arc::clone(&self.schema)
92 }
93
94 #[must_use]
98 pub fn sender(&self) -> BridgeSender {
99 self.sender.clone()
100 }
101
102 #[must_use]
111 pub fn into_stream(mut self) -> BridgeStream {
112 BridgeStream {
113 schema: self.schema,
114 receiver: self.receiver.take().expect("receiver already taken"),
115 }
116 }
117
118 #[must_use]
122 pub fn take_stream(&mut self) -> Option<BridgeStream> {
123 self.receiver.take().map(|receiver| BridgeStream {
124 schema: Arc::clone(&self.schema),
125 receiver,
126 })
127 }
128}
129
130#[derive(Debug, Clone)]
134pub struct BridgeSender {
135 tx: mpsc::Sender<Result<RecordBatch, DataFusionError>>,
136}
137
138impl BridgeSender {
139 pub async fn send(&self, batch: RecordBatch) -> Result<(), BridgeSendError> {
145 self.tx
146 .send(Ok(batch))
147 .await
148 .map_err(|_| BridgeSendError::ReceiverDropped)
149 }
150
151 pub async fn send_error(&self, error: DataFusionError) -> Result<(), BridgeSendError> {
159 self.tx
160 .send(Err(error))
161 .await
162 .map_err(|_| BridgeSendError::ReceiverDropped)
163 }
164
165 pub fn try_send(&self, batch: RecordBatch) -> Result<(), BridgeTrySendError> {
171 self.tx.try_send(Ok(batch)).map_err(|e| match e {
172 mpsc::error::TrySendError::Full(_) => BridgeTrySendError::Full,
173 mpsc::error::TrySendError::Closed(_) => BridgeTrySendError::ReceiverDropped,
174 })
175 }
176
177 #[must_use]
179 pub fn is_closed(&self) -> bool {
180 self.tx.is_closed()
181 }
182}
183
184pub struct BridgeStream {
189 schema: SchemaRef,
190 receiver: mpsc::Receiver<Result<RecordBatch, DataFusionError>>,
191}
192
193impl std::fmt::Debug for BridgeStream {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 f.debug_struct("BridgeStream")
196 .field("schema", &self.schema)
197 .finish_non_exhaustive()
198 }
199}
200
201impl Stream for BridgeStream {
202 type Item = Result<RecordBatch, DataFusionError>;
203
204 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
205 Pin::new(&mut self.receiver).poll_recv(cx)
206 }
207}
208
209impl RecordBatchStream for BridgeStream {
210 fn schema(&self) -> SchemaRef {
211 Arc::clone(&self.schema)
212 }
213}
214
215#[derive(Debug, thiserror::Error)]
217pub enum BridgeSendError {
218 #[error("bridge receiver has been dropped")]
220 ReceiverDropped,
221}
222
223#[derive(Debug, thiserror::Error)]
225pub enum BridgeTrySendError {
226 #[error("bridge channel is full")]
228 Full,
229 #[error("bridge receiver has been dropped")]
231 ReceiverDropped,
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use arrow_array::Int64Array;
238 use arrow_schema::{DataType, Field, Schema};
239 use futures::StreamExt;
240
241 fn test_schema() -> SchemaRef {
242 Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]))
243 }
244
245 fn test_batch(schema: &SchemaRef, values: Vec<i64>) -> RecordBatch {
246 let array = Arc::new(Int64Array::from(values));
247 RecordBatch::try_new(Arc::clone(schema), vec![array]).unwrap()
248 }
249
250 #[tokio::test]
251 async fn test_bridge_send_receive() {
252 let schema = test_schema();
253 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
254 let sender = bridge.sender();
255 let mut stream = bridge.into_stream();
256
257 let batch = test_batch(&schema, vec![1, 2, 3]);
259 sender.send(batch.clone()).await.unwrap();
260 drop(sender); let received = stream.next().await.unwrap().unwrap();
264 assert_eq!(received.num_rows(), 3);
265
266 assert!(stream.next().await.is_none());
268 }
269
270 #[tokio::test]
271 async fn test_bridge_multiple_batches() {
272 let schema = test_schema();
273 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
274 let sender = bridge.sender();
275 let mut stream = bridge.into_stream();
276
277 for i in 0..5 {
279 let batch = test_batch(&schema, vec![i64::from(i)]);
280 sender.send(batch).await.unwrap();
281 }
282 drop(sender);
283
284 let mut count = 0;
286 while let Some(result) = stream.next().await {
287 result.unwrap();
288 count += 1;
289 }
290 assert_eq!(count, 5);
291 }
292
293 #[tokio::test]
294 async fn test_bridge_sender_clone() {
295 let schema = test_schema();
296 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
297 let sender1 = bridge.sender();
298 let sender2 = sender1.clone();
299 let mut stream = bridge.into_stream();
300
301 sender1.send(test_batch(&schema, vec![1])).await.unwrap();
303 sender2.send(test_batch(&schema, vec![2])).await.unwrap();
304 drop(sender1);
305 drop(sender2);
306
307 let mut count = 0;
308 while let Some(result) = stream.next().await {
309 result.unwrap();
310 count += 1;
311 }
312 assert_eq!(count, 2);
313 }
314
315 #[tokio::test]
316 async fn test_bridge_send_error() {
317 let schema = test_schema();
318 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
319 let sender = bridge.sender();
320 let mut stream = bridge.into_stream();
321
322 sender
324 .send_error(DataFusionError::Plan("test error".to_string()))
325 .await
326 .unwrap();
327 drop(sender);
328
329 let result = stream.next().await.unwrap();
330 assert!(result.is_err());
331 }
332
333 #[tokio::test]
334 async fn test_bridge_try_send() {
335 let schema = test_schema();
336 let bridge = StreamBridge::new(Arc::clone(&schema), 2);
337 let sender = bridge.sender();
338 let _stream = bridge.into_stream();
340
341 sender.try_send(test_batch(&schema, vec![1])).unwrap();
343 sender.try_send(test_batch(&schema, vec![2])).unwrap();
344
345 let result = sender.try_send(test_batch(&schema, vec![3]));
347 assert!(matches!(result, Err(BridgeTrySendError::Full)));
348 }
349
350 #[tokio::test]
351 async fn test_bridge_receiver_dropped() {
352 let schema = test_schema();
353 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
354 let sender = bridge.sender();
355 let stream = bridge.into_stream();
356 drop(stream);
357
358 assert!(sender.is_closed());
360
361 let result = sender.send(test_batch(&schema, vec![1])).await;
362 assert!(matches!(result, Err(BridgeSendError::ReceiverDropped)));
363 }
364
365 #[test]
366 fn test_bridge_stream_schema() {
367 let schema = test_schema();
368 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
369 let stream = bridge.into_stream();
370
371 assert_eq!(RecordBatchStream::schema(&stream), schema);
372 }
373}