1use std::collections::HashMap;
2use std::sync::Arc;
3
4use arrow_schema::DataType;
5use arrow_schema::Field;
6use arrow_schema::Schema;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct FieldDef {
10 pub name: String,
11 pub dtyp: DataType,
12 pub null: bool,
13}
14
15pub fn promote_unsigned(a: &DataType, b: &DataType) -> Option<DataType> {
16 use DataType::*;
17 let order = [UInt8, UInt16, UInt32, UInt64];
18 let idx_a = order.iter().position(|t| t == a);
19 let idx_b = order.iter().position(|t| t == b);
20
21 match (idx_a, idx_b) {
22 (Some(i), Some(j)) => Some(order[usize::max(i, j)].clone()),
23 _ => None,
24 }
25}
26
27pub fn promote_signed(a: &DataType, b: &DataType) -> Option<DataType> {
28 use DataType::*;
29 let order = [Int8, Int16, Int32, Int64];
30 let idx_a = order.iter().position(|t| t == a);
31 let idx_b = order.iter().position(|t| t == b);
32
33 match (idx_a, idx_b) {
34 (Some(i), Some(j)) => Some(order[usize::max(i, j)].clone()),
35 _ => None,
36 }
37}
38
39pub fn promote_integer(a: &DataType, b: &DataType) -> Option<DataType> {
40 promote_signed(a, b).or_else(|| promote_unsigned(a, b))
41}
42
43pub fn promote_float(a: &DataType, b: &DataType) -> Option<DataType> {
44 use DataType::*;
45 let order = [Float16, Float32, Float64];
46
47 let idx_a = order.iter().position(|t| t == a);
48 let idx_b = order.iter().position(|t| t == b);
49
50 match (idx_a, idx_b) {
51 (Some(i), Some(j)) => Some(order[usize::max(i, j)].clone()),
52 _ => None,
53 }
54}
55
56pub fn promote_numeric(a: &DataType, b: &DataType) -> Option<DataType> {
57 use DataType::*;
58
59 let is_float = |t: &DataType| matches!(t, Float16 | Float32 | Float64);
60 let is_int = |t: &DataType| {
61 matches!(
62 t,
63 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
64 )
65 };
66
67 if is_float(a) || is_float(b) {
68 return promote_float(a, b);
69 }
70
71 if is_int(a) && is_int(b) {
72 return promote_integer(a, b);
73 }
74
75 None
76}
77
78pub enum TypeMergeState {
79 Promoted(DataType),
80 NotChanged(DataType),
81 UserNewer(DataType),
82}
83
84impl TypeMergeState {
85 pub fn new(old: DataType, neo: DataType) -> Self {
86 if old == neo {
87 return TypeMergeState::NotChanged(old);
88 }
89
90 if let Some(promoted) = promote_numeric(&old, &neo) {
91 return TypeMergeState::Promoted(promoted);
92 }
93
94 TypeMergeState::UserNewer(neo)
95 }
96}
97
98impl FieldDef {
99 pub fn merge(self, old: Self) -> Self {
101 let merged_null = old.null || self.null;
102
103 let merged_name = old.name;
104
105 let merged_type = match TypeMergeState::new(old.dtyp, self.dtyp) {
106 TypeMergeState::NotChanged(t) => t,
107 TypeMergeState::Promoted(t) => t,
108 TypeMergeState::UserNewer(t) => t,
109 };
110
111 FieldDef {
112 name: merged_name,
113 dtyp: merged_type,
114 null: merged_null,
115 }
116 }
117}
118
119impl From<&Field> for FieldDef {
120 fn from(f: &Field) -> Self {
121 let name: String = f.name().to_string();
122 let dtyp: DataType = f.data_type().clone();
123 let null: bool = f.is_nullable();
124
125 FieldDef { name, dtyp, null }
126 }
127}
128
129impl From<&Arc<Field>> for FieldDef {
130 fn from(af: &Arc<Field>) -> Self {
131 Self::from(af.as_ref())
132 }
133}
134
135impl From<FieldDef> for Field {
136 fn from(f: FieldDef) -> Self {
137 Field::new(&f.name, f.dtyp, f.null)
138 }
139}
140
141pub fn merge_field(neo: &Field, old: &Field) -> Field {
142 let neo_def: FieldDef = FieldDef::from(neo);
143 let old_def: FieldDef = FieldDef::from(old);
144
145 let merged_def: FieldDef = neo_def.merge(old_def);
146
147 merged_def.into()
148}
149
150pub fn merge_unordered<D, G>(defined: D, guess: G) -> HashMap<String, FieldDef>
151where
152 D: Iterator<Item = FieldDef>,
153 G: Iterator<Item = FieldDef>,
154{
155 let mut map: HashMap<String, FieldDef> = HashMap::new();
156 for f in defined {
157 map.insert(f.name.clone(), f);
158 }
159
160 guess.fold(map, |mut state, next| {
161 let name: &str = &next.name;
162
163 match state.remove_entry(name) {
164 None => {
165 state.insert(name.into(), next);
166 state
167 }
168 Some((name, specified)) => {
169 let merged = specified.merge(next);
170 state.insert(name, merged);
171 state
172 }
173 }
174 })
175}
176
177pub fn merge_schema_unordered(defined: Schema, guess: Schema) -> Schema {
178 let Schema {
179 fields: def_fields,
180 metadata: def_meta,
181 } = defined;
182 let Schema {
183 fields: guess_fields,
184 metadata: guess_meta,
185 } = guess;
186
187 let def_iter = def_fields.into_iter().map(FieldDef::from);
188 let guess_iter = guess_fields.into_iter().map(FieldDef::from);
189
190 let merged_map = merge_unordered(def_iter, guess_iter);
191 let fields: Vec<Field> = merged_map.into_values().map(Field::from).collect();
192
193 let mut metadata = def_meta;
194 metadata.extend(guess_meta);
195 Schema::new_with_metadata(fields, metadata)
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use arrow_schema::{DataType, Field, Schema};
202 use std::collections::HashMap;
203
204 #[test]
208 fn test_promote_unsigned() {
209 assert_eq!(
210 promote_unsigned(&DataType::UInt8, &DataType::UInt16),
211 Some(DataType::UInt16)
212 );
213 assert_eq!(
214 promote_unsigned(&DataType::UInt32, &DataType::UInt64),
215 Some(DataType::UInt64)
216 );
217 assert_eq!(promote_unsigned(&DataType::UInt8, &DataType::Int8), None);
218 }
219
220 #[test]
221 fn test_promote_signed() {
222 assert_eq!(
223 promote_signed(&DataType::Int8, &DataType::Int32),
224 Some(DataType::Int32)
225 );
226 assert_eq!(
227 promote_signed(&DataType::Int64, &DataType::Int16),
228 Some(DataType::Int64)
229 );
230 assert_eq!(promote_signed(&DataType::Int8, &DataType::UInt8), None);
231 }
232
233 #[test]
234 fn test_promote_integer() {
235 assert_eq!(
236 promote_integer(&DataType::Int8, &DataType::Int16),
237 Some(DataType::Int16)
238 );
239 assert_eq!(
240 promote_integer(&DataType::UInt8, &DataType::UInt32),
241 Some(DataType::UInt32)
242 );
243 assert_eq!(promote_integer(&DataType::Int8, &DataType::UInt8), None);
244 }
245
246 #[test]
247 fn test_promote_float() {
248 assert_eq!(
249 promote_float(&DataType::Float16, &DataType::Float64),
250 Some(DataType::Float64)
251 );
252 assert_eq!(
253 promote_float(&DataType::Float32, &DataType::Float16),
254 Some(DataType::Float32)
255 );
256 }
257
258 #[test]
259 fn test_promote_numeric() {
260 assert_eq!(
261 promote_numeric(&DataType::Int8, &DataType::Int16),
262 Some(DataType::Int16)
263 );
264 assert_eq!(promote_numeric(&DataType::Float32, &DataType::Int16), None);
265 assert_eq!(promote_numeric(&DataType::Utf8, &DataType::Int32), None);
266 }
267
268 fn field(name: &str, typ: DataType, nullable: bool) -> Field {
272 Field::new(name, typ, nullable)
273 }
274
275 #[test]
276 fn test_merge_field_nullability() {
277 let f1 = field("a", DataType::Int32, false);
278 let f2 = field("a", DataType::Int32, true);
279 let merged = merge_field(&f1, &f2);
280 assert_eq!(merged.name(), "a");
281 assert_eq!(merged.data_type(), &DataType::Int32);
282 assert!(merged.is_nullable());
283 }
284
285 #[test]
286 fn test_merge_field_type_promotion() {
287 let f1 = field("a", DataType::Int8, false);
288 let f2 = field("a", DataType::Int16, false);
289 let merged = merge_field(&f1, &f2);
290 assert_eq!(merged.data_type(), &DataType::Int16);
291 }
292
293 #[test]
294 fn test_merge_field_user_newer() {
295 let neo = field("a", DataType::Int8, false);
296 let old = field("a", DataType::UInt8, false);
297 let merged = merge_field(&neo, &old);
298 assert_eq!(merged.data_type(), &DataType::Int8);
299 }
300
301 #[test]
305 fn test_merge_schema_unordered() {
306 let meta1: HashMap<String, String> = [("k1".to_string(), "v1".to_string())]
307 .iter()
308 .cloned()
309 .collect();
310 let schema1 =
311 Schema::new_with_metadata(vec![field("a", DataType::Int8, false)], meta1.clone());
312
313 let meta2: HashMap<String, String> = [("k2".to_string(), "v2".to_string())]
314 .iter()
315 .cloned()
316 .collect();
317 let schema2 =
318 Schema::new_with_metadata(vec![field("a", DataType::Int16, true)], meta2.clone());
319
320 let merged = merge_schema_unordered(schema1.clone(), schema2.clone());
321
322 let merged_field = merged.field_with_name("a").expect("field 'a' missing");
323 assert_eq!(merged_field.data_type(), &DataType::Int16);
324 assert!(merged_field.is_nullable());
325
326 let expected_meta: HashMap<String, String> = [
327 ("k1".to_string(), "v1".to_string()),
328 ("k2".to_string(), "v2".to_string()),
329 ]
330 .iter()
331 .cloned()
332 .collect();
333 assert_eq!(merged.metadata(), &expected_meta);
334 }
335
336 #[test]
337 fn test_merge_schema_unordered_no_overlap() {
338 let schema1 = Schema::new(vec![field("x", DataType::Int8, false)]);
339
340 let schema2 = Schema::new(vec![field("y", DataType::Float32, true)]);
341
342 let merged = merge_schema_unordered(schema1.clone(), schema2.clone());
343
344 let mut names: Vec<_> = merged.fields().iter().map(|f| f.name()).collect();
345 names.sort();
346 assert_eq!(names, vec!["x", "y"]);
347 }
348}