datafusion_table_providers/util/
column_reference.rs1use itertools::Itertools;
2use snafu::prelude::*;
3use std::{fmt::Display, hash::Hash};
4
5#[derive(Debug, Snafu)]
6pub enum Error {
7 #[snafu(display(r#"The column reference "{column_ref}" is missing a closing parenthensis."#))]
8 MissingClosingParenthesisInColumnReference { column_ref: String },
9}
10
11#[derive(Debug, Clone, Eq)]
12pub struct ColumnReference {
13 columns: Vec<String>,
14}
15
16impl ColumnReference {
17 #[must_use]
18 pub fn new(columns: Vec<String>) -> Self {
19 Self {
20 columns: columns.into_iter().sorted().collect(),
21 }
22 }
23
24 #[must_use]
25 pub fn empty() -> Self {
26 Self { columns: vec![] }
27 }
28
29 pub fn iter(&self) -> impl Iterator<Item = &str> {
30 self.columns.iter().map(String::as_str)
31 }
32
33 #[must_use]
34 pub fn is_empty(&self) -> bool {
35 self.columns.is_empty()
36 }
37
38 #[must_use]
39 pub fn contains(&self, column: &String) -> bool {
40 self.columns.contains(column)
41 }
42}
43
44impl Default for ColumnReference {
45 fn default() -> Self {
46 Self::empty()
47 }
48}
49
50impl PartialEq for ColumnReference {
51 fn eq(&self, other: &Self) -> bool {
52 if self.columns.len() != other.columns.len() {
53 return false;
54 }
55
56 self.columns
57 .iter()
58 .zip(other.columns.iter())
59 .all(|(a, b)| a == b)
60 }
61}
62
63impl Hash for ColumnReference {
64 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
65 self.columns.hash(state);
66 }
67}
68
69impl TryFrom<&str> for ColumnReference {
87 type Error = Error;
88
89 fn try_from(columns: &str) -> Result<Self, Self::Error> {
90 if columns.starts_with('(') {
92 let end =
94 columns
95 .find(')')
96 .context(MissingClosingParenthesisInColumnReferenceSnafu {
97 column_ref: columns.to_string(),
98 })?;
99 Ok(Self {
100 columns: columns[1..end]
101 .split(',')
102 .map(str::trim)
103 .map(String::from)
104 .sorted()
105 .collect::<Vec<String>>(),
106 })
107 } else {
108 Ok(Self {
110 columns: vec![columns.to_string()],
111 })
112 }
113 }
114}
115
116impl Display for ColumnReference {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 if self.columns.len() == 1 {
119 write!(f, "{}", self.columns[0])
120 } else {
121 write!(f, "({})", self.columns.join(", "))
122 }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn test_get_column_ref() {
132 let column_ref = ColumnReference::try_from("foo").expect("valid columns");
133 assert_eq!(column_ref.iter().collect::<Vec<_>>(), vec!["foo"]);
134
135 let column_ref = ColumnReference::try_from("(foo, bar)").expect("valid columns");
136 assert_eq!(column_ref.iter().collect::<Vec<_>>(), vec!["bar", "foo"]);
137
138 let column_ref = ColumnReference::try_from("(foo,bar)").expect("valid columns");
139 assert_eq!(column_ref.iter().collect::<Vec<_>>(), vec!["bar", "foo"]);
140 }
141}