datafusion_table_providers/util/
column_reference.rs

1use itertools::Itertools;
2use snafu::prelude::*;
3use std::{fmt::Display, hash::Hash};
4
5#[derive(Debug, Snafu)]
6pub enum Error {
7    #[snafu(display(r#"The column reference "{column_ref}" is missing a closing parenthensis."#))]
8    MissingClosingParenthesisInColumnReference { column_ref: String },
9}
10
11#[derive(Debug, Clone, Eq)]
12pub struct ColumnReference {
13    columns: Vec<String>,
14}
15
16impl ColumnReference {
17    #[must_use]
18    pub fn new(columns: Vec<String>) -> Self {
19        Self {
20            columns: columns.into_iter().sorted().collect(),
21        }
22    }
23
24    #[must_use]
25    pub fn empty() -> Self {
26        Self { columns: vec![] }
27    }
28
29    pub fn iter(&self) -> impl Iterator<Item = &str> {
30        self.columns.iter().map(String::as_str)
31    }
32
33    #[must_use]
34    pub fn is_empty(&self) -> bool {
35        self.columns.is_empty()
36    }
37
38    #[must_use]
39    pub fn contains(&self, column: &String) -> bool {
40        self.columns.contains(column)
41    }
42}
43
44impl Default for ColumnReference {
45    fn default() -> Self {
46        Self::empty()
47    }
48}
49
50impl PartialEq for ColumnReference {
51    fn eq(&self, other: &Self) -> bool {
52        if self.columns.len() != other.columns.len() {
53            return false;
54        }
55
56        self.columns
57            .iter()
58            .zip(other.columns.iter())
59            .all(|(a, b)| a == b)
60    }
61}
62
63impl Hash for ColumnReference {
64    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
65        self.columns.hash(state);
66    }
67}
68
69/// Parses column references from a string into a vector of each individual column reference.
70///
71/// "foo" -> vec!["foo"]
72/// "(foo, bar)" -> vec!["foo", "bar"]
73/// "(foo, bar" -> Err(The column reference "(foo,bar" is missing a closing parenthensis.)
74///
75/// # Examples
76///
77/// ```rust,ignore
78/// use datafusion_table_providers::util::column_reference::ColumnReference;
79///
80/// let column_ref = ColumnReference::try_from("foo").expect("valid columns");
81/// assert_eq!(column_ref.iter().collect::<Vec<_>>(), vec!["foo"]);
82///
83/// let column_ref = ColumnReference::try_from("(foo, bar)").expect("valid columns");
84/// assert_eq!(column_ref.iter().collect::<Vec<_>>(), vec!["foo", "bar"]);
85/// ```
86impl TryFrom<&str> for ColumnReference {
87    type Error = Error;
88
89    fn try_from(columns: &str) -> Result<Self, Self::Error> {
90        // The index/primary key can be either a single column or a compound index
91        if columns.starts_with('(') {
92            // Compound index
93            let end =
94                columns
95                    .find(')')
96                    .context(MissingClosingParenthesisInColumnReferenceSnafu {
97                        column_ref: columns.to_string(),
98                    })?;
99            Ok(Self {
100                columns: columns[1..end]
101                    .split(',')
102                    .map(str::trim)
103                    .map(String::from)
104                    .sorted()
105                    .collect::<Vec<String>>(),
106            })
107        } else {
108            // Single column reference
109            Ok(Self {
110                columns: vec![columns.to_string()],
111            })
112        }
113    }
114}
115
116impl Display for ColumnReference {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        if self.columns.len() == 1 {
119            write!(f, "{}", self.columns[0])
120        } else {
121            write!(f, "({})", self.columns.join(", "))
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_get_column_ref() {
132        let column_ref = ColumnReference::try_from("foo").expect("valid columns");
133        assert_eq!(column_ref.iter().collect::<Vec<_>>(), vec!["foo"]);
134
135        let column_ref = ColumnReference::try_from("(foo, bar)").expect("valid columns");
136        assert_eq!(column_ref.iter().collect::<Vec<_>>(), vec!["bar", "foo"]);
137
138        let column_ref = ColumnReference::try_from("(foo,bar)").expect("valid columns");
139        assert_eq!(column_ref.iter().collect::<Vec<_>>(), vec!["bar", "foo"]);
140    }
141}