nodedb_sql/planner/window/
extract.rs1use std::collections::HashMap;
14
15use sqlparser::ast;
16
17use crate::error::{Result, SqlError};
18use crate::functions::registry::{FunctionCategory, FunctionRegistry};
19use crate::parser::normalize::{SCHEMA_QUALIFIED_MSG, normalize_ident};
20use crate::resolver::expr::convert_expr;
21use crate::types::{SortKey, WindowSpec};
22use nodedb_query::{FrameBound, WindowFrame};
23
24use super::frame::convert_window_frame;
25use super::named::{collect_named_windows, flatten_window_spec, resolve_named_def};
26
27pub fn extract_window_functions(
29 select: &ast::Select,
30 functions: &FunctionRegistry,
31) -> Result<Vec<WindowSpec>> {
32 let named = collect_named_windows(&select.named_window)?;
33 let mut specs = Vec::new();
34 for item in &select.projection {
35 let (expr, alias) = match item {
36 ast::SelectItem::UnnamedExpr(e) => (e, format!("{e}")),
37 ast::SelectItem::ExprWithAlias { expr, alias } => (expr, normalize_ident(alias)),
38 _ => continue,
39 };
40 if let ast::Expr::Function(func) = expr
41 && func.over.is_some()
42 {
43 specs.push(convert_window_spec(func, &alias, functions, &named)?);
44 }
45 }
46 Ok(specs)
47}
48
49fn convert_window_spec(
50 func: &ast::Function,
51 alias: &str,
52 functions: &FunctionRegistry,
53 named: &HashMap<String, &ast::NamedWindowExpr>,
54) -> Result<WindowSpec> {
55 if func.name.0.len() > 1 {
56 let qualified: String = func
57 .name
58 .0
59 .iter()
60 .map(|p| match p {
61 ast::ObjectNamePart::Identifier(ident) => ident.value.clone(),
62 _ => String::new(),
63 })
64 .collect::<Vec<_>>()
65 .join(".");
66 return Err(SqlError::Unsupported {
67 detail: format!(
68 "schema-qualified window function name '{qualified}': {SCHEMA_QUALIFIED_MSG}"
69 ),
70 });
71 }
72 let name = func
73 .name
74 .0
75 .iter()
76 .map(|p| match p {
77 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
78 _ => String::new(),
79 })
80 .collect::<Vec<_>>()
81 .join(".");
82
83 match functions.lookup(&name).map(|m| m.category) {
86 Some(FunctionCategory::Window) | Some(FunctionCategory::Aggregate) => {}
87 Some(FunctionCategory::Scalar) => {
88 return Err(SqlError::InvalidFunction {
89 detail: format!(
90 "function '{name}() OVER ()' does not exist as a window function \
91 (it is a scalar function)"
92 ),
93 });
94 }
95 None => {
96 return Err(SqlError::InvalidFunction {
97 detail: format!("function '{name}() OVER ()' does not exist"),
98 });
99 }
100 }
101
102 let args = match &func.args {
103 ast::FunctionArguments::List(args) => args
104 .args
105 .iter()
106 .filter_map(|a| match a {
107 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => convert_expr(e).ok(),
108 _ => None,
109 })
110 .collect(),
111 _ => Vec::new(),
112 };
113
114 let flat = match &func.over {
116 Some(ast::WindowType::WindowSpec(spec)) => {
117 Some(flatten_window_spec(spec, named, &mut Vec::new())?)
118 }
119 Some(ast::WindowType::NamedWindow(ident)) => {
120 let n = normalize_ident(ident);
121 let mut seen = vec![n.clone()];
122 let base = resolve_named_def(&n, named, &mut seen)?;
123 Some(flatten_window_spec(base, named, &mut seen)?)
124 }
125 None => None,
126 };
127
128 let (partition_by, order_by, frame) = match flat {
129 Some(flat) => {
130 let pb = flat
131 .partition_by
132 .iter()
133 .map(convert_expr)
134 .collect::<Result<Vec<_>>>()?;
135 let ob = flat
136 .order_by
137 .iter()
138 .map(|o| {
139 Ok(SortKey {
140 expr: convert_expr(&o.expr)?,
141 ascending: o.options.asc.unwrap_or(true),
142 nulls_first: o.options.nulls_first.unwrap_or(false),
143 })
144 })
145 .collect::<Result<Vec<_>>>()?;
146 let frame = match &flat.frame {
147 Some(f) => convert_window_frame(f, &ob)?,
148 None => {
153 if ob.is_empty() {
154 WindowFrame {
155 mode: "range".into(),
156 start: FrameBound::UnboundedPreceding,
157 end: FrameBound::UnboundedFollowing,
158 }
159 } else {
160 WindowFrame::default()
161 }
162 }
163 };
164 (pb, ob, frame)
165 }
166 None => (
168 Vec::new(),
169 Vec::new(),
170 WindowFrame {
171 mode: "range".into(),
172 start: FrameBound::UnboundedPreceding,
173 end: FrameBound::UnboundedFollowing,
174 },
175 ),
176 };
177
178 Ok(WindowSpec {
179 function: name,
180 args,
181 partition_by,
182 order_by,
183 alias: alias.into(),
184 frame,
185 })
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::functions::registry::FunctionRegistry;
192 use crate::parser::statement::parse_sql;
193
194 fn select_of(sql: &str) -> Box<ast::Select> {
195 match parse_sql(sql).unwrap().into_iter().next().unwrap() {
196 ast::Statement::Query(q) => match *q.body {
197 ast::SetExpr::Select(s) => s,
198 _ => panic!("not a SELECT"),
199 },
200 _ => panic!("not a query"),
201 }
202 }
203
204 #[test]
205 fn named_window_referenced_by_multiple_functions() {
206 let reg = FunctionRegistry::new();
207 let select = select_of(
208 "SELECT first_value(price) OVER w AS o, last_value(price) OVER w AS c, sum(volume) OVER w AS v
209 FROM ticks
210 WINDOW w AS (PARTITION BY bucket ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
211 );
212 let specs = extract_window_functions(&select, ®).unwrap();
213 assert_eq!(specs.len(), 3);
214 for s in &specs {
215 assert_eq!(
216 s.partition_by.len(),
217 1,
218 "partition by must be resolved from WINDOW clause"
219 );
220 assert_eq!(
221 s.order_by.len(),
222 1,
223 "order by must be resolved from WINDOW clause"
224 );
225 assert_eq!(s.frame.mode, "rows");
226 assert!(matches!(s.frame.start, FrameBound::UnboundedPreceding));
227 assert!(matches!(s.frame.end, FrameBound::UnboundedFollowing));
228 }
229 }
230
231 #[test]
232 fn undefined_named_window_is_rejected() {
233 let reg = FunctionRegistry::new();
234 let select = select_of("SELECT row_number() OVER missing AS r FROM t");
235 let err = extract_window_functions(&select, ®).unwrap_err();
236 assert!(
237 format!("{err}").contains("missing"),
238 "error must name the missing window: {err}"
239 );
240 }
241
242 #[test]
243 fn window_definition_referencing_another_resolves() {
244 let reg = FunctionRegistry::new();
245 let select = select_of(
246 "SELECT sum(x) OVER w2 AS s FROM t WINDOW w1 AS (PARTITION BY a), w2 AS (w1 ORDER BY ts)",
247 );
248 let specs = extract_window_functions(&select, ®).unwrap();
249 assert_eq!(specs.len(), 1);
250 assert_eq!(
251 specs[0].partition_by.len(),
252 1,
253 "PARTITION BY inherited from w1"
254 );
255 assert_eq!(specs[0].order_by.len(), 1, "ORDER BY added by w2");
256 }
257
258 #[test]
259 fn circular_named_window_is_rejected() {
260 let reg = FunctionRegistry::new();
261 let select = select_of("SELECT sum(x) OVER w1 AS s FROM t WINDOW w1 AS (w2), w2 AS (w1)");
262 let err = extract_window_functions(&select, ®).unwrap_err();
263 assert!(
264 format!("{err}").to_lowercase().contains("circular"),
265 "got: {err}"
266 );
267 }
268
269 #[test]
270 fn ohlcv_shape_base_window_plus_derived_ordered_window() {
271 let reg = FunctionRegistry::new();
275 let select = select_of(
276 "SELECT first_value(price) OVER w_ord AS o, max(price) OVER w AS h,
277 min(price) OVER w AS l, last_value(price) OVER w_ord AS c, sum(volume) OVER w AS v
278 FROM ticks
279 WINDOW w AS (PARTITION BY time_bucket('1m', ts), symbol),
280 w_ord AS (w ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
281 );
282 let specs = extract_window_functions(&select, ®).unwrap();
283 assert_eq!(specs.len(), 5);
284 for s in &specs {
285 assert_eq!(
286 s.partition_by.len(),
287 2,
288 "{}: partition inherited from w",
289 s.function
290 );
291 }
292 for f in ["first_value", "last_value"] {
294 let s = specs.iter().find(|s| s.function == f).unwrap();
295 assert_eq!(s.order_by.len(), 1, "{f}: order by from w_ord");
296 assert_eq!(s.frame.mode, "rows", "{f}: frame from w_ord");
297 assert!(matches!(s.frame.start, FrameBound::UnboundedPreceding));
298 assert!(matches!(s.frame.end, FrameBound::UnboundedFollowing));
299 }
300 for f in ["max", "min", "sum"] {
302 let s = specs.iter().find(|s| s.function == f).unwrap();
303 assert!(s.order_by.is_empty(), "{f}: no order by");
304 assert_eq!(s.frame.mode, "range");
305 assert!(matches!(s.frame.start, FrameBound::UnboundedPreceding));
306 assert!(matches!(s.frame.end, FrameBound::UnboundedFollowing));
307 }
308 }
309
310 #[test]
311 fn inline_window_referencing_named_inherits_partition() {
312 let reg = FunctionRegistry::new();
313 let select = select_of(
314 "SELECT sum(x) OVER (w ORDER BY ts) AS s FROM t WINDOW w AS (PARTITION BY a)",
315 );
316 let specs = extract_window_functions(&select, ®).unwrap();
317 assert_eq!(specs[0].partition_by.len(), 1);
318 assert_eq!(specs[0].order_by.len(), 1);
319 }
320}