clickhouse_datafusion/
stream.rs1use std::pin::Pin;
2use std::sync::Arc;
3use std::task::Poll;
4
5use datafusion::arrow::array::RecordBatch;
6use datafusion::arrow::compute::cast;
7use datafusion::arrow::datatypes::SchemaRef;
8use datafusion::common::exec_err;
9use datafusion::error::Result as DataFusionResult;
10use futures_util::{Stream, StreamExt, TryStreamExt, ready};
11use pin_project::pin_project;
12
13use crate::ClickHouseConnectionPool;
14
15pub type RecordBatchStreamWrapper =
16 RecordBatchStream<Pin<Box<dyn Stream<Item = DataFusionResult<RecordBatch>> + Send>>>;
17
18#[pin_project]
22pub struct RecordBatchStream<S> {
23 schema: SchemaRef,
24 #[pin]
25 stream: S,
26 coerce_schema: bool,
28}
29
30impl<S> RecordBatchStream<S> {
31 pub fn new(stream: S, schema: SchemaRef) -> Self {
32 Self { schema, stream, coerce_schema: false }
33 }
34
35 #[must_use]
36 pub fn with_coercion(mut self, coerce: bool) -> Self {
37 self.coerce_schema = coerce;
38 self
39 }
40
41 fn coerce_batch_schema(&self, batch: RecordBatch) -> DataFusionResult<RecordBatch> {
46 if self.coerce_schema {
47 let (batch_schema, mut arrays, _) = batch.into_parts();
48
49 let from_fields = batch_schema.fields();
50 let to_fields = self.schema.fields();
51 if from_fields.len() != to_fields.len() {
52 return exec_err!("Cannot coerce types, incompatible schemas");
53 }
54
55 let mut new_arrays = Vec::with_capacity(arrays.len());
56 let field_map = batch_schema.fields().iter().zip(self.schema.fields().iter());
57
58 for (from_field, to_field) in field_map.rev() {
60 let Some(current_array) = arrays.pop() else {
61 return exec_err!("Cannot coerce types, missing array");
62 };
63
64 if from_field.data_type() == to_field.data_type() {
65 new_arrays.push(current_array);
66 } else {
67 let new_array = cast(¤t_array, to_field.data_type())?;
68 new_arrays.push(new_array);
69 }
70 }
71
72 new_arrays.reverse();
74 Ok(RecordBatch::try_new(Arc::clone(&self.schema), new_arrays)?)
75 } else {
76 Ok(batch)
77 }
78 }
79}
80
81impl RecordBatchStreamWrapper {
82 pub fn new_from_stream(
83 stream: Pin<Box<dyn Stream<Item = DataFusionResult<RecordBatch>> + Send>>,
84 schema: SchemaRef,
85 ) -> Self {
86 Self { schema, stream, coerce_schema: false }
87 }
88
89 pub fn new_from_query(
90 sql: impl Into<String>,
91 pool: Arc<ClickHouseConnectionPool>,
92 schema: SchemaRef,
93 coerce_schema: bool,
94 ) -> Self {
95 let sql = sql.into();
96 let pool_schema = Arc::clone(&schema);
97 let stream = Box::pin(
98 futures_util::stream::once(async move {
99 pool.connect()
100 .await?
101 .query_arrow_with_schema(&sql, &[], pool_schema, coerce_schema)
102 .await
103 })
104 .try_flatten(),
105 );
106 Self { schema, stream, coerce_schema: false }
107 }
108}
109
110impl<S> Stream for RecordBatchStream<S>
111where
112 S: Stream<Item = DataFusionResult<RecordBatch>>,
113{
114 type Item = DataFusionResult<RecordBatch>;
115
116 fn poll_next(
117 mut self: Pin<&mut Self>,
118 cx: &mut std::task::Context<'_>,
119 ) -> Poll<Option<Self::Item>> {
120 if !self.coerce_schema {
121 return self.as_mut().project().stream.poll_next(cx);
122 }
123
124 Poll::Ready(match ready!(self.as_mut().project().stream.poll_next(cx)) {
125 Some(batch) => Some(self.coerce_batch_schema(batch?)),
126 None => None,
127 })
128 }
129
130 fn size_hint(&self) -> (usize, Option<usize>) { self.stream.size_hint() }
131}
132
133impl<S> datafusion::physical_plan::RecordBatchStream for RecordBatchStream<S>
134where
135 S: Stream<Item = DataFusionResult<RecordBatch>>,
136{
137 fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) }
138}
139
140pub async fn record_batch_stream_from_stream(
146 mut stream: impl Stream<Item = DataFusionResult<RecordBatch>> + Send + Unpin + 'static,
147) -> DataFusionResult<RecordBatchStreamWrapper> {
148 let Some(first_batch) = stream.next().await else {
149 return exec_err!("No schema provided and record batch stream is empty");
150 };
151 let first_batch = first_batch?;
152 let schema = first_batch.schema();
153 let stream = Box::pin(futures_util::stream::once(async { Ok(first_batch) }).chain(stream));
154 Ok(RecordBatchStream::new_from_stream(stream, schema))
155}
156
157#[cfg(all(test, feature = "test-utils"))]
158mod tests {
159 use std::pin::Pin;
160 use std::task::{Context, Poll};
161
162 use datafusion::arrow::array::{Int32Array, StringArray};
163 use datafusion::arrow::datatypes::{DataType, Field, Schema};
164 use datafusion::arrow::record_batch::RecordBatch;
165 use datafusion::physical_plan::RecordBatchStream as RecordBatchStreamTrait;
166 use futures_util::stream;
167
168 use super::*;
169
170 fn create_test_record_batch() -> RecordBatch {
171 let schema = Arc::new(Schema::new(vec![
172 Field::new("id", DataType::Int32, false),
173 Field::new("name", DataType::Utf8, false),
174 ]));
175
176 let id_array = Int32Array::from(vec![1, 2, 3]);
177 let name_array = StringArray::from(vec!["Alice", "Bob", "Charlie"]);
178
179 RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(name_array)]).unwrap()
180 }
181
182 #[test]
183 fn test_record_batch_stream_new() {
184 let batch = create_test_record_batch();
185 let schema = batch.schema();
186 let stream = Box::pin(stream::once(async move { Ok(batch) }));
187
188 let record_batch_stream =
189 RecordBatchStreamWrapper::new_from_stream(stream, Arc::clone(&schema));
190 assert_eq!(record_batch_stream.schema(), schema);
191 }
192
193 #[test]
194 fn test_record_batch_stream_schema() {
195 let batch = create_test_record_batch();
196 let schema = batch.schema();
197 let stream = Box::pin(stream::once(async move { Ok(batch) }));
198
199 let record_batch_stream =
200 RecordBatchStreamWrapper::new_from_stream(stream, Arc::clone(&schema));
201 let returned_schema = record_batch_stream.schema();
202
203 assert_eq!(returned_schema.fields().len(), 2);
204 assert_eq!(returned_schema.field(0).name(), "id");
205 assert_eq!(returned_schema.field(1).name(), "name");
206 }
207
208 #[tokio::test]
209 async fn test_record_batch_stream_poll_next() {
210 let batch = create_test_record_batch();
211 let schema = batch.schema();
212 let stream = Box::pin(stream::once(async move { Ok(batch.clone()) }));
213
214 let mut record_batch_stream = RecordBatchStreamWrapper::new_from_stream(stream, schema);
215
216 let waker = futures_util::task::noop_waker();
218 let mut context = Context::from_waker(&waker);
219
220 let pinned = Pin::new(&mut record_batch_stream);
222 if let Poll::Ready(Some(result)) = pinned.poll_next(&mut context) {
223 let received_batch = result.unwrap();
224 assert_eq!(received_batch.num_rows(), 3);
225 assert_eq!(received_batch.num_columns(), 2);
226 } else {
227 panic!("Expected Poll::Ready with batch");
228 }
229 }
230
231 #[tokio::test]
232 async fn test_record_batch_stream_from_stream_success() {
233 let batch1 = create_test_record_batch();
234 let batch2 = create_test_record_batch();
235 let test_stream = stream::iter(vec![Ok(batch1.clone()), Ok(batch2)]);
236
237 let result = record_batch_stream_from_stream(test_stream).await;
238 assert!(result.is_ok());
239
240 let mut sendable_stream = result.unwrap();
241 let first_batch = sendable_stream.next().await.unwrap().unwrap();
242 assert_eq!(first_batch.num_rows(), 3);
243 assert_eq!(first_batch.num_columns(), 2);
244 }
245
246 #[tokio::test]
247 async fn test_record_batch_stream_from_stream_empty() {
248 let empty_stream = stream::iter(Vec::<DataFusionResult<RecordBatch>>::new());
249
250 match record_batch_stream_from_stream(empty_stream).await {
251 Ok(_) => panic!("Expected error for empty stream"),
252 Err(error) => {
253 assert!(error.to_string().contains("record batch stream is empty"));
254 }
255 }
256 }
257
258 #[tokio::test]
259 async fn test_record_batch_stream_from_stream_first_batch_error() {
260 use datafusion::common::DataFusionError;
261
262 let error_stream =
263 stream::iter(vec![Err(DataFusionError::Internal("test error".to_string()))]);
264
265 match record_batch_stream_from_stream(error_stream).await {
266 Ok(_) => panic!("Expected error from first batch"),
267 Err(error) => {
268 assert!(error.to_string().contains("test error"));
269 }
270 }
271 }
272
273 #[tokio::test]
274 async fn test_record_batch_stream_from_stream_multiple_batches() {
275 let batch1 = create_test_record_batch();
276 let batch2 = create_test_record_batch();
277 let batch3 = create_test_record_batch();
278
279 let test_stream = stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]);
280
281 let result = record_batch_stream_from_stream(test_stream).await;
282 assert!(result.is_ok());
283
284 let mut sendable_stream = result.unwrap();
285 let mut count = 0;
286 while let Some(batch_result) = sendable_stream.next().await {
287 let batch = batch_result.unwrap();
288 assert_eq!(batch.num_rows(), 3);
289 assert_eq!(batch.num_columns(), 2);
290 count += 1;
291 }
292 assert_eq!(count, 3);
293 }
294}