1use std::collections::HashMap;
8
9use serde::Serialize;
10
11use crate::config::ColumnFilter;
12use crate::errors::AppError;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
15#[serde(rename_all = "lowercase")]
16pub enum LogicalType {
17 Bool,
18 Int,
19 Float,
20 Utf8,
21 Temporal,
24 Other,
27}
28
29impl LogicalType {
30 pub fn needs_cast(self) -> bool {
33 matches!(self, LogicalType::Temporal)
34 }
35}
36
37#[derive(Debug, Clone, Serialize)]
38pub struct ColumnInfo {
39 pub name: String,
40 pub logical: LogicalType,
41 pub sql_type: String,
44 pub nullable: bool,
45}
46
47#[derive(Debug, Clone)]
48pub struct DatasetSchema {
49 pub name: String,
50 pub columns: Vec<ColumnInfo>,
51 pub by_name: HashMap<String, usize>,
53 pub predicate_filter: ColumnFilter,
55 pub projection_filter: ColumnFilter,
58}
59
60impl DatasetSchema {
61 pub fn new(name: impl Into<String>, columns: Vec<ColumnInfo>) -> Self {
62 let by_name = columns
63 .iter()
64 .enumerate()
65 .map(|(i, c)| (c.name.to_lowercase(), i))
66 .collect();
67 Self {
68 name: name.into(),
69 columns,
70 by_name,
71 predicate_filter: ColumnFilter::default(),
72 projection_filter: ColumnFilter::default(),
73 }
74 }
75
76 pub fn with_filters(
81 mut self,
82 predicate_filter: ColumnFilter,
83 projection_filter: ColumnFilter,
84 ) -> Result<Self, AppError> {
85 for (ctx, filter) in [
86 ("predicate_filter", &predicate_filter),
87 ("projection_filter", &projection_filter),
88 ] {
89 filter.validate(&self.name, ctx)?;
90 for col in filter.listed() {
91 if !self.by_name.contains_key(&col.to_lowercase()) {
92 return Err(AppError::InvalidValue(format!(
93 "dataset '{}': {ctx} references unknown column '{col}'",
94 self.name
95 )));
96 }
97 }
98 }
99 self.predicate_filter = predicate_filter;
100 self.projection_filter = projection_filter;
101 Ok(self)
102 }
103
104 pub fn has_column_filters(&self) -> bool {
106 self.predicate_filter.is_active() || self.projection_filter.is_active()
107 }
108
109 pub fn is_visible(&self, name: &str) -> bool {
112 self.projection_filter.allows(name)
113 }
114
115 pub fn visible_columns(&self) -> Vec<&ColumnInfo> {
117 self.columns
118 .iter()
119 .filter(|c| self.projection_filter.allows(&c.name))
120 .collect()
121 }
122
123 pub fn find(&self, name: &str) -> Result<&ColumnInfo, AppError> {
125 self.by_name
126 .get(&name.to_lowercase())
127 .map(|&i| &self.columns[i])
128 .ok_or_else(|| AppError::UnknownColumn(name.into()))
129 }
130
131 pub fn find_visible(&self, name: &str) -> Result<&ColumnInfo, AppError> {
135 let col = self.find(name)?;
136 if self.projection_filter.allows(&col.name) {
137 Ok(col)
138 } else {
139 Err(AppError::UnknownColumn(name.into()))
140 }
141 }
142
143 pub fn find_for_predicate(&self, name: &str) -> Result<&ColumnInfo, AppError> {
147 let col = self.find_visible(name)?;
148 if self.predicate_filter.allows(&col.name) {
149 Ok(col)
150 } else {
151 Err(AppError::Forbidden(format!(
152 "column '{}' may not be used in predicates on dataset '{}'",
153 col.name, self.name
154 )))
155 }
156 }
157
158 pub fn quote_ident(name: &str) -> String {
161 format!("\"{}\"", name.replace('"', "\"\""))
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 fn s() -> DatasetSchema {
170 DatasetSchema::new(
171 "ds",
172 vec![
173 ColumnInfo {
174 name: "Id".into(),
175 logical: LogicalType::Int,
176 sql_type: "BIGINT".into(),
177 nullable: false,
178 },
179 ColumnInfo {
180 name: "When".into(),
181 logical: LogicalType::Temporal,
182 sql_type: "TIMESTAMP".into(),
183 nullable: true,
184 },
185 ],
186 )
187 }
188
189 #[test]
190 fn quote_ident_plain() {
191 assert_eq!(DatasetSchema::quote_ident("foo"), "\"foo\"");
192 }
193
194 #[test]
195 fn quote_ident_escapes_inner_quote() {
196 assert_eq!(DatasetSchema::quote_ident("a\"b"), "\"a\"\"b\"");
197 }
198
199 #[test]
200 fn find_case_insensitive_returns_canonical_name() {
201 let sch = s();
202 let c = sch.find("ID").expect("found");
203 assert_eq!(c.name, "Id");
204 }
205
206 #[test]
207 fn find_unknown_column() {
208 let sch = s();
209 let err = sch.find("nope").unwrap_err();
210 assert!(matches!(err, AppError::UnknownColumn(_)));
211 }
212
213 #[test]
214 fn needs_cast_only_temporal() {
215 assert!(LogicalType::Temporal.needs_cast());
216 for t in [
217 LogicalType::Bool,
218 LogicalType::Int,
219 LogicalType::Float,
220 LogicalType::Utf8,
221 LogicalType::Other,
222 ] {
223 assert!(!t.needs_cast());
224 }
225 }
226
227 fn excl(cols: &[&str]) -> ColumnFilter {
228 ColumnFilter {
229 include: vec![],
230 exclude: cols.iter().map(|s| s.to_string()).collect(),
231 }
232 }
233
234 fn incl(cols: &[&str]) -> ColumnFilter {
235 ColumnFilter {
236 include: cols.iter().map(|s| s.to_string()).collect(),
237 exclude: vec![],
238 }
239 }
240
241 #[test]
242 fn with_filters_rejects_unknown_column() {
243 let err = s()
244 .with_filters(excl(&["ghost"]), ColumnFilter::default())
245 .unwrap_err();
246 assert!(matches!(err, AppError::InvalidValue(_)));
247 }
248
249 #[test]
250 fn with_filters_rejects_include_and_exclude() {
251 let both = ColumnFilter {
252 include: vec!["Id".into()],
253 exclude: vec!["When".into()],
254 };
255 let err = s()
256 .with_filters(ColumnFilter::default(), both)
257 .unwrap_err();
258 assert!(matches!(err, AppError::InvalidValue(_)));
259 }
260
261 #[test]
262 fn projection_exclude_hides_column() {
263 let sch = s()
265 .with_filters(ColumnFilter::default(), excl(&["when"]))
266 .unwrap();
267 assert!(sch.is_visible("Id"));
268 assert!(!sch.is_visible("When"));
269 let visible: Vec<_> = sch.visible_columns().iter().map(|c| &c.name).collect();
270 assert_eq!(visible, vec!["Id"]);
271 assert!(matches!(
273 sch.find_visible("When").unwrap_err(),
274 AppError::UnknownColumn(_)
275 ));
276 }
277
278 #[test]
279 fn projection_include_is_an_allowlist() {
280 let sch = s()
281 .with_filters(ColumnFilter::default(), incl(&["Id"]))
282 .unwrap();
283 assert!(sch.is_visible("Id"));
284 assert!(!sch.is_visible("When"));
285 }
286
287 #[test]
288 fn predicate_denied_column_is_forbidden_but_visible() {
289 let sch = s()
290 .with_filters(excl(&["When"]), ColumnFilter::default())
291 .unwrap();
292 assert!(sch.find_visible("When").is_ok());
294 assert!(matches!(
296 sch.find_for_predicate("When").unwrap_err(),
297 AppError::Forbidden(_)
298 ));
299 assert!(sch.find_for_predicate("Id").is_ok());
300 }
301
302 #[test]
303 fn hidden_column_in_predicate_is_unknown_not_forbidden() {
304 let sch = s()
306 .with_filters(ColumnFilter::default(), excl(&["When"]))
307 .unwrap();
308 assert!(matches!(
309 sch.find_for_predicate("When").unwrap_err(),
310 AppError::UnknownColumn(_)
311 ));
312 }
313}
314