Skip to main content

chartml_datafusion/
conversion.rs

1//! Conversion between `Vec<Row>` (HashMap<String, serde_json::Value>) and Arrow RecordBatch.
2
3use arrow::array::{
4    ArrayRef, BooleanArray, Float64Array, RecordBatch, StringBuilder, StringArray,
5};
6use arrow::datatypes::{DataType, Field, Schema};
7use chartml_core::data::Row;
8use chartml_core::error::ChartError;
9use std::sync::Arc;
10
11/// Inferred column type from JSON values.
12#[derive(Debug, Clone, Copy, PartialEq)]
13enum InferredType {
14    Float64,
15    Boolean,
16    Utf8,
17    Null,
18}
19
20/// Convert `Vec<Row>` into an Arrow `RecordBatch`.
21///
22/// Type inference strategy:
23/// - Numbers → Float64
24/// - Booleans → Boolean
25/// - Strings → Utf8
26/// - Null → nullable (skipped during inference)
27/// - Mixed types → coerced to Utf8
28pub fn rows_to_record_batch(rows: &[Row]) -> Result<RecordBatch, ChartError> {
29    if rows.is_empty() {
30        // Return an empty RecordBatch with no columns
31        let schema = Arc::new(Schema::new(Vec::<Field>::new()));
32        return Ok(RecordBatch::new_empty(schema));
33    }
34
35    // 1. Collect unique column names preserving insertion order
36    let mut column_names: Vec<String> = Vec::new();
37    let mut seen = std::collections::HashSet::new();
38    for row in rows {
39        for key in row.keys() {
40            if seen.insert(key.clone()) {
41                column_names.push(key.clone());
42            }
43        }
44    }
45    // Sort for deterministic column order
46    column_names.sort();
47
48    // 2. Infer types for each column
49    let mut col_types: Vec<InferredType> = vec![InferredType::Null; column_names.len()];
50    for row in rows {
51        for (i, name) in column_names.iter().enumerate() {
52            if let Some(val) = row.get(name) {
53                let val_type = match val {
54                    serde_json::Value::Number(_) => InferredType::Float64,
55                    serde_json::Value::Bool(_) => InferredType::Boolean,
56                    serde_json::Value::String(_) => InferredType::Utf8,
57                    serde_json::Value::Null => InferredType::Null,
58                    _ => InferredType::Utf8, // arrays/objects → string
59                };
60
61                col_types[i] = merge_types(col_types[i], val_type);
62            }
63        }
64    }
65
66    // Convert Null columns to Utf8 (no data → string)
67    for t in &mut col_types {
68        if *t == InferredType::Null {
69            *t = InferredType::Utf8;
70        }
71    }
72
73    // 3. Build schema
74    let fields: Vec<Field> = column_names
75        .iter()
76        .zip(col_types.iter())
77        .map(|(name, typ)| {
78            let dt = match typ {
79                InferredType::Float64 => DataType::Float64,
80                InferredType::Boolean => DataType::Boolean,
81                InferredType::Utf8 | InferredType::Null => DataType::Utf8,
82            };
83            Field::new(name, dt, true) // all columns nullable
84        })
85        .collect();
86    let schema = Arc::new(Schema::new(fields));
87
88    // 4. Build arrays column by column
89    let mut arrays: Vec<ArrayRef> = Vec::with_capacity(column_names.len());
90    for (i, name) in column_names.iter().enumerate() {
91        let arr: ArrayRef = match col_types[i] {
92            InferredType::Float64 => {
93                let values: Vec<Option<f64>> = rows
94                    .iter()
95                    .map(|row| {
96                        row.get(name).and_then(|v| match v {
97                            serde_json::Value::Number(n) => n.as_f64(),
98                            serde_json::Value::String(s) => s.parse::<f64>().ok(),
99                            serde_json::Value::Null => None,
100                            _ => None,
101                        })
102                    })
103                    .collect();
104                Arc::new(Float64Array::from(values))
105            }
106            InferredType::Boolean => {
107                let values: Vec<Option<bool>> = rows
108                    .iter()
109                    .map(|row| {
110                        row.get(name).and_then(|v| match v {
111                            serde_json::Value::Bool(b) => Some(*b),
112                            serde_json::Value::Null => None,
113                            _ => None,
114                        })
115                    })
116                    .collect();
117                Arc::new(BooleanArray::from(values))
118            }
119            InferredType::Utf8 | InferredType::Null => {
120                let mut builder = StringBuilder::new();
121                for row in rows {
122                    match row.get(name) {
123                        Some(serde_json::Value::String(s)) => builder.append_value(s),
124                        Some(serde_json::Value::Number(n)) => {
125                            builder.append_value(n.to_string())
126                        }
127                        Some(serde_json::Value::Bool(b)) => {
128                            builder.append_value(b.to_string())
129                        }
130                        Some(serde_json::Value::Null) | None => builder.append_null(),
131                        Some(other) => builder.append_value(other.to_string()),
132                    }
133                }
134                Arc::new(builder.finish())
135            }
136        };
137        arrays.push(arr);
138    }
139
140    RecordBatch::try_new(schema, arrays)
141        .map_err(|e| ChartError::DataError(format!("Failed to create RecordBatch: {}", e)))
142}
143
144/// Convert Arrow `RecordBatch` slices back into `Vec<Row>`.
145pub fn record_batch_to_rows(batches: &[RecordBatch]) -> Vec<Row> {
146    let mut rows = Vec::new();
147
148    for batch in batches {
149        let schema = batch.schema();
150        for row_idx in 0..batch.num_rows() {
151            let mut row = Row::new();
152            for (col_idx, field) in schema.fields().iter().enumerate() {
153                let col = batch.column(col_idx);
154                let value = arrow_value_to_json(col, row_idx);
155                row.insert(field.name().clone(), value);
156            }
157            rows.push(row);
158        }
159    }
160
161    rows
162}
163
164/// Extract a single cell from an Arrow array as serde_json::Value.
165fn arrow_value_to_json(array: &dyn arrow::array::Array, idx: usize) -> serde_json::Value {
166    if array.is_null(idx) {
167        return serde_json::Value::Null;
168    }
169
170    match array.data_type() {
171        DataType::Float64 => {
172            let arr = array
173                .as_any()
174                .downcast_ref::<Float64Array>()
175                .unwrap();
176            let v = arr.value(idx);
177            serde_json::json!(v)
178        }
179        DataType::Float32 => {
180            let arr = array
181                .as_any()
182                .downcast_ref::<arrow::array::Float32Array>()
183                .unwrap();
184            serde_json::json!(arr.value(idx) as f64)
185        }
186        DataType::Int8 => {
187            let arr = array.as_any().downcast_ref::<arrow::array::Int8Array>().unwrap();
188            serde_json::json!(arr.value(idx))
189        }
190        DataType::Int16 => {
191            let arr = array.as_any().downcast_ref::<arrow::array::Int16Array>().unwrap();
192            serde_json::json!(arr.value(idx))
193        }
194        DataType::Int32 => {
195            let arr = array.as_any().downcast_ref::<arrow::array::Int32Array>().unwrap();
196            serde_json::json!(arr.value(idx))
197        }
198        DataType::Int64 => {
199            let arr = array.as_any().downcast_ref::<arrow::array::Int64Array>().unwrap();
200            serde_json::json!(arr.value(idx))
201        }
202        DataType::UInt8 => {
203            let arr = array.as_any().downcast_ref::<arrow::array::UInt8Array>().unwrap();
204            serde_json::json!(arr.value(idx))
205        }
206        DataType::UInt16 => {
207            let arr = array.as_any().downcast_ref::<arrow::array::UInt16Array>().unwrap();
208            serde_json::json!(arr.value(idx))
209        }
210        DataType::UInt32 => {
211            let arr = array.as_any().downcast_ref::<arrow::array::UInt32Array>().unwrap();
212            serde_json::json!(arr.value(idx))
213        }
214        DataType::UInt64 => {
215            let arr = array.as_any().downcast_ref::<arrow::array::UInt64Array>().unwrap();
216            serde_json::json!(arr.value(idx))
217        }
218        DataType::Boolean => {
219            let arr = array
220                .as_any()
221                .downcast_ref::<BooleanArray>()
222                .unwrap();
223            serde_json::json!(arr.value(idx))
224        }
225        DataType::Utf8 => {
226            let arr = array
227                .as_any()
228                .downcast_ref::<StringArray>()
229                .unwrap();
230            serde_json::json!(arr.value(idx))
231        }
232        DataType::LargeUtf8 => {
233            let arr = array
234                .as_any()
235                .downcast_ref::<arrow::array::LargeStringArray>()
236                .unwrap();
237            serde_json::json!(arr.value(idx))
238        }
239        DataType::Date32 => {
240            // Date32 stores days since epoch — convert to ISO date string
241            let arr = array
242                .as_any()
243                .downcast_ref::<arrow::array::Date32Array>()
244                .unwrap();
245            let days = arr.value(idx);
246            // Convert days since epoch to YYYY-MM-DD
247            let naive = days_to_iso(days as i64);
248            serde_json::json!(naive)
249        }
250        _ => {
251            // Fallback: use debug representation
252            serde_json::Value::String(format!("{:?}", array.data_type()))
253        }
254    }
255}
256
257/// Convert days-since-epoch to ISO date string (YYYY-MM-DD).
258fn days_to_iso(days: i64) -> String {
259    let (year, month, day) = civil_from_days(days);
260    format!("{:04}-{:02}-{:02}", year, month, day)
261}
262
263/// Convert days since Unix epoch to (year, month, day).
264/// Algorithm from Howard Hinnant's date algorithms.
265fn civil_from_days(days: i64) -> (i64, u32, u32) {
266    let z = days + 719468;
267    let era = if z >= 0 { z } else { z - 146096 } / 146097;
268    let doe = (z - era * 146097) as u32;
269    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
270    let y = yoe as i64 + era * 400;
271    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
272    let mp = (5 * doy + 2) / 153;
273    let d = doy - (153 * mp + 2) / 5 + 1;
274    let m = if mp < 10 { mp + 3 } else { mp - 9 };
275    let y = if m <= 2 { y + 1 } else { y };
276    (y, m, d)
277}
278
279/// Merge two inferred types, handling conflicts by coercing to Utf8.
280fn merge_types(existing: InferredType, new: InferredType) -> InferredType {
281    if new == InferredType::Null {
282        return existing;
283    }
284    if existing == InferredType::Null {
285        return new;
286    }
287    if existing == new {
288        return existing;
289    }
290    // Types conflict → coerce to string
291    InferredType::Utf8
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use serde_json::json;
298
299    fn make_row(pairs: Vec<(&str, serde_json::Value)>) -> Row {
300        pairs
301            .into_iter()
302            .map(|(k, v)| (k.to_string(), v))
303            .collect()
304    }
305
306    #[test]
307    fn test_rows_to_batch_roundtrip() {
308        let rows = vec![
309            make_row(vec![
310                ("name", json!("Alice")),
311                ("age", json!(30)),
312                ("active", json!(true)),
313            ]),
314            make_row(vec![
315                ("name", json!("Bob")),
316                ("age", json!(25)),
317                ("active", json!(false)),
318            ]),
319            make_row(vec![
320                ("name", json!("Charlie")),
321                ("age", json!(35)),
322                ("active", json!(true)),
323            ]),
324        ];
325
326        let batch = rows_to_record_batch(&rows).unwrap();
327        assert_eq!(batch.num_rows(), 3);
328        assert_eq!(batch.num_columns(), 3);
329
330        let result = record_batch_to_rows(&[batch]);
331        assert_eq!(result.len(), 3);
332
333        // Verify values roundtripped correctly
334        for (orig, converted) in rows.iter().zip(result.iter()) {
335            assert_eq!(
336                orig.get("name").and_then(|v| v.as_str()),
337                converted.get("name").and_then(|v| v.as_str()),
338            );
339            assert_eq!(
340                orig.get("age").and_then(|v| v.as_f64()),
341                converted.get("age").and_then(|v| v.as_f64()),
342            );
343            assert_eq!(
344                orig.get("active").and_then(|v| v.as_bool()),
345                converted.get("active").and_then(|v| v.as_bool()),
346            );
347        }
348    }
349
350    #[test]
351    fn test_empty_rows() {
352        let rows: Vec<Row> = vec![];
353        let batch = rows_to_record_batch(&rows).unwrap();
354        assert_eq!(batch.num_rows(), 0);
355        let result = record_batch_to_rows(&[batch]);
356        assert!(result.is_empty());
357    }
358
359    #[test]
360    fn test_null_values() {
361        let rows = vec![
362            make_row(vec![("x", json!(1.0)), ("y", json!(null))]),
363            make_row(vec![("x", json!(null)), ("y", json!("hello"))]),
364        ];
365
366        let batch = rows_to_record_batch(&rows).unwrap();
367        assert_eq!(batch.num_rows(), 2);
368
369        let result = record_batch_to_rows(&[batch]);
370        assert_eq!(result.len(), 2);
371    }
372
373    #[test]
374    fn test_mixed_types_coerce_to_string() {
375        // First row has number, second has string for same column
376        let rows = vec![
377            make_row(vec![("val", json!(42))]),
378            make_row(vec![("val", json!("hello"))]),
379        ];
380
381        let batch = rows_to_record_batch(&rows).unwrap();
382        // Should coerce to Utf8
383        assert_eq!(batch.schema().field(0).data_type(), &DataType::Utf8);
384    }
385
386    #[test]
387    fn test_missing_fields() {
388        // Rows with different keys
389        let rows = vec![
390            make_row(vec![("a", json!(1.0))]),
391            make_row(vec![("b", json!("x"))]),
392        ];
393
394        let batch = rows_to_record_batch(&rows).unwrap();
395        assert_eq!(batch.num_columns(), 2);
396        assert_eq!(batch.num_rows(), 2);
397    }
398}