Skip to main content

krishiv_sql/
streaming_tvf.rs

1#![deny(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
2//! Streaming window table-valued functions: TUMBLE, HOP, SESSION.
3//!
4//! Rewrites Flink-SQL-style window TVF syntax into standard SQL that uses
5//! the existing `tumble_start`/`tumble_end`/`hop_start`/`hop_end`/`session_start`/
6//! `session_end` scalar UDFs registered by `window_functions.rs`.
7//!
8//! # Syntax (FROM clause)
9//!
10//! ```sql
11//! -- Tumbling window: each event appears in exactly one window
12//! SELECT key, window_start, window_end, COUNT(*)
13//! FROM TUMBLE(TABLE events, DESCRIPTOR(ts), 60000)
14//! GROUP BY key, window_start, window_end
15//!
16//! -- Hopping / sliding window: each event appears in (size/slide) windows
17//! SELECT key, window_start, window_end, COUNT(*)
18//! FROM HOP(TABLE events, DESCRIPTOR(ts), 30000, 60000)
19//! GROUP BY key, window_start, window_end
20//!
21//! -- Session window: gaps between events delimit window boundaries
22//! SELECT key, window_start, window_end, COUNT(*)
23//! FROM SESSION(TABLE events, DESCRIPTOR(ts), 5000)
24//! GROUP BY key, window_start, window_end
25//! ```
26//!
27//! # Interval expressions
28//!
29//! The size / slide / gap argument can be:
30//! - An integer literal (milliseconds): `60000`
31//! - A SQL interval string: `'1 minute'`, `'30 seconds'`, `'1 hour'` → converted to ms
32//!
33//! # Rewrite output
34//!
35//! ```sql
36//! -- TUMBLE → subquery with window_start / window_end columns:
37//! SELECT key, window_start, window_end, COUNT(*)
38//! FROM (
39//!   SELECT *, tumble_start(ts, 60000) AS window_start,
40//!             tumble_end(ts, 60000)   AS window_end
41//!   FROM events
42//! ) AS _tvf_window
43//! GROUP BY key, window_start, window_end
44//! ```
45
46/// Parse a quoted interval string to milliseconds.
47/// Supports `'N unit'` where unit is one of: millisecond(s), second(s),
48/// minute(s), hour(s), day(s).  Returns `None` if unparseable.
49fn interval_str_to_ms(s: &str) -> Option<i64> {
50    let inner = s.trim().trim_matches('\'').trim();
51    let mut parts = inner.splitn(2, ' ');
52    let n: i64 = parts.next()?.trim().parse().ok()?;
53    let unit = parts.next()?.trim().to_ascii_lowercase();
54    let ms = match unit.trim_end_matches('s') {
55        "millisecond" => n,
56        "second" => n * 1_000,
57        "minute" => n * 60_000,
58        "hour" => n * 3_600_000,
59        "day" => n * 86_400_000,
60        _ => return None,
61    };
62    Some(ms)
63}
64
65/// Convert a TVF interval argument (integer literal or quoted string) to a
66/// millisecond `i64` string suitable for embedding in SQL.
67fn normalise_interval_arg(arg: &str) -> String {
68    let trimmed = arg.trim();
69    // Already an integer literal.
70    if trimmed.parse::<i64>().is_ok() {
71        return trimmed.to_owned();
72    }
73    // Quoted interval string.
74    if let Some(ms) = interval_str_to_ms(trimmed) {
75        return ms.to_string();
76    }
77    // Return as-is and let DataFusion complain if invalid.
78    trimmed.to_owned()
79}
80
81/// State for the TVF argument scanner (handles nested parentheses).
82struct ArgScanner<'a> {
83    chars: std::iter::Peekable<std::str::CharIndices<'a>>,
84    src: &'a str,
85}
86
87impl<'a> ArgScanner<'a> {
88    fn new(src: &'a str) -> Self {
89        Self {
90            chars: src.char_indices().peekable(),
91            src,
92        }
93    }
94
95    /// Consume past leading whitespace.
96    fn skip_ws(&mut self) {
97        while self.chars.peek().map(|(_, c)| c.is_ascii_whitespace()) == Some(true) {
98            self.chars.next();
99        }
100    }
101
102    /// Read until a `,` or `)` at depth 0, handling nested parens and quotes.
103    /// Returns the argument string (trimmed) and the terminator char.
104    fn next_arg(&mut self) -> Option<(&'a str, char)> {
105        self.skip_ws();
106        let (start, _) = self.chars.peek().copied()?;
107        let mut depth = 0i32;
108        let mut in_quote = false;
109        let mut quote_char = '\0';
110        let mut end = start;
111
112        loop {
113            match self.chars.next() {
114                None => break,
115                Some((i, c)) => {
116                    end = i + c.len_utf8();
117                    if in_quote {
118                        if c == quote_char {
119                            in_quote = false;
120                        }
121                    } else {
122                        match c {
123                            '\'' | '"' => {
124                                in_quote = true;
125                                quote_char = c;
126                            }
127                            '(' => depth += 1,
128                            ')' => {
129                                if depth == 0 {
130                                    return Some((self.src[start..i].trim(), c));
131                                }
132                                depth -= 1;
133                            }
134                            ',' if depth == 0 => {
135                                return Some((self.src[start..i].trim(), c));
136                            }
137                            _ => {}
138                        }
139                    }
140                }
141            }
142        }
143        if end > start {
144            Some((self.src[start..end].trim(), '\0'))
145        } else {
146            None
147        }
148    }
149}
150
151/// Parsed form of a window TVF call.
152#[derive(Debug, PartialEq)]
153pub enum WindowTvf<'a> {
154    Tumble {
155        source: &'a str,
156        ts_col: &'a str,
157        size_ms: String,
158    },
159    Hop {
160        source: &'a str,
161        ts_col: &'a str,
162        slide_ms: String,
163        size_ms: String,
164    },
165    Session {
166        source: &'a str,
167        ts_col: &'a str,
168        gap_ms: String,
169    },
170}
171
172/// Try to parse `DESCRIPTOR(col_name)` and return `col_name`.
173fn parse_descriptor(s: &str) -> Option<&str> {
174    let s = s.trim();
175    let lower = s.to_ascii_lowercase();
176    let inner = lower.strip_prefix("descriptor(")?;
177    let inner = inner.strip_suffix(')')?;
178    // Map back to original-case span by offset.
179    let prefix_len = "descriptor(".len();
180    Some(s[prefix_len..prefix_len + inner.len()].trim())
181}
182
183/// Try to parse `TABLE name` and return `name`.
184fn parse_table_ref(s: &str) -> Option<&str> {
185    let s = s.trim();
186    let rest = s
187        .strip_prefix("TABLE ")
188        .or_else(|| s.strip_prefix("table "))
189        .or_else(|| {
190            let lower = s.to_ascii_lowercase();
191            if lower.starts_with("table ") || lower.starts_with("table\t") {
192                Some(&s[6..])
193            } else {
194                None
195            }
196        })?;
197    Some(rest.trim())
198}
199
200/// Scan `sql` for the first occurrence of a window TVF in a FROM clause and
201/// return `(pre, tvf, post)` where `pre + rewrite(tvf) + post` produces the
202/// final SQL.  Returns `None` if no TVF is found.
203pub fn find_window_tvf(sql: &str) -> Option<(usize, WindowTvf<'_>, usize)> {
204    let lower = sql.to_ascii_lowercase();
205
206    for kw in ["tumble", "hop", "session"] {
207        let mut search_start = 0;
208        while let Some(pos) = lower[search_start..].find(kw) {
209            let abs = search_start + pos;
210            // Must be preceded by whitespace, comma, or start-of-string (not part of identifier).
211            let preceded_ok = abs == 0
212                || sql[..abs]
213                    .chars()
214                    .last()
215                    .map(|c| !c.is_alphanumeric() && c != '_')
216                    .unwrap_or(true);
217            // Must be followed by '(' optionally with whitespace.
218            let after = abs + kw.len();
219            let followed_ok = sql[after..].trim_start().starts_with('(');
220
221            if preceded_ok && followed_ok {
222                // Find the opening paren.
223                let paren_pos = after + sql[after..].find('(')?;
224                let inner_start = paren_pos + 1;
225                let mut scanner = ArgScanner::new(&sql[inner_start..]);
226
227                let tvf = match kw {
228                    "tumble" => {
229                        let (a0, _) = scanner.next_arg()?;
230                        let (a1, _) = scanner.next_arg()?;
231                        let (a2, term) = scanner.next_arg()?;
232                        if term != ')' && term != ',' {
233                            search_start = abs + 1;
234                            continue;
235                        }
236                        let source = parse_table_ref(a0)?;
237                        let ts_col = parse_descriptor(a1)?;
238                        let size_ms = normalise_interval_arg(a2);
239                        WindowTvf::Tumble {
240                            source,
241                            ts_col,
242                            size_ms,
243                        }
244                    }
245                    "hop" => {
246                        let (a0, _) = scanner.next_arg()?;
247                        let (a1, _) = scanner.next_arg()?;
248                        let (a2, _) = scanner.next_arg()?;
249                        let (a3, term) = scanner.next_arg()?;
250                        if term != ')' && term != ',' {
251                            search_start = abs + 1;
252                            continue;
253                        }
254                        let source = parse_table_ref(a0)?;
255                        let ts_col = parse_descriptor(a1)?;
256                        let slide_ms = normalise_interval_arg(a2);
257                        let size_ms = normalise_interval_arg(a3);
258                        WindowTvf::Hop {
259                            source,
260                            ts_col,
261                            slide_ms,
262                            size_ms,
263                        }
264                    }
265                    "session" => {
266                        let (a0, _) = scanner.next_arg()?;
267                        let (a1, _) = scanner.next_arg()?;
268                        let (a2, term) = scanner.next_arg()?;
269                        if term != ')' && term != ',' {
270                            search_start = abs + 1;
271                            continue;
272                        }
273                        let source = parse_table_ref(a0)?;
274                        let ts_col = parse_descriptor(a1)?;
275                        let gap_ms = normalise_interval_arg(a2);
276                        WindowTvf::Session {
277                            source,
278                            ts_col,
279                            gap_ms,
280                        }
281                    }
282                    _ => unreachable!(),
283                };
284
285                // The end of the TVF call is right after the closing ')' consumed
286                // by scanner (scanner consumed inner content; paren_pos is the '(').
287                // We need to find where in the original sql the scanner stopped.
288                let consumed = scanner
289                    .chars
290                    .next()
291                    .map(|(i, _)| i)
292                    .unwrap_or(sql[inner_start..].len());
293                let tvf_end = inner_start + consumed;
294                return Some((abs, tvf, tvf_end));
295            }
296            search_start = abs + 1;
297        }
298    }
299    None
300}
301
302/// Emit SQL for a window TVF call.
303fn emit_tvf_subquery(tvf: &WindowTvf<'_>) -> String {
304    match tvf {
305        WindowTvf::Tumble {
306            source,
307            ts_col,
308            size_ms,
309        } => format!(
310            "(SELECT *, tumble_start({ts_col}, {size_ms}) AS window_start, \
311             tumble_end({ts_col}, {size_ms}) AS window_end FROM {source}) AS _tvf_window"
312        ),
313        WindowTvf::Hop {
314            source,
315            ts_col,
316            slide_ms,
317            size_ms,
318        } => format!(
319            "(SELECT *, hop_start({ts_col}, {slide_ms}, {size_ms}) AS window_start, \
320             hop_end({ts_col}, {slide_ms}, {size_ms}) AS window_end FROM {source}) AS _tvf_window"
321        ),
322        WindowTvf::Session {
323            source,
324            ts_col,
325            gap_ms,
326        } => format!(
327            "(SELECT *, session_start({ts_col}, {gap_ms}) AS window_start, \
328             session_end({ts_col}, {gap_ms}) AS window_end FROM {source}) AS _tvf_window"
329        ),
330    }
331}
332
333/// Rewrite all window TVF calls in `sql` to subquery form.
334/// Iterates until no more TVFs are found (handles multiple TVFs in one query).
335pub fn rewrite_window_tvfs(sql: &str) -> String {
336    let mut current = sql.to_owned();
337    // Limit iterations to avoid infinite loop on malformed input.
338    for _ in 0..16 {
339        match find_window_tvf(&current) {
340            None => break,
341            Some((start, tvf, end)) => {
342                let subq = emit_tvf_subquery(&tvf);
343                let mut next = current[..start].to_owned();
344                next.push_str(&subq);
345                next.push_str(&current[end..]);
346                current = next;
347            }
348        }
349    }
350    current
351}
352
353#[cfg(test)]
354#[allow(clippy::unwrap_used)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn interval_str_seconds() {
360        assert_eq!(interval_str_to_ms("'30 seconds'"), Some(30_000));
361        assert_eq!(interval_str_to_ms("'1 minute'"), Some(60_000));
362        assert_eq!(interval_str_to_ms("'2 hours'"), Some(7_200_000));
363        assert_eq!(interval_str_to_ms("'1 day'"), Some(86_400_000));
364    }
365
366    #[test]
367    fn tumble_rewrite_integer_interval() {
368        let sql = "SELECT key, COUNT(*) FROM TUMBLE(TABLE events, DESCRIPTOR(ts), 60000) GROUP BY key, window_start, window_end";
369        let out = rewrite_window_tvfs(sql);
370        assert!(
371            out.contains("tumble_start(ts, 60000) AS window_start"),
372            "{out}"
373        );
374        assert!(out.contains("tumble_end(ts, 60000) AS window_end"), "{out}");
375        assert!(out.contains("FROM events"), "{out}");
376        assert!(out.contains("_tvf_window"), "{out}");
377    }
378
379    #[test]
380    fn tumble_rewrite_interval_string() {
381        let sql = "SELECT key FROM TUMBLE(TABLE clicks, DESCRIPTOR(event_ts), '1 minute') GROUP BY key, window_start, window_end";
382        let out = rewrite_window_tvfs(sql);
383        assert!(out.contains("tumble_start(event_ts, 60000)"), "{out}");
384    }
385
386    #[test]
387    fn hop_rewrite() {
388        let sql = "SELECT key FROM HOP(TABLE events, DESCRIPTOR(ts), 30000, 60000) GROUP BY key, window_start, window_end";
389        let out = rewrite_window_tvfs(sql);
390        assert!(
391            out.contains("hop_start(ts, 30000, 60000) AS window_start"),
392            "{out}"
393        );
394        assert!(
395            out.contains("hop_end(ts, 30000, 60000) AS window_end"),
396            "{out}"
397        );
398    }
399
400    #[test]
401    fn session_rewrite() {
402        let sql = "SELECT key FROM SESSION(TABLE events, DESCRIPTOR(ts), 5000) GROUP BY key, window_start, window_end";
403        let out = rewrite_window_tvfs(sql);
404        assert!(
405            out.contains("session_start(ts, 5000) AS window_start"),
406            "{out}"
407        );
408        assert!(out.contains("session_end(ts, 5000) AS window_end"), "{out}");
409    }
410
411    #[test]
412    fn no_tvf_is_identity() {
413        let sql = "SELECT * FROM events WHERE ts > 0";
414        assert_eq!(rewrite_window_tvfs(sql), sql);
415    }
416
417    #[test]
418    fn lowercase_keywords_work() {
419        let sql = "SELECT key FROM tumble(TABLE events, DESCRIPTOR(ts), 60000) GROUP BY key";
420        let out = rewrite_window_tvfs(sql);
421        assert!(out.contains("tumble_start"), "{out}");
422    }
423
424    #[test]
425    fn interval_normalisation() {
426        assert_eq!(normalise_interval_arg("60000"), "60000");
427        assert_eq!(normalise_interval_arg("'1 minute'"), "60000");
428        assert_eq!(normalise_interval_arg("'30 seconds'"), "30000");
429    }
430}