laminar_sql/datafusion/
bridge.rs1use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use arrow_array::RecordBatch;
9use arrow_schema::SchemaRef;
10use crossfire::stream::AsyncStream;
11use crossfire::{mpsc, AsyncRx, MAsyncTx, TrySendError};
12use datafusion::physical_plan::RecordBatchStream;
13use datafusion_common::DataFusionError;
14use futures::Stream;
15
16const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
18
19#[derive(Debug)]
22pub struct StreamBridge {
23 schema: SchemaRef,
24 sender: BridgeSender,
25 receiver: Option<AsyncRx<mpsc::Array<Result<RecordBatch, DataFusionError>>>>,
26}
27
28impl StreamBridge {
29 #[must_use]
31 pub fn new(schema: SchemaRef, capacity: usize) -> Self {
32 let (tx, rx) = mpsc::bounded_async::<Result<RecordBatch, DataFusionError>>(capacity);
33 Self {
34 schema,
35 sender: BridgeSender { tx },
36 receiver: Some(rx),
37 }
38 }
39
40 #[must_use]
42 pub fn with_default_capacity(schema: SchemaRef) -> Self {
43 Self::new(schema, DEFAULT_CHANNEL_CAPACITY)
44 }
45
46 #[must_use]
48 pub fn schema(&self) -> SchemaRef {
49 Arc::clone(&self.schema)
50 }
51
52 #[must_use]
56 pub fn sender(&self) -> BridgeSender {
57 self.sender.clone()
58 }
59
60 #[must_use]
69 pub fn into_stream(mut self) -> BridgeStream {
70 BridgeStream {
71 schema: self.schema,
72 receiver: self
73 .receiver
74 .take()
75 .expect("receiver already taken")
76 .into_stream(),
77 }
78 }
79
80 #[must_use]
84 pub fn take_stream(&mut self) -> Option<BridgeStream> {
85 self.receiver.take().map(|receiver| BridgeStream {
86 schema: Arc::clone(&self.schema),
87 receiver: receiver.into_stream(),
88 })
89 }
90}
91
92#[derive(Debug, Clone)]
96pub struct BridgeSender {
97 tx: MAsyncTx<mpsc::Array<Result<RecordBatch, DataFusionError>>>,
98}
99
100impl BridgeSender {
101 pub async fn send(&self, batch: RecordBatch) -> Result<(), BridgeSendError> {
107 self.tx
108 .send(Ok(batch))
109 .await
110 .map_err(|_| BridgeSendError::ReceiverDropped)
111 }
112
113 pub async fn send_error(&self, error: DataFusionError) -> Result<(), BridgeSendError> {
121 self.tx
122 .send(Err(error))
123 .await
124 .map_err(|_| BridgeSendError::ReceiverDropped)
125 }
126
127 pub fn try_send(&self, batch: RecordBatch) -> Result<(), BridgeTrySendError> {
133 self.tx.try_send(Ok(batch)).map_err(|e| match e {
134 TrySendError::Full(_) => BridgeTrySendError::Full,
135 TrySendError::Disconnected(_) => BridgeTrySendError::ReceiverDropped,
136 })
137 }
138
139 #[must_use]
141 pub fn is_closed(&self) -> bool {
142 self.tx.is_disconnected()
143 }
144}
145
146pub struct BridgeStream {
151 schema: SchemaRef,
152 receiver: AsyncStream<mpsc::Array<Result<RecordBatch, DataFusionError>>>,
153}
154
155impl std::fmt::Debug for BridgeStream {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 f.debug_struct("BridgeStream")
158 .field("schema", &self.schema)
159 .finish_non_exhaustive()
160 }
161}
162
163impl Stream for BridgeStream {
164 type Item = Result<RecordBatch, DataFusionError>;
165
166 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
167 Pin::new(&mut self.receiver).poll_next(cx)
168 }
169}
170
171impl RecordBatchStream for BridgeStream {
172 fn schema(&self) -> SchemaRef {
173 Arc::clone(&self.schema)
174 }
175}
176
177#[derive(Debug, thiserror::Error)]
179pub enum BridgeSendError {
180 #[error("bridge receiver has been dropped")]
182 ReceiverDropped,
183}
184
185#[derive(Debug, thiserror::Error)]
187pub enum BridgeTrySendError {
188 #[error("bridge channel is full")]
190 Full,
191 #[error("bridge receiver has been dropped")]
193 ReceiverDropped,
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use arrow_array::Int64Array;
200 use arrow_schema::{DataType, Field, Schema};
201 use futures::StreamExt;
202
203 fn test_schema() -> SchemaRef {
204 Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]))
205 }
206
207 fn test_batch(schema: &SchemaRef, values: Vec<i64>) -> RecordBatch {
208 let array = Arc::new(Int64Array::from(values));
209 RecordBatch::try_new(Arc::clone(schema), vec![array]).unwrap()
210 }
211
212 #[tokio::test]
213 async fn test_bridge_send_receive() {
214 let schema = test_schema();
215 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
216 let sender = bridge.sender();
217 let mut stream = bridge.into_stream();
218
219 let batch = test_batch(&schema, vec![1, 2, 3]);
221 sender.send(batch.clone()).await.unwrap();
222 drop(sender); let received = stream.next().await.unwrap().unwrap();
226 assert_eq!(received.num_rows(), 3);
227
228 assert!(stream.next().await.is_none());
230 }
231
232 #[tokio::test]
233 async fn test_bridge_multiple_batches() {
234 let schema = test_schema();
235 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
236 let sender = bridge.sender();
237 let mut stream = bridge.into_stream();
238
239 for i in 0..5 {
241 let batch = test_batch(&schema, vec![i64::from(i)]);
242 sender.send(batch).await.unwrap();
243 }
244 drop(sender);
245
246 let mut count = 0;
248 while let Some(result) = stream.next().await {
249 result.unwrap();
250 count += 1;
251 }
252 assert_eq!(count, 5);
253 }
254
255 #[tokio::test]
256 async fn test_bridge_sender_clone() {
257 let schema = test_schema();
258 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
259 let sender1 = bridge.sender();
260 let sender2 = sender1.clone();
261 let mut stream = bridge.into_stream();
262
263 sender1.send(test_batch(&schema, vec![1])).await.unwrap();
265 sender2.send(test_batch(&schema, vec![2])).await.unwrap();
266 drop(sender1);
267 drop(sender2);
268
269 let mut count = 0;
270 while let Some(result) = stream.next().await {
271 result.unwrap();
272 count += 1;
273 }
274 assert_eq!(count, 2);
275 }
276
277 #[tokio::test]
278 async fn test_bridge_send_error() {
279 let schema = test_schema();
280 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
281 let sender = bridge.sender();
282 let mut stream = bridge.into_stream();
283
284 sender
286 .send_error(DataFusionError::Plan("test error".to_string()))
287 .await
288 .unwrap();
289 drop(sender);
290
291 let result = stream.next().await.unwrap();
292 assert!(result.is_err());
293 }
294
295 #[tokio::test]
296 async fn test_bridge_try_send() {
297 let schema = test_schema();
298 let bridge = StreamBridge::new(Arc::clone(&schema), 2);
299 let sender = bridge.sender();
300 let _stream = bridge.into_stream();
302
303 sender.try_send(test_batch(&schema, vec![1])).unwrap();
305 sender.try_send(test_batch(&schema, vec![2])).unwrap();
306
307 let result = sender.try_send(test_batch(&schema, vec![3]));
309 assert!(matches!(result, Err(BridgeTrySendError::Full)));
310 }
311
312 #[tokio::test]
313 async fn test_bridge_receiver_dropped() {
314 let schema = test_schema();
315 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
316 let sender = bridge.sender();
317 let stream = bridge.into_stream();
318 drop(stream);
319
320 assert!(sender.is_closed());
322
323 let result = sender.send(test_batch(&schema, vec![1])).await;
324 assert!(matches!(result, Err(BridgeSendError::ReceiverDropped)));
325 }
326
327 #[test]
328 fn test_bridge_stream_schema() {
329 let schema = test_schema();
330 let bridge = StreamBridge::new(Arc::clone(&schema), 10);
331 let stream = bridge.into_stream();
332
333 assert_eq!(RecordBatchStream::schema(&stream), schema);
334 }
335}