Skip to main content

azof_datafusion/
parse.rs

1use azof::AsOf;
2use azof::AsOf::{Current, EventTime};
3use chrono::{DateTime, Utc};
4use datafusion::logical_expr::sqlparser::ast::{
5    Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, ObjectName,
6    TableFactor, TableVersion, Value, VisitMut, VisitorMut,
7};
8use datafusion::sql::parser::Statement;
9use std::ops::ControlFlow;
10
11pub struct VersionedTable {
12    pub name: ObjectName,
13    pub versioned_name: ObjectName,
14    pub as_of: AsOf,
15}
16
17pub fn rewrite_and_extract_tables(
18    statement: &mut Statement,
19) -> Result<Vec<VersionedTable>, Box<dyn std::error::Error>> {
20    let mut visitor = RewriteVersionIntoTableIdent { relations: vec![] };
21    match statement {
22        Statement::Statement(s) => {
23            if let ControlFlow::Break(err) = s.visit(&mut visitor) {
24                Err(err)
25            } else {
26                Ok(visitor.relations)
27            }
28        }
29        _ => Ok(visitor.relations),
30    }
31}
32
33struct RewriteVersionIntoTableIdent {
34    relations: Vec<VersionedTable>,
35}
36impl VisitorMut for RewriteVersionIntoTableIdent {
37    type Break = Box<dyn std::error::Error>;
38    fn post_visit_table_factor(
39        &mut self,
40        table_factor: &mut TableFactor,
41    ) -> ControlFlow<Self::Break> {
42        match rewrite_and_extract_versioned_tables(table_factor) {
43            Ok(Some(table)) => {
44                self.relations.push(table);
45                ControlFlow::Continue(())
46            }
47            Err(e) => ControlFlow::Break(e),
48            _ => ControlFlow::Continue(()),
49        }
50    }
51}
52
53fn rewrite_and_extract_versioned_tables(
54    table_factor: &mut TableFactor,
55) -> Result<Option<VersionedTable>, Box<dyn std::error::Error>> {
56    if let TableFactor::Table { name, version, .. } = table_factor {
57        let original_name = name.clone();
58        let as_of: Result<AsOf, Box<dyn std::error::Error>> = {
59            if let Some(TableVersion::ForSystemTimeAsOf(Expr::Value(Value::SingleQuotedString(
60                str,
61            )))) = version
62            {
63                let event_time =
64                    DateTime::parse_from_rfc3339(str).map(|dt| dt.with_timezone(&Utc))?;
65                let ObjectName(idents) = name;
66                let mut new_idents: Vec<Ident> = Vec::with_capacity(idents.len());
67
68                new_idents.extend(idents.iter().take(idents.len() - 1).cloned());
69
70                if let Some(last) = idents.last() {
71                    new_idents.push(Ident {
72                        value: format!("{}__{}", last.value, event_time.timestamp_millis()),
73                        quote_style: last.quote_style,
74                        span: last.span,
75                    });
76
77                    *name = ObjectName(new_idents);
78                    *version = None;
79                }
80                Ok(EventTime(event_time))
81            } else if let Some(TableVersion::Function(Expr::Function(func))) = version {
82                if func.name.0.len() == 1 && func.name.0[0].value.to_uppercase() == "AT" {
83                    let timestamp_value = extract_timestamp_from_at_function(func)?;
84                    let event_time = DateTime::parse_from_rfc3339(&timestamp_value)
85                        .map(|dt| dt.with_timezone(&Utc))?;
86
87                    let ObjectName(idents) = name;
88                    let mut new_idents: Vec<Ident> = Vec::with_capacity(idents.len());
89
90                    new_idents.extend(idents.iter().take(idents.len() - 1).cloned());
91
92                    if let Some(last) = idents.last() {
93                        new_idents.push(Ident {
94                            value: format!("{}__{}", last.value, event_time.timestamp_millis()),
95                            quote_style: last.quote_style,
96                            span: last.span,
97                        });
98
99                        *name = ObjectName(new_idents);
100                        *version = None;
101                    }
102                    Ok(EventTime(event_time))
103                } else {
104                    Ok(Current)
105                }
106            } else {
107                Ok(Current)
108            }
109        };
110
111        return Ok(Some(VersionedTable {
112            name: original_name,
113            versioned_name: name.clone(),
114            as_of: as_of?,
115        }));
116    }
117    Ok(None)
118}
119
120fn extract_timestamp_from_at_function(
121    func: &Function,
122) -> Result<String, Box<dyn std::error::Error>> {
123    if let FunctionArguments::List(list) = &func.args {
124        for arg in &list.args {
125            match arg {
126                FunctionArg::Unnamed(expr) => {
127                    if let FunctionArgExpr::Expr(Expr::Value(Value::SingleQuotedString(
128                        timestamp,
129                    ))) = expr
130                    {
131                        return Ok(timestamp.clone());
132                    }
133                }
134                FunctionArg::Named {
135                    name,
136                    arg,
137                    operator: _,
138                } => {
139                    if name.value.to_uppercase() == "TIMESTAMP" {
140                        if let FunctionArgExpr::Expr(Expr::Value(Value::SingleQuotedString(
141                            timestamp,
142                        ))) = arg
143                        {
144                            return Ok(timestamp.clone());
145                        }
146                    }
147                }
148                FunctionArg::ExprNamed {
149                    name,
150                    arg,
151                    operator: _,
152                } => {
153                    if let Expr::Identifier(ident) = name {
154                        if ident.value.to_uppercase() == "TIMESTAMP" {
155                            if let FunctionArgExpr::Expr(Expr::Value(Value::SingleQuotedString(
156                                timestamp,
157                            ))) = arg
158                            {
159                                return Ok(timestamp.clone());
160                            }
161                        }
162                    }
163                }
164            }
165        }
166    }
167    Err("No valid timestamp found in AT function".into())
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use chrono::{TimeZone, Utc};
174    use datafusion::prelude::SessionContext;
175
176    #[test]
177    fn inserts_version_into_table_ident() {
178        let ctx = SessionContext::new();
179        let mut stmt = ctx
180            .state()
181            .sql_to_statement(
182                "SELECT * FROM tbl FOR SYSTEM_TIME AS OF '2019-01-17T00:00:00.000Z'",
183                "snowflake",
184            )
185            .unwrap();
186
187        let tables = rewrite_and_extract_tables(&mut stmt).unwrap();
188        assert_eq!(tables.len(), 1);
189
190        assert_eq!(tables[0].name.to_string(), "tbl".to_string());
191
192        assert_eq!(
193            tables[0].versioned_name.to_string(),
194            "tbl__1547683200000".to_string()
195        );
196
197        assert_eq!(
198            tables[0].as_of,
199            EventTime(Utc.with_ymd_and_hms(2019, 1, 17, 0, 0, 0).unwrap()),
200        );
201    }
202
203    #[test]
204    fn handles_at_function_with_unnamed_timestamp() {
205        let ctx = SessionContext::new();
206        let mut stmt = ctx
207            .state()
208            .sql_to_statement(
209                "SELECT * FROM tbl AT('2019-01-17T00:00:00.000Z')",
210                "snowflake",
211            )
212            .unwrap();
213
214        let tables = rewrite_and_extract_tables(&mut stmt).unwrap();
215        assert_eq!(tables.len(), 1);
216
217        assert_eq!(tables[0].name.to_string(), "tbl".to_string());
218
219        assert_eq!(
220            tables[0].versioned_name.to_string(),
221            "tbl__1547683200000".to_string()
222        );
223
224        assert_eq!(
225            tables[0].as_of,
226            EventTime(Utc.with_ymd_and_hms(2019, 1, 17, 0, 0, 0).unwrap()),
227        );
228    }
229
230    #[test]
231    fn handles_at_function_with_named_timestamp() {
232        let ctx = SessionContext::new();
233        let mut stmt = ctx
234            .state()
235            .sql_to_statement(
236                "SELECT * FROM tbl AT(TIMESTAMP => '2019-01-17T00:00:00.000Z')",
237                "snowflake",
238            )
239            .unwrap();
240
241        let tables = rewrite_and_extract_tables(&mut stmt).unwrap();
242        assert_eq!(tables.len(), 1);
243
244        assert_eq!(tables[0].name.to_string(), "tbl".to_string());
245
246        assert_eq!(
247            tables[0].versioned_name.to_string(),
248            "tbl__1547683200000".to_string()
249        );
250
251        assert_eq!(
252            tables[0].as_of,
253            EventTime(Utc.with_ymd_and_hms(2019, 1, 17, 0, 0, 0).unwrap()),
254        );
255    }
256
257    #[test]
258    fn returns_error_on_invalid_at_timestamp() {
259        let ctx = SessionContext::new();
260        let mut stmt = ctx
261            .state()
262            .sql_to_statement("SELECT * FROM tbl AT('not_a_date')", "snowflake")
263            .unwrap();
264
265        let result = rewrite_and_extract_tables(&mut stmt);
266
267        assert!(result.is_err());
268    }
269
270    #[test]
271    fn returns_error_on_non_convertible_string() {
272        let ctx = SessionContext::new();
273        let mut stmt = ctx
274            .state()
275            .sql_to_statement(
276                "SELECT * FROM tbl FOR SYSTEM_TIME AS OF 'not_a_date'",
277                "snowflake",
278            )
279            .unwrap();
280
281        let result = rewrite_and_extract_tables(&mut stmt);
282
283        assert!(result.is_err());
284    }
285}