1use std::collections::{HashMap, HashSet};
2use std::path::Path;
3
4use apache_avro::types::Value;
5use apache_avro::{Reader, Schema};
6use polars::prelude::{DataFrame, NamedFrom, Series};
7
8use crate::checks::normalize::normalize_name;
9use crate::io::format::{self, FileReadError, InputAdapter, InputFile, ReadInput};
10use crate::{config, FloeResult};
11
12struct AvroInputAdapter;
13
14static AVRO_INPUT_ADAPTER: AvroInputAdapter = AvroInputAdapter;
15
16pub(crate) fn avro_input_adapter() -> &'static dyn InputAdapter {
17 &AVRO_INPUT_ADAPTER
18}
19
20#[derive(Debug, Clone)]
21pub struct AvroReadError {
22 pub rule: String,
23 pub message: String,
24}
25
26impl std::fmt::Display for AvroReadError {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(f, "{}: {}", self.rule, self.message)
29 }
30}
31
32impl std::error::Error for AvroReadError {}
33
34fn read_avro_schema_fields(input_path: &Path) -> Result<Vec<String>, AvroReadError> {
35 let file = std::fs::File::open(input_path).map_err(|err| AvroReadError {
36 rule: "avro_read_error".to_string(),
37 message: format!("failed to open avro at {}: {err}", input_path.display()),
38 })?;
39 let reader = Reader::new(file).map_err(|err| AvroReadError {
40 rule: "avro_read_error".to_string(),
41 message: format!(
42 "failed to read avro header at {}: {err}",
43 input_path.display()
44 ),
45 })?;
46 schema_field_names(reader.writer_schema())
47}
48
49fn schema_field_names(schema: &Schema) -> Result<Vec<String>, AvroReadError> {
50 match schema {
51 Schema::Record(record) => Ok(record
52 .fields
53 .iter()
54 .map(|field| field.name.clone())
55 .collect()),
56 Schema::Union(union) => {
57 let mut record_fields = None;
58 for variant in union.variants() {
59 match variant {
60 Schema::Null => continue,
61 Schema::Record(record) => {
62 if record_fields.is_some() {
63 return Err(AvroReadError {
64 rule: "avro_schema_error".to_string(),
65 message: "avro schema has multiple record variants at root"
66 .to_string(),
67 });
68 }
69 record_fields = Some(
70 record
71 .fields
72 .iter()
73 .map(|field| field.name.clone())
74 .collect(),
75 );
76 }
77 _ => {}
78 }
79 }
80 record_fields.ok_or_else(|| AvroReadError {
81 rule: "avro_schema_error".to_string(),
82 message: "expected avro record schema at root".to_string(),
83 })
84 }
85 _ => Err(AvroReadError {
86 rule: "avro_schema_error".to_string(),
87 message: "expected avro record schema at root".to_string(),
88 }),
89 }
90}
91
92fn read_avro_file(
93 input_path: &Path,
94 cast_mode: &str,
95 declared_columns: &HashSet<String>,
96 normalize_strategy: Option<&str>,
97) -> Result<DataFrame, AvroReadError> {
98 let file = std::fs::File::open(input_path).map_err(|err| AvroReadError {
99 rule: "avro_read_error".to_string(),
100 message: format!("failed to open avro at {}: {err}", input_path.display()),
101 })?;
102 let reader = Reader::new(file).map_err(|err| AvroReadError {
103 rule: "avro_read_error".to_string(),
104 message: format!(
105 "failed to read avro header at {}: {err}",
106 input_path.display()
107 ),
108 })?;
109 let columns = schema_field_names(reader.writer_schema())?;
110 if columns.is_empty() {
111 return Err(AvroReadError {
112 rule: "avro_schema_error".to_string(),
113 message: format!("avro schema has no fields at {}", input_path.display()),
114 });
115 }
116
117 let mut indices = HashMap::with_capacity(columns.len());
118 for (idx, name) in columns.iter().enumerate() {
119 indices.insert(name.clone(), idx);
120 }
121
122 let mut values = vec![Vec::new(); columns.len()];
123 for item in reader {
124 let value = item.map_err(|err| AvroReadError {
125 rule: "avro_read_error".to_string(),
126 message: format!(
127 "failed to read avro record at {}: {err}",
128 input_path.display()
129 ),
130 })?;
131 let mut row = vec![None; columns.len()];
132 if let Some(fields) = record_fields_from_root_value(value, input_path)? {
133 for (name, value) in fields {
134 if !is_declared_field(&name, declared_columns, normalize_strategy) {
135 continue;
137 }
138 if let Some(index) = indices.get(&name) {
139 row[*index] = value_to_string(&value, cast_mode)?;
140 }
141 }
142 }
143 for (index, cell) in row.into_iter().enumerate() {
144 values[index].push(cell);
145 }
146 }
147
148 let mut series = Vec::with_capacity(columns.len());
149 for (idx, name) in columns.iter().enumerate() {
150 series.push(Series::new(name.as_str().into(), std::mem::take(&mut values[idx])).into());
151 }
152 DataFrame::new(series).map_err(|err| AvroReadError {
153 rule: "avro_read_error".to_string(),
154 message: format!("failed to build dataframe: {err}"),
155 })
156}
157
158fn record_fields_from_root_value(
159 value: Value,
160 input_path: &Path,
161) -> Result<Option<Vec<(String, Value)>>, AvroReadError> {
162 match value {
163 Value::Record(fields) => Ok(Some(fields)),
164 Value::Null => Ok(None),
165 Value::Union(_, boxed) => match *boxed {
166 Value::Record(fields) => Ok(Some(fields)),
167 Value::Null => Ok(None),
168 other => Err(AvroReadError {
169 rule: "avro_schema_error".to_string(),
170 message: format!(
171 "expected avro record value at {}, got {:?}",
172 input_path.display(),
173 other
174 ),
175 }),
176 },
177 other => Err(AvroReadError {
178 rule: "avro_schema_error".to_string(),
179 message: format!(
180 "expected avro record value at {}, got {:?}",
181 input_path.display(),
182 other
183 ),
184 }),
185 }
186}
187
188fn is_declared_field(
189 field_name: &str,
190 declared_columns: &HashSet<String>,
191 normalize_strategy: Option<&str>,
192) -> bool {
193 if declared_columns.contains(field_name) {
194 return true;
195 }
196 normalize_strategy
197 .is_some_and(|strategy| declared_columns.contains(&normalize_name(field_name, strategy)))
198}
199
200fn value_to_string(value: &Value, cast_mode: &str) -> Result<Option<String>, AvroReadError> {
201 match value {
202 Value::Null => Ok(None),
203 Value::Boolean(value) => Ok(Some(value.to_string())),
204 Value::Int(value) => Ok(Some(value.to_string())),
205 Value::Long(value) => Ok(Some(value.to_string())),
206 Value::Float(value) => Ok(Some(value.to_string())),
207 Value::Double(value) => Ok(Some(value.to_string())),
208 Value::String(value) => Ok(Some(value.clone())),
209 Value::Enum(_, value) => Ok(Some(value.clone())),
210 Value::Bytes(value) => Ok(Some(String::from_utf8_lossy(value).to_string())),
211 Value::Fixed(_, value) => Ok(Some(String::from_utf8_lossy(value).to_string())),
212 Value::Date(value) => Ok(Some(value.to_string())),
213 Value::TimeMillis(value) => Ok(Some(value.to_string())),
214 Value::TimeMicros(value) => Ok(Some(value.to_string())),
215 Value::TimestampMillis(value) => Ok(Some(value.to_string())),
216 Value::TimestampMicros(value) => Ok(Some(value.to_string())),
217 Value::LocalTimestampMillis(value) => Ok(Some(value.to_string())),
218 Value::LocalTimestampMicros(value) => Ok(Some(value.to_string())),
219 Value::Uuid(value) => Ok(Some(value.to_string())),
220 Value::Union(_, value) => value_to_string(value, cast_mode),
221 other => {
222 if cast_mode == "coerce" {
223 Ok(Some(format!("{other:?}")))
224 } else {
225 Err(AvroReadError {
226 rule: "avro_cast_error".to_string(),
227 message: format!("unsupported avro value: {other:?}"),
228 })
229 }
230 }
231 }
232}
233
234impl InputAdapter for AvroInputAdapter {
235 fn format(&self) -> &'static str {
236 "avro"
237 }
238
239 fn read_input_columns(
240 &self,
241 _entity: &config::EntityConfig,
242 input_file: &InputFile,
243 _columns: &[config::ColumnConfig],
244 ) -> Result<Vec<String>, FileReadError> {
245 read_avro_schema_fields(&input_file.source_local_path).map_err(|err| FileReadError {
246 rule: err.rule,
247 message: err.message,
248 })
249 }
250
251 fn read_inputs(
252 &self,
253 entity: &config::EntityConfig,
254 files: &[InputFile],
255 columns: &[config::ColumnConfig],
256 normalize_strategy: Option<&str>,
257 collect_raw: bool,
258 ) -> FloeResult<Vec<ReadInput>> {
259 let cast_mode = entity.source.cast_mode.as_deref().unwrap_or("strict");
260 let declared_columns = columns
261 .iter()
262 .map(|column| column.name.clone())
263 .collect::<HashSet<_>>();
264 let mut inputs = Vec::with_capacity(files.len());
265 for input_file in files {
266 match read_avro_file(
267 &input_file.source_local_path,
268 cast_mode,
269 &declared_columns,
270 normalize_strategy,
271 ) {
272 Ok(df) => {
273 let input = format::read_input_from_df(
274 input_file,
275 &df,
276 columns,
277 normalize_strategy,
278 collect_raw,
279 )?;
280 inputs.push(input);
281 }
282 Err(err) => {
283 inputs.push(ReadInput::FileError {
284 input_file: input_file.clone(),
285 error: FileReadError {
286 rule: err.rule,
287 message: err.message,
288 },
289 });
290 }
291 }
292 }
293 Ok(inputs)
294 }
295}