1use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use arrow_array::RecordBatch;
12use arrow_schema::SchemaRef;
13use async_trait::async_trait;
14use datafusion::physical_plan::RecordBatchStream;
15use datafusion_common::DataFusionError;
16use datafusion_expr::Expr;
17use futures::Stream;
18use parking_lot::Mutex;
19
20use super::bridge::{BridgeSender, StreamBridge};
21use super::source::{SortColumn, StreamSource};
22
23const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
25
26pub struct ChannelStreamSource {
62 schema: SchemaRef,
64 bridge: Mutex<Option<StreamBridge>>,
66 sender: Mutex<Option<BridgeSender>>,
68 capacity: usize,
70 ordering: Option<Vec<SortColumn>>,
72}
73
74impl ChannelStreamSource {
75 #[must_use]
81 pub fn new(schema: SchemaRef) -> Self {
82 Self::with_capacity(schema, DEFAULT_CHANNEL_CAPACITY)
83 }
84
85 #[must_use]
92 pub fn with_capacity(schema: SchemaRef, capacity: usize) -> Self {
93 let bridge = StreamBridge::new(Arc::clone(&schema), capacity);
94 let sender = bridge.sender();
95 Self {
96 schema,
97 bridge: Mutex::new(Some(bridge)),
98 sender: Mutex::new(Some(sender)),
99 capacity,
100 ordering: None,
101 }
102 }
103
104 #[must_use]
113 pub fn with_ordering(mut self, ordering: Vec<SortColumn>) -> Self {
114 self.ordering = Some(ordering);
115 self
116 }
117
118 #[must_use]
128 pub fn take_sender(&self) -> Option<BridgeSender> {
129 self.sender.lock().take()
130 }
131
132 #[must_use]
138 pub fn sender(&self) -> Option<BridgeSender> {
139 self.sender.lock().as_ref().map(BridgeSender::clone)
140 }
141
142 pub fn reset(&self) -> BridgeSender {
150 let bridge = StreamBridge::new(Arc::clone(&self.schema), self.capacity);
151 let sender = bridge.sender();
152 *self.bridge.lock() = Some(bridge);
153 *self.sender.lock() = Some(sender.clone());
154 sender
155 }
156}
157
158impl Debug for ChannelStreamSource {
159 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("ChannelStreamSource")
161 .field("schema", &self.schema)
162 .field("capacity", &self.capacity)
163 .finish_non_exhaustive()
164 }
165}
166
167#[async_trait]
168impl StreamSource for ChannelStreamSource {
169 fn schema(&self) -> SchemaRef {
170 Arc::clone(&self.schema)
171 }
172
173 fn output_ordering(&self) -> Option<Vec<SortColumn>> {
174 self.ordering.clone()
175 }
176
177 fn stream(
178 &self,
179 projection: Option<Vec<usize>>,
180 _filters: Vec<Expr>,
181 ) -> Result<datafusion::physical_plan::SendableRecordBatchStream, DataFusionError> {
182 let mut bridge_guard = self.bridge.lock();
183 let bridge = bridge_guard.take().ok_or_else(|| {
184 DataFusionError::Execution(
185 "Stream already taken; call reset() to create a new bridge".to_string(),
186 )
187 })?;
188
189 let inner_stream = bridge.into_stream();
190
191 let stream: datafusion::physical_plan::SendableRecordBatchStream =
193 if let Some(indices) = projection {
194 let projected_schema = {
195 let fields: Vec<_> = indices
196 .iter()
197 .map(|&i| self.schema.field(i).clone())
198 .collect();
199 Arc::new(arrow_schema::Schema::new(fields))
200 };
201 Box::pin(ProjectingStream::new(
202 inner_stream,
203 projected_schema,
204 indices,
205 ))
206 } else {
207 Box::pin(inner_stream)
208 };
209
210 Ok(stream)
211 }
212}
213
214struct ProjectingStream<S> {
216 inner: S,
217 schema: SchemaRef,
218 indices: Vec<usize>,
219}
220
221impl<S> ProjectingStream<S> {
222 fn new(inner: S, schema: SchemaRef, indices: Vec<usize>) -> Self {
223 Self {
224 inner,
225 schema,
226 indices,
227 }
228 }
229}
230
231impl<S> Debug for ProjectingStream<S> {
232 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233 f.debug_struct("ProjectingStream")
234 .field("schema", &self.schema)
235 .field("indices", &self.indices)
236 .finish_non_exhaustive()
237 }
238}
239
240impl<S> Stream for ProjectingStream<S>
241where
242 S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
243{
244 type Item = Result<RecordBatch, DataFusionError>;
245
246 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
247 match Pin::new(&mut self.inner).poll_next(cx) {
248 Poll::Ready(Some(Ok(batch))) => {
249 let columns: Vec<_> = self
251 .indices
252 .iter()
253 .map(|&i| Arc::clone(batch.column(i)))
254 .collect();
255 let projected =
256 RecordBatch::try_new(Arc::clone(&self.schema), columns).map_err(|e| {
257 DataFusionError::ArrowError(
258 Box::new(e),
259 Some("projection failed".to_string()),
260 )
261 });
262 Poll::Ready(Some(projected))
263 }
264 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
265 Poll::Ready(None) => Poll::Ready(None),
266 Poll::Pending => Poll::Pending,
267 }
268 }
269}
270
271impl<S> RecordBatchStream for ProjectingStream<S>
272where
273 S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
274{
275 fn schema(&self) -> SchemaRef {
276 Arc::clone(&self.schema)
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use arrow_array::Int64Array;
284 use arrow_schema::{DataType, Field, Schema};
285 use futures::StreamExt;
286
287 fn test_schema() -> SchemaRef {
288 Arc::new(Schema::new(vec![
289 Field::new("id", DataType::Int64, false),
290 Field::new("value", DataType::Int64, false),
291 ]))
292 }
293
294 fn test_batch(schema: &SchemaRef, ids: Vec<i64>, values: Vec<i64>) -> RecordBatch {
295 RecordBatch::try_new(
296 Arc::clone(schema),
297 vec![
298 Arc::new(Int64Array::from(ids)),
299 Arc::new(Int64Array::from(values)),
300 ],
301 )
302 .unwrap()
303 }
304
305 #[test]
306 fn test_channel_source_schema() {
307 let schema = test_schema();
308 let source = ChannelStreamSource::new(Arc::clone(&schema));
309
310 assert_eq!(source.schema(), schema);
311 }
312
313 #[tokio::test]
314 async fn test_channel_source_stream() {
315 let schema = test_schema();
316 let source = ChannelStreamSource::new(Arc::clone(&schema));
317 let sender = source.take_sender().unwrap();
318
319 let mut stream = source.stream(None, vec![]).unwrap();
320
321 sender
323 .send(test_batch(&schema, vec![1, 2], vec![10, 20]))
324 .await
325 .unwrap();
326 drop(sender);
327
328 let batch = stream.next().await.unwrap().unwrap();
330 assert_eq!(batch.num_rows(), 2);
331 assert_eq!(batch.num_columns(), 2);
332 }
333
334 #[tokio::test]
335 async fn test_channel_source_projection() {
336 let schema = test_schema();
337 let source = ChannelStreamSource::new(Arc::clone(&schema));
338 let sender = source.take_sender().unwrap();
339
340 let mut stream = source.stream(Some(vec![1]), vec![]).unwrap();
342
343 sender
344 .send(test_batch(&schema, vec![1, 2], vec![100, 200]))
345 .await
346 .unwrap();
347 drop(sender);
348
349 let batch = stream.next().await.unwrap().unwrap();
350 assert_eq!(batch.num_columns(), 1);
351 assert_eq!(batch.schema().field(0).name(), "value");
352
353 let values = batch
354 .column(0)
355 .as_any()
356 .downcast_ref::<Int64Array>()
357 .unwrap();
358 assert_eq!(values.value(0), 100);
359 assert_eq!(values.value(1), 200);
360 }
361
362 #[tokio::test]
363 async fn test_channel_source_stream_already_taken() {
364 let schema = test_schema();
365 let source = ChannelStreamSource::new(Arc::clone(&schema));
366
367 let _stream = source.stream(None, vec![]).unwrap();
369
370 let result = source.stream(None, vec![]);
372 assert!(result.is_err());
373 }
374
375 #[tokio::test]
376 async fn test_channel_source_multiple_batches() {
377 let schema = test_schema();
378 let source = ChannelStreamSource::new(Arc::clone(&schema));
379 let sender = source.take_sender().unwrap();
380 let mut stream = source.stream(None, vec![]).unwrap();
381
382 for i in 0..5i64 {
384 sender
385 .send(test_batch(&schema, vec![i], vec![i * 10]))
386 .await
387 .unwrap();
388 }
389 drop(sender);
390
391 let mut count = 0;
393 while let Some(result) = stream.next().await {
394 result.unwrap();
395 count += 1;
396 }
397 assert_eq!(count, 5);
398 }
399
400 #[tokio::test]
401 async fn test_channel_source_take_sender_once() {
402 let schema = test_schema();
403 let source = ChannelStreamSource::new(Arc::clone(&schema));
404
405 let sender = source.take_sender();
407 assert!(sender.is_some());
408
409 let sender2 = source.take_sender();
411 assert!(sender2.is_none());
412 }
413
414 #[tokio::test]
415 async fn test_channel_source_reset() {
416 let schema = test_schema();
417 let source = ChannelStreamSource::new(Arc::clone(&schema));
418
419 let _sender = source.take_sender().unwrap();
421 let _stream = source.stream(None, vec![]).unwrap();
422
423 let new_sender = source.reset();
425 let mut new_stream = source.stream(None, vec![]).unwrap();
426
427 new_sender
429 .send(test_batch(&schema, vec![1], vec![10]))
430 .await
431 .unwrap();
432 drop(new_sender);
433
434 let batch = new_stream.next().await.unwrap().unwrap();
435 assert_eq!(batch.num_rows(), 1);
436 }
437
438 #[test]
439 fn test_channel_source_debug() {
440 let schema = test_schema();
441 let source = ChannelStreamSource::new(Arc::clone(&schema));
442
443 let debug_str = format!("{source:?}");
444 assert!(debug_str.contains("ChannelStreamSource"));
445 assert!(debug_str.contains("capacity"));
446 }
447
448 #[test]
449 fn test_channel_source_default_no_ordering() {
450 let schema = test_schema();
451 let source = ChannelStreamSource::new(Arc::clone(&schema));
452
453 assert!(source.output_ordering().is_none());
454 }
455
456 #[test]
457 fn test_channel_source_with_ordering() {
458 let schema = test_schema();
459 let source = ChannelStreamSource::new(Arc::clone(&schema))
460 .with_ordering(vec![SortColumn::ascending("id")]);
461
462 let ordering = source.output_ordering();
463 assert!(ordering.is_some());
464 let cols = ordering.unwrap();
465 assert_eq!(cols.len(), 1);
466 assert_eq!(cols[0].name, "id");
467 assert!(!cols[0].descending);
468 }
469}