helios_sof/
parquet_schema.rs

1use arrow::array::{
2    ArrayRef, BooleanBuilder, Float64Builder, Int32Builder, ListBuilder, StringBuilder,
3};
4use arrow::datatypes::{DataType, Field, Schema};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::{ProcessedRow, SofError};
10
11pub fn infer_arrow_type(values: &[Option<Value>]) -> DataType {
12    let mut type_counts: HashMap<String, usize> = HashMap::new();
13    let mut has_array = false;
14    let mut has_object = false;
15    let mut array_element_type = None;
16
17    for value in values.iter().flatten() {
18        match value {
19            Value::Bool(_) => {
20                *type_counts.entry("bool".to_string()).or_insert(0) += 1;
21            }
22            Value::Number(n) => {
23                if n.is_i64() || n.is_u64() {
24                    *type_counts.entry("integer".to_string()).or_insert(0) += 1;
25                } else {
26                    *type_counts.entry("decimal".to_string()).or_insert(0) += 1;
27                }
28            }
29            Value::String(_) => {
30                *type_counts.entry("string".to_string()).or_insert(0) += 1;
31            }
32            Value::Array(arr) => {
33                has_array = true;
34                if !arr.is_empty() && array_element_type.is_none() {
35                    let element_values: Vec<Option<Value>> =
36                        arr.iter().map(|v| Some(v.clone())).collect();
37                    array_element_type = Some(infer_arrow_type(&element_values));
38                }
39            }
40            Value::Object(_) => {
41                has_object = true;
42            }
43            Value::Null => {}
44        }
45    }
46
47    if has_array {
48        if let Some(element_type) = array_element_type {
49            return DataType::List(Arc::new(Field::new("item", element_type, true)));
50        } else {
51            return DataType::List(Arc::new(Field::new("item", DataType::Utf8, true)));
52        }
53    }
54
55    if has_object {
56        return DataType::Utf8;
57    }
58
59    let most_common = type_counts
60        .into_iter()
61        .max_by_key(|(_, count)| *count)
62        .map(|(type_name, _)| type_name);
63
64    match most_common.as_deref() {
65        Some("bool") => DataType::Boolean,
66        Some("integer") => DataType::Int32,
67        Some("decimal") => DataType::Float64,
68        Some("string") => DataType::Utf8,
69        _ => DataType::Utf8,
70    }
71}
72
73pub fn create_arrow_schema(columns: &[String], rows: &[ProcessedRow]) -> Result<Schema, SofError> {
74    let sample_size = std::cmp::min(100, rows.len());
75    let mut fields = Vec::new();
76
77    for (col_idx, column_name) in columns.iter().enumerate() {
78        let sample_values: Vec<Option<Value>> = rows
79            .iter()
80            .take(sample_size)
81            .map(|row| row.values.get(col_idx).cloned().flatten())
82            .collect();
83
84        let data_type = infer_arrow_type(&sample_values);
85        let field = Field::new(column_name, data_type, true);
86        fields.push(field);
87    }
88
89    Ok(Schema::new(fields))
90}
91
92fn build_array_from_values(
93    values: Vec<Option<Value>>,
94    data_type: &DataType,
95) -> Result<ArrayRef, SofError> {
96    match data_type {
97        DataType::Boolean => {
98            let mut builder = BooleanBuilder::new();
99            for value in values {
100                match value {
101                    Some(Value::Bool(b)) => builder.append_value(b),
102                    _ => builder.append_null(),
103                }
104            }
105            Ok(Arc::new(builder.finish()))
106        }
107        DataType::Int32 => {
108            let mut builder = Int32Builder::new();
109            for value in values {
110                match value {
111                    Some(Value::Number(n)) if n.is_i64() => {
112                        if let Some(i) = n.as_i64() {
113                            builder.append_value(i as i32);
114                        } else {
115                            builder.append_null();
116                        }
117                    }
118                    _ => builder.append_null(),
119                }
120            }
121            Ok(Arc::new(builder.finish()))
122        }
123        DataType::Float64 => {
124            let mut builder = Float64Builder::new();
125            for value in values {
126                match value {
127                    Some(Value::Number(n)) => {
128                        if let Some(f) = n.as_f64() {
129                            builder.append_value(f);
130                        } else {
131                            builder.append_null();
132                        }
133                    }
134                    _ => builder.append_null(),
135                }
136            }
137            Ok(Arc::new(builder.finish()))
138        }
139        DataType::Utf8 => {
140            let mut builder = StringBuilder::new();
141            for value in values {
142                match value {
143                    Some(Value::String(s)) => builder.append_value(s),
144                    Some(Value::Number(n)) => builder.append_value(n.to_string()),
145                    Some(Value::Bool(b)) => builder.append_value(b.to_string()),
146                    Some(Value::Object(_)) | Some(Value::Array(_)) => {
147                        builder.append_value(
148                            serde_json::to_string(&value.unwrap())
149                                .unwrap_or_else(|_| "null".to_string()),
150                        );
151                    }
152                    _ => builder.append_null(),
153                }
154            }
155            Ok(Arc::new(builder.finish()))
156        }
157        DataType::List(field) => {
158            let element_type = field.data_type();
159            match element_type {
160                DataType::Utf8 => {
161                    let mut builder = ListBuilder::new(StringBuilder::new());
162                    for value in values {
163                        match value {
164                            Some(Value::Array(arr)) => {
165                                for elem in arr {
166                                    match elem {
167                                        Value::String(s) => builder.values().append_value(s),
168                                        _ => builder.values().append_value(elem.to_string()),
169                                    }
170                                }
171                                builder.append(true);
172                            }
173                            _ => builder.append(false),
174                        }
175                    }
176                    Ok(Arc::new(builder.finish()))
177                }
178                _ => {
179                    let mut string_builder = ListBuilder::new(StringBuilder::new());
180                    for value in values {
181                        match value {
182                            Some(Value::Array(arr)) => {
183                                for elem in arr {
184                                    string_builder.values().append_value(elem.to_string());
185                                }
186                                string_builder.append(true);
187                            }
188                            _ => string_builder.append(false),
189                        }
190                    }
191                    Ok(Arc::new(string_builder.finish()))
192                }
193            }
194        }
195        _ => Err(SofError::ParquetConversionError(format!(
196            "Unsupported data type for Parquet conversion: {:?}",
197            data_type
198        ))),
199    }
200}
201
202pub fn process_to_arrow_arrays(
203    schema: &Schema,
204    _columns: &[String],
205    rows: &[ProcessedRow],
206) -> Result<Vec<ArrayRef>, SofError> {
207    let mut arrays = Vec::new();
208
209    for (col_idx, field) in schema.fields().iter().enumerate() {
210        let values: Vec<Option<Value>> = rows
211            .iter()
212            .map(|row| row.values.get(col_idx).cloned().flatten())
213            .collect();
214
215        let array = build_array_from_values(values, field.data_type())?;
216        arrays.push(array);
217    }
218
219    Ok(arrays)
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use arrow::array::Array;
226    use serde_json::json;
227
228    #[test]
229    fn test_infer_boolean_type() {
230        let values = vec![
231            Some(json!(true)),
232            Some(json!(false)),
233            None,
234            Some(json!(true)),
235        ];
236        assert_eq!(infer_arrow_type(&values), DataType::Boolean);
237    }
238
239    #[test]
240    fn test_infer_integer_type() {
241        let values = vec![Some(json!(42)), Some(json!(100)), None, Some(json!(-5))];
242        assert_eq!(infer_arrow_type(&values), DataType::Int32);
243    }
244
245    #[test]
246    fn test_infer_decimal_type() {
247        let values = vec![
248            Some(json!(std::f64::consts::PI)),
249            Some(json!(std::f64::consts::E)),
250            None,
251            Some(json!(1.0)),
252        ];
253        assert_eq!(infer_arrow_type(&values), DataType::Float64);
254    }
255
256    #[test]
257    fn test_infer_string_type() {
258        let values = vec![
259            Some(json!("hello")),
260            Some(json!("world")),
261            None,
262            Some(json!("test")),
263        ];
264        assert_eq!(infer_arrow_type(&values), DataType::Utf8);
265    }
266
267    #[test]
268    fn test_infer_array_type() {
269        let values = vec![Some(json!(["a", "b", "c"])), Some(json!(["d", "e"])), None];
270        match infer_arrow_type(&values) {
271            DataType::List(field) => {
272                assert_eq!(field.name(), "item");
273                assert_eq!(field.data_type(), &DataType::Utf8);
274            }
275            _ => panic!("Expected List type"),
276        }
277    }
278
279    #[test]
280    fn test_infer_object_type_as_string() {
281        let values = vec![
282            Some(json!({"key": "value"})),
283            Some(json!({"foo": "bar"})),
284            None,
285        ];
286        assert_eq!(infer_arrow_type(&values), DataType::Utf8);
287    }
288
289    #[test]
290    fn test_mixed_types_favor_most_common() {
291        let values = vec![
292            Some(json!("string1")),
293            Some(json!("string2")),
294            Some(json!(42)),
295            Some(json!("string3")),
296        ];
297        assert_eq!(infer_arrow_type(&values), DataType::Utf8);
298    }
299
300    #[test]
301    fn test_create_schema_basic() {
302        let columns = vec!["id".to_string(), "name".to_string(), "age".to_string()];
303        let rows = vec![
304            ProcessedRow {
305                values: vec![Some(json!("123")), Some(json!("John Doe")), Some(json!(42))],
306            },
307            ProcessedRow {
308                values: vec![
309                    Some(json!("456")),
310                    Some(json!("Jane Smith")),
311                    Some(json!(35)),
312                ],
313            },
314        ];
315
316        let schema = create_arrow_schema(&columns, &rows).unwrap();
317        assert_eq!(schema.fields().len(), 3);
318        assert_eq!(schema.field(0).name(), "id");
319        assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
320        assert_eq!(schema.field(1).name(), "name");
321        assert_eq!(schema.field(1).data_type(), &DataType::Utf8);
322        assert_eq!(schema.field(2).name(), "age");
323        assert_eq!(schema.field(2).data_type(), &DataType::Int32);
324    }
325
326    #[test]
327    fn test_build_boolean_array() {
328        let values = vec![
329            Some(json!(true)),
330            None,
331            Some(json!(false)),
332            Some(json!(true)),
333        ];
334        let array = build_array_from_values(values, &DataType::Boolean).unwrap();
335        let bool_array = array
336            .as_any()
337            .downcast_ref::<arrow::array::BooleanArray>()
338            .unwrap();
339
340        assert_eq!(array.len(), 4);
341        assert!(bool_array.value(0));
342        assert!(array.is_null(1));
343        assert!(!bool_array.value(2));
344        assert!(bool_array.value(3));
345    }
346
347    #[test]
348    fn test_build_string_array_with_mixed_types() {
349        let values = vec![
350            Some(json!("text")),
351            Some(json!(42)),
352            Some(json!(true)),
353            Some(json!({"key": "value"})),
354            None,
355        ];
356        let array = build_array_from_values(values, &DataType::Utf8).unwrap();
357        let string_array = array
358            .as_any()
359            .downcast_ref::<arrow::array::StringArray>()
360            .unwrap();
361
362        assert_eq!(array.len(), 5);
363        assert_eq!(string_array.value(0), "text");
364        assert_eq!(string_array.value(1), "42");
365        assert_eq!(string_array.value(2), "true");
366        assert!(string_array.value(3).contains("\"key\""));
367        assert!(array.is_null(4));
368    }
369}