1use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use arrow_array::RecordBatch;
12use arrow_schema::SchemaRef;
13use async_trait::async_trait;
14use datafusion::physical_plan::RecordBatchStream;
15use datafusion_common::DataFusionError;
16use datafusion_expr::Expr;
17use futures::Stream;
18use parking_lot::Mutex;
19
20use super::bridge::{BridgeSender, StreamBridge};
21use super::source::{SortColumn, StreamSource};
22
23const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
25
26pub struct ChannelStreamSource {
62 schema: SchemaRef,
64 bridge: Mutex<Option<StreamBridge>>,
66 sender: Mutex<Option<BridgeSender>>,
68 capacity: usize,
70 ordering: Option<Vec<SortColumn>>,
72}
73
74impl ChannelStreamSource {
75 #[must_use]
77 pub fn new(schema: SchemaRef) -> Self {
78 Self::with_capacity(schema, DEFAULT_CHANNEL_CAPACITY)
79 }
80
81 #[must_use]
83 pub fn with_capacity(schema: SchemaRef, capacity: usize) -> Self {
84 let bridge = StreamBridge::new(Arc::clone(&schema), capacity);
85 let sender = bridge.sender();
86 Self {
87 schema,
88 bridge: Mutex::new(Some(bridge)),
89 sender: Mutex::new(Some(sender)),
90 capacity,
91 ordering: None,
92 }
93 }
94
95 #[must_use]
99 pub fn with_ordering(mut self, ordering: Vec<SortColumn>) -> Self {
100 self.ordering = Some(ordering);
101 self
102 }
103
104 #[must_use]
114 pub fn take_sender(&self) -> Option<BridgeSender> {
115 self.sender.lock().take()
116 }
117
118 #[must_use]
124 pub fn sender(&self) -> Option<BridgeSender> {
125 self.sender.lock().as_ref().map(BridgeSender::clone)
126 }
127
128 pub fn reset(&self) -> BridgeSender {
136 let bridge = StreamBridge::new(Arc::clone(&self.schema), self.capacity);
137 let sender = bridge.sender();
138 *self.bridge.lock() = Some(bridge);
139 *self.sender.lock() = Some(sender.clone());
140 sender
141 }
142}
143
144impl Debug for ChannelStreamSource {
145 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
146 f.debug_struct("ChannelStreamSource")
147 .field("schema", &self.schema)
148 .field("capacity", &self.capacity)
149 .finish_non_exhaustive()
150 }
151}
152
153#[async_trait]
154impl StreamSource for ChannelStreamSource {
155 fn schema(&self) -> SchemaRef {
156 Arc::clone(&self.schema)
157 }
158
159 fn output_ordering(&self) -> Option<Vec<SortColumn>> {
160 self.ordering.clone()
161 }
162
163 fn stream(
164 &self,
165 projection: Option<Vec<usize>>,
166 _filters: Vec<Expr>,
167 ) -> Result<datafusion::physical_plan::SendableRecordBatchStream, DataFusionError> {
168 let mut bridge_guard = self.bridge.lock();
169 let bridge = bridge_guard.take().ok_or_else(|| {
170 DataFusionError::Execution(
171 "Stream already taken; call reset() to create a new bridge".to_string(),
172 )
173 })?;
174
175 let inner_stream = bridge.into_stream();
176
177 let stream: datafusion::physical_plan::SendableRecordBatchStream =
179 if let Some(indices) = projection {
180 let projected_schema = {
181 let fields: Vec<_> = indices
182 .iter()
183 .map(|&i| self.schema.field(i).clone())
184 .collect();
185 Arc::new(arrow_schema::Schema::new(fields))
186 };
187 Box::pin(ProjectingStream::new(
188 inner_stream,
189 projected_schema,
190 indices,
191 ))
192 } else {
193 Box::pin(inner_stream)
194 };
195
196 Ok(stream)
197 }
198}
199
200struct ProjectingStream<S> {
202 inner: S,
203 schema: SchemaRef,
204 indices: Vec<usize>,
205}
206
207impl<S> ProjectingStream<S> {
208 fn new(inner: S, schema: SchemaRef, indices: Vec<usize>) -> Self {
209 Self {
210 inner,
211 schema,
212 indices,
213 }
214 }
215}
216
217impl<S> Debug for ProjectingStream<S> {
218 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
219 f.debug_struct("ProjectingStream")
220 .field("schema", &self.schema)
221 .field("indices", &self.indices)
222 .finish_non_exhaustive()
223 }
224}
225
226impl<S> Stream for ProjectingStream<S>
227where
228 S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
229{
230 type Item = Result<RecordBatch, DataFusionError>;
231
232 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233 match Pin::new(&mut self.inner).poll_next(cx) {
234 Poll::Ready(Some(Ok(batch))) => {
235 let projected = batch.project(&self.indices).map_err(|e| {
237 DataFusionError::ArrowError(Box::new(e), Some("projection failed".to_string()))
238 });
239 Poll::Ready(Some(projected))
240 }
241 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
242 Poll::Ready(None) => Poll::Ready(None),
243 Poll::Pending => Poll::Pending,
244 }
245 }
246}
247
248impl<S> RecordBatchStream for ProjectingStream<S>
249where
250 S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
251{
252 fn schema(&self) -> SchemaRef {
253 Arc::clone(&self.schema)
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use arrow_array::Int64Array;
261 use arrow_schema::{DataType, Field, Schema};
262 use futures::StreamExt;
263
264 fn test_schema() -> SchemaRef {
265 Arc::new(Schema::new(vec![
266 Field::new("id", DataType::Int64, false),
267 Field::new("value", DataType::Int64, false),
268 ]))
269 }
270
271 fn test_batch(schema: &SchemaRef, ids: Vec<i64>, values: Vec<i64>) -> RecordBatch {
272 RecordBatch::try_new(
273 Arc::clone(schema),
274 vec![
275 Arc::new(Int64Array::from(ids)),
276 Arc::new(Int64Array::from(values)),
277 ],
278 )
279 .unwrap()
280 }
281
282 #[test]
283 fn test_channel_source_schema() {
284 let schema = test_schema();
285 let source = ChannelStreamSource::new(Arc::clone(&schema));
286
287 assert_eq!(source.schema(), schema);
288 }
289
290 #[tokio::test]
291 async fn test_channel_source_stream() {
292 let schema = test_schema();
293 let source = ChannelStreamSource::new(Arc::clone(&schema));
294 let sender = source.take_sender().unwrap();
295
296 let mut stream = source.stream(None, vec![]).unwrap();
297
298 sender
300 .send(test_batch(&schema, vec![1, 2], vec![10, 20]))
301 .await
302 .unwrap();
303 drop(sender);
304
305 let batch = stream.next().await.unwrap().unwrap();
307 assert_eq!(batch.num_rows(), 2);
308 assert_eq!(batch.num_columns(), 2);
309 }
310
311 #[tokio::test]
312 async fn test_channel_source_projection() {
313 let schema = test_schema();
314 let source = ChannelStreamSource::new(Arc::clone(&schema));
315 let sender = source.take_sender().unwrap();
316
317 let mut stream = source.stream(Some(vec![1]), vec![]).unwrap();
319
320 sender
321 .send(test_batch(&schema, vec![1, 2], vec![100, 200]))
322 .await
323 .unwrap();
324 drop(sender);
325
326 let batch = stream.next().await.unwrap().unwrap();
327 assert_eq!(batch.num_columns(), 1);
328 assert_eq!(batch.schema().field(0).name(), "value");
329
330 let values = batch
331 .column(0)
332 .as_any()
333 .downcast_ref::<Int64Array>()
334 .unwrap();
335 assert_eq!(values.value(0), 100);
336 assert_eq!(values.value(1), 200);
337 }
338
339 #[tokio::test]
340 async fn test_channel_source_stream_already_taken() {
341 let schema = test_schema();
342 let source = ChannelStreamSource::new(Arc::clone(&schema));
343
344 let _stream = source.stream(None, vec![]).unwrap();
346
347 let result = source.stream(None, vec![]);
349 assert!(result.is_err());
350 }
351
352 #[tokio::test]
353 async fn test_channel_source_multiple_batches() {
354 let schema = test_schema();
355 let source = ChannelStreamSource::new(Arc::clone(&schema));
356 let sender = source.take_sender().unwrap();
357 let mut stream = source.stream(None, vec![]).unwrap();
358
359 for i in 0..5i64 {
361 sender
362 .send(test_batch(&schema, vec![i], vec![i * 10]))
363 .await
364 .unwrap();
365 }
366 drop(sender);
367
368 let mut count = 0;
370 while let Some(result) = stream.next().await {
371 result.unwrap();
372 count += 1;
373 }
374 assert_eq!(count, 5);
375 }
376
377 #[tokio::test]
378 async fn test_channel_source_take_sender_once() {
379 let schema = test_schema();
380 let source = ChannelStreamSource::new(Arc::clone(&schema));
381
382 let sender = source.take_sender();
384 assert!(sender.is_some());
385
386 let sender2 = source.take_sender();
388 assert!(sender2.is_none());
389 }
390
391 #[tokio::test]
392 async fn test_channel_source_reset() {
393 let schema = test_schema();
394 let source = ChannelStreamSource::new(Arc::clone(&schema));
395
396 let _sender = source.take_sender().unwrap();
398 let _stream = source.stream(None, vec![]).unwrap();
399
400 let new_sender = source.reset();
402 let mut new_stream = source.stream(None, vec![]).unwrap();
403
404 new_sender
406 .send(test_batch(&schema, vec![1], vec![10]))
407 .await
408 .unwrap();
409 drop(new_sender);
410
411 let batch = new_stream.next().await.unwrap().unwrap();
412 assert_eq!(batch.num_rows(), 1);
413 }
414
415 #[test]
416 fn test_channel_source_debug() {
417 let schema = test_schema();
418 let source = ChannelStreamSource::new(Arc::clone(&schema));
419
420 let debug_str = format!("{source:?}");
421 assert!(debug_str.contains("ChannelStreamSource"));
422 assert!(debug_str.contains("capacity"));
423 }
424
425 #[test]
426 fn test_channel_source_default_no_ordering() {
427 let schema = test_schema();
428 let source = ChannelStreamSource::new(Arc::clone(&schema));
429
430 assert!(source.output_ordering().is_none());
431 }
432
433 #[test]
434 fn test_channel_source_with_ordering() {
435 let schema = test_schema();
436 let source = ChannelStreamSource::new(Arc::clone(&schema))
437 .with_ordering(vec![SortColumn::ascending("id")]);
438
439 let ordering = source.output_ordering();
440 assert!(ordering.is_some());
441 let cols = ordering.unwrap();
442 assert_eq!(cols.len(), 1);
443 assert_eq!(cols[0].name, "id");
444 assert!(!cols[0].descending);
445 }
446}