1use std::ops::Deref;
19use std::sync::Arc;
20
21use crate::{ArrowError, DataType, Field, FieldRef};
22
23#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
59#[cfg_attr(feature = "serde", serde(transparent))]
60pub struct Fields(Arc<[FieldRef]>);
61
62impl std::fmt::Debug for Fields {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        self.0.as_ref().fmt(f)
65    }
66}
67
68impl Fields {
69    pub fn empty() -> Self {
71        Self(Arc::new([]))
72    }
73
74    pub fn size(&self) -> usize {
76        self.iter()
77            .map(|field| field.size() + std::mem::size_of::<FieldRef>())
78            .sum()
79    }
80
81    pub fn find(&self, name: &str) -> Option<(usize, &FieldRef)> {
83        self.0.iter().enumerate().find(|(_, b)| b.name() == name)
84    }
85
86    pub fn contains(&self, other: &Fields) -> bool {
93        if Arc::ptr_eq(&self.0, &other.0) {
94            return true;
95        }
96        self.len() == other.len()
97            && self
98                .iter()
99                .zip(other.iter())
100                .all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b))
101    }
102
103    pub fn filter_leaves<F: FnMut(usize, &FieldRef) -> bool>(&self, mut filter: F) -> Self {
140        self.try_filter_leaves(|idx, field| Ok(filter(idx, field)))
141            .unwrap()
142    }
143
144    pub fn try_filter_leaves<F: FnMut(usize, &FieldRef) -> Result<bool, ArrowError>>(
149        &self,
150        mut filter: F,
151    ) -> Result<Self, ArrowError> {
152        fn filter_field<F: FnMut(&FieldRef) -> Result<bool, ArrowError>>(
153            f: &FieldRef,
154            filter: &mut F,
155        ) -> Result<Option<FieldRef>, ArrowError> {
156            use DataType::*;
157
158            let v = match f.data_type() {
159                Dictionary(_, v) => v.as_ref(),       RunEndEncoded(_, v) => v.data_type(), d => d,
162            };
163            let d = match v {
164                List(child) => {
165                    let fields = filter_field(child, filter)?;
166                    if let Some(fields) = fields {
167                        List(fields)
168                    } else {
169                        return Ok(None);
170                    }
171                }
172                LargeList(child) => {
173                    let fields = filter_field(child, filter)?;
174                    if let Some(fields) = fields {
175                        LargeList(fields)
176                    } else {
177                        return Ok(None);
178                    }
179                }
180                Map(child, ordered) => {
181                    let fields = filter_field(child, filter)?;
182                    if let Some(fields) = fields {
183                        Map(fields, *ordered)
184                    } else {
185                        return Ok(None);
186                    }
187                }
188                FixedSizeList(child, size) => {
189                    let fields = filter_field(child, filter)?;
190                    if let Some(fields) = fields {
191                        FixedSizeList(fields, *size)
192                    } else {
193                        return Ok(None);
194                    }
195                }
196                Struct(fields) => {
197                    let filtered: Result<Vec<_>, _> =
198                        fields.iter().map(|f| filter_field(f, filter)).collect();
199                    let filtered: Fields = filtered?
200                        .iter()
201                        .filter_map(|f| f.as_ref().cloned())
202                        .collect();
203
204                    if filtered.is_empty() {
205                        return Ok(None);
206                    }
207
208                    Struct(filtered)
209                }
210                Union(fields, mode) => {
211                    let filtered: Result<Vec<_>, _> = fields
212                        .iter()
213                        .map(|(id, f)| filter_field(f, filter).map(|f| f.map(|f| (id, f))))
214                        .collect();
215                    let filtered: UnionFields = filtered?
216                        .iter()
217                        .filter_map(|f| f.as_ref().cloned())
218                        .collect();
219
220                    if filtered.is_empty() {
221                        return Ok(None);
222                    }
223
224                    Union(filtered, *mode)
225                }
226                _ => {
227                    let filtered = filter(f)?;
228                    return Ok(filtered.then(|| f.clone()));
229                }
230            };
231            let d = match f.data_type() {
232                Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)),
233                RunEndEncoded(v, f) => {
234                    RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d)))
235                }
236                _ => d,
237            };
238            Ok(Some(Arc::new(f.as_ref().clone().with_data_type(d))))
239        }
240
241        let mut leaf_idx = 0;
242        let mut filter = |f: &FieldRef| {
243            let t = filter(leaf_idx, f)?;
244            leaf_idx += 1;
245            Ok(t)
246        };
247
248        let filtered: Result<Vec<_>, _> = self
249            .0
250            .iter()
251            .map(|f| filter_field(f, &mut filter))
252            .collect();
253        let filtered = filtered?
254            .iter()
255            .filter_map(|f| f.as_ref().cloned())
256            .collect();
257        Ok(filtered)
258    }
259}
260
261impl Default for Fields {
262    fn default() -> Self {
263        Self::empty()
264    }
265}
266
267impl FromIterator<Field> for Fields {
268    fn from_iter<T: IntoIterator<Item = Field>>(iter: T) -> Self {
269        iter.into_iter().map(Arc::new).collect()
270    }
271}
272
273impl FromIterator<FieldRef> for Fields {
274    fn from_iter<T: IntoIterator<Item = FieldRef>>(iter: T) -> Self {
275        Self(iter.into_iter().collect())
276    }
277}
278
279impl From<Vec<Field>> for Fields {
280    fn from(value: Vec<Field>) -> Self {
281        value.into_iter().collect()
282    }
283}
284
285impl From<Vec<FieldRef>> for Fields {
286    fn from(value: Vec<FieldRef>) -> Self {
287        Self(value.into())
288    }
289}
290
291impl From<&[FieldRef]> for Fields {
292    fn from(value: &[FieldRef]) -> Self {
293        Self(value.into())
294    }
295}
296
297impl<const N: usize> From<[FieldRef; N]> for Fields {
298    fn from(value: [FieldRef; N]) -> Self {
299        Self(Arc::new(value))
300    }
301}
302
303impl Deref for Fields {
304    type Target = [FieldRef];
305
306    fn deref(&self) -> &Self::Target {
307        self.0.as_ref()
308    }
309}
310
311impl<'a> IntoIterator for &'a Fields {
312    type Item = &'a FieldRef;
313    type IntoIter = std::slice::Iter<'a, FieldRef>;
314
315    fn into_iter(self) -> Self::IntoIter {
316        self.0.iter()
317    }
318}
319
320#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
322#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
323#[cfg_attr(feature = "serde", serde(transparent))]
324pub struct UnionFields(Arc<[(i8, FieldRef)]>);
325
326impl std::fmt::Debug for UnionFields {
327    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328        self.0.as_ref().fmt(f)
329    }
330}
331
332impl UnionFields {
333    pub fn empty() -> Self {
335        Self(Arc::from([]))
336    }
337
338    pub fn new<F, T>(type_ids: T, fields: F) -> Self
356    where
357        F: IntoIterator,
358        F::Item: Into<FieldRef>,
359        T: IntoIterator<Item = i8>,
360    {
361        let fields = fields.into_iter().map(Into::into);
362        let mut set = 0_u128;
363        type_ids
364            .into_iter()
365            .inspect(|&idx| {
366                let mask = 1_u128 << idx;
367                if (set & mask) != 0 {
368                    panic!("duplicate type id: {idx}");
369                } else {
370                    set |= mask;
371                }
372            })
373            .zip(fields)
374            .collect()
375    }
376
377    pub fn size(&self) -> usize {
379        self.iter()
380            .map(|(_, field)| field.size() + std::mem::size_of::<(i8, FieldRef)>())
381            .sum()
382    }
383
384    pub fn len(&self) -> usize {
386        self.0.len()
387    }
388
389    pub fn is_empty(&self) -> bool {
391        self.0.is_empty()
392    }
393
394    pub fn iter(&self) -> impl Iterator<Item = (i8, &FieldRef)> + '_ {
396        self.0.iter().map(|(id, f)| (*id, f))
397    }
398
399    pub(crate) fn try_merge(&mut self, other: &Self) -> Result<(), ArrowError> {
403        let mut output: Vec<_> = self.iter().map(|(id, f)| (id, f.clone())).collect();
405        for (field_type_id, from_field) in other.iter() {
406            let mut is_new_field = true;
407            for (self_type_id, self_field) in output.iter_mut() {
408                if from_field == self_field {
409                    if *self_type_id != field_type_id {
412                        return Err(ArrowError::SchemaError(format!(
413                            "Fail to merge schema field '{}' because the self_type_id = {} does not equal field_type_id = {}",
414                            self_field.name(),
415                            self_type_id,
416                            field_type_id
417                        )));
418                    }
419
420                    is_new_field = false;
421                    break;
422                }
423            }
424
425            if is_new_field {
426                output.push((field_type_id, from_field.clone()))
427            }
428        }
429        *self = output.into_iter().collect();
430        Ok(())
431    }
432}
433
434impl FromIterator<(i8, FieldRef)> for UnionFields {
435    fn from_iter<T: IntoIterator<Item = (i8, FieldRef)>>(iter: T) -> Self {
436        Self(iter.into_iter().collect())
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use crate::UnionMode;
445
446    #[test]
447    fn test_filter() {
448        let floats = Fields::from(vec![
449            Field::new("a", DataType::Float32, false),
450            Field::new("b", DataType::Float32, false),
451        ]);
452        let fields = Fields::from(vec![
453            Field::new("a", DataType::Int32, true),
454            Field::new("floats", DataType::Struct(floats.clone()), true),
455            Field::new("b", DataType::Int16, true),
456            Field::new(
457                "c",
458                DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
459                false,
460            ),
461            Field::new(
462                "d",
463                DataType::Dictionary(
464                    Box::new(DataType::Int32),
465                    Box::new(DataType::Struct(floats.clone())),
466                ),
467                false,
468            ),
469            Field::new_list(
470                "e",
471                Field::new("floats", DataType::Struct(floats.clone()), true),
472                true,
473            ),
474            Field::new_fixed_size_list(
475                "f",
476                Field::new_list_field(DataType::Int32, false),
477                3,
478                false,
479            ),
480            Field::new_map(
481                "g",
482                "entries",
483                Field::new("keys", DataType::LargeUtf8, false),
484                Field::new("values", DataType::Int32, true),
485                false,
486                false,
487            ),
488            Field::new(
489                "h",
490                DataType::Union(
491                    UnionFields::new(
492                        vec![1, 3],
493                        vec![
494                            Field::new("field1", DataType::UInt8, false),
495                            Field::new("field3", DataType::Utf8, false),
496                        ],
497                    ),
498                    UnionMode::Dense,
499                ),
500                true,
501            ),
502            Field::new(
503                "i",
504                DataType::RunEndEncoded(
505                    Arc::new(Field::new("run_ends", DataType::Int32, false)),
506                    Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)),
507                ),
508                false,
509            ),
510        ]);
511
512        let floats_a = DataType::Struct(vec![floats[0].clone()].into());
513
514        let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1);
515        assert_eq!(r.len(), 2);
516        assert_eq!(r[0], fields[0]);
517        assert_eq!(r[1].data_type(), &floats_a);
518
519        let r = fields.filter_leaves(|_, f| f.name() == "a");
520        assert_eq!(r.len(), 5);
521        assert_eq!(r[0], fields[0]);
522        assert_eq!(r[1].data_type(), &floats_a);
523        assert_eq!(
524            r[2].data_type(),
525            &DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone()))
526        );
527        assert_eq!(
528            r[3].as_ref(),
529            &Field::new_list("e", Field::new("floats", floats_a.clone(), true), true)
530        );
531        assert_eq!(
532            r[4].as_ref(),
533            &Field::new(
534                "i",
535                DataType::RunEndEncoded(
536                    Arc::new(Field::new("run_ends", DataType::Int32, false)),
537                    Arc::new(Field::new("values", floats_a.clone(), true)),
538                ),
539                false,
540            )
541        );
542
543        let r = fields.filter_leaves(|_, f| f.name() == "floats");
544        assert_eq!(r.len(), 0);
545
546        let r = fields.filter_leaves(|idx, _| idx == 9);
547        assert_eq!(r.len(), 1);
548        assert_eq!(r[0], fields[6]);
549
550        let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11);
551        assert_eq!(r.len(), 1);
552        assert_eq!(r[0], fields[7]);
553
554        let union = DataType::Union(
555            UnionFields::new(vec![1], vec![Field::new("field1", DataType::UInt8, false)]),
556            UnionMode::Dense,
557        );
558
559        let r = fields.filter_leaves(|idx, _| idx == 12);
560        assert_eq!(r.len(), 1);
561        assert_eq!(r[0].data_type(), &union);
562
563        let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15);
564        assert_eq!(r.len(), 1);
565        assert_eq!(r[0], fields[9]);
566
567        let r = fields.try_filter_leaves(|_, _| Err(ArrowError::SchemaError("error".to_string())));
569        assert!(r.is_err());
570    }
571}