Skip to main content

rs_arrow_merge_schema/
lib.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use arrow_schema::DataType;
5use arrow_schema::Field;
6use arrow_schema::Schema;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct FieldDef {
10    pub name: String,
11    pub dtyp: DataType,
12    pub null: bool,
13}
14
15pub fn promote_unsigned(a: &DataType, b: &DataType) -> Option<DataType> {
16    use DataType::*;
17    let order = [UInt8, UInt16, UInt32, UInt64];
18    let idx_a = order.iter().position(|t| t == a);
19    let idx_b = order.iter().position(|t| t == b);
20
21    match (idx_a, idx_b) {
22        (Some(i), Some(j)) => Some(order[usize::max(i, j)].clone()),
23        _ => None,
24    }
25}
26
27pub fn promote_signed(a: &DataType, b: &DataType) -> Option<DataType> {
28    use DataType::*;
29    let order = [Int8, Int16, Int32, Int64];
30    let idx_a = order.iter().position(|t| t == a);
31    let idx_b = order.iter().position(|t| t == b);
32
33    match (idx_a, idx_b) {
34        (Some(i), Some(j)) => Some(order[usize::max(i, j)].clone()),
35        _ => None,
36    }
37}
38
39pub fn promote_integer(a: &DataType, b: &DataType) -> Option<DataType> {
40    promote_signed(a, b).or_else(|| promote_unsigned(a, b))
41}
42
43pub fn promote_float(a: &DataType, b: &DataType) -> Option<DataType> {
44    use DataType::*;
45    let order = [Float16, Float32, Float64];
46
47    let idx_a = order.iter().position(|t| t == a);
48    let idx_b = order.iter().position(|t| t == b);
49
50    match (idx_a, idx_b) {
51        (Some(i), Some(j)) => Some(order[usize::max(i, j)].clone()),
52        _ => None,
53    }
54}
55
56pub fn promote_numeric(a: &DataType, b: &DataType) -> Option<DataType> {
57    use DataType::*;
58
59    let is_float = |t: &DataType| matches!(t, Float16 | Float32 | Float64);
60    let is_int = |t: &DataType| {
61        matches!(
62            t,
63            Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
64        )
65    };
66
67    if is_float(a) || is_float(b) {
68        return promote_float(a, b);
69    }
70
71    if is_int(a) && is_int(b) {
72        return promote_integer(a, b);
73    }
74
75    None
76}
77
78pub enum TypeMergeState {
79    Promoted(DataType),
80    NotChanged(DataType),
81    UserNewer(DataType),
82}
83
84impl TypeMergeState {
85    pub fn new(old: DataType, neo: DataType) -> Self {
86        if old == neo {
87            return TypeMergeState::NotChanged(old);
88        }
89
90        if let Some(promoted) = promote_numeric(&old, &neo) {
91            return TypeMergeState::Promoted(promoted);
92        }
93
94        TypeMergeState::UserNewer(neo)
95    }
96}
97
98impl FieldDef {
99    /// Merges the types(assuming same name).
100    pub fn merge(self, old: Self) -> Self {
101        let merged_null = old.null || self.null;
102
103        let merged_name = old.name;
104
105        let merged_type = match TypeMergeState::new(old.dtyp, self.dtyp) {
106            TypeMergeState::NotChanged(t) => t,
107            TypeMergeState::Promoted(t) => t,
108            TypeMergeState::UserNewer(t) => t,
109        };
110
111        FieldDef {
112            name: merged_name,
113            dtyp: merged_type,
114            null: merged_null,
115        }
116    }
117}
118
119impl From<&Field> for FieldDef {
120    fn from(f: &Field) -> Self {
121        let name: String = f.name().to_string();
122        let dtyp: DataType = f.data_type().clone();
123        let null: bool = f.is_nullable();
124
125        FieldDef { name, dtyp, null }
126    }
127}
128
129impl From<&Arc<Field>> for FieldDef {
130    fn from(af: &Arc<Field>) -> Self {
131        Self::from(af.as_ref())
132    }
133}
134
135impl From<FieldDef> for Field {
136    fn from(f: FieldDef) -> Self {
137        Field::new(&f.name, f.dtyp, f.null)
138    }
139}
140
141pub fn merge_field(neo: &Field, old: &Field) -> Field {
142    let neo_def: FieldDef = FieldDef::from(neo);
143    let old_def: FieldDef = FieldDef::from(old);
144
145    let merged_def: FieldDef = neo_def.merge(old_def);
146
147    merged_def.into()
148}
149
150pub fn merge_unordered<D, G>(defined: D, guess: G) -> HashMap<String, FieldDef>
151where
152    D: Iterator<Item = FieldDef>,
153    G: Iterator<Item = FieldDef>,
154{
155    let mut map: HashMap<String, FieldDef> = HashMap::new();
156    for f in defined {
157        map.insert(f.name.clone(), f);
158    }
159
160    guess.fold(map, |mut state, next| {
161        let name: &str = &next.name;
162
163        match state.remove_entry(name) {
164            None => {
165                state.insert(name.into(), next);
166                state
167            }
168            Some((name, specified)) => {
169                let merged = specified.merge(next);
170                state.insert(name, merged);
171                state
172            }
173        }
174    })
175}
176
177pub fn merge_schema_unordered(defined: Schema, guess: Schema) -> Schema {
178    let Schema {
179        fields: def_fields,
180        metadata: def_meta,
181    } = defined;
182    let Schema {
183        fields: guess_fields,
184        metadata: guess_meta,
185    } = guess;
186
187    let def_iter = def_fields.into_iter().map(FieldDef::from);
188    let guess_iter = guess_fields.into_iter().map(FieldDef::from);
189
190    let merged_map = merge_unordered(def_iter, guess_iter);
191    let fields: Vec<Field> = merged_map.into_values().map(Field::from).collect();
192
193    let mut metadata = def_meta;
194    metadata.extend(guess_meta);
195    Schema::new_with_metadata(fields, metadata)
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use arrow_schema::{DataType, Field, Schema};
202    use std::collections::HashMap;
203
204    // --------------------------------------------------------------------
205    // Promotion helpers ----------------------------------------------------
206    // --------------------------------------------------------------------
207    #[test]
208    fn test_promote_unsigned() {
209        assert_eq!(
210            promote_unsigned(&DataType::UInt8, &DataType::UInt16),
211            Some(DataType::UInt16)
212        );
213        assert_eq!(
214            promote_unsigned(&DataType::UInt32, &DataType::UInt64),
215            Some(DataType::UInt64)
216        );
217        assert_eq!(promote_unsigned(&DataType::UInt8, &DataType::Int8), None);
218    }
219
220    #[test]
221    fn test_promote_signed() {
222        assert_eq!(
223            promote_signed(&DataType::Int8, &DataType::Int32),
224            Some(DataType::Int32)
225        );
226        assert_eq!(
227            promote_signed(&DataType::Int64, &DataType::Int16),
228            Some(DataType::Int64)
229        );
230        assert_eq!(promote_signed(&DataType::Int8, &DataType::UInt8), None);
231    }
232
233    #[test]
234    fn test_promote_integer() {
235        assert_eq!(
236            promote_integer(&DataType::Int8, &DataType::Int16),
237            Some(DataType::Int16)
238        );
239        assert_eq!(
240            promote_integer(&DataType::UInt8, &DataType::UInt32),
241            Some(DataType::UInt32)
242        );
243        assert_eq!(promote_integer(&DataType::Int8, &DataType::UInt8), None);
244    }
245
246    #[test]
247    fn test_promote_float() {
248        assert_eq!(
249            promote_float(&DataType::Float16, &DataType::Float64),
250            Some(DataType::Float64)
251        );
252        assert_eq!(
253            promote_float(&DataType::Float32, &DataType::Float16),
254            Some(DataType::Float32)
255        );
256    }
257
258    #[test]
259    fn test_promote_numeric() {
260        assert_eq!(
261            promote_numeric(&DataType::Int8, &DataType::Int16),
262            Some(DataType::Int16)
263        );
264        assert_eq!(promote_numeric(&DataType::Float32, &DataType::Int16), None);
265        assert_eq!(promote_numeric(&DataType::Utf8, &DataType::Int32), None);
266    }
267
268    // --------------------------------------------------------------------
269    // Field merging ------------------------------------------------------
270    // --------------------------------------------------------------------
271    fn field(name: &str, typ: DataType, nullable: bool) -> Field {
272        Field::new(name, typ, nullable)
273    }
274
275    #[test]
276    fn test_merge_field_nullability() {
277        let f1 = field("a", DataType::Int32, false);
278        let f2 = field("a", DataType::Int32, true);
279        let merged = merge_field(&f1, &f2);
280        assert_eq!(merged.name(), "a");
281        assert_eq!(merged.data_type(), &DataType::Int32);
282        assert!(merged.is_nullable());
283    }
284
285    #[test]
286    fn test_merge_field_type_promotion() {
287        let f1 = field("a", DataType::Int8, false);
288        let f2 = field("a", DataType::Int16, false);
289        let merged = merge_field(&f1, &f2);
290        assert_eq!(merged.data_type(), &DataType::Int16);
291    }
292
293    #[test]
294    fn test_merge_field_user_newer() {
295        let neo = field("a", DataType::Int8, false);
296        let old = field("a", DataType::UInt8, false);
297        let merged = merge_field(&neo, &old);
298        assert_eq!(merged.data_type(), &DataType::Int8);
299    }
300
301    // --------------------------------------------------------------------
302    // Schema merging ------------------------------------------------------
303    // --------------------------------------------------------------------
304    #[test]
305    fn test_merge_schema_unordered() {
306        let meta1: HashMap<String, String> = [("k1".to_string(), "v1".to_string())]
307            .iter()
308            .cloned()
309            .collect();
310        let schema1 =
311            Schema::new_with_metadata(vec![field("a", DataType::Int8, false)], meta1.clone());
312
313        let meta2: HashMap<String, String> = [("k2".to_string(), "v2".to_string())]
314            .iter()
315            .cloned()
316            .collect();
317        let schema2 =
318            Schema::new_with_metadata(vec![field("a", DataType::Int16, true)], meta2.clone());
319
320        let merged = merge_schema_unordered(schema1.clone(), schema2.clone());
321
322        let merged_field = merged.field_with_name("a").expect("field 'a' missing");
323        assert_eq!(merged_field.data_type(), &DataType::Int16);
324        assert!(merged_field.is_nullable());
325
326        let expected_meta: HashMap<String, String> = [
327            ("k1".to_string(), "v1".to_string()),
328            ("k2".to_string(), "v2".to_string()),
329        ]
330        .iter()
331        .cloned()
332        .collect();
333        assert_eq!(merged.metadata(), &expected_meta);
334    }
335
336    #[test]
337    fn test_merge_schema_unordered_no_overlap() {
338        let schema1 = Schema::new(vec![field("x", DataType::Int8, false)]);
339
340        let schema2 = Schema::new(vec![field("y", DataType::Float32, true)]);
341
342        let merged = merge_schema_unordered(schema1.clone(), schema2.clone());
343
344        let mut names: Vec<_> = merged.fields().iter().map(|f| f.name()).collect();
345        names.sort();
346        assert_eq!(names, vec!["x", "y"]);
347    }
348}