clickhouse_datafusion/
sink.rs

1use std::fmt;
2use std::sync::Arc;
3
4use datafusion::arrow::datatypes::SchemaRef;
5use datafusion::common::exec_err;
6use datafusion::error::Result;
7use datafusion::execution::SendableRecordBatchStream;
8use datafusion::physical_plan::DisplayAs;
9use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
10use datafusion::sql::TableReference;
11use futures_util::{StreamExt, TryStreamExt};
12
13use crate::connection::ClickHouseConnectionPool;
14
15// TODO: Docs
16/// [`datafusion::datasource::sink::DataSink`] for `ClickHouse`
17#[derive(Debug)]
18pub struct ClickHouseDataSink {
19    #[cfg_attr(feature = "mocks", expect(unused))]
20    writer:            Arc<ClickHouseConnectionPool>,
21    table:             TableReference,
22    schema:            SchemaRef,
23    metrics:           ExecutionPlanMetricsSet,
24    write_concurrency: usize,
25}
26
27impl ClickHouseDataSink {
28    pub fn new(
29        writer: Arc<ClickHouseConnectionPool>,
30        table: TableReference,
31        schema: SchemaRef,
32    ) -> Self {
33        let write_concurrency = writer.write_concurrency();
34        Self { writer, table, schema, metrics: ExecutionPlanMetricsSet::new(), write_concurrency }
35    }
36
37    /// Verify that a passed in schema aligns with the sink schema
38    ///
39    /// Ordering and metadata don't matter
40    ///
41    /// # Errors
42    /// - Returns an error if the field lengths don't match
43    /// - Returns an error if data types don't match
44    /// - Returns an error if names don't match
45    /// - Returns an error if nullability doesn't match
46    pub fn verify_input_schema(&self, input: &SchemaRef) -> Result<()> {
47        let sink_fields = self.schema.fields();
48        let input_fields = input.fields();
49        if sink_fields.len() != input_fields.len() {
50            let (input_len, sink_len) = (input_fields.len(), sink_fields.len());
51            return exec_err!(
52                "Schema fields must match, input has {input_len} fields, sink {sink_len}"
53            );
54        }
55
56        for field in sink_fields {
57            let name = field.name();
58            let data_type = field.data_type();
59            let is_nullable = field.is_nullable();
60
61            let Some((_, input_field)) = input_fields.find(name) else {
62                return exec_err!("Sink field {name} missing from input");
63            };
64
65            if data_type != input_field.data_type() {
66                return exec_err!(
67                    "Sink field {name} expected data type {data_type:?} but found {:?}",
68                    input_field.data_type()
69                );
70            }
71
72            if is_nullable != input_field.is_nullable() {
73                return exec_err!(
74                    "Sink field {name} expected nullability {is_nullable} but found {}",
75                    input_field.is_nullable()
76                );
77            }
78        }
79
80        Ok(())
81    }
82}
83
84impl DisplayAs for ClickHouseDataSink {
85    fn fmt_as(
86        &self,
87        _t: datafusion::physical_plan::DisplayFormatType,
88        f: &mut fmt::Formatter<'_>,
89    ) -> fmt::Result {
90        write!(f, "ClickHouseDataSink: table={}", self.table)
91    }
92}
93
94#[async_trait::async_trait]
95impl datafusion::datasource::sink::DataSink for ClickHouseDataSink {
96    fn as_any(&self) -> &dyn std::any::Any { self }
97
98    fn schema(&self) -> &SchemaRef { &self.schema }
99
100    fn metrics(&self) -> Option<MetricsSet> { Some(self.metrics.clone_inner()) }
101
102    async fn write_all(
103        &self,
104        data: SendableRecordBatchStream,
105        _context: &Arc<datafusion::execution::TaskContext>,
106    ) -> Result<u64> {
107        #[cfg(not(feature = "mocks"))]
108        use datafusion::error::DataFusionError;
109
110        // Create baseline metrics for this partition
111        // DataSink always runs on partition 0 (by DataFusion design)
112        let partition = 0;
113        let baseline = BaselineMetrics::new(&self.metrics, partition);
114        let _timer = baseline.elapsed_compute().timer();
115
116        let db = self.table.schema();
117        let table = self.table.table();
118
119        let query = if let Some(db) = db {
120            format!("INSERT INTO {db}.{table} FORMAT Native")
121        } else {
122            format!("INSERT INTO {table} FORMAT Native")
123        };
124
125        #[cfg(not(feature = "mocks"))]
126        let writer = Arc::clone(&self.writer);
127        let schema = Arc::clone(&self.schema);
128        let concurrency = self.write_concurrency;
129        let baseline_clone = baseline.clone();
130
131        // Process batches concurrently using buffer_unordered
132        let row_count = data
133            .map(move |batch_result| {
134                #[cfg(not(feature = "mocks"))]
135                let writer_clone = Arc::clone(&writer);
136                let query = query.clone();
137                let schema = Arc::clone(&schema);
138                let baseline = baseline_clone.clone();
139
140                async move {
141                    let batch = batch_result?;
142
143                    // Runtime schema validation
144                    let sink_fields = schema.fields();
145                    let input_fields = batch.schema_ref().fields();
146                    if sink_fields.len() != input_fields.len() {
147                        let (input_len, sink_len) = (input_fields.len(), sink_fields.len());
148                        return exec_err!(
149                            "Schema fields must match, input has {input_len} fields, sink \
150                             {sink_len}"
151                        );
152                    }
153
154                    for field in sink_fields {
155                        let name = field.name();
156                        let data_type = field.data_type();
157                        let is_nullable = field.is_nullable();
158
159                        let Some((_, input_field)) = input_fields.find(name) else {
160                            return exec_err!("Sink field {name} missing from input");
161                        };
162
163                        if data_type != input_field.data_type() {
164                            return exec_err!(
165                                "Sink field {name} expected data type {data_type:?} but found {:?}",
166                                input_field.data_type()
167                            );
168                        }
169
170                        if is_nullable != input_field.is_nullable() {
171                            return exec_err!(
172                                "Sink field {name} expected nullability {is_nullable} but found {}",
173                                input_field.is_nullable()
174                            );
175                        }
176                    }
177
178                    let num_rows = batch.num_rows();
179
180                    #[cfg(not(feature = "mocks"))]
181                    {
182                        let pool_conn = writer_clone
183                            .pool()
184                            .get()
185                            .await
186                            .map_err(|e| DataFusionError::External(Box::new(e)))?;
187
188                        let mut results = pool_conn
189                            .insert(&query, batch, None)
190                            .await
191                            .map_err(|e| DataFusionError::External(Box::new(e)))?;
192
193                        // Drain the result stream to ensure the insert completes
194                        while let Some(result) = results.next().await {
195                            result.map_err(|e| DataFusionError::External(Box::new(e)))?;
196                        }
197                    }
198
199                    #[cfg(feature = "mocks")]
200                    eprintln!("Mocking query: {query}");
201
202                    baseline.record_output(num_rows);
203                    Ok(num_rows as u64)
204                }
205            })
206            .buffer_unordered(concurrency)
207            .try_fold(0u64, |acc, rows| async move { Ok(acc + rows) })
208            .await?;
209
210        Ok(row_count)
211    }
212}
213
214#[cfg(all(test, feature = "mocks"))]
215mod tests {
216    use std::sync::Arc;
217
218    use datafusion::arrow::datatypes::{DataType, Field, Schema};
219    use datafusion::datasource::sink::DataSink;
220    use datafusion::sql::TableReference;
221
222    use super::*;
223
224    fn create_test_sink() -> ClickHouseDataSink {
225        let schema = Arc::new(Schema::new(vec![
226            Field::new("id", DataType::Int32, false),
227            Field::new("name", DataType::Utf8, true),
228            Field::new("value", DataType::Float64, false),
229        ]));
230
231        // Create a test pool - uses mock when available
232        let pool = Arc::new(ClickHouseConnectionPool::new("test_pool", ()));
233
234        ClickHouseDataSink::new(pool, TableReference::bare("test_table"), schema)
235    }
236
237    #[test]
238    fn test_verify_input_schema_valid() {
239        let sink = create_test_sink();
240        let input = Arc::new(Schema::new(vec![
241            Field::new("id", DataType::Int32, false),
242            Field::new("name", DataType::Utf8, true),
243            Field::new("value", DataType::Float64, false),
244        ]));
245
246        assert!(sink.verify_input_schema(&input).is_ok());
247    }
248
249    #[test]
250    fn test_verify_input_schema_field_count_mismatch() {
251        let sink = create_test_sink();
252        let input = Arc::new(Schema::new(vec![
253            Field::new("id", DataType::Int32, false),
254            Field::new("name", DataType::Utf8, true),
255        ]));
256
257        let result = sink.verify_input_schema(&input);
258        assert!(result.is_err());
259        let err = result.unwrap_err().to_string();
260        assert!(err.contains("Schema fields must match"));
261        assert!(err.contains("input has 2 fields, sink 3"));
262    }
263
264    #[test]
265    fn test_verify_input_schema_missing_field() {
266        let sink = create_test_sink();
267        let input = Arc::new(Schema::new(vec![
268            Field::new("id", DataType::Int32, false),
269            Field::new("wrong_name", DataType::Utf8, true),
270            Field::new("value", DataType::Float64, false),
271        ]));
272
273        let result = sink.verify_input_schema(&input);
274        assert!(result.is_err());
275        let err = result.unwrap_err().to_string();
276        assert!(err.contains("missing from input"));
277    }
278
279    #[test]
280    fn test_verify_input_schema_data_type_mismatch() {
281        let sink = create_test_sink();
282        let input = Arc::new(Schema::new(vec![
283            Field::new("id", DataType::Int64, false), // Wrong type
284            Field::new("name", DataType::Utf8, true),
285            Field::new("value", DataType::Float64, false),
286        ]));
287
288        let result = sink.verify_input_schema(&input);
289        assert!(result.is_err());
290        let err = result.unwrap_err().to_string();
291        assert!(err.contains("expected data type"));
292    }
293
294    #[test]
295    fn test_verify_input_schema_nullability_mismatch() {
296        let sink = create_test_sink();
297        let input = Arc::new(Schema::new(vec![
298            Field::new("id", DataType::Int32, true), // Wrong nullability
299            Field::new("name", DataType::Utf8, true),
300            Field::new("value", DataType::Float64, false),
301        ]));
302
303        let result = sink.verify_input_schema(&input);
304        assert!(result.is_err());
305        let err = result.unwrap_err().to_string();
306        assert!(err.contains("expected nullability"));
307    }
308
309    #[test]
310    fn test_new_sink() {
311        let sink = create_test_sink();
312        // Verify sink was created successfully
313        assert_eq!(sink.write_concurrency, 4);
314        assert_eq!(sink.table, TableReference::bare("test_table"));
315    }
316
317    #[test]
318    fn test_as_any() {
319        let sink = create_test_sink();
320        let any = sink.as_any();
321        assert!(any.downcast_ref::<ClickHouseDataSink>().is_some());
322    }
323
324    #[test]
325    fn test_schema() {
326        let sink = create_test_sink();
327        let schema = sink.schema();
328        assert_eq!(schema.fields().len(), 3);
329        assert_eq!(schema.field(0).name(), "id");
330        assert_eq!(schema.field(1).name(), "name");
331        assert_eq!(schema.field(2).name(), "value");
332    }
333
334    #[test]
335    fn test_metrics() {
336        let sink = create_test_sink();
337        let metrics = sink.metrics();
338        assert!(metrics.is_some());
339    }
340}