1use std::ops::Deref;
19use std::sync::Arc;
20
21use crate::{ArrowError, DataType, Field, FieldRef};
22
23#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
59#[cfg_attr(feature = "serde", serde(transparent))]
60pub struct Fields(Arc<[FieldRef]>);
61
62impl std::fmt::Debug for Fields {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 self.0.as_ref().fmt(f)
65 }
66}
67
68impl Fields {
69 pub fn empty() -> Self {
71 Self(Arc::new([]))
72 }
73
74 pub fn size(&self) -> usize {
76 self.iter()
77 .map(|field| field.size() + std::mem::size_of::<FieldRef>())
78 .sum()
79 }
80
81 pub fn find(&self, name: &str) -> Option<(usize, &FieldRef)> {
83 self.0.iter().enumerate().find(|(_, b)| b.name() == name)
84 }
85
86 pub fn contains(&self, other: &Fields) -> bool {
93 if Arc::ptr_eq(&self.0, &other.0) {
94 return true;
95 }
96 self.len() == other.len()
97 && self
98 .iter()
99 .zip(other.iter())
100 .all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b))
101 }
102
103 pub fn filter_leaves<F: FnMut(usize, &FieldRef) -> bool>(&self, mut filter: F) -> Self {
140 self.try_filter_leaves(|idx, field| Ok(filter(idx, field)))
141 .unwrap()
142 }
143
144 pub fn try_filter_leaves<F: FnMut(usize, &FieldRef) -> Result<bool, ArrowError>>(
149 &self,
150 mut filter: F,
151 ) -> Result<Self, ArrowError> {
152 fn filter_field<F: FnMut(&FieldRef) -> Result<bool, ArrowError>>(
153 f: &FieldRef,
154 filter: &mut F,
155 ) -> Result<Option<FieldRef>, ArrowError> {
156 use DataType::*;
157
158 let v = match f.data_type() {
159 Dictionary(_, v) => v.as_ref(), RunEndEncoded(_, v) => v.data_type(), d => d,
162 };
163 let d = match v {
164 List(child) => {
165 let fields = filter_field(child, filter)?;
166 if let Some(fields) = fields {
167 List(fields)
168 } else {
169 return Ok(None);
170 }
171 }
172 LargeList(child) => {
173 let fields = filter_field(child, filter)?;
174 if let Some(fields) = fields {
175 LargeList(fields)
176 } else {
177 return Ok(None);
178 }
179 }
180 Map(child, ordered) => {
181 let fields = filter_field(child, filter)?;
182 if let Some(fields) = fields {
183 Map(fields, *ordered)
184 } else {
185 return Ok(None);
186 }
187 }
188 FixedSizeList(child, size) => {
189 let fields = filter_field(child, filter)?;
190 if let Some(fields) = fields {
191 FixedSizeList(fields, *size)
192 } else {
193 return Ok(None);
194 }
195 }
196 Struct(fields) => {
197 let filtered: Result<Vec<_>, _> =
198 fields.iter().map(|f| filter_field(f, filter)).collect();
199 let filtered: Fields = filtered?
200 .iter()
201 .filter_map(|f| f.as_ref().cloned())
202 .collect();
203
204 if filtered.is_empty() {
205 return Ok(None);
206 }
207
208 Struct(filtered)
209 }
210 Union(fields, mode) => {
211 let filtered: Result<Vec<_>, _> = fields
212 .iter()
213 .map(|(id, f)| filter_field(f, filter).map(|f| f.map(|f| (id, f))))
214 .collect();
215 let filtered: UnionFields = filtered?
216 .iter()
217 .filter_map(|f| f.as_ref().cloned())
218 .collect();
219
220 if filtered.is_empty() {
221 return Ok(None);
222 }
223
224 Union(filtered, *mode)
225 }
226 _ => {
227 let filtered = filter(f)?;
228 return Ok(filtered.then(|| f.clone()));
229 }
230 };
231 let d = match f.data_type() {
232 Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)),
233 RunEndEncoded(v, f) => {
234 RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d)))
235 }
236 _ => d,
237 };
238 Ok(Some(Arc::new(f.as_ref().clone().with_data_type(d))))
239 }
240
241 let mut leaf_idx = 0;
242 let mut filter = |f: &FieldRef| {
243 let t = filter(leaf_idx, f)?;
244 leaf_idx += 1;
245 Ok(t)
246 };
247
248 let filtered: Result<Vec<_>, _> = self
249 .0
250 .iter()
251 .map(|f| filter_field(f, &mut filter))
252 .collect();
253 let filtered = filtered?
254 .iter()
255 .filter_map(|f| f.as_ref().cloned())
256 .collect();
257 Ok(filtered)
258 }
259}
260
261impl Default for Fields {
262 fn default() -> Self {
263 Self::empty()
264 }
265}
266
267impl FromIterator<Field> for Fields {
268 fn from_iter<T: IntoIterator<Item = Field>>(iter: T) -> Self {
269 iter.into_iter().map(Arc::new).collect()
270 }
271}
272
273impl FromIterator<FieldRef> for Fields {
274 fn from_iter<T: IntoIterator<Item = FieldRef>>(iter: T) -> Self {
275 Self(iter.into_iter().collect())
276 }
277}
278
279impl From<Vec<Field>> for Fields {
280 fn from(value: Vec<Field>) -> Self {
281 value.into_iter().collect()
282 }
283}
284
285impl From<Vec<FieldRef>> for Fields {
286 fn from(value: Vec<FieldRef>) -> Self {
287 Self(value.into())
288 }
289}
290
291impl From<&[FieldRef]> for Fields {
292 fn from(value: &[FieldRef]) -> Self {
293 Self(value.into())
294 }
295}
296
297impl<const N: usize> From<[FieldRef; N]> for Fields {
298 fn from(value: [FieldRef; N]) -> Self {
299 Self(Arc::new(value))
300 }
301}
302
303impl Deref for Fields {
304 type Target = [FieldRef];
305
306 fn deref(&self) -> &Self::Target {
307 self.0.as_ref()
308 }
309}
310
311impl<'a> IntoIterator for &'a Fields {
312 type Item = &'a FieldRef;
313 type IntoIter = std::slice::Iter<'a, FieldRef>;
314
315 fn into_iter(self) -> Self::IntoIter {
316 self.0.iter()
317 }
318}
319
320#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
322#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
323#[cfg_attr(feature = "serde", serde(transparent))]
324pub struct UnionFields(Arc<[(i8, FieldRef)]>);
325
326impl std::fmt::Debug for UnionFields {
327 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328 self.0.as_ref().fmt(f)
329 }
330}
331
332impl std::ops::Index<usize> for UnionFields {
341 type Output = (i8, FieldRef);
342
343 fn index(&self, index: usize) -> &Self::Output {
344 &self.0[index]
345 }
346}
347
348impl UnionFields {
349 pub fn empty() -> Self {
351 Self(Arc::from([]))
352 }
353
354 pub fn try_new<F, T>(type_ids: T, fields: F) -> Result<Self, ArrowError>
391 where
392 F: IntoIterator,
393 F::Item: Into<FieldRef>,
394 T: IntoIterator<Item = i8>,
395 {
396 let mut type_ids_iter = type_ids.into_iter();
397 let mut fields_iter = fields.into_iter().map(Into::into);
398
399 let mut seen_type_ids = 0u128;
400
401 let mut out = Vec::new();
402
403 loop {
404 match (type_ids_iter.next(), fields_iter.next()) {
405 (None, None) => return Ok(Self(out.into())),
406 (Some(type_id), Some(field)) => {
407 if type_id < 0 {
409 return Err(ArrowError::InvalidArgumentError(format!(
410 "type ids must be non-negative: {type_id}"
411 )));
412 }
413
414 let mask = 1_u128 << type_id;
416 if (seen_type_ids & mask) != 0 {
417 return Err(ArrowError::InvalidArgumentError(format!(
418 "duplicate type id: {type_id}"
419 )));
420 }
421
422 seen_type_ids |= mask;
423
424 out.push((type_id, field));
425 }
426 (None, Some(_)) => {
427 return Err(ArrowError::InvalidArgumentError(
428 "fields iterator has more elements than type_ids iterator".to_string(),
429 ));
430 }
431 (Some(_), None) => {
432 return Err(ArrowError::InvalidArgumentError(
433 "type_ids iterator has more elements than fields iterator".to_string(),
434 ));
435 }
436 }
437 }
438 }
439
440 pub fn from_fields<F>(fields: F) -> Self
468 where
469 F: IntoIterator,
470 F::Item: Into<FieldRef>,
471 {
472 fields
473 .into_iter()
474 .enumerate()
475 .map(|(i, field)| {
476 let id = i8::try_from(i).expect("UnionFields cannot contain more than 128 fields");
477
478 (id, field.into())
479 })
480 .collect()
481 }
482
483 pub fn try_from_fields<F>(fields: F) -> Result<Self, ArrowError>
518 where
519 F: IntoIterator,
520 F::Item: Into<FieldRef>,
521 {
522 let mut out = Vec::with_capacity(i8::MAX as usize + 1);
523
524 for (i, field) in fields.into_iter().enumerate() {
525 let id = i8::try_from(i).map_err(|_| {
526 ArrowError::InvalidArgumentError(
527 "UnionFields cannot contain more than 128 fields".into(),
528 )
529 })?;
530
531 out.push((id, field.into()));
532 }
533
534 Ok(Self(out.into()))
535 }
536
537 pub fn size(&self) -> usize {
539 self.iter()
540 .map(|(_, field)| field.size() + std::mem::size_of::<(i8, FieldRef)>())
541 .sum()
542 }
543
544 pub fn len(&self) -> usize {
546 self.0.len()
547 }
548
549 pub fn is_empty(&self) -> bool {
551 self.0.is_empty()
552 }
553
554 pub fn iter(&self) -> impl Iterator<Item = (i8, &FieldRef)> + '_ {
556 self.0.iter().map(|(id, f)| (*id, f))
557 }
558
559 pub fn get(&self, index: usize) -> Option<&(i8, FieldRef)> {
581 self.0.get(index)
582 }
583
584 pub fn find_by_type_id(&self, type_id: i8) -> Option<(i8, &FieldRef)> {
587 self.iter().find(|&(i, _)| i == type_id)
588 }
589
590 pub fn find_by_field(&self, field: &Field) -> Option<(i8, &FieldRef)> {
593 self.iter().find(|&(_, f)| f.as_ref() == field)
594 }
595
596 pub(crate) fn try_merge(&mut self, other: &Self) -> Result<(), ArrowError> {
600 let mut output: Vec<_> = self.iter().map(|(id, f)| (id, f.clone())).collect();
602 for (field_type_id, from_field) in other.iter() {
603 let mut is_new_field = true;
604 for (self_type_id, self_field) in output.iter_mut() {
605 if from_field == self_field {
606 if *self_type_id != field_type_id {
609 return Err(ArrowError::SchemaError(format!(
610 "Fail to merge schema field '{}' because the self_type_id = {} does not equal field_type_id = {}",
611 self_field.name(),
612 self_type_id,
613 field_type_id
614 )));
615 }
616
617 is_new_field = false;
618 break;
619 }
620 }
621
622 if is_new_field {
623 output.push((field_type_id, from_field.clone()))
624 }
625 }
626 *self = output.into_iter().collect();
627 Ok(())
628 }
629}
630
631impl FromIterator<(i8, FieldRef)> for UnionFields {
632 fn from_iter<T: IntoIterator<Item = (i8, FieldRef)>>(iter: T) -> Self {
633 Self(iter.into_iter().collect())
634 }
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640 use crate::UnionMode;
641
642 #[test]
643 fn test_filter() {
644 let floats = Fields::from(vec![
645 Field::new("a", DataType::Float32, false),
646 Field::new("b", DataType::Float32, false),
647 ]);
648 let fields = Fields::from(vec![
649 Field::new("a", DataType::Int32, true),
650 Field::new("floats", DataType::Struct(floats.clone()), true),
651 Field::new("b", DataType::Int16, true),
652 Field::new(
653 "c",
654 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
655 false,
656 ),
657 Field::new(
658 "d",
659 DataType::Dictionary(
660 Box::new(DataType::Int32),
661 Box::new(DataType::Struct(floats.clone())),
662 ),
663 false,
664 ),
665 Field::new_list(
666 "e",
667 Field::new("floats", DataType::Struct(floats.clone()), true),
668 true,
669 ),
670 Field::new_fixed_size_list(
671 "f",
672 Field::new_list_field(DataType::Int32, false),
673 3,
674 false,
675 ),
676 Field::new_map(
677 "g",
678 "entries",
679 Field::new("keys", DataType::LargeUtf8, false),
680 Field::new("values", DataType::Int32, true),
681 false,
682 false,
683 ),
684 Field::new(
685 "h",
686 DataType::Union(
687 UnionFields::try_new(
688 vec![1, 3],
689 vec![
690 Field::new("field1", DataType::UInt8, false),
691 Field::new("field3", DataType::Utf8, false),
692 ],
693 )
694 .unwrap(),
695 UnionMode::Dense,
696 ),
697 true,
698 ),
699 Field::new(
700 "i",
701 DataType::RunEndEncoded(
702 Arc::new(Field::new("run_ends", DataType::Int32, false)),
703 Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)),
704 ),
705 false,
706 ),
707 ]);
708
709 let floats_a = DataType::Struct(vec![floats[0].clone()].into());
710
711 let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1);
712 assert_eq!(r.len(), 2);
713 assert_eq!(r[0], fields[0]);
714 assert_eq!(r[1].data_type(), &floats_a);
715
716 let r = fields.filter_leaves(|_, f| f.name() == "a");
717 assert_eq!(r.len(), 5);
718 assert_eq!(r[0], fields[0]);
719 assert_eq!(r[1].data_type(), &floats_a);
720 assert_eq!(
721 r[2].data_type(),
722 &DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone()))
723 );
724 assert_eq!(
725 r[3].as_ref(),
726 &Field::new_list("e", Field::new("floats", floats_a.clone(), true), true)
727 );
728 assert_eq!(
729 r[4].as_ref(),
730 &Field::new(
731 "i",
732 DataType::RunEndEncoded(
733 Arc::new(Field::new("run_ends", DataType::Int32, false)),
734 Arc::new(Field::new("values", floats_a.clone(), true)),
735 ),
736 false,
737 )
738 );
739
740 let r = fields.filter_leaves(|_, f| f.name() == "floats");
741 assert_eq!(r.len(), 0);
742
743 let r = fields.filter_leaves(|idx, _| idx == 9);
744 assert_eq!(r.len(), 1);
745 assert_eq!(r[0], fields[6]);
746
747 let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11);
748 assert_eq!(r.len(), 1);
749 assert_eq!(r[0], fields[7]);
750
751 let union = DataType::Union(
752 UnionFields::try_new(vec![1], vec![Field::new("field1", DataType::UInt8, false)])
753 .unwrap(),
754 UnionMode::Dense,
755 );
756
757 let r = fields.filter_leaves(|idx, _| idx == 12);
758 assert_eq!(r.len(), 1);
759 assert_eq!(r[0].data_type(), &union);
760
761 let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15);
762 assert_eq!(r.len(), 1);
763 assert_eq!(r[0], fields[9]);
764
765 let r = fields.try_filter_leaves(|_, _| Err(ArrowError::SchemaError("error".to_string())));
767 assert!(r.is_err());
768 }
769
770 #[test]
771 fn test_union_fields_try_new_valid() {
772 let res = UnionFields::try_new(
773 vec![1, 6, 7],
774 vec![
775 Field::new("f1", DataType::UInt8, false),
776 Field::new("f6", DataType::Utf8, false),
777 Field::new("f7", DataType::Int32, true),
778 ],
779 );
780 assert!(res.is_ok());
781 let union_fields = res.unwrap();
782 assert_eq!(union_fields.len(), 3);
783 assert_eq!(
784 union_fields.iter().map(|(id, _)| id).collect::<Vec<_>>(),
785 vec![1, 6, 7]
786 );
787 }
788
789 #[test]
790 fn test_union_fields_try_new_empty() {
791 let res = UnionFields::try_new(Vec::<i8>::new(), Vec::<Field>::new());
792 assert!(res.is_ok());
793 assert!(res.unwrap().is_empty());
794 }
795
796 #[test]
797 fn test_union_fields_try_new_duplicate_type_id() {
798 let res = UnionFields::try_new(
799 vec![1, 1],
800 vec![
801 Field::new("f1", DataType::UInt8, false),
802 Field::new("f2", DataType::Utf8, false),
803 ],
804 );
805 assert!(res.is_err());
806 assert!(
807 res.unwrap_err()
808 .to_string()
809 .contains("duplicate type id: 1")
810 );
811 }
812
813 #[test]
814 fn test_union_fields_try_new_duplicate_field() {
815 let field = Field::new("field", DataType::UInt8, false);
816 let res = UnionFields::try_new(vec![1, 2], vec![field.clone(), field]);
817 assert!(res.is_ok());
818 }
819
820 #[test]
821 fn test_union_fields_try_new_more_type_ids() {
822 let res = UnionFields::try_new(
823 vec![1, 2, 3],
824 vec![
825 Field::new("f1", DataType::UInt8, false),
826 Field::new("f2", DataType::Utf8, false),
827 ],
828 );
829 assert!(res.is_err());
830 assert!(
831 res.unwrap_err()
832 .to_string()
833 .contains("type_ids iterator has more elements")
834 );
835 }
836
837 #[test]
838 fn test_union_fields_try_new_more_fields() {
839 let res = UnionFields::try_new(
840 vec![1, 2],
841 vec![
842 Field::new("f1", DataType::UInt8, false),
843 Field::new("f2", DataType::Utf8, false),
844 Field::new("f3", DataType::Int32, true),
845 ],
846 );
847 assert!(res.is_err());
848 assert!(
849 res.unwrap_err()
850 .to_string()
851 .contains("fields iterator has more elements")
852 );
853 }
854
855 #[test]
856 fn test_union_fields_try_new_negative_type_ids() {
857 let res = UnionFields::try_new(
858 vec![-128, -1, 0, 127],
859 vec![
860 Field::new("field_min", DataType::UInt8, false),
861 Field::new("field_neg", DataType::Utf8, false),
862 Field::new("field_zero", DataType::Int32, true),
863 Field::new("field_max", DataType::Boolean, false),
864 ],
865 );
866 assert!(res.is_err());
867 assert!(
868 res.unwrap_err()
869 .to_string()
870 .contains("type ids must be non-negative")
871 )
872 }
873
874 #[test]
875 fn test_union_fields_try_new_complex_types() {
876 let res = UnionFields::try_new(
877 vec![0, 1, 2],
878 vec![
879 Field::new(
880 "struct_field",
881 DataType::Struct(Fields::from(vec![
882 Field::new("a", DataType::Int32, false),
883 Field::new("b", DataType::Utf8, true),
884 ])),
885 false,
886 ),
887 Field::new_list(
888 "list_field",
889 Field::new("item", DataType::Float64, true),
890 true,
891 ),
892 Field::new(
893 "dict_field",
894 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
895 false,
896 ),
897 ],
898 );
899 assert!(res.is_ok());
900 assert_eq!(res.unwrap().len(), 3);
901 }
902
903 #[test]
904 fn test_union_fields_try_new_single_field() {
905 let res = UnionFields::try_new(
906 vec![42],
907 vec![Field::new("only_field", DataType::Int64, false)],
908 );
909 assert!(res.is_ok());
910 let union_fields = res.unwrap();
911 assert_eq!(union_fields.len(), 1);
912 assert_eq!(union_fields.iter().next().unwrap().0, 42);
913 }
914
915 #[test]
916 fn test_union_fields_try_from_fields_empty() {
917 let res = UnionFields::try_from_fields(Vec::<Field>::new());
918 assert!(res.is_ok());
919 assert!(res.unwrap().is_empty());
920 }
921
922 #[test]
923 fn test_union_fields_try_from_fields_single() {
924 let res = UnionFields::try_from_fields(vec![Field::new("only", DataType::Int64, false)]);
925 assert!(res.is_ok());
926 let union_fields = res.unwrap();
927 assert_eq!(union_fields.len(), 1);
928 assert_eq!(union_fields.iter().next().unwrap().0, 0);
929 }
930
931 #[test]
932 fn test_union_fields_try_from_fields_too_many() {
933 let many_fields: Vec<_> = (0..200)
934 .map(|i| Field::new(format!("field{}", i), DataType::Int32, false))
935 .collect();
936 let res = UnionFields::try_from_fields(many_fields);
937 assert!(res.is_err());
938 assert!(
939 res.unwrap_err()
940 .to_string()
941 .contains("UnionFields cannot contain more than 128 fields")
942 );
943 }
944
945 #[test]
946 fn test_union_fields_try_from_fields_max_valid() {
947 let fields: Vec<_> = (0..=i8::MAX)
948 .map(|i| Field::new(format!("field{}", i), DataType::Int32, false))
949 .collect();
950 let res = UnionFields::try_from_fields(fields);
951 assert!(res.is_ok());
952 let union_fields = res.unwrap();
953 assert_eq!(union_fields.len(), 128);
954 assert_eq!(union_fields.iter().map(|(id, _)| id).min().unwrap(), 0);
955 assert_eq!(union_fields.iter().map(|(id, _)| id).max().unwrap(), 127);
956 }
957
958 #[test]
959 fn test_union_fields_try_from_fields_over_max() {
960 let fields: Vec<_> = (0..129)
962 .map(|i| Field::new(format!("field{}", i), DataType::Int32, false))
963 .collect();
964 let res = UnionFields::try_from_fields(fields);
965 assert!(res.is_err());
966 }
967
968 #[test]
969 fn test_union_fields_try_from_fields_complex_types() {
970 let res = UnionFields::try_from_fields(vec![
971 Field::new(
972 "struct_field",
973 DataType::Struct(Fields::from(vec![
974 Field::new("a", DataType::Int32, false),
975 Field::new("b", DataType::Utf8, true),
976 ])),
977 false,
978 ),
979 Field::new_list(
980 "list_field",
981 Field::new("item", DataType::Float64, true),
982 true,
983 ),
984 Field::new(
985 "dict_field",
986 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
987 false,
988 ),
989 ]);
990 assert!(res.is_ok());
991 assert_eq!(res.unwrap().len(), 3);
992 }
993}