1use std::collections::HashMap;
18
19use datafusion::sql::sqlparser::ast::{
20 Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, SelectItem, SetExpr, Statement,
21};
22use datafusion::sql::sqlparser::dialect::GenericDialect;
23use datafusion::sql::sqlparser::parser::Parser;
24use krishiv_plan::window::{WindowAgg, WindowAggKind, WindowExecutionSpec, WindowKind};
25
26use crate::streaming_tvf::{WindowTvf, find_window_tvf, rewrite_window_tvfs};
27use crate::{SqlError, SqlResult};
28
29#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct StreamingWindowPlan {
33 pub spec: WindowExecutionSpec,
35 pub source: String,
37}
38
39fn unsupported(msg: impl Into<String>) -> SqlError {
40 SqlError::Unsupported {
41 feature: msg.into(),
42 }
43}
44
45fn parse_ms(raw: &str) -> SqlResult<u64> {
46 raw.trim().parse::<u64>().map_err(|_| {
47 unsupported(format!(
48 "window interval '{raw}' is not a millisecond count"
49 ))
50 })
51}
52
53pub fn is_windowed_streaming_sql(sql: &str) -> bool {
55 find_window_tvf(sql).is_some()
56}
57
58pub fn compile_streaming_window_sql(sql: &str) -> SqlResult<StreamingWindowPlan> {
63 let (_, tvf, _) = find_window_tvf(sql)
64 .ok_or_else(|| unsupported("query has no TUMBLE/HOP/SESSION window"))?;
65
66 let (source, event_time_column, window_kind, window_size_ms, slide_ms, session_gap_ms) =
67 match &tvf {
68 WindowTvf::Tumble {
69 source,
70 ts_col,
71 size_ms,
72 } => (
73 (*source).to_string(),
74 (*ts_col).to_string(),
75 WindowKind::Tumbling,
76 parse_ms(size_ms)?,
77 None,
78 None,
79 ),
80 WindowTvf::Hop {
81 source,
82 ts_col,
83 slide_ms,
84 size_ms,
85 } => (
86 (*source).to_string(),
87 (*ts_col).to_string(),
88 WindowKind::Sliding,
89 parse_ms(size_ms)?,
90 Some(parse_ms(slide_ms)?),
91 None,
92 ),
93 WindowTvf::Session {
94 source,
95 ts_col,
96 gap_ms,
97 } => {
98 let gap = parse_ms(gap_ms)?;
99 (
100 (*source).to_string(),
101 (*ts_col).to_string(),
102 WindowKind::Session,
103 gap,
104 None,
105 Some(gap),
106 )
107 }
108 };
109
110 let rewritten = rewrite_window_tvfs(sql);
113 let (key_column, agg_exprs) = extract_key_and_aggs(&rewritten)?;
114
115 let spec = WindowExecutionSpec {
116 key_column,
117 key_column_type: String::from("utf8"),
118 event_time_column,
119 watermark_lag_ms: 0,
120 window_kind,
121 window_size_ms,
122 slide_ms,
123 session_gap_ms,
124 agg_exprs,
125 state_ttl_ms: None,
126 allowed_lateness_ms: None,
127 source_watermark_lags: HashMap::new(),
128 source_id_column: None,
129 window_timezone: None,
130 };
131 Ok(StreamingWindowPlan { spec, source })
132}
133
134const WINDOW_BOUNDARY_COLS: [&str; 2] = ["window_start", "window_end"];
135
136fn extract_key_and_aggs(sql: &str) -> SqlResult<(String, Vec<WindowAgg>)> {
137 let dialect = GenericDialect {};
138 let stmts = Parser::parse_sql(&dialect, sql)
139 .map_err(|e| unsupported(format!("streaming window query parse error: {e}")))?;
140 let query = stmts
141 .into_iter()
142 .find_map(|s| match s {
143 Statement::Query(q) => Some(q),
144 _ => None,
145 })
146 .ok_or_else(|| unsupported("streaming window query must be a SELECT"))?;
147
148 let SetExpr::Select(select) = query.body.as_ref() else {
149 return Err(unsupported("streaming window query must be a plain SELECT"));
150 };
151
152 let mut key_column: Option<String> = None;
153 let mut aggs: Vec<WindowAgg> = Vec::new();
154
155 for item in &select.projection {
156 let (expr, alias) = match item {
157 SelectItem::UnnamedExpr(e) => (e, None),
158 SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
159 _ => continue,
160 };
161 match expr {
162 Expr::Function(f) => aggs.push(function_to_agg(f, alias)?),
163 Expr::Identifier(id) => maybe_set_key(&mut key_column, &id.value),
164 Expr::CompoundIdentifier(parts) => {
165 if let Some(last) = parts.last() {
166 maybe_set_key(&mut key_column, &last.value);
167 }
168 }
169 _ => continue,
170 }
171 }
172
173 let key_column = key_column.ok_or_else(|| {
174 unsupported("streaming window query needs a grouping key column in the SELECT list")
175 })?;
176 if aggs.is_empty() {
177 aggs.push(WindowAgg::count("count"));
178 }
179 Ok((key_column, aggs))
180}
181
182fn maybe_set_key(key: &mut Option<String>, name: &str) {
183 if key.is_none() && !WINDOW_BOUNDARY_COLS.contains(&name) {
184 *key = Some(name.to_string());
185 }
186}
187
188fn function_to_agg(f: &Function, alias: Option<String>) -> SqlResult<WindowAgg> {
189 let fname = f.name.to_string().to_ascii_lowercase();
190 let kind = match fname.as_str() {
191 "count" => WindowAggKind::Count,
192 "sum" => WindowAggKind::Sum,
193 "min" => WindowAggKind::Min,
194 "max" => WindowAggKind::Max,
195 "avg" => WindowAggKind::Avg,
196 "stddev" | "stddev_samp" => WindowAggKind::Stddev,
197 other => {
198 return Err(unsupported(format!(
199 "aggregate '{other}' is not supported in streaming windows; \
200 use count/sum/min/max/avg/stddev"
201 )));
202 }
203 };
204 let input_column = first_arg_column(f);
205 let output_column = alias.unwrap_or_else(|| match &input_column {
206 Some(col) => format!("{fname}_{col}"),
207 None => fname.clone(),
208 });
209 Ok(WindowAgg {
210 kind,
211 input_column: input_column.unwrap_or_default(),
212 output_column,
213 })
214}
215
216fn first_arg_column(f: &Function) -> Option<String> {
217 let FunctionArguments::List(list) = &f.args else {
218 return None;
219 };
220 for fa in &list.args {
221 let expr = match fa {
222 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
223 FunctionArg::Named {
224 arg: FunctionArgExpr::Expr(e),
225 ..
226 } => Some(e),
227 _ => None,
228 };
229 match expr {
230 Some(Expr::Identifier(id)) => return Some(id.value.clone()),
231 Some(Expr::CompoundIdentifier(parts)) => return parts.last().map(|p| p.value.clone()),
232 _ => {}
233 }
234 }
235 None
236}
237
238#[cfg(test)]
239mod tests {
240 #![allow(clippy::unwrap_used)]
241
242 use super::*;
243
244 #[test]
245 fn compiles_tumbling_window() {
246 let sql = "SELECT user_id, SUM(amount) AS total \
247 FROM TUMBLE(TABLE events, DESCRIPTOR(ts), 60000) \
248 GROUP BY user_id, window_start, window_end";
249 let plan = compile_streaming_window_sql(sql).unwrap();
250 assert_eq!(plan.source, "events");
251 assert_eq!(plan.spec.window_kind, WindowKind::Tumbling);
252 assert_eq!(plan.spec.window_size_ms, 60000);
253 assert_eq!(plan.spec.event_time_column, "ts");
254 assert_eq!(plan.spec.key_column, "user_id");
255 assert_eq!(plan.spec.agg_exprs.len(), 1);
256 assert_eq!(plan.spec.agg_exprs[0].kind, WindowAggKind::Sum);
257 assert_eq!(plan.spec.agg_exprs[0].input_column, "amount");
258 assert_eq!(plan.spec.agg_exprs[0].output_column, "total");
259 }
260
261 #[test]
262 fn compiles_tumbling_window_with_stddev() {
263 let sql = "SELECT k, STDDEV(v) AS spread \
264 FROM TUMBLE(TABLE m, DESCRIPTOR(ts), 60000) \
265 GROUP BY k, window_start, window_end";
266 let plan = compile_streaming_window_sql(sql).unwrap();
267 assert_eq!(plan.spec.agg_exprs[0].kind, WindowAggKind::Stddev);
268 assert_eq!(plan.spec.agg_exprs[0].input_column, "v");
269 assert_eq!(plan.spec.agg_exprs[0].output_column, "spread");
270 }
271
272 #[test]
273 fn compiles_hop_window_with_slide() {
274 let sql = "SELECT k, COUNT(*) AS c \
275 FROM HOP(TABLE clicks, DESCRIPTOR(ts), 30000, 60000) \
276 GROUP BY k, window_start, window_end";
277 let plan = compile_streaming_window_sql(sql).unwrap();
278 assert_eq!(plan.spec.window_kind, WindowKind::Sliding);
279 assert_eq!(plan.spec.window_size_ms, 60000);
280 assert_eq!(plan.spec.slide_ms, Some(30000));
281 assert_eq!(plan.spec.agg_exprs[0].kind, WindowAggKind::Count);
282 }
283
284 #[test]
285 fn compiles_session_window_with_gap() {
286 let sql = "SELECT k, MAX(v) AS hi \
287 FROM SESSION(TABLE events, DESCRIPTOR(ts), 15000) \
288 GROUP BY k, window_start, window_end";
289 let plan = compile_streaming_window_sql(sql).unwrap();
290 assert_eq!(plan.spec.window_kind, WindowKind::Session);
291 assert_eq!(plan.spec.session_gap_ms, Some(15000));
292 assert_eq!(plan.spec.agg_exprs[0].kind, WindowAggKind::Max);
293 }
294
295 #[test]
296 fn non_windowed_query_is_unsupported() {
297 let err = compile_streaming_window_sql("SELECT a FROM t").unwrap_err();
298 assert!(matches!(err, SqlError::Unsupported { .. }));
299 }
300
301 #[test]
302 fn unsupported_aggregate_is_rejected() {
303 let sql = "SELECT k, MEDIAN(v) AS s \
304 FROM TUMBLE(TABLE events, DESCRIPTOR(ts), 60000) \
305 GROUP BY k, window_start, window_end";
306 let err = compile_streaming_window_sql(sql).unwrap_err();
307 assert!(matches!(err, SqlError::Unsupported { .. }));
308 }
309
310 #[test]
311 fn window_boundary_columns_are_not_treated_as_key() {
312 let sql = "SELECT window_start, user_id, COUNT(*) AS c \
313 FROM TUMBLE(TABLE events, DESCRIPTOR(ts), 60000) \
314 GROUP BY user_id, window_start, window_end";
315 let plan = compile_streaming_window_sql(sql).unwrap();
316 assert_eq!(plan.spec.key_column, "user_id");
317 }
318
319 #[test]
320 fn detects_windowed_sql() {
321 assert!(is_windowed_streaming_sql(
322 "SELECT k FROM TUMBLE(TABLE t, DESCRIPTOR(ts), 1000) GROUP BY k"
323 ));
324 assert!(!is_windowed_streaming_sql("SELECT k FROM t"));
325 }
326}