1use std::{collections::HashMap, fs, io::BufReader, sync::Arc, time::Duration};
14
15use anyhow::{anyhow, Result};
16
17use apache_avro::{from_avro_datum, schema::RecordSchema, types::Value, Schema};
18use chrono::DateTime;
19use chrono_tz::{self};
20
21use clickhouse_rs::types::{DateTimeType, SqlType, Value as CHValue};
22use reqwest::Url;
23use schema_registry_api::{SchemaRegistry, SchemaVersion, SubjectName};
24use serde::Deserialize;
25use uuid::Uuid;
26
27use super::{Row, CONFLUENT_HEADER_LEN};
28
29#[derive(Deserialize)]
30pub struct Settings {
31 pub field_names: Option<HashMap<String, String>>,
32 pub include_fields: Option<Vec<String>>,
33 pub exclude_fields: Option<Vec<String>>,
34 pub schema_file: Option<String>,
35 pub registry_url: Option<String>,
36}
37
38pub struct Decoder {
39 schema: Schema,
40 array_types: TypeMapping,
41 map_types: TypeMapping,
42 name_overrides: Vec<(String, String)>,
43 include_fields: Vec<String>,
44 exclude_fields: Vec<String>,
45 null_values: NullValues,
46}
47
48struct TypeMapping(Vec<(String, &'static SqlType)>);
49struct NullValues(Vec<(String, CHValue)>);
50
51impl Decoder {
52 fn avro2ch(&self, column_name: &str, v: Value) -> Result<CHValue, anyhow::Error> {
53 match v {
54 Value::Null => Err(anyhow!("unexpected null")),
55 Value::Record(_) => Err(anyhow!("unsupported nested record")),
56 Value::Boolean(x) => Ok(CHValue::from(x)),
57 Value::Int(x) => Ok(CHValue::from(x)),
58 Value::Long(x) => Ok(CHValue::from(x)),
59 Value::Float(x) => Ok(CHValue::from(x)),
60 Value::Double(x) => Ok(CHValue::from(x)),
61 Value::Bytes(x) => Ok(CHValue::from(x)),
62 Value::String(x) => Ok(CHValue::from(x)),
63 Value::Fixed(_, x) => Ok(CHValue::String(Arc::new(x))),
64 Value::Enum(_, y) => Ok(CHValue::from(y)),
65 Value::Union(_, v) => match *v {
66 Value::Null => match self.null_values.0.iter().find(|(n, _)| column_name.eq(n)) {
67 None => Err(anyhow!("cannot find nullable field {}", column_name)),
68 Some(x) => Ok(x.1.clone()),
69 },
70 v => self.avro2ch(column_name, v),
71 },
72 Value::Array(x) => {
73 let mut arr: Vec<CHValue> = Vec::new();
74 for elem in x {
75 arr.push(self.avro2ch(column_name, elem)?)
76 }
77 let column_type = self.array_types.get_type(column_name).unwrap();
78 Ok(CHValue::Array(column_type, Arc::new(arr)))
79 }
80 Value::Map(x) => {
81 let mut m: HashMap<CHValue, CHValue> = HashMap::new();
82 for (k, v) in x {
83 m.insert(CHValue::from(k), self.avro2ch(column_name, v)?);
84 }
85 let column_type = self.map_types.get_type(column_name).unwrap();
86 Ok(CHValue::Map(&SqlType::String, &column_type, Arc::new(m)))
87 }
88 Value::Date(x) => Ok(CHValue::Date(x as u16)),
89 Value::Decimal(_) => Err(anyhow!("unsupported decimal type")),
90 Value::TimeMillis(x) => Ok(CHValue::from(x)),
91 Value::TimeMicros(x) => Ok(CHValue::from(x)),
92 Value::TimestampMillis(x) => {
93 Ok(CHValue::from(DateTime::from_timestamp_millis(x).unwrap()))
94 }
95 Value::TimestampMicros(x) => {
96 Ok(CHValue::from(DateTime::from_timestamp_micros(x).unwrap()))
97 }
98 Value::LocalTimestampMillis(x) => {
99 Ok(CHValue::from(DateTime::from_timestamp_millis(x).unwrap()))
100 }
101 Value::LocalTimestampMicros(x) => {
102 Ok(CHValue::from(DateTime::from_timestamp_micros(x).unwrap()))
103 }
104 Value::Duration(x) => {
105 let duration = Duration::from_millis(u32::from(x.millis()) as u64)
107 + Duration::from_secs(86400 * u32::from(x.days()) as u64)
108 + Duration::from_secs(30 * 86400 * u32::from(x.months()) as u64);
109 Ok(CHValue::from(duration.as_millis() as u64))
110 }
111 Value::Uuid(x) => Ok(CHValue::from(x)),
112 }
113 }
114}
115
116impl super::Decoder for Decoder {
117 fn get_name(&self) -> String {
118 String::from("avro")
119 }
120
121 fn decode(&self, message: &[u8]) -> Result<Row> {
122 let mut datum = BufReader::new(&message[CONFLUENT_HEADER_LEN..]);
123 let record;
124 let mut row = Row::new();
125 match from_avro_datum(&self.schema, &mut datum, None)? {
126 Value::Record(x) => record = x,
127 _ => return Err(anyhow!("avro message must be a record")),
128 };
129 for (column, value) in record {
130 if self.exclude_fields.iter().any(|c| c == column.as_str()) {
131 continue;
132 }
133 if !self.include_fields.is_empty()
134 && !self.include_fields.iter().any(|c| c == column.as_str())
135 {
136 continue;
137 }
138 let v = self.avro2ch(&column, value)?;
139 let column_name = match self.name_overrides.iter().find(|m| m.0 == column) {
140 None => column,
141 Some((_, n)) => n.to_owned(),
142 };
143 row.push((column_name, v));
144 }
145 Ok(row)
146 }
147}
148
149impl TypeMapping {
150 fn new(r: &RecordSchema) -> Result<(Self, Self)> {
151 let mut map_types: Vec<(String, &SqlType)> = Vec::new();
152 let mut arr_types: Vec<(String, &SqlType)> = Vec::new();
153 for field in &r.fields {
154 match &field.schema {
155 Schema::Array(v) => arr_types.push((field.name.clone(), get_schema_type(v)?.0)),
156 Schema::Map(v) => map_types.push((field.name.clone(), get_schema_type(v)?.0)),
157 _ => (),
158 };
159 }
160 Ok((TypeMapping(arr_types), TypeMapping(map_types)))
161 }
162
163 fn get_type(&self, column: &str) -> Option<&'static SqlType> {
164 match self.0.iter().find(|x| x.0 == column) {
165 None => None,
166 Some((_, t)) => Some(t),
167 }
168 }
169}
170
171pub async fn new(topic: &str, settings: Settings) -> Result<Decoder> {
172 let schema = get_schema(&topic, &settings).await?;
173 let mut name_overrides: Vec<(String, String)> = Vec::new();
174 let mut include_fields: Vec<String> = Vec::new();
175 let mut exclude_fields: Vec<String> = Vec::new();
176 if let Some(names) = settings.field_names {
177 names
178 .iter()
179 .for_each(|(k, v)| name_overrides.push((k.to_owned(), v.to_owned())));
180 }
181 if let Some(flds) = settings.include_fields {
182 flds.iter().for_each(|e| include_fields.push(e.to_owned()));
183 }
184 if let Some(flds) = settings.exclude_fields {
185 flds.iter().for_each(|e| exclude_fields.push(e.to_owned()));
186 }
187 match schema {
188 Schema::Record(record) => {
189 let (array_types, map_types, null_values) = analyze_schema(&record)?;
190 Ok(Decoder {
191 array_types,
192 map_types,
193 schema: Schema::Record(record),
194 name_overrides,
195 include_fields,
196 exclude_fields,
197 null_values,
198 })
199 }
200 _ => Err(anyhow!("avro schema root must be a record")),
201 }
202}
203
204async fn get_schema(topic: &str, settings: &Settings) -> Result<Schema> {
205 match &settings.schema_file {
206 Some(f) => Ok(Schema::parse_str(&fs::read_to_string(f)?)?),
207 None => match &settings.registry_url {
208 None => Err(anyhow!("registry_url or schema_file must be specified")),
209 Some(registry_url) => {
210 let sr_client = SchemaRegistry::build_default(Url::parse(®istry_url)?)?;
211 let subject_name = format!("{topic}-value").parse::<SubjectName>()?;
212 let subject = sr_client
213 .subject()
214 .version(&subject_name, SchemaVersion::Latest)
215 .await?;
216 match subject {
217 None => Err(anyhow!("subject {} not found", subject_name)),
218 Some(s) => Ok(Schema::parse_str(&s.schema)?),
219 }
220 }
221 },
222 }
223}
224
225fn analyze_schema(s: &RecordSchema) -> Result<(TypeMapping, TypeMapping, NullValues)> {
228 let (array_types, map_types) = TypeMapping::new(s)?;
229 let mut null_values: Vec<(String, CHValue)> = Vec::new();
230 for fld in &s.fields {
231 match &fld.schema {
232 Schema::Record(_) => {
233 return Err(anyhow!(
234 "field {}: nested records are not supported",
235 fld.name
236 ))
237 }
238 Schema::Union(union) => {
239 let schemas = union.variants();
240 if schemas.len() != 2 {
241 return Err(anyhow!(
242 "field {}: only supported union type is [null, <type>]",
243 fld.name
244 ));
245 }
246 match schemas[0] {
247 Schema::Null => (),
248 _ => {
249 return Err(anyhow!(
250 "field {}: only supported union type is [null, <type>]",
251 fld.name
252 ))
253 }
254 }
255 null_values.push((fld.name.clone(), get_schema_type(&schemas[1])?.1))
256 }
257 _ => (),
258 }
259 }
260 Ok((array_types, map_types, NullValues(null_values)))
261}
262
263fn get_schema_type(s: &Schema) -> Result<(&'static SqlType, CHValue)> {
266 match s {
267 Schema::Boolean => Ok((&SqlType::Bool, CHValue::from(false))),
268 Schema::Int => Ok((&SqlType::Int32, CHValue::Int32(0))),
269 Schema::Long => Ok((&SqlType::Int64, CHValue::Int64(0))),
270 Schema::Float => Ok((&SqlType::Float32, CHValue::Float32(0.0))),
271 Schema::Double => Ok((&SqlType::Float64, CHValue::Float64(0.0))),
272 Schema::Bytes => Ok((&SqlType::String, CHValue::from(Vec::<u8>::new()))),
273 Schema::String => Ok((&SqlType::String, CHValue::from(String::new()))),
274 Schema::Uuid => Ok((&SqlType::Uuid, CHValue::from(Uuid::nil()))),
275 Schema::Date => Ok((&SqlType::Date, CHValue::Date(0u16))),
277 Schema::TimeMillis => Ok((&SqlType::Int32, CHValue::DateTime(0, chrono_tz::UTC))),
278 Schema::TimeMicros => Ok((&SqlType::Int32, CHValue::DateTime(0, chrono_tz::UTC))),
279 Schema::TimestampMillis => Ok((
280 &SqlType::DateTime(DateTimeType::DateTime64(3, chrono_tz::UTC)),
281 CHValue::DateTime64(0, (3, chrono_tz::UTC)),
282 )),
283 Schema::TimestampMicros => Ok((
284 &SqlType::DateTime(DateTimeType::DateTime64(6, chrono_tz::UTC)),
285 CHValue::DateTime64(0, (6, chrono_tz::UTC)),
286 )),
287 Schema::LocalTimestampMillis => Ok((
288 &SqlType::DateTime(DateTimeType::DateTime64(3, chrono_tz::UTC)),
289 CHValue::DateTime64(0, (3, chrono_tz::UTC)),
290 )),
291 Schema::LocalTimestampMicros => Ok((
292 &SqlType::DateTime(DateTimeType::DateTime64(6, chrono_tz::UTC)),
293 CHValue::DateTime64(0, (6, chrono_tz::UTC)),
294 )),
295 Schema::Duration => Ok((&SqlType::Int64, CHValue::UInt64(0))),
296 _ => Err(anyhow!("unsupported type")),
297 }
298}