1use std::fmt;
2use std::sync::Arc;
3
4use datafusion::arrow::datatypes::SchemaRef;
5use datafusion::common::exec_err;
6use datafusion::error::Result;
7use datafusion::execution::SendableRecordBatchStream;
8use datafusion::physical_plan::DisplayAs;
9use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
10use datafusion::sql::TableReference;
11use futures_util::{StreamExt, TryStreamExt};
12
13use crate::connection::ClickHouseConnectionPool;
14
15#[derive(Debug)]
18pub struct ClickHouseDataSink {
19 #[cfg_attr(feature = "mocks", expect(unused))]
20 writer: Arc<ClickHouseConnectionPool>,
21 table: TableReference,
22 schema: SchemaRef,
23 metrics: ExecutionPlanMetricsSet,
24 write_concurrency: usize,
25}
26
27impl ClickHouseDataSink {
28 pub fn new(
29 writer: Arc<ClickHouseConnectionPool>,
30 table: TableReference,
31 schema: SchemaRef,
32 ) -> Self {
33 let write_concurrency = writer.write_concurrency();
34 Self { writer, table, schema, metrics: ExecutionPlanMetricsSet::new(), write_concurrency }
35 }
36
37 pub fn verify_input_schema(&self, input: &SchemaRef) -> Result<()> {
47 let sink_fields = self.schema.fields();
48 let input_fields = input.fields();
49 if sink_fields.len() != input_fields.len() {
50 let (input_len, sink_len) = (input_fields.len(), sink_fields.len());
51 return exec_err!(
52 "Schema fields must match, input has {input_len} fields, sink {sink_len}"
53 );
54 }
55
56 for field in sink_fields {
57 let name = field.name();
58 let data_type = field.data_type();
59 let is_nullable = field.is_nullable();
60
61 let Some((_, input_field)) = input_fields.find(name) else {
62 return exec_err!("Sink field {name} missing from input");
63 };
64
65 if data_type != input_field.data_type() {
66 return exec_err!(
67 "Sink field {name} expected data type {data_type:?} but found {:?}",
68 input_field.data_type()
69 );
70 }
71
72 if is_nullable != input_field.is_nullable() {
73 return exec_err!(
74 "Sink field {name} expected nullability {is_nullable} but found {}",
75 input_field.is_nullable()
76 );
77 }
78 }
79
80 Ok(())
81 }
82}
83
84impl DisplayAs for ClickHouseDataSink {
85 fn fmt_as(
86 &self,
87 _t: datafusion::physical_plan::DisplayFormatType,
88 f: &mut fmt::Formatter<'_>,
89 ) -> fmt::Result {
90 write!(f, "ClickHouseDataSink: table={}", self.table)
91 }
92}
93
94#[async_trait::async_trait]
95impl datafusion::datasource::sink::DataSink for ClickHouseDataSink {
96 fn as_any(&self) -> &dyn std::any::Any { self }
97
98 fn schema(&self) -> &SchemaRef { &self.schema }
99
100 fn metrics(&self) -> Option<MetricsSet> { Some(self.metrics.clone_inner()) }
101
102 async fn write_all(
103 &self,
104 data: SendableRecordBatchStream,
105 _context: &Arc<datafusion::execution::TaskContext>,
106 ) -> Result<u64> {
107 #[cfg(not(feature = "mocks"))]
108 use datafusion::error::DataFusionError;
109
110 let partition = 0;
113 let baseline = BaselineMetrics::new(&self.metrics, partition);
114 let _timer = baseline.elapsed_compute().timer();
115
116 let db = self.table.schema();
117 let table = self.table.table();
118
119 let query = if let Some(db) = db {
120 format!("INSERT INTO {db}.{table} FORMAT Native")
121 } else {
122 format!("INSERT INTO {table} FORMAT Native")
123 };
124
125 #[cfg(not(feature = "mocks"))]
126 let writer = Arc::clone(&self.writer);
127 let schema = Arc::clone(&self.schema);
128 let concurrency = self.write_concurrency;
129 let baseline_clone = baseline.clone();
130
131 let row_count = data
133 .map(move |batch_result| {
134 #[cfg(not(feature = "mocks"))]
135 let writer_clone = Arc::clone(&writer);
136 let query = query.clone();
137 let schema = Arc::clone(&schema);
138 let baseline = baseline_clone.clone();
139
140 async move {
141 let batch = batch_result?;
142
143 let sink_fields = schema.fields();
145 let input_fields = batch.schema_ref().fields();
146 if sink_fields.len() != input_fields.len() {
147 let (input_len, sink_len) = (input_fields.len(), sink_fields.len());
148 return exec_err!(
149 "Schema fields must match, input has {input_len} fields, sink \
150 {sink_len}"
151 );
152 }
153
154 for field in sink_fields {
155 let name = field.name();
156 let data_type = field.data_type();
157 let is_nullable = field.is_nullable();
158
159 let Some((_, input_field)) = input_fields.find(name) else {
160 return exec_err!("Sink field {name} missing from input");
161 };
162
163 if data_type != input_field.data_type() {
164 return exec_err!(
165 "Sink field {name} expected data type {data_type:?} but found {:?}",
166 input_field.data_type()
167 );
168 }
169
170 if is_nullable != input_field.is_nullable() {
171 return exec_err!(
172 "Sink field {name} expected nullability {is_nullable} but found {}",
173 input_field.is_nullable()
174 );
175 }
176 }
177
178 let num_rows = batch.num_rows();
179
180 #[cfg(not(feature = "mocks"))]
181 {
182 let pool_conn = writer_clone
183 .pool()
184 .get()
185 .await
186 .map_err(|e| DataFusionError::External(Box::new(e)))?;
187
188 let mut results = pool_conn
189 .insert(&query, batch, None)
190 .await
191 .map_err(|e| DataFusionError::External(Box::new(e)))?;
192
193 while let Some(result) = results.next().await {
195 result.map_err(|e| DataFusionError::External(Box::new(e)))?;
196 }
197 }
198
199 #[cfg(feature = "mocks")]
200 eprintln!("Mocking query: {query}");
201
202 baseline.record_output(num_rows);
203 Ok(num_rows as u64)
204 }
205 })
206 .buffer_unordered(concurrency)
207 .try_fold(0u64, |acc, rows| async move { Ok(acc + rows) })
208 .await?;
209
210 Ok(row_count)
211 }
212}
213
214#[cfg(all(test, feature = "mocks"))]
215mod tests {
216 use std::sync::Arc;
217
218 use datafusion::arrow::datatypes::{DataType, Field, Schema};
219 use datafusion::datasource::sink::DataSink;
220 use datafusion::sql::TableReference;
221
222 use super::*;
223
224 fn create_test_sink() -> ClickHouseDataSink {
225 let schema = Arc::new(Schema::new(vec![
226 Field::new("id", DataType::Int32, false),
227 Field::new("name", DataType::Utf8, true),
228 Field::new("value", DataType::Float64, false),
229 ]));
230
231 let pool = Arc::new(ClickHouseConnectionPool::new("test_pool", ()));
233
234 ClickHouseDataSink::new(pool, TableReference::bare("test_table"), schema)
235 }
236
237 #[test]
238 fn test_verify_input_schema_valid() {
239 let sink = create_test_sink();
240 let input = Arc::new(Schema::new(vec![
241 Field::new("id", DataType::Int32, false),
242 Field::new("name", DataType::Utf8, true),
243 Field::new("value", DataType::Float64, false),
244 ]));
245
246 assert!(sink.verify_input_schema(&input).is_ok());
247 }
248
249 #[test]
250 fn test_verify_input_schema_field_count_mismatch() {
251 let sink = create_test_sink();
252 let input = Arc::new(Schema::new(vec![
253 Field::new("id", DataType::Int32, false),
254 Field::new("name", DataType::Utf8, true),
255 ]));
256
257 let result = sink.verify_input_schema(&input);
258 assert!(result.is_err());
259 let err = result.unwrap_err().to_string();
260 assert!(err.contains("Schema fields must match"));
261 assert!(err.contains("input has 2 fields, sink 3"));
262 }
263
264 #[test]
265 fn test_verify_input_schema_missing_field() {
266 let sink = create_test_sink();
267 let input = Arc::new(Schema::new(vec![
268 Field::new("id", DataType::Int32, false),
269 Field::new("wrong_name", DataType::Utf8, true),
270 Field::new("value", DataType::Float64, false),
271 ]));
272
273 let result = sink.verify_input_schema(&input);
274 assert!(result.is_err());
275 let err = result.unwrap_err().to_string();
276 assert!(err.contains("missing from input"));
277 }
278
279 #[test]
280 fn test_verify_input_schema_data_type_mismatch() {
281 let sink = create_test_sink();
282 let input = Arc::new(Schema::new(vec![
283 Field::new("id", DataType::Int64, false), Field::new("name", DataType::Utf8, true),
285 Field::new("value", DataType::Float64, false),
286 ]));
287
288 let result = sink.verify_input_schema(&input);
289 assert!(result.is_err());
290 let err = result.unwrap_err().to_string();
291 assert!(err.contains("expected data type"));
292 }
293
294 #[test]
295 fn test_verify_input_schema_nullability_mismatch() {
296 let sink = create_test_sink();
297 let input = Arc::new(Schema::new(vec![
298 Field::new("id", DataType::Int32, true), Field::new("name", DataType::Utf8, true),
300 Field::new("value", DataType::Float64, false),
301 ]));
302
303 let result = sink.verify_input_schema(&input);
304 assert!(result.is_err());
305 let err = result.unwrap_err().to_string();
306 assert!(err.contains("expected nullability"));
307 }
308
309 #[test]
310 fn test_new_sink() {
311 let sink = create_test_sink();
312 assert_eq!(sink.write_concurrency, 4);
314 assert_eq!(sink.table, TableReference::bare("test_table"));
315 }
316
317 #[test]
318 fn test_as_any() {
319 let sink = create_test_sink();
320 let any = sink.as_any();
321 assert!(any.downcast_ref::<ClickHouseDataSink>().is_some());
322 }
323
324 #[test]
325 fn test_schema() {
326 let sink = create_test_sink();
327 let schema = sink.schema();
328 assert_eq!(schema.fields().len(), 3);
329 assert_eq!(schema.field(0).name(), "id");
330 assert_eq!(schema.field(1).name(), "name");
331 assert_eq!(schema.field(2).name(), "value");
332 }
333
334 #[test]
335 fn test_metrics() {
336 let sink = create_test_sink();
337 let metrics = sink.metrics();
338 assert!(metrics.is_some());
339 }
340}