polars_json/json/
infer_schema.rs

1use std::borrow::Borrow;
2
3use arrow::datatypes::{ArrowDataType, Field};
4use indexmap::map::Entry;
5use polars_utils::pl_str::PlSmallStr;
6use simd_json::borrowed::Object;
7use simd_json::{BorrowedValue, StaticNode};
8
9use super::*;
10
11const ITEM_NAME: &str = "item";
12
13/// Infers [`ArrowDataType`] from [`Value`][Value].
14///
15/// [Value]: simd_json::value::Value
16pub fn infer(json: &BorrowedValue) -> PolarsResult<ArrowDataType> {
17    Ok(match json {
18        BorrowedValue::Static(StaticNode::Bool(_)) => ArrowDataType::Boolean,
19        BorrowedValue::Static(StaticNode::U64(_) | StaticNode::I64(_)) => ArrowDataType::Int64,
20        BorrowedValue::Static(StaticNode::F64(_)) => ArrowDataType::Float64,
21        BorrowedValue::Static(StaticNode::Null) => ArrowDataType::Null,
22        BorrowedValue::Array(array) => infer_array(array)?,
23        BorrowedValue::String(_) => ArrowDataType::LargeUtf8,
24        BorrowedValue::Object(inner) => infer_object(inner)?,
25    })
26}
27
28fn infer_object(inner: &Object) -> PolarsResult<ArrowDataType> {
29    let fields = inner
30        .iter()
31        .map(|(key, value)| infer(value).map(|dt| (key, dt)))
32        .map(|maybe_dt| {
33            let (key, dt) = maybe_dt?;
34            Ok(Field::new(key.as_ref().into(), dt, true))
35        })
36        .collect::<PolarsResult<Vec<_>>>()?;
37    Ok(ArrowDataType::Struct(fields))
38}
39
40fn infer_array(values: &[BorrowedValue]) -> PolarsResult<ArrowDataType> {
41    let types = values
42        .iter()
43        .map(infer)
44        // deduplicate entries
45        .collect::<PolarsResult<PlHashSet<_>>>()?;
46
47    let dt = if !types.is_empty() {
48        let types = types.into_iter().collect::<Vec<_>>();
49        coerce_dtype(&types)
50    } else {
51        ArrowDataType::Null
52    };
53
54    Ok(ArrowDataType::LargeList(Box::new(Field::new(
55        PlSmallStr::from_static(ITEM_NAME),
56        dt,
57        true,
58    ))))
59}
60
61/// Coerce an heterogeneous set of [`ArrowDataType`] into a single one. Rules:
62/// * The empty set is coerced to `Null`
63/// * `Int64` and `Float64` are `Float64`
64/// * Lists and scalars are coerced to a list of a compatible scalar
65/// * Structs contain the union of all fields
66/// * All other types are coerced to `Utf8`
67pub(crate) fn coerce_dtype<A: Borrow<ArrowDataType>>(datatypes: &[A]) -> ArrowDataType {
68    use ArrowDataType::*;
69
70    if datatypes.is_empty() {
71        return Null;
72    }
73
74    let are_all_equal = datatypes.windows(2).all(|w| w[0].borrow() == w[1].borrow());
75
76    if are_all_equal {
77        return datatypes[0].borrow().clone();
78    }
79    let mut are_all_structs = true;
80    let mut are_all_lists = true;
81    for dt in datatypes {
82        are_all_structs &= matches!(dt.borrow(), Struct(_));
83        are_all_lists &= matches!(dt.borrow(), LargeList(_));
84    }
85
86    if are_all_structs {
87        // all are structs => union of all fields (that may have equal names)
88        let fields = datatypes.iter().fold(vec![], |mut acc, dt| {
89            if let Struct(new_fields) = dt.borrow() {
90                acc.extend(new_fields);
91            };
92            acc
93        });
94        // group fields by unique
95        let fields = fields.iter().fold(
96            PlIndexMap::<&str, PlHashSet<&ArrowDataType>>::default(),
97            |mut acc, field| {
98                match acc.entry(field.name.as_str()) {
99                    Entry::Occupied(mut v) => {
100                        v.get_mut().insert(&field.dtype);
101                    },
102                    Entry::Vacant(v) => {
103                        let mut a = PlHashSet::default();
104                        a.insert(&field.dtype);
105                        v.insert(a);
106                    },
107                }
108                acc
109            },
110        );
111        // and finally, coerce each of the fields within the same name
112        let fields = fields
113            .into_iter()
114            .map(|(name, dts)| {
115                let dts = dts.into_iter().collect::<Vec<_>>();
116                Field::new(name.into(), coerce_dtype(&dts), true)
117            })
118            .collect();
119        return Struct(fields);
120    } else if are_all_lists {
121        let inner_types: Vec<&ArrowDataType> = datatypes
122            .iter()
123            .map(|dt| {
124                if let LargeList(inner) = dt.borrow() {
125                    inner.dtype()
126                } else {
127                    unreachable!();
128                }
129            })
130            .collect();
131        return LargeList(Box::new(Field::new(
132            PlSmallStr::from_static(ITEM_NAME),
133            coerce_dtype(inner_types.as_slice()),
134            true,
135        )));
136    } else if datatypes.len() > 2 {
137        return datatypes
138            .iter()
139            .map(|t| t.borrow().clone())
140            .reduce(|a, b| coerce_dtype(&[a, b]))
141            .expect("not empty");
142    }
143    let (lhs, rhs) = (datatypes[0].borrow(), datatypes[1].borrow());
144
145    match (lhs, rhs) {
146        (lhs, rhs) if lhs == rhs => lhs.clone(),
147        (LargeList(lhs), LargeList(rhs)) => {
148            let inner = coerce_dtype(&[lhs.dtype(), rhs.dtype()]);
149            LargeList(Box::new(Field::new(
150                PlSmallStr::from_static(ITEM_NAME),
151                inner,
152                true,
153            )))
154        },
155        (scalar, LargeList(list)) => {
156            let inner = coerce_dtype(&[scalar, list.dtype()]);
157            LargeList(Box::new(Field::new(
158                PlSmallStr::from_static(ITEM_NAME),
159                inner,
160                true,
161            )))
162        },
163        (LargeList(list), scalar) => {
164            let inner = coerce_dtype(&[scalar, list.dtype()]);
165            LargeList(Box::new(Field::new(
166                PlSmallStr::from_static(ITEM_NAME),
167                inner,
168                true,
169            )))
170        },
171        (Float64, Int64) => Float64,
172        (Int64, Float64) => Float64,
173        (Int64, Boolean) => Int64,
174        (Boolean, Int64) => Int64,
175        (Null, rhs) => rhs.clone(),
176        (lhs, Null) => lhs.clone(),
177        (_, _) => LargeUtf8,
178    }
179}