chafka/decoder/
avro.rs

1//! generic avro message decoder.
2//!
3//! Takes schema either from configured file, or from [Schema Registry]
4//! using Confluent-compatible message [header].
5//!
6//! Supports all primitive types, maps and arrays of primitive types, nullable values,
7//! and most of the logical types, like timestamps, UUIDs etc.
8//! Union types other than nullables, nested records, and decimals are not supported.
9//! All time types are assumed to be in UTC timezone.
10//!
11//! [Schema Registry]: https://docs.confluent.io/platform/current/schema-registry/index.html
12//! [header]: https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format
13use std::{collections::HashMap, fs, io::BufReader, sync::Arc, time::Duration};
14
15use anyhow::{anyhow, Result};
16
17use apache_avro::{from_avro_datum, schema::RecordSchema, types::Value, Schema};
18use chrono::DateTime;
19use chrono_tz::{self};
20
21use clickhouse_rs::types::{DateTimeType, SqlType, Value as CHValue};
22use reqwest::Url;
23use schema_registry_api::{SchemaRegistry, SchemaVersion, SubjectName};
24use serde::Deserialize;
25use uuid::Uuid;
26
27use super::{Row, CONFLUENT_HEADER_LEN};
28
29#[derive(Deserialize)]
30pub struct Settings {
31    pub field_names: Option<HashMap<String, String>>,
32    pub include_fields: Option<Vec<String>>,
33    pub exclude_fields: Option<Vec<String>>,
34    pub schema_file: Option<String>,
35    pub registry_url: Option<String>,
36}
37
38pub struct Decoder {
39    schema: Schema,
40    array_types: TypeMapping,
41    map_types: TypeMapping,
42    name_overrides: Vec<(String, String)>,
43    include_fields: Vec<String>,
44    exclude_fields: Vec<String>,
45    null_values: NullValues,
46}
47
48struct TypeMapping(Vec<(String, &'static SqlType)>);
49struct NullValues(Vec<(String, CHValue)>);
50
51impl Decoder {
52    fn avro2ch(&self, column_name: &str, v: Value) -> Result<CHValue, anyhow::Error> {
53        match v {
54            Value::Null => Err(anyhow!("unexpected null")),
55            Value::Record(_) => Err(anyhow!("unsupported nested record")),
56            Value::Boolean(x) => Ok(CHValue::from(x)),
57            Value::Int(x) => Ok(CHValue::from(x)),
58            Value::Long(x) => Ok(CHValue::from(x)),
59            Value::Float(x) => Ok(CHValue::from(x)),
60            Value::Double(x) => Ok(CHValue::from(x)),
61            Value::Bytes(x) => Ok(CHValue::from(x)),
62            Value::String(x) => Ok(CHValue::from(x)),
63            Value::Fixed(_, x) => Ok(CHValue::String(Arc::new(x))),
64            Value::Enum(_, y) => Ok(CHValue::from(y)),
65            Value::Union(_, v) => match *v {
66                Value::Null => match self.null_values.0.iter().find(|(n, _)| column_name.eq(n)) {
67                    None => Err(anyhow!("cannot find nullable field {}", column_name)),
68                    Some(x) => Ok(x.1.clone()),
69                },
70                v => self.avro2ch(column_name, v),
71            },
72            Value::Array(x) => {
73                let mut arr: Vec<CHValue> = Vec::new();
74                for elem in x {
75                    arr.push(self.avro2ch(column_name, elem)?)
76                }
77                let column_type = self.array_types.get_type(column_name).unwrap();
78                Ok(CHValue::Array(column_type, Arc::new(arr)))
79            }
80            Value::Map(x) => {
81                let mut m: HashMap<CHValue, CHValue> = HashMap::new();
82                for (k, v) in x {
83                    m.insert(CHValue::from(k), self.avro2ch(column_name, v)?);
84                }
85                let column_type = self.map_types.get_type(column_name).unwrap();
86                Ok(CHValue::Map(&SqlType::String, &column_type, Arc::new(m)))
87            }
88            Value::Date(x) => Ok(CHValue::Date(x as u16)),
89            Value::Decimal(_) => Err(anyhow!("unsupported decimal type")),
90            Value::TimeMillis(x) => Ok(CHValue::from(x)),
91            Value::TimeMicros(x) => Ok(CHValue::from(x)),
92            Value::TimestampMillis(x) => {
93                Ok(CHValue::from(DateTime::from_timestamp_millis(x).unwrap()))
94            }
95            Value::TimestampMicros(x) => {
96                Ok(CHValue::from(DateTime::from_timestamp_micros(x).unwrap()))
97            }
98            Value::LocalTimestampMillis(x) => {
99                Ok(CHValue::from(DateTime::from_timestamp_millis(x).unwrap()))
100            }
101            Value::LocalTimestampMicros(x) => {
102                Ok(CHValue::from(DateTime::from_timestamp_micros(x).unwrap()))
103            }
104            Value::Duration(x) => {
105                // don't ask, programmers and time ¯\_(ツ)_/¯
106                let duration = Duration::from_millis(u32::from(x.millis()) as u64)
107                    + Duration::from_secs(86400 * u32::from(x.days()) as u64)
108                    + Duration::from_secs(30 * 86400 * u32::from(x.months()) as u64);
109                Ok(CHValue::from(duration.as_millis() as u64))
110            }
111            Value::Uuid(x) => Ok(CHValue::from(x)),
112        }
113    }
114}
115
116impl super::Decoder for Decoder {
117    fn get_name(&self) -> String {
118        String::from("avro")
119    }
120
121    fn decode(&self, message: &[u8]) -> Result<Row> {
122        let mut datum = BufReader::new(&message[CONFLUENT_HEADER_LEN..]);
123        let record;
124        let mut row = Row::new();
125        match from_avro_datum(&self.schema, &mut datum, None)? {
126            Value::Record(x) => record = x,
127            _ => return Err(anyhow!("avro message must be a record")),
128        };
129        for (column, value) in record {
130            if self.exclude_fields.iter().any(|c| c == column.as_str()) {
131                continue;
132            }
133            if !self.include_fields.is_empty()
134                && !self.include_fields.iter().any(|c| c == column.as_str())
135            {
136                continue;
137            }
138            let v = self.avro2ch(&column, value)?;
139            let column_name = match self.name_overrides.iter().find(|m| m.0 == column) {
140                None => column,
141                Some((_, n)) => n.to_owned(),
142            };
143            row.push((column_name, v));
144        }
145        Ok(row)
146    }
147}
148
149impl TypeMapping {
150    fn new(r: &RecordSchema) -> Result<(Self, Self)> {
151        let mut map_types: Vec<(String, &SqlType)> = Vec::new();
152        let mut arr_types: Vec<(String, &SqlType)> = Vec::new();
153        for field in &r.fields {
154            match &field.schema {
155                Schema::Array(v) => arr_types.push((field.name.clone(), get_schema_type(v)?.0)),
156                Schema::Map(v) => map_types.push((field.name.clone(), get_schema_type(v)?.0)),
157                _ => (),
158            };
159        }
160        Ok((TypeMapping(arr_types), TypeMapping(map_types)))
161    }
162
163    fn get_type(&self, column: &str) -> Option<&'static SqlType> {
164        match self.0.iter().find(|x| x.0 == column) {
165            None => None,
166            Some((_, t)) => Some(t),
167        }
168    }
169}
170
171pub async fn new(topic: &str, settings: Settings) -> Result<Decoder> {
172    let schema = get_schema(&topic, &settings).await?;
173    let mut name_overrides: Vec<(String, String)> = Vec::new();
174    let mut include_fields: Vec<String> = Vec::new();
175    let mut exclude_fields: Vec<String> = Vec::new();
176    if let Some(names) = settings.field_names {
177        names
178            .iter()
179            .for_each(|(k, v)| name_overrides.push((k.to_owned(), v.to_owned())));
180    }
181    if let Some(flds) = settings.include_fields {
182        flds.iter().for_each(|e| include_fields.push(e.to_owned()));
183    }
184    if let Some(flds) = settings.exclude_fields {
185        flds.iter().for_each(|e| exclude_fields.push(e.to_owned()));
186    }
187    match schema {
188        Schema::Record(record) => {
189            let (array_types, map_types, null_values) = analyze_schema(&record)?;
190            Ok(Decoder {
191                array_types,
192                map_types,
193                schema: Schema::Record(record),
194                name_overrides,
195                include_fields,
196                exclude_fields,
197                null_values,
198            })
199        }
200        _ => Err(anyhow!("avro schema root must be a record")),
201    }
202}
203
204async fn get_schema(topic: &str, settings: &Settings) -> Result<Schema> {
205    match &settings.schema_file {
206        Some(f) => Ok(Schema::parse_str(&fs::read_to_string(f)?)?),
207        None => match &settings.registry_url {
208            None => Err(anyhow!("registry_url or schema_file must be specified")),
209            Some(registry_url) => {
210                let sr_client = SchemaRegistry::build_default(Url::parse(&registry_url)?)?;
211                let subject_name = format!("{topic}-value").parse::<SubjectName>()?;
212                let subject = sr_client
213                    .subject()
214                    .version(&subject_name, SchemaVersion::Latest)
215                    .await?;
216                match subject {
217                    None => Err(anyhow!("subject {} not found", subject_name)),
218                    Some(s) => Ok(Schema::parse_str(&s.schema)?),
219                }
220            }
221        },
222    }
223}
224
225/// checks schema for compatibility and returns type mappings for arrays and maps,
226/// as well as mapping of nullable fields to its zero values
227fn analyze_schema(s: &RecordSchema) -> Result<(TypeMapping, TypeMapping, NullValues)> {
228    let (array_types, map_types) = TypeMapping::new(s)?;
229    let mut null_values: Vec<(String, CHValue)> = Vec::new();
230    for fld in &s.fields {
231        match &fld.schema {
232            Schema::Record(_) => {
233                return Err(anyhow!(
234                    "field {}: nested records are not supported",
235                    fld.name
236                ))
237            }
238            Schema::Union(union) => {
239                let schemas = union.variants();
240                if schemas.len() != 2 {
241                    return Err(anyhow!(
242                        "field {}: only supported union type is [null, <type>]",
243                        fld.name
244                    ));
245                }
246                match schemas[0] {
247                    Schema::Null => (),
248                    _ => {
249                        return Err(anyhow!(
250                            "field {}: only supported union type is [null, <type>]",
251                            fld.name
252                        ))
253                    }
254                }
255                null_values.push((fld.name.clone(), get_schema_type(&schemas[1])?.1))
256            }
257            _ => (),
258        }
259    }
260    Ok((array_types, map_types, NullValues(null_values)))
261}
262
263/// translates avro type into clickhouse type
264/// and returns relevant SqlType and its zero value
265fn get_schema_type(s: &Schema) -> Result<(&'static SqlType, CHValue)> {
266    match s {
267        Schema::Boolean => Ok((&SqlType::Bool, CHValue::from(false))),
268        Schema::Int => Ok((&SqlType::Int32, CHValue::Int32(0))),
269        Schema::Long => Ok((&SqlType::Int64, CHValue::Int64(0))),
270        Schema::Float => Ok((&SqlType::Float32, CHValue::Float32(0.0))),
271        Schema::Double => Ok((&SqlType::Float64, CHValue::Float64(0.0))),
272        Schema::Bytes => Ok((&SqlType::String, CHValue::from(Vec::<u8>::new()))),
273        Schema::String => Ok((&SqlType::String, CHValue::from(String::new()))),
274        Schema::Uuid => Ok((&SqlType::Uuid, CHValue::from(Uuid::nil()))),
275        //Schema::Fixed(s) => Ok((&SqlType::FixedString(s.size))),
276        Schema::Date => Ok((&SqlType::Date, CHValue::Date(0u16))),
277        Schema::TimeMillis => Ok((&SqlType::Int32, CHValue::DateTime(0, chrono_tz::UTC))),
278        Schema::TimeMicros => Ok((&SqlType::Int32, CHValue::DateTime(0, chrono_tz::UTC))),
279        Schema::TimestampMillis => Ok((
280            &SqlType::DateTime(DateTimeType::DateTime64(3, chrono_tz::UTC)),
281            CHValue::DateTime64(0, (3, chrono_tz::UTC)),
282        )),
283        Schema::TimestampMicros => Ok((
284            &SqlType::DateTime(DateTimeType::DateTime64(6, chrono_tz::UTC)),
285            CHValue::DateTime64(0, (6, chrono_tz::UTC)),
286        )),
287        Schema::LocalTimestampMillis => Ok((
288            &SqlType::DateTime(DateTimeType::DateTime64(3, chrono_tz::UTC)),
289            CHValue::DateTime64(0, (3, chrono_tz::UTC)),
290        )),
291        Schema::LocalTimestampMicros => Ok((
292            &SqlType::DateTime(DateTimeType::DateTime64(6, chrono_tz::UTC)),
293            CHValue::DateTime64(0, (6, chrono_tz::UTC)),
294        )),
295        Schema::Duration => Ok((&SqlType::Int64, CHValue::UInt64(0))),
296        _ => Err(anyhow!("unsupported type")),
297    }
298}