1use std::collections::HashSet;
4
5use arrow_schema::SchemaRef;
6
7use crate::Result;
8use crate::error::Error;
9
10pub fn resolve(
16 schema: &SchemaRef,
17 include: Option<&[String]>,
18 exclude: Option<&[String]>,
19) -> Result<Option<Vec<String>>> {
20 if include.is_none() && exclude.is_none() {
21 return Ok(None);
22 }
23
24 let all: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
25 let available_set: HashSet<&str> = all.iter().copied().collect();
26
27 if let Some(excl) = exclude {
28 check_dupes(excl)?;
29 check_present(excl, &available_set, &all)?;
30 let excl_set: HashSet<&str> = excl.iter().map(String::as_str).collect();
31 let remaining: Vec<String> = all
32 .iter()
33 .filter(|c| !excl_set.contains(*c))
34 .map(|c| (*c).to_string())
35 .collect();
36 return Ok(Some(remaining));
37 }
38
39 let incl = include.expect("checked above");
41 check_dupes(incl)?;
42 check_present(incl, &available_set, &all)?;
43 Ok(Some(incl.to_vec()))
44}
45
46fn check_dupes(cols: &[String]) -> Result<()> {
47 let mut seen = HashSet::new();
48 for c in cols {
49 if !seen.insert(c.as_str()) {
50 return Err(Error::DuplicateColumn(c.clone()));
51 }
52 }
53 Ok(())
54}
55
56fn check_present(cols: &[String], available: &HashSet<&str>, all_ordered: &[&str]) -> Result<()> {
57 for c in cols {
58 if !available.contains(c.as_str()) {
59 return Err(Error::UnknownColumn {
60 name: c.clone(),
61 available: all_ordered.join(", "),
62 });
63 }
64 }
65 Ok(())
66}
67
68#[cfg(test)]
69mod tests {
70 use std::sync::Arc;
71
72 use arrow_schema::{DataType, Field, Schema};
73
74 use super::*;
75
76 fn schema() -> SchemaRef {
77 Arc::new(Schema::new(vec![
78 Field::new("a", DataType::Int32, true),
79 Field::new("b", DataType::Utf8, true),
80 Field::new("c", DataType::Float64, true),
81 ]))
82 }
83
84 #[test]
85 fn none_means_all() {
86 let s = schema();
87 assert!(resolve(&s, None, None).unwrap().is_none());
88 }
89
90 #[test]
91 fn include_preserves_user_order() {
92 let s = schema();
93 let incl = vec!["c".to_string(), "a".to_string()];
94 let got = resolve(&s, Some(&incl), None).unwrap().unwrap();
95 assert_eq!(got, vec!["c".to_string(), "a".to_string()]);
96 }
97
98 #[test]
99 fn exclude_keeps_schema_order() {
100 let s = schema();
101 let excl = vec!["b".to_string()];
102 let got = resolve(&s, None, Some(&excl)).unwrap().unwrap();
103 assert_eq!(got, vec!["a".to_string(), "c".to_string()]);
104 }
105
106 #[test]
107 fn exclude_takes_precedence() {
108 let s = schema();
109 let incl = vec!["a".to_string(), "b".to_string()];
110 let excl = vec!["b".to_string()];
111 let got = resolve(&s, Some(&incl), Some(&excl)).unwrap().unwrap();
112 assert_eq!(got, vec!["a".to_string(), "c".to_string()]);
113 }
114
115 #[test]
116 fn unknown_column_errors() {
117 let s = schema();
118 let incl = vec!["zzz".to_string()];
119 assert!(matches!(
120 resolve(&s, Some(&incl), None),
121 Err(Error::UnknownColumn { .. })
122 ));
123 }
124
125 #[test]
126 fn duplicate_errors() {
127 let s = schema();
128 let incl = vec!["a".to_string(), "a".to_string()];
129 assert!(matches!(
130 resolve(&s, Some(&incl), None),
131 Err(Error::DuplicateColumn(_))
132 ));
133 }
134}