Skip to main content

krishiv_sql/
streaming_window_plan.rs

1//! Compile a windowed streaming SQL query into a [`WindowExecutionSpec`].
2//!
3//! Supports the canonical keyed windowed-aggregation shape:
4//!
5//! ```sql
6//! SELECT key, AGG(col) AS out [, ...]
7//! FROM TUMBLE(TABLE src, DESCRIPTOR(ts), <size>)   -- or HOP / SESSION
8//! GROUP BY key, window_start, window_end
9//! ```
10//!
11//! The streaming engine uses the resulting [`WindowExecutionSpec`] to drive the
12//! dataflow `ContinuousWindowExecutor`. The window operator computes the
13//! aggregation itself, so the SELECT/GROUP BY is only mined for the grouping
14//! key column and the aggregate list — the rest of the query shape is the
15//! window TVF, which [`find_window_tvf`] already parses.
16
17use std::collections::HashMap;
18
19use datafusion::sql::sqlparser::ast::{
20    Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, SelectItem, SetExpr, Statement,
21};
22use datafusion::sql::sqlparser::dialect::GenericDialect;
23use datafusion::sql::sqlparser::parser::Parser;
24use krishiv_plan::window::{WindowAgg, WindowAggKind, WindowExecutionSpec, WindowKind};
25
26use crate::streaming_tvf::{WindowTvf, find_window_tvf, rewrite_window_tvfs};
27use crate::{SqlError, SqlResult};
28
29/// A compiled windowed streaming plan: the operator spec plus the name of the
30/// source table the window reads from.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct StreamingWindowPlan {
33    /// The keyed-window operator specification.
34    pub spec: WindowExecutionSpec,
35    /// The source table the window TVF reads from.
36    pub source: String,
37}
38
39fn unsupported(msg: impl Into<String>) -> SqlError {
40    SqlError::Unsupported {
41        feature: msg.into(),
42    }
43}
44
45fn parse_ms(raw: &str) -> SqlResult<u64> {
46    raw.trim().parse::<u64>().map_err(|_| {
47        unsupported(format!(
48            "window interval '{raw}' is not a millisecond count"
49        ))
50    })
51}
52
53/// Returns `true` when `sql` contains a TUMBLE/HOP/SESSION window TVF.
54pub fn is_windowed_streaming_sql(sql: &str) -> bool {
55    find_window_tvf(sql).is_some()
56}
57
58/// Compile a windowed streaming SQL query into a [`StreamingWindowPlan`].
59///
60/// Returns [`SqlError::Unsupported`] when the query is not a recognised keyed
61/// windowed aggregation.
62pub fn compile_streaming_window_sql(sql: &str) -> SqlResult<StreamingWindowPlan> {
63    let (_, tvf, _) = find_window_tvf(sql)
64        .ok_or_else(|| unsupported("query has no TUMBLE/HOP/SESSION window"))?;
65
66    let (source, event_time_column, window_kind, window_size_ms, slide_ms, session_gap_ms) =
67        match &tvf {
68            WindowTvf::Tumble {
69                source,
70                ts_col,
71                size_ms,
72            } => (
73                (*source).to_string(),
74                (*ts_col).to_string(),
75                WindowKind::Tumbling,
76                parse_ms(size_ms)?,
77                None,
78                None,
79            ),
80            WindowTvf::Hop {
81                source,
82                ts_col,
83                slide_ms,
84                size_ms,
85            } => (
86                (*source).to_string(),
87                (*ts_col).to_string(),
88                WindowKind::Sliding,
89                parse_ms(size_ms)?,
90                Some(parse_ms(slide_ms)?),
91                None,
92            ),
93            WindowTvf::Session {
94                source,
95                ts_col,
96                gap_ms,
97            } => {
98                let gap = parse_ms(gap_ms)?;
99                (
100                    (*source).to_string(),
101                    (*ts_col).to_string(),
102                    WindowKind::Session,
103                    gap,
104                    None,
105                    Some(gap),
106                )
107            }
108        };
109
110    // Mine the SELECT projection for the key column and aggregates. Rewrite the
111    // TVF to a plain subquery first so the parser accepts the SQL.
112    let rewritten = rewrite_window_tvfs(sql);
113    let (key_column, agg_exprs) = extract_key_and_aggs(&rewritten)?;
114
115    let spec = WindowExecutionSpec {
116        key_column,
117        key_column_type: String::from("utf8"),
118        event_time_column,
119        watermark_lag_ms: 0,
120        window_kind,
121        window_size_ms,
122        slide_ms,
123        session_gap_ms,
124        agg_exprs,
125        state_ttl_ms: None,
126        allowed_lateness_ms: None,
127        source_watermark_lags: HashMap::new(),
128        source_id_column: None,
129        window_timezone: None,
130    };
131    Ok(StreamingWindowPlan { spec, source })
132}
133
134const WINDOW_BOUNDARY_COLS: [&str; 2] = ["window_start", "window_end"];
135
136fn extract_key_and_aggs(sql: &str) -> SqlResult<(String, Vec<WindowAgg>)> {
137    let dialect = GenericDialect {};
138    let stmts = Parser::parse_sql(&dialect, sql)
139        .map_err(|e| unsupported(format!("streaming window query parse error: {e}")))?;
140    let query = stmts
141        .into_iter()
142        .find_map(|s| match s {
143            Statement::Query(q) => Some(q),
144            _ => None,
145        })
146        .ok_or_else(|| unsupported("streaming window query must be a SELECT"))?;
147
148    let SetExpr::Select(select) = query.body.as_ref() else {
149        return Err(unsupported("streaming window query must be a plain SELECT"));
150    };
151
152    let mut key_column: Option<String> = None;
153    let mut aggs: Vec<WindowAgg> = Vec::new();
154
155    for item in &select.projection {
156        let (expr, alias) = match item {
157            SelectItem::UnnamedExpr(e) => (e, None),
158            SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
159            _ => continue,
160        };
161        match expr {
162            Expr::Function(f) => aggs.push(function_to_agg(f, alias)?),
163            Expr::Identifier(id) => maybe_set_key(&mut key_column, &id.value),
164            Expr::CompoundIdentifier(parts) => {
165                if let Some(last) = parts.last() {
166                    maybe_set_key(&mut key_column, &last.value);
167                }
168            }
169            _ => continue,
170        }
171    }
172
173    let key_column = key_column.ok_or_else(|| {
174        unsupported("streaming window query needs a grouping key column in the SELECT list")
175    })?;
176    if aggs.is_empty() {
177        aggs.push(WindowAgg::count("count"));
178    }
179    Ok((key_column, aggs))
180}
181
182fn maybe_set_key(key: &mut Option<String>, name: &str) {
183    if key.is_none() && !WINDOW_BOUNDARY_COLS.contains(&name) {
184        *key = Some(name.to_string());
185    }
186}
187
188fn function_to_agg(f: &Function, alias: Option<String>) -> SqlResult<WindowAgg> {
189    let fname = f.name.to_string().to_ascii_lowercase();
190    let kind = match fname.as_str() {
191        "count" => WindowAggKind::Count,
192        "sum" => WindowAggKind::Sum,
193        "min" => WindowAggKind::Min,
194        "max" => WindowAggKind::Max,
195        "avg" => WindowAggKind::Avg,
196        "stddev" | "stddev_samp" => WindowAggKind::Stddev,
197        other => {
198            return Err(unsupported(format!(
199                "aggregate '{other}' is not supported in streaming windows; \
200                 use count/sum/min/max/avg/stddev"
201            )));
202        }
203    };
204    let input_column = first_arg_column(f);
205    let output_column = alias.unwrap_or_else(|| match &input_column {
206        Some(col) => format!("{fname}_{col}"),
207        None => fname.clone(),
208    });
209    Ok(WindowAgg {
210        kind,
211        input_column: input_column.unwrap_or_default(),
212        output_column,
213    })
214}
215
216fn first_arg_column(f: &Function) -> Option<String> {
217    let FunctionArguments::List(list) = &f.args else {
218        return None;
219    };
220    for fa in &list.args {
221        let expr = match fa {
222            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
223            FunctionArg::Named {
224                arg: FunctionArgExpr::Expr(e),
225                ..
226            } => Some(e),
227            _ => None,
228        };
229        match expr {
230            Some(Expr::Identifier(id)) => return Some(id.value.clone()),
231            Some(Expr::CompoundIdentifier(parts)) => return parts.last().map(|p| p.value.clone()),
232            _ => {}
233        }
234    }
235    None
236}
237
238#[cfg(test)]
239mod tests {
240    #![allow(clippy::unwrap_used)]
241
242    use super::*;
243
244    #[test]
245    fn compiles_tumbling_window() {
246        let sql = "SELECT user_id, SUM(amount) AS total \
247                   FROM TUMBLE(TABLE events, DESCRIPTOR(ts), 60000) \
248                   GROUP BY user_id, window_start, window_end";
249        let plan = compile_streaming_window_sql(sql).unwrap();
250        assert_eq!(plan.source, "events");
251        assert_eq!(plan.spec.window_kind, WindowKind::Tumbling);
252        assert_eq!(plan.spec.window_size_ms, 60000);
253        assert_eq!(plan.spec.event_time_column, "ts");
254        assert_eq!(plan.spec.key_column, "user_id");
255        assert_eq!(plan.spec.agg_exprs.len(), 1);
256        assert_eq!(plan.spec.agg_exprs[0].kind, WindowAggKind::Sum);
257        assert_eq!(plan.spec.agg_exprs[0].input_column, "amount");
258        assert_eq!(plan.spec.agg_exprs[0].output_column, "total");
259    }
260
261    #[test]
262    fn compiles_tumbling_window_with_stddev() {
263        let sql = "SELECT k, STDDEV(v) AS spread \
264                   FROM TUMBLE(TABLE m, DESCRIPTOR(ts), 60000) \
265                   GROUP BY k, window_start, window_end";
266        let plan = compile_streaming_window_sql(sql).unwrap();
267        assert_eq!(plan.spec.agg_exprs[0].kind, WindowAggKind::Stddev);
268        assert_eq!(plan.spec.agg_exprs[0].input_column, "v");
269        assert_eq!(plan.spec.agg_exprs[0].output_column, "spread");
270    }
271
272    #[test]
273    fn compiles_hop_window_with_slide() {
274        let sql = "SELECT k, COUNT(*) AS c \
275                   FROM HOP(TABLE clicks, DESCRIPTOR(ts), 30000, 60000) \
276                   GROUP BY k, window_start, window_end";
277        let plan = compile_streaming_window_sql(sql).unwrap();
278        assert_eq!(plan.spec.window_kind, WindowKind::Sliding);
279        assert_eq!(plan.spec.window_size_ms, 60000);
280        assert_eq!(plan.spec.slide_ms, Some(30000));
281        assert_eq!(plan.spec.agg_exprs[0].kind, WindowAggKind::Count);
282    }
283
284    #[test]
285    fn compiles_session_window_with_gap() {
286        let sql = "SELECT k, MAX(v) AS hi \
287                   FROM SESSION(TABLE events, DESCRIPTOR(ts), 15000) \
288                   GROUP BY k, window_start, window_end";
289        let plan = compile_streaming_window_sql(sql).unwrap();
290        assert_eq!(plan.spec.window_kind, WindowKind::Session);
291        assert_eq!(plan.spec.session_gap_ms, Some(15000));
292        assert_eq!(plan.spec.agg_exprs[0].kind, WindowAggKind::Max);
293    }
294
295    #[test]
296    fn non_windowed_query_is_unsupported() {
297        let err = compile_streaming_window_sql("SELECT a FROM t").unwrap_err();
298        assert!(matches!(err, SqlError::Unsupported { .. }));
299    }
300
301    #[test]
302    fn unsupported_aggregate_is_rejected() {
303        let sql = "SELECT k, MEDIAN(v) AS s \
304                   FROM TUMBLE(TABLE events, DESCRIPTOR(ts), 60000) \
305                   GROUP BY k, window_start, window_end";
306        let err = compile_streaming_window_sql(sql).unwrap_err();
307        assert!(matches!(err, SqlError::Unsupported { .. }));
308    }
309
310    #[test]
311    fn window_boundary_columns_are_not_treated_as_key() {
312        let sql = "SELECT window_start, user_id, COUNT(*) AS c \
313                   FROM TUMBLE(TABLE events, DESCRIPTOR(ts), 60000) \
314                   GROUP BY user_id, window_start, window_end";
315        let plan = compile_streaming_window_sql(sql).unwrap();
316        assert_eq!(plan.spec.key_column, "user_id");
317    }
318
319    #[test]
320    fn detects_windowed_sql() {
321        assert!(is_windowed_streaming_sql(
322            "SELECT k FROM TUMBLE(TABLE t, DESCRIPTOR(ts), 1000) GROUP BY k"
323        ));
324        assert!(!is_windowed_streaming_sql("SELECT k FROM t"));
325    }
326}