clickhouse_datafusion/
sink.rs

1use std::fmt;
2use std::sync::Arc;
3
4use datafusion::arrow::datatypes::SchemaRef;
5use datafusion::common::exec_err;
6use datafusion::error::Result;
7use datafusion::execution::SendableRecordBatchStream;
8use datafusion::physical_plan::DisplayAs;
9use datafusion::sql::TableReference;
10use futures_util::StreamExt;
11
12use crate::connection::ClickHouseConnectionPool;
13
14// TODO: Docs
15/// [`datafusion::datasource::sink::DataSink`] for `ClickHouse`
16#[derive(Debug)]
17pub struct ClickHouseDataSink {
18    #[cfg_attr(feature = "mocks", expect(unused))]
19    writer: Arc<ClickHouseConnectionPool>,
20    table:  TableReference,
21    schema: SchemaRef,
22}
23
24impl ClickHouseDataSink {
25    pub fn new(
26        writer: Arc<ClickHouseConnectionPool>,
27        table: TableReference,
28        schema: SchemaRef,
29    ) -> Self {
30        Self { writer, table, schema }
31    }
32
33    /// Verify that a passed in schema aligns with the sink schema
34    ///
35    /// Ordering and metadata don't matter
36    ///
37    /// # Errors
38    /// - Returns an error if the field lengths don't match
39    /// - Returns an error if data types don't match
40    /// - Returns an error if names don't match
41    /// - Returns an error if nullability doesn't match
42    pub fn verify_input_schema(&self, input: &SchemaRef) -> Result<()> {
43        let sink_fields = self.schema.fields();
44        let input_fields = input.fields();
45        if sink_fields.len() != input_fields.len() {
46            let (input_len, sink_len) = (input_fields.len(), sink_fields.len());
47            return exec_err!(
48                "Schema fields must match, input has {input_len} fields, sink {sink_len}"
49            );
50        }
51
52        for field in sink_fields {
53            let name = field.name();
54            let data_type = field.data_type();
55            let is_nullable = field.is_nullable();
56
57            let Some((_, input_field)) = input_fields.find(name) else {
58                return exec_err!("Sink field {name} missing from input");
59            };
60
61            if data_type != input_field.data_type() {
62                return exec_err!(
63                    "Sink field {name} expected data type {data_type:?} but found {:?}",
64                    input_field.data_type()
65                );
66            }
67
68            if is_nullable != input_field.is_nullable() {
69                return exec_err!(
70                    "Sink field {name} expected nullability {is_nullable} but found {}",
71                    input_field.is_nullable()
72                );
73            }
74        }
75
76        Ok(())
77    }
78}
79
80impl DisplayAs for ClickHouseDataSink {
81    fn fmt_as(
82        &self,
83        _t: datafusion::physical_plan::DisplayFormatType,
84        f: &mut fmt::Formatter<'_>,
85    ) -> fmt::Result {
86        write!(f, "ClickHouseDataSink: table={}", self.table)
87    }
88}
89
90#[async_trait::async_trait]
91impl datafusion::datasource::sink::DataSink for ClickHouseDataSink {
92    fn as_any(&self) -> &dyn std::any::Any { self }
93
94    fn schema(&self) -> &SchemaRef { &self.schema }
95
96    async fn write_all(
97        &self,
98        mut data: SendableRecordBatchStream,
99        _context: &Arc<datafusion::execution::TaskContext>,
100    ) -> Result<u64> {
101        #[cfg(not(feature = "mocks"))]
102        use datafusion::error::DataFusionError;
103
104        let db = self.table.schema();
105        let table = self.table.table();
106
107        let query = if let Some(db) = db {
108            format!("INSERT INTO {db}.{table} FORMAT Native")
109        } else {
110            format!("INSERT INTO {table} FORMAT Native")
111        };
112
113        let mut row_count = 0;
114
115        #[cfg(not(feature = "mocks"))]
116        let pool =
117            self.writer.pool().get().await.map_err(|e| DataFusionError::External(Box::new(e)))?;
118
119        while let Some(batch) = data.next().await.transpose()? {
120            // Runtime schema validation
121            self.verify_input_schema(batch.schema_ref())?;
122
123            let num_rows = batch.num_rows();
124
125            #[cfg(not(feature = "mocks"))]
126            let mut results = pool
127                .insert(&query, batch, None)
128                .await
129                .map_err(|e| DataFusionError::External(Box::new(e)))?;
130
131            #[cfg(feature = "mocks")]
132            eprintln!("Mocking query: {query}");
133
134            // Drain the result stream to ensure the insert completes
135            #[cfg(not(feature = "mocks"))]
136            while let Some(result) = results.next().await {
137                result.map_err(|e| DataFusionError::External(Box::new(e)))?;
138            }
139
140            row_count += num_rows as u64;
141        }
142
143        Ok(row_count)
144    }
145}