Skip to main content

arrs/
projection.rs

1//! Resolve `--columns` / `--exclude-columns` against an arrow schema.
2
3use std::collections::HashSet;
4
5use arrow_schema::SchemaRef;
6
7use crate::Result;
8use crate::error::Error;
9
10/// Resolve the requested projection against `schema`. Returns `None` when neither
11/// flag was provided (callers interpret as "all columns, no filtering").
12///
13/// `exclude` takes precedence over `include`: if both are set, the result is
14/// `<all columns> \ exclude` (the include list is ignored, matching the spec).
15pub fn resolve(
16    schema: &SchemaRef,
17    include: Option<&[String]>,
18    exclude: Option<&[String]>,
19) -> Result<Option<Vec<String>>> {
20    if include.is_none() && exclude.is_none() {
21        return Ok(None);
22    }
23
24    let all: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
25    let available_set: HashSet<&str> = all.iter().copied().collect();
26
27    if let Some(excl) = exclude {
28        check_dupes(excl)?;
29        check_present(excl, &available_set, &all)?;
30        let excl_set: HashSet<&str> = excl.iter().map(String::as_str).collect();
31        let remaining: Vec<String> = all
32            .iter()
33            .filter(|c| !excl_set.contains(*c))
34            .map(|c| (*c).to_string())
35            .collect();
36        return Ok(Some(remaining));
37    }
38
39    // Only --columns is set.
40    let incl = include.expect("checked above");
41    check_dupes(incl)?;
42    check_present(incl, &available_set, &all)?;
43    Ok(Some(incl.to_vec()))
44}
45
46fn check_dupes(cols: &[String]) -> Result<()> {
47    let mut seen = HashSet::new();
48    for c in cols {
49        if !seen.insert(c.as_str()) {
50            return Err(Error::DuplicateColumn(c.clone()));
51        }
52    }
53    Ok(())
54}
55
56fn check_present(cols: &[String], available: &HashSet<&str>, all_ordered: &[&str]) -> Result<()> {
57    for c in cols {
58        if !available.contains(c.as_str()) {
59            return Err(Error::UnknownColumn {
60                name: c.clone(),
61                available: all_ordered.join(", "),
62            });
63        }
64    }
65    Ok(())
66}
67
68#[cfg(test)]
69mod tests {
70    use std::sync::Arc;
71
72    use arrow_schema::{DataType, Field, Schema};
73
74    use super::*;
75
76    fn schema() -> SchemaRef {
77        Arc::new(Schema::new(vec![
78            Field::new("a", DataType::Int32, true),
79            Field::new("b", DataType::Utf8, true),
80            Field::new("c", DataType::Float64, true),
81        ]))
82    }
83
84    #[test]
85    fn none_means_all() {
86        let s = schema();
87        assert!(resolve(&s, None, None).unwrap().is_none());
88    }
89
90    #[test]
91    fn include_preserves_user_order() {
92        let s = schema();
93        let incl = vec!["c".to_string(), "a".to_string()];
94        let got = resolve(&s, Some(&incl), None).unwrap().unwrap();
95        assert_eq!(got, vec!["c".to_string(), "a".to_string()]);
96    }
97
98    #[test]
99    fn exclude_keeps_schema_order() {
100        let s = schema();
101        let excl = vec!["b".to_string()];
102        let got = resolve(&s, None, Some(&excl)).unwrap().unwrap();
103        assert_eq!(got, vec!["a".to_string(), "c".to_string()]);
104    }
105
106    #[test]
107    fn exclude_takes_precedence() {
108        let s = schema();
109        let incl = vec!["a".to_string(), "b".to_string()];
110        let excl = vec!["b".to_string()];
111        let got = resolve(&s, Some(&incl), Some(&excl)).unwrap().unwrap();
112        assert_eq!(got, vec!["a".to_string(), "c".to_string()]);
113    }
114
115    #[test]
116    fn unknown_column_errors() {
117        let s = schema();
118        let incl = vec!["zzz".to_string()];
119        assert!(matches!(
120            resolve(&s, Some(&incl), None),
121            Err(Error::UnknownColumn { .. })
122        ));
123    }
124
125    #[test]
126    fn duplicate_errors() {
127        let s = schema();
128        let incl = vec!["a".to_string(), "a".to_string()];
129        assert!(matches!(
130            resolve(&s, Some(&incl), None),
131            Err(Error::DuplicateColumn(_))
132        ));
133    }
134}