clickhouse_datafusion/
stream.rs

1use std::pin::Pin;
2use std::sync::Arc;
3use std::task::Poll;
4
5use datafusion::arrow::array::RecordBatch;
6use datafusion::arrow::compute::cast;
7use datafusion::arrow::datatypes::SchemaRef;
8use datafusion::common::exec_err;
9use datafusion::error::Result as DataFusionResult;
10use futures_util::{Stream, StreamExt, TryStreamExt, ready};
11use pin_project::pin_project;
12
13use crate::ClickHouseConnectionPool;
14
15pub type RecordBatchStreamWrapper =
16    RecordBatchStream<Pin<Box<dyn Stream<Item = DataFusionResult<RecordBatch>> + Send>>>;
17
18// TODO: Docs - also does DataFusion provide anything that makes this unnecessary?
19//
20// Stream adapter for ClickHouse query results
21#[pin_project]
22pub struct RecordBatchStream<S> {
23    schema:        SchemaRef,
24    #[pin]
25    stream:        S,
26    // TODO: Support actual arrow CastOptions, using that as indicator of whether to coerce types
27    coerce_schema: bool,
28}
29
30impl<S> RecordBatchStream<S> {
31    pub fn new(stream: S, schema: SchemaRef) -> Self {
32        Self { schema, stream, coerce_schema: false }
33    }
34
35    #[must_use]
36    pub fn with_coercion(mut self, coerce: bool) -> Self {
37        self.coerce_schema = coerce;
38        self
39    }
40
41    /// Coerce the schema of the returned record batch to match the expected schema
42    ///
43    /// Assumes the fields are in the same order. Will blindly attempt to cast, returning an error
44    /// if the arrow cast fails.
45    fn coerce_batch_schema(&self, batch: RecordBatch) -> DataFusionResult<RecordBatch> {
46        if self.coerce_schema {
47            let (batch_schema, mut arrays, _) = batch.into_parts();
48
49            let from_fields = batch_schema.fields();
50            let to_fields = self.schema.fields();
51            if from_fields.len() != to_fields.len() {
52                return exec_err!("Cannot coerce types, incompatible schemas");
53            }
54
55            let mut new_arrays = Vec::with_capacity(arrays.len());
56            let field_map = batch_schema.fields().iter().zip(self.schema.fields().iter());
57
58            // Reverse to allow popping from the end
59            for (from_field, to_field) in field_map.rev() {
60                let Some(current_array) = arrays.pop() else {
61                    return exec_err!("Cannot coerce types, missing array");
62                };
63
64                if from_field.data_type() == to_field.data_type() {
65                    new_arrays.push(current_array);
66                } else {
67                    let new_array = cast(&current_array, to_field.data_type())?;
68                    new_arrays.push(new_array);
69                }
70            }
71
72            // Reverse the arrays back
73            new_arrays.reverse();
74            Ok(RecordBatch::try_new(Arc::clone(&self.schema), new_arrays)?)
75        } else {
76            Ok(batch)
77        }
78    }
79}
80
81impl RecordBatchStreamWrapper {
82    pub fn new_from_stream(
83        stream: Pin<Box<dyn Stream<Item = DataFusionResult<RecordBatch>> + Send>>,
84        schema: SchemaRef,
85    ) -> Self {
86        Self { schema, stream, coerce_schema: false }
87    }
88
89    pub fn new_from_query(
90        sql: impl Into<String>,
91        pool: Arc<ClickHouseConnectionPool>,
92        schema: SchemaRef,
93        coerce_schema: bool,
94    ) -> Self {
95        let sql = sql.into();
96        let pool_schema = Arc::clone(&schema);
97        let stream = Box::pin(
98            futures_util::stream::once(async move {
99                pool.connect()
100                    .await?
101                    .query_arrow_with_schema(&sql, &[], pool_schema, coerce_schema)
102                    .await
103            })
104            .try_flatten(),
105        );
106        Self { schema, stream, coerce_schema: false }
107    }
108}
109
110impl<S> Stream for RecordBatchStream<S>
111where
112    S: Stream<Item = DataFusionResult<RecordBatch>>,
113{
114    type Item = DataFusionResult<RecordBatch>;
115
116    fn poll_next(
117        mut self: Pin<&mut Self>,
118        cx: &mut std::task::Context<'_>,
119    ) -> Poll<Option<Self::Item>> {
120        if !self.coerce_schema {
121            return self.as_mut().project().stream.poll_next(cx);
122        }
123
124        Poll::Ready(match ready!(self.as_mut().project().stream.poll_next(cx)) {
125            Some(batch) => Some(self.coerce_batch_schema(batch?)),
126            None => None,
127        })
128    }
129
130    fn size_hint(&self) -> (usize, Option<usize>) { self.stream.size_hint() }
131}
132
133impl<S> datafusion::physical_plan::RecordBatchStream for RecordBatchStream<S>
134where
135    S: Stream<Item = DataFusionResult<RecordBatch>>,
136{
137    fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) }
138}
139
140/// Helper function to create a `SendableRecordBatchStream` from a stream of `RecordBatch`es where
141/// the schema must be extracted from the first batch.
142///
143/// # Errors
144/// - Returns an error if the stream is empty or the first batch fails.
145pub async fn record_batch_stream_from_stream(
146    mut stream: impl Stream<Item = DataFusionResult<RecordBatch>> + Send + Unpin + 'static,
147) -> DataFusionResult<RecordBatchStreamWrapper> {
148    let Some(first_batch) = stream.next().await else {
149        return exec_err!("No schema provided and record batch stream is empty");
150    };
151    let first_batch = first_batch?;
152    let schema = first_batch.schema();
153    let stream = Box::pin(futures_util::stream::once(async { Ok(first_batch) }).chain(stream));
154    Ok(RecordBatchStream::new_from_stream(stream, schema))
155}
156
157#[cfg(all(test, feature = "test-utils"))]
158mod tests {
159    use std::pin::Pin;
160    use std::task::{Context, Poll};
161
162    use datafusion::arrow::array::{Int32Array, StringArray};
163    use datafusion::arrow::datatypes::{DataType, Field, Schema};
164    use datafusion::arrow::record_batch::RecordBatch;
165    use datafusion::physical_plan::RecordBatchStream as RecordBatchStreamTrait;
166    use futures_util::stream;
167
168    use super::*;
169
170    fn create_test_record_batch() -> RecordBatch {
171        let schema = Arc::new(Schema::new(vec![
172            Field::new("id", DataType::Int32, false),
173            Field::new("name", DataType::Utf8, false),
174        ]));
175
176        let id_array = Int32Array::from(vec![1, 2, 3]);
177        let name_array = StringArray::from(vec!["Alice", "Bob", "Charlie"]);
178
179        RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(name_array)]).unwrap()
180    }
181
182    #[test]
183    fn test_record_batch_stream_new() {
184        let batch = create_test_record_batch();
185        let schema = batch.schema();
186        let stream = Box::pin(stream::once(async move { Ok(batch) }));
187
188        let record_batch_stream =
189            RecordBatchStreamWrapper::new_from_stream(stream, Arc::clone(&schema));
190        assert_eq!(record_batch_stream.schema(), schema);
191    }
192
193    #[test]
194    fn test_record_batch_stream_schema() {
195        let batch = create_test_record_batch();
196        let schema = batch.schema();
197        let stream = Box::pin(stream::once(async move { Ok(batch) }));
198
199        let record_batch_stream =
200            RecordBatchStreamWrapper::new_from_stream(stream, Arc::clone(&schema));
201        let returned_schema = record_batch_stream.schema();
202
203        assert_eq!(returned_schema.fields().len(), 2);
204        assert_eq!(returned_schema.field(0).name(), "id");
205        assert_eq!(returned_schema.field(1).name(), "name");
206    }
207
208    #[tokio::test]
209    async fn test_record_batch_stream_poll_next() {
210        let batch = create_test_record_batch();
211        let schema = batch.schema();
212        let stream = Box::pin(stream::once(async move { Ok(batch.clone()) }));
213
214        let mut record_batch_stream = RecordBatchStreamWrapper::new_from_stream(stream, schema);
215
216        // Create a mock context for polling
217        let waker = futures_util::task::noop_waker();
218        let mut context = Context::from_waker(&waker);
219
220        // Poll the stream
221        let pinned = Pin::new(&mut record_batch_stream);
222        if let Poll::Ready(Some(result)) = pinned.poll_next(&mut context) {
223            let received_batch = result.unwrap();
224            assert_eq!(received_batch.num_rows(), 3);
225            assert_eq!(received_batch.num_columns(), 2);
226        } else {
227            panic!("Expected Poll::Ready with batch");
228        }
229    }
230
231    #[tokio::test]
232    async fn test_record_batch_stream_from_stream_success() {
233        let batch1 = create_test_record_batch();
234        let batch2 = create_test_record_batch();
235        let test_stream = stream::iter(vec![Ok(batch1.clone()), Ok(batch2)]);
236
237        let result = record_batch_stream_from_stream(test_stream).await;
238        assert!(result.is_ok());
239
240        let mut sendable_stream = result.unwrap();
241        let first_batch = sendable_stream.next().await.unwrap().unwrap();
242        assert_eq!(first_batch.num_rows(), 3);
243        assert_eq!(first_batch.num_columns(), 2);
244    }
245
246    #[tokio::test]
247    async fn test_record_batch_stream_from_stream_empty() {
248        let empty_stream = stream::iter(Vec::<DataFusionResult<RecordBatch>>::new());
249
250        match record_batch_stream_from_stream(empty_stream).await {
251            Ok(_) => panic!("Expected error for empty stream"),
252            Err(error) => {
253                assert!(error.to_string().contains("record batch stream is empty"));
254            }
255        }
256    }
257
258    #[tokio::test]
259    async fn test_record_batch_stream_from_stream_first_batch_error() {
260        use datafusion::common::DataFusionError;
261
262        let error_stream =
263            stream::iter(vec![Err(DataFusionError::Internal("test error".to_string()))]);
264
265        match record_batch_stream_from_stream(error_stream).await {
266            Ok(_) => panic!("Expected error from first batch"),
267            Err(error) => {
268                assert!(error.to_string().contains("test error"));
269            }
270        }
271    }
272
273    #[tokio::test]
274    async fn test_record_batch_stream_from_stream_multiple_batches() {
275        let batch1 = create_test_record_batch();
276        let batch2 = create_test_record_batch();
277        let batch3 = create_test_record_batch();
278
279        let test_stream = stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]);
280
281        let result = record_batch_stream_from_stream(test_stream).await;
282        assert!(result.is_ok());
283
284        let mut sendable_stream = result.unwrap();
285        let mut count = 0;
286        while let Some(batch_result) = sendable_stream.next().await {
287            let batch = batch_result.unwrap();
288            assert_eq!(batch.num_rows(), 3);
289            assert_eq!(batch.num_columns(), 2);
290            count += 1;
291        }
292        assert_eq!(count, 3);
293    }
294}