Skip to main content

floe_core/io/read/
avro.rs

1use std::collections::{HashMap, HashSet};
2use std::path::Path;
3
4use apache_avro::types::Value;
5use apache_avro::{Reader, Schema};
6use polars::prelude::{DataFrame, NamedFrom, Series};
7
8use crate::checks::normalize::normalize_name;
9use crate::io::format::{self, FileReadError, InputAdapter, InputFile, ReadInput};
10use crate::{config, FloeResult};
11
12struct AvroInputAdapter;
13
14static AVRO_INPUT_ADAPTER: AvroInputAdapter = AvroInputAdapter;
15
16pub(crate) fn avro_input_adapter() -> &'static dyn InputAdapter {
17    &AVRO_INPUT_ADAPTER
18}
19
20#[derive(Debug, Clone)]
21pub struct AvroReadError {
22    pub rule: String,
23    pub message: String,
24}
25
26impl std::fmt::Display for AvroReadError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        write!(f, "{}: {}", self.rule, self.message)
29    }
30}
31
32impl std::error::Error for AvroReadError {}
33
34fn read_avro_schema_fields(input_path: &Path) -> Result<Vec<String>, AvroReadError> {
35    let file = std::fs::File::open(input_path).map_err(|err| AvroReadError {
36        rule: "avro_read_error".to_string(),
37        message: format!("failed to open avro at {}: {err}", input_path.display()),
38    })?;
39    let reader = Reader::new(file).map_err(|err| AvroReadError {
40        rule: "avro_read_error".to_string(),
41        message: format!(
42            "failed to read avro header at {}: {err}",
43            input_path.display()
44        ),
45    })?;
46    schema_field_names(reader.writer_schema())
47}
48
49fn schema_field_names(schema: &Schema) -> Result<Vec<String>, AvroReadError> {
50    match schema {
51        Schema::Record(record) => Ok(record
52            .fields
53            .iter()
54            .map(|field| field.name.clone())
55            .collect()),
56        Schema::Union(union) => {
57            let mut record_fields = None;
58            for variant in union.variants() {
59                match variant {
60                    Schema::Null => continue,
61                    Schema::Record(record) => {
62                        if record_fields.is_some() {
63                            return Err(AvroReadError {
64                                rule: "avro_schema_error".to_string(),
65                                message: "avro schema has multiple record variants at root"
66                                    .to_string(),
67                            });
68                        }
69                        record_fields = Some(
70                            record
71                                .fields
72                                .iter()
73                                .map(|field| field.name.clone())
74                                .collect(),
75                        );
76                    }
77                    _ => {}
78                }
79            }
80            record_fields.ok_or_else(|| AvroReadError {
81                rule: "avro_schema_error".to_string(),
82                message: "expected avro record schema at root".to_string(),
83            })
84        }
85        _ => Err(AvroReadError {
86            rule: "avro_schema_error".to_string(),
87            message: "expected avro record schema at root".to_string(),
88        }),
89    }
90}
91
92fn read_avro_file(
93    input_path: &Path,
94    cast_mode: &str,
95    declared_columns: &HashSet<String>,
96    normalize_strategy: Option<&str>,
97) -> Result<DataFrame, AvroReadError> {
98    let file = std::fs::File::open(input_path).map_err(|err| AvroReadError {
99        rule: "avro_read_error".to_string(),
100        message: format!("failed to open avro at {}: {err}", input_path.display()),
101    })?;
102    let reader = Reader::new(file).map_err(|err| AvroReadError {
103        rule: "avro_read_error".to_string(),
104        message: format!(
105            "failed to read avro header at {}: {err}",
106            input_path.display()
107        ),
108    })?;
109    let columns = schema_field_names(reader.writer_schema())?;
110    if columns.is_empty() {
111        return Err(AvroReadError {
112            rule: "avro_schema_error".to_string(),
113            message: format!("avro schema has no fields at {}", input_path.display()),
114        });
115    }
116
117    let mut indices = HashMap::with_capacity(columns.len());
118    for (idx, name) in columns.iter().enumerate() {
119        indices.insert(name.clone(), idx);
120    }
121
122    let mut values = vec![Vec::new(); columns.len()];
123    for item in reader {
124        let value = item.map_err(|err| AvroReadError {
125            rule: "avro_read_error".to_string(),
126            message: format!(
127                "failed to read avro record at {}: {err}",
128                input_path.display()
129            ),
130        })?;
131        let mut row = vec![None; columns.len()];
132        if let Some(fields) = record_fields_from_root_value(value, input_path)? {
133            for (name, value) in fields {
134                if !is_declared_field(&name, declared_columns, normalize_strategy) {
135                    // Extra Avro fields are handled by mismatch policy; skip conversion here.
136                    continue;
137                }
138                if let Some(index) = indices.get(&name) {
139                    row[*index] = value_to_string(&value, cast_mode)?;
140                }
141            }
142        }
143        for (index, cell) in row.into_iter().enumerate() {
144            values[index].push(cell);
145        }
146    }
147
148    let mut series = Vec::with_capacity(columns.len());
149    for (idx, name) in columns.iter().enumerate() {
150        series.push(Series::new(name.as_str().into(), std::mem::take(&mut values[idx])).into());
151    }
152    DataFrame::new(series).map_err(|err| AvroReadError {
153        rule: "avro_read_error".to_string(),
154        message: format!("failed to build dataframe: {err}"),
155    })
156}
157
158fn record_fields_from_root_value(
159    value: Value,
160    input_path: &Path,
161) -> Result<Option<Vec<(String, Value)>>, AvroReadError> {
162    match value {
163        Value::Record(fields) => Ok(Some(fields)),
164        Value::Null => Ok(None),
165        Value::Union(_, boxed) => match *boxed {
166            Value::Record(fields) => Ok(Some(fields)),
167            Value::Null => Ok(None),
168            other => Err(AvroReadError {
169                rule: "avro_schema_error".to_string(),
170                message: format!(
171                    "expected avro record value at {}, got {:?}",
172                    input_path.display(),
173                    other
174                ),
175            }),
176        },
177        other => Err(AvroReadError {
178            rule: "avro_schema_error".to_string(),
179            message: format!(
180                "expected avro record value at {}, got {:?}",
181                input_path.display(),
182                other
183            ),
184        }),
185    }
186}
187
188fn is_declared_field(
189    field_name: &str,
190    declared_columns: &HashSet<String>,
191    normalize_strategy: Option<&str>,
192) -> bool {
193    if declared_columns.contains(field_name) {
194        return true;
195    }
196    normalize_strategy
197        .is_some_and(|strategy| declared_columns.contains(&normalize_name(field_name, strategy)))
198}
199
200fn value_to_string(value: &Value, cast_mode: &str) -> Result<Option<String>, AvroReadError> {
201    match value {
202        Value::Null => Ok(None),
203        Value::Boolean(value) => Ok(Some(value.to_string())),
204        Value::Int(value) => Ok(Some(value.to_string())),
205        Value::Long(value) => Ok(Some(value.to_string())),
206        Value::Float(value) => Ok(Some(value.to_string())),
207        Value::Double(value) => Ok(Some(value.to_string())),
208        Value::String(value) => Ok(Some(value.clone())),
209        Value::Enum(_, value) => Ok(Some(value.clone())),
210        Value::Bytes(value) => Ok(Some(String::from_utf8_lossy(value).to_string())),
211        Value::Fixed(_, value) => Ok(Some(String::from_utf8_lossy(value).to_string())),
212        Value::Date(value) => Ok(Some(value.to_string())),
213        Value::TimeMillis(value) => Ok(Some(value.to_string())),
214        Value::TimeMicros(value) => Ok(Some(value.to_string())),
215        Value::TimestampMillis(value) => Ok(Some(value.to_string())),
216        Value::TimestampMicros(value) => Ok(Some(value.to_string())),
217        Value::LocalTimestampMillis(value) => Ok(Some(value.to_string())),
218        Value::LocalTimestampMicros(value) => Ok(Some(value.to_string())),
219        Value::Uuid(value) => Ok(Some(value.to_string())),
220        Value::Union(_, value) => value_to_string(value, cast_mode),
221        other => {
222            if cast_mode == "coerce" {
223                Ok(Some(format!("{other:?}")))
224            } else {
225                Err(AvroReadError {
226                    rule: "avro_cast_error".to_string(),
227                    message: format!("unsupported avro value: {other:?}"),
228                })
229            }
230        }
231    }
232}
233
234impl InputAdapter for AvroInputAdapter {
235    fn format(&self) -> &'static str {
236        "avro"
237    }
238
239    fn read_input_columns(
240        &self,
241        _entity: &config::EntityConfig,
242        input_file: &InputFile,
243        _columns: &[config::ColumnConfig],
244    ) -> Result<Vec<String>, FileReadError> {
245        read_avro_schema_fields(&input_file.source_local_path).map_err(|err| FileReadError {
246            rule: err.rule,
247            message: err.message,
248        })
249    }
250
251    fn read_inputs(
252        &self,
253        entity: &config::EntityConfig,
254        files: &[InputFile],
255        columns: &[config::ColumnConfig],
256        normalize_strategy: Option<&str>,
257        collect_raw: bool,
258    ) -> FloeResult<Vec<ReadInput>> {
259        let cast_mode = entity.source.cast_mode.as_deref().unwrap_or("strict");
260        let declared_columns = columns
261            .iter()
262            .map(|column| column.name.clone())
263            .collect::<HashSet<_>>();
264        let mut inputs = Vec::with_capacity(files.len());
265        for input_file in files {
266            match read_avro_file(
267                &input_file.source_local_path,
268                cast_mode,
269                &declared_columns,
270                normalize_strategy,
271            ) {
272                Ok(df) => {
273                    let input = format::read_input_from_df(
274                        input_file,
275                        &df,
276                        columns,
277                        normalize_strategy,
278                        collect_raw,
279                    )?;
280                    inputs.push(input);
281                }
282                Err(err) => {
283                    inputs.push(ReadInput::FileError {
284                        input_file: input_file.clone(),
285                        error: FileReadError {
286                            rule: err.rule,
287                            message: err.message,
288                        },
289                    });
290                }
291            }
292        }
293        Ok(inputs)
294    }
295}