1use azof::AsOf;
2use azof::AsOf::{Current, EventTime};
3use chrono::{DateTime, Utc};
4use datafusion::logical_expr::sqlparser::ast::{
5 Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, ObjectName,
6 TableFactor, TableVersion, Value, VisitMut, VisitorMut,
7};
8use datafusion::sql::parser::Statement;
9use std::ops::ControlFlow;
10
11pub struct VersionedTable {
12 pub name: ObjectName,
13 pub versioned_name: ObjectName,
14 pub as_of: AsOf,
15}
16
17pub fn rewrite_and_extract_tables(
18 statement: &mut Statement,
19) -> Result<Vec<VersionedTable>, Box<dyn std::error::Error>> {
20 let mut visitor = RewriteVersionIntoTableIdent { relations: vec![] };
21 match statement {
22 Statement::Statement(s) => {
23 if let ControlFlow::Break(err) = s.visit(&mut visitor) {
24 Err(err)
25 } else {
26 Ok(visitor.relations)
27 }
28 }
29 _ => Ok(visitor.relations),
30 }
31}
32
33struct RewriteVersionIntoTableIdent {
34 relations: Vec<VersionedTable>,
35}
36impl VisitorMut for RewriteVersionIntoTableIdent {
37 type Break = Box<dyn std::error::Error>;
38 fn post_visit_table_factor(
39 &mut self,
40 table_factor: &mut TableFactor,
41 ) -> ControlFlow<Self::Break> {
42 match rewrite_and_extract_versioned_tables(table_factor) {
43 Ok(Some(table)) => {
44 self.relations.push(table);
45 ControlFlow::Continue(())
46 }
47 Err(e) => ControlFlow::Break(e),
48 _ => ControlFlow::Continue(()),
49 }
50 }
51}
52
53fn rewrite_and_extract_versioned_tables(
54 table_factor: &mut TableFactor,
55) -> Result<Option<VersionedTable>, Box<dyn std::error::Error>> {
56 if let TableFactor::Table { name, version, .. } = table_factor {
57 let original_name = name.clone();
58 let as_of: Result<AsOf, Box<dyn std::error::Error>> = {
59 if let Some(TableVersion::ForSystemTimeAsOf(Expr::Value(Value::SingleQuotedString(
60 str,
61 )))) = version
62 {
63 let event_time =
64 DateTime::parse_from_rfc3339(str).map(|dt| dt.with_timezone(&Utc))?;
65 let ObjectName(idents) = name;
66 let mut new_idents: Vec<Ident> = Vec::with_capacity(idents.len());
67
68 new_idents.extend(idents.iter().take(idents.len() - 1).cloned());
69
70 if let Some(last) = idents.last() {
71 new_idents.push(Ident {
72 value: format!("{}__{}", last.value, event_time.timestamp_millis()),
73 quote_style: last.quote_style,
74 span: last.span,
75 });
76
77 *name = ObjectName(new_idents);
78 *version = None;
79 }
80 Ok(EventTime(event_time))
81 } else if let Some(TableVersion::Function(Expr::Function(func))) = version {
82 if func.name.0.len() == 1 && func.name.0[0].value.to_uppercase() == "AT" {
83 let timestamp_value = extract_timestamp_from_at_function(func)?;
84 let event_time = DateTime::parse_from_rfc3339(×tamp_value)
85 .map(|dt| dt.with_timezone(&Utc))?;
86
87 let ObjectName(idents) = name;
88 let mut new_idents: Vec<Ident> = Vec::with_capacity(idents.len());
89
90 new_idents.extend(idents.iter().take(idents.len() - 1).cloned());
91
92 if let Some(last) = idents.last() {
93 new_idents.push(Ident {
94 value: format!("{}__{}", last.value, event_time.timestamp_millis()),
95 quote_style: last.quote_style,
96 span: last.span,
97 });
98
99 *name = ObjectName(new_idents);
100 *version = None;
101 }
102 Ok(EventTime(event_time))
103 } else {
104 Ok(Current)
105 }
106 } else {
107 Ok(Current)
108 }
109 };
110
111 return Ok(Some(VersionedTable {
112 name: original_name,
113 versioned_name: name.clone(),
114 as_of: as_of?,
115 }));
116 }
117 Ok(None)
118}
119
120fn extract_timestamp_from_at_function(
121 func: &Function,
122) -> Result<String, Box<dyn std::error::Error>> {
123 if let FunctionArguments::List(list) = &func.args {
124 for arg in &list.args {
125 match arg {
126 FunctionArg::Unnamed(expr) => {
127 if let FunctionArgExpr::Expr(Expr::Value(Value::SingleQuotedString(
128 timestamp,
129 ))) = expr
130 {
131 return Ok(timestamp.clone());
132 }
133 }
134 FunctionArg::Named {
135 name,
136 arg,
137 operator: _,
138 } => {
139 if name.value.to_uppercase() == "TIMESTAMP" {
140 if let FunctionArgExpr::Expr(Expr::Value(Value::SingleQuotedString(
141 timestamp,
142 ))) = arg
143 {
144 return Ok(timestamp.clone());
145 }
146 }
147 }
148 FunctionArg::ExprNamed {
149 name,
150 arg,
151 operator: _,
152 } => {
153 if let Expr::Identifier(ident) = name {
154 if ident.value.to_uppercase() == "TIMESTAMP" {
155 if let FunctionArgExpr::Expr(Expr::Value(Value::SingleQuotedString(
156 timestamp,
157 ))) = arg
158 {
159 return Ok(timestamp.clone());
160 }
161 }
162 }
163 }
164 }
165 }
166 }
167 Err("No valid timestamp found in AT function".into())
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use chrono::{TimeZone, Utc};
174 use datafusion::prelude::SessionContext;
175
176 #[test]
177 fn inserts_version_into_table_ident() {
178 let ctx = SessionContext::new();
179 let mut stmt = ctx
180 .state()
181 .sql_to_statement(
182 "SELECT * FROM tbl FOR SYSTEM_TIME AS OF '2019-01-17T00:00:00.000Z'",
183 "snowflake",
184 )
185 .unwrap();
186
187 let tables = rewrite_and_extract_tables(&mut stmt).unwrap();
188 assert_eq!(tables.len(), 1);
189
190 assert_eq!(tables[0].name.to_string(), "tbl".to_string());
191
192 assert_eq!(
193 tables[0].versioned_name.to_string(),
194 "tbl__1547683200000".to_string()
195 );
196
197 assert_eq!(
198 tables[0].as_of,
199 EventTime(Utc.with_ymd_and_hms(2019, 1, 17, 0, 0, 0).unwrap()),
200 );
201 }
202
203 #[test]
204 fn handles_at_function_with_unnamed_timestamp() {
205 let ctx = SessionContext::new();
206 let mut stmt = ctx
207 .state()
208 .sql_to_statement(
209 "SELECT * FROM tbl AT('2019-01-17T00:00:00.000Z')",
210 "snowflake",
211 )
212 .unwrap();
213
214 let tables = rewrite_and_extract_tables(&mut stmt).unwrap();
215 assert_eq!(tables.len(), 1);
216
217 assert_eq!(tables[0].name.to_string(), "tbl".to_string());
218
219 assert_eq!(
220 tables[0].versioned_name.to_string(),
221 "tbl__1547683200000".to_string()
222 );
223
224 assert_eq!(
225 tables[0].as_of,
226 EventTime(Utc.with_ymd_and_hms(2019, 1, 17, 0, 0, 0).unwrap()),
227 );
228 }
229
230 #[test]
231 fn handles_at_function_with_named_timestamp() {
232 let ctx = SessionContext::new();
233 let mut stmt = ctx
234 .state()
235 .sql_to_statement(
236 "SELECT * FROM tbl AT(TIMESTAMP => '2019-01-17T00:00:00.000Z')",
237 "snowflake",
238 )
239 .unwrap();
240
241 let tables = rewrite_and_extract_tables(&mut stmt).unwrap();
242 assert_eq!(tables.len(), 1);
243
244 assert_eq!(tables[0].name.to_string(), "tbl".to_string());
245
246 assert_eq!(
247 tables[0].versioned_name.to_string(),
248 "tbl__1547683200000".to_string()
249 );
250
251 assert_eq!(
252 tables[0].as_of,
253 EventTime(Utc.with_ymd_and_hms(2019, 1, 17, 0, 0, 0).unwrap()),
254 );
255 }
256
257 #[test]
258 fn returns_error_on_invalid_at_timestamp() {
259 let ctx = SessionContext::new();
260 let mut stmt = ctx
261 .state()
262 .sql_to_statement("SELECT * FROM tbl AT('not_a_date')", "snowflake")
263 .unwrap();
264
265 let result = rewrite_and_extract_tables(&mut stmt);
266
267 assert!(result.is_err());
268 }
269
270 #[test]
271 fn returns_error_on_non_convertible_string() {
272 let ctx = SessionContext::new();
273 let mut stmt = ctx
274 .state()
275 .sql_to_statement(
276 "SELECT * FROM tbl FOR SYSTEM_TIME AS OF 'not_a_date'",
277 "snowflake",
278 )
279 .unwrap();
280
281 let result = rewrite_and_extract_tables(&mut stmt);
282
283 assert!(result.is_err());
284 }
285}