clickhouse_datafusion/
sink.rs1use std::fmt;
2use std::sync::Arc;
3
4use datafusion::arrow::datatypes::SchemaRef;
5use datafusion::common::exec_err;
6use datafusion::error::Result;
7use datafusion::execution::SendableRecordBatchStream;
8use datafusion::physical_plan::DisplayAs;
9use datafusion::sql::TableReference;
10use futures_util::StreamExt;
11
12use crate::connection::ClickHouseConnectionPool;
13
14#[derive(Debug)]
17pub struct ClickHouseDataSink {
18 #[cfg_attr(feature = "mocks", expect(unused))]
19 writer: Arc<ClickHouseConnectionPool>,
20 table: TableReference,
21 schema: SchemaRef,
22}
23
24impl ClickHouseDataSink {
25 pub fn new(
26 writer: Arc<ClickHouseConnectionPool>,
27 table: TableReference,
28 schema: SchemaRef,
29 ) -> Self {
30 Self { writer, table, schema }
31 }
32
33 pub fn verify_input_schema(&self, input: &SchemaRef) -> Result<()> {
43 let sink_fields = self.schema.fields();
44 let input_fields = input.fields();
45 if sink_fields.len() != input_fields.len() {
46 let (input_len, sink_len) = (input_fields.len(), sink_fields.len());
47 return exec_err!(
48 "Schema fields must match, input has {input_len} fields, sink {sink_len}"
49 );
50 }
51
52 for field in sink_fields {
53 let name = field.name();
54 let data_type = field.data_type();
55 let is_nullable = field.is_nullable();
56
57 let Some((_, input_field)) = input_fields.find(name) else {
58 return exec_err!("Sink field {name} missing from input");
59 };
60
61 if data_type != input_field.data_type() {
62 return exec_err!(
63 "Sink field {name} expected data type {data_type:?} but found {:?}",
64 input_field.data_type()
65 );
66 }
67
68 if is_nullable != input_field.is_nullable() {
69 return exec_err!(
70 "Sink field {name} expected nullability {is_nullable} but found {}",
71 input_field.is_nullable()
72 );
73 }
74 }
75
76 Ok(())
77 }
78}
79
80impl DisplayAs for ClickHouseDataSink {
81 fn fmt_as(
82 &self,
83 _t: datafusion::physical_plan::DisplayFormatType,
84 f: &mut fmt::Formatter<'_>,
85 ) -> fmt::Result {
86 write!(f, "ClickHouseDataSink: table={}", self.table)
87 }
88}
89
90#[async_trait::async_trait]
91impl datafusion::datasource::sink::DataSink for ClickHouseDataSink {
92 fn as_any(&self) -> &dyn std::any::Any { self }
93
94 fn schema(&self) -> &SchemaRef { &self.schema }
95
96 async fn write_all(
97 &self,
98 mut data: SendableRecordBatchStream,
99 _context: &Arc<datafusion::execution::TaskContext>,
100 ) -> Result<u64> {
101 #[cfg(not(feature = "mocks"))]
102 use datafusion::error::DataFusionError;
103
104 let db = self.table.schema();
105 let table = self.table.table();
106
107 let query = if let Some(db) = db {
108 format!("INSERT INTO {db}.{table} FORMAT Native")
109 } else {
110 format!("INSERT INTO {table} FORMAT Native")
111 };
112
113 let mut row_count = 0;
114
115 #[cfg(not(feature = "mocks"))]
116 let pool =
117 self.writer.pool().get().await.map_err(|e| DataFusionError::External(Box::new(e)))?;
118
119 while let Some(batch) = data.next().await.transpose()? {
120 self.verify_input_schema(batch.schema_ref())?;
122
123 let num_rows = batch.num_rows();
124
125 #[cfg(not(feature = "mocks"))]
126 let mut results = pool
127 .insert(&query, batch, None)
128 .await
129 .map_err(|e| DataFusionError::External(Box::new(e)))?;
130
131 #[cfg(feature = "mocks")]
132 eprintln!("Mocking query: {query}");
133
134 #[cfg(not(feature = "mocks"))]
136 while let Some(result) = results.next().await {
137 result.map_err(|e| DataFusionError::External(Box::new(e)))?;
138 }
139
140 row_count += num_rows as u64;
141 }
142
143 Ok(row_count)
144 }
145}