1use std::collections::HashMap;
10
11use nodedb_types::Value;
12
13use super::spec::WindowFuncSpec;
14use super::value_agg::apply_v_aggregate;
15use crate::expr::types::SqlExpr;
16use crate::value_ops::compare_values;
17
18#[derive(Debug, thiserror::Error)]
20pub enum WindowError {
21 #[error("window column '{name}' not found in result columns")]
22 ColumnNotFound { name: String },
23
24 #[error("window function argument error: {detail}")]
25 ArgEval { detail: String },
26
27 #[error("window frame error: {detail}")]
28 BadFrame { detail: String },
29}
30
31pub fn evaluate_window_functions_value(
37 rows: &mut [Vec<Value>],
38 column_index: &HashMap<String, usize>,
39 specs: &[WindowFuncSpec],
40) -> Result<Vec<String>, WindowError> {
41 let mut new_cols: Vec<String> = Vec::with_capacity(specs.len());
42
43 for spec in specs {
44 let partitions = build_value_partitions(rows, column_index, spec)?;
45 let write_col = rows.first().map(|r| r.len()).unwrap_or(0);
46
47 for row in rows.iter_mut() {
48 row.push(Value::Null);
49 }
50
51 for partition_indices in &partitions {
52 match spec.func_name.as_str() {
53 "row_number" => apply_v_row_number(rows, partition_indices, write_col),
54 "rank" => apply_v_rank(rows, partition_indices, column_index, spec, write_col),
55 "dense_rank" => {
56 apply_v_dense_rank(rows, partition_indices, column_index, spec, write_col)
57 }
58 "ntile" => apply_v_ntile(rows, partition_indices, spec, write_col)?,
59 "percent_rank" => {
60 apply_v_percent_rank(rows, partition_indices, column_index, spec, write_col)
61 }
62 "cume_dist" => {
63 apply_v_cume_dist(rows, partition_indices, column_index, spec, write_col)
64 }
65 "lag" => apply_v_lag(rows, partition_indices, column_index, spec, write_col)?,
66 "lead" => apply_v_lead(rows, partition_indices, column_index, spec, write_col)?,
67 "nth_value" => {
68 apply_v_nth_value(rows, partition_indices, column_index, spec, write_col)?
69 }
70 "sum" | "count" | "avg" | "min" | "max" | "first_value" | "last_value" => {
71 apply_v_aggregate(rows, partition_indices, column_index, spec, write_col)
72 }
73 other => {
74 return Err(WindowError::ArgEval {
75 detail: format!(
76 "unknown window function '{other}'; valid names: row_number, rank, \
77 dense_rank, ntile, percent_rank, cume_dist, lag, lead, nth_value, \
78 sum, count, avg, min, max, first_value, last_value"
79 ),
80 });
81 }
82 }
83 }
84
85 new_cols.push(spec.alias.clone());
86 }
87
88 Ok(new_cols)
89}
90
91fn build_value_partitions(
94 rows: &[Vec<Value>],
95 column_index: &HashMap<String, usize>,
96 spec: &WindowFuncSpec,
97) -> Result<Vec<Vec<usize>>, WindowError> {
98 if spec.partition_by.is_empty() {
99 return Ok(vec![(0..rows.len()).collect()]);
100 }
101
102 let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
103 let mut order: Vec<String> = Vec::new();
104
105 for (i, row) in rows.iter().enumerate() {
106 let key = partition_key(row, column_index, &spec.partition_by);
107 let entry = groups.entry(key.clone()).or_default();
108 if entry.is_empty() {
109 order.push(key);
110 }
111 entry.push(i);
112 }
113
114 Ok(order.iter().filter_map(|k| groups.remove(k)).collect())
115}
116
117fn partition_key(
118 row: &[Value],
119 column_index: &HashMap<String, usize>,
120 partition_by: &[SqlExpr],
121) -> String {
122 partition_by
123 .iter()
124 .map(|expr| {
125 let v = eval_arg_for_row(expr, row, column_index);
126 format!("{v:?}")
127 })
128 .collect::<Vec<_>>()
129 .join("\x00")
130}
131
132pub(super) fn cmp_values(a: &Value, b: &Value) -> std::cmp::Ordering {
135 match (a, b) {
136 (Value::Null, Value::Null) => std::cmp::Ordering::Equal,
137 (Value::Null, _) => std::cmp::Ordering::Less,
138 (_, Value::Null) => std::cmp::Ordering::Greater,
139 (va, vb) => compare_values(va, vb),
140 }
141}
142
143pub(super) fn order_keys_equal_v(
144 rows: &[Vec<Value>],
145 a: usize,
146 b: usize,
147 column_index: &HashMap<String, usize>,
148 order_by: &[(SqlExpr, bool)],
149) -> bool {
150 order_by.iter().all(|(expr, _)| {
151 let row_a = rows.get(a).map(|r| r.as_slice()).unwrap_or(&[]);
152 let row_b = rows.get(b).map(|r| r.as_slice()).unwrap_or(&[]);
153 let va = eval_arg_for_row(expr, row_a, column_index);
154 let vb = eval_arg_for_row(expr, row_b, column_index);
155 matches!(cmp_values(&va, &vb), std::cmp::Ordering::Equal)
156 })
157}
158
159pub(super) fn eval_arg_for_row(
162 expr: &SqlExpr,
163 row: &[Value],
164 column_index: &HashMap<String, usize>,
165) -> Value {
166 match expr {
167 SqlExpr::Column(name) => column_index
168 .get(name.as_str())
169 .and_then(|&idx| row.get(idx))
170 .cloned()
171 .unwrap_or(Value::Null),
172 SqlExpr::Literal(v) => v.clone(),
173 other => {
174 let doc = row_to_obj(row, column_index);
175 other.eval(&doc)
176 }
177 }
178}
179
180fn row_to_obj(row: &[Value], column_index: &HashMap<String, usize>) -> Value {
181 let mut map = HashMap::new();
182 for (name, &idx) in column_index {
183 if let Some(v) = row.get(idx) {
184 map.insert(name.clone(), v.clone());
185 }
186 }
187 Value::Object(map)
188}
189
190fn usize_arg(spec: &WindowFuncSpec, idx: usize, default: usize) -> usize {
191 spec.args
192 .get(idx)
193 .and_then(|e| match e {
194 SqlExpr::Literal(v) => v.as_f64().map(|n| n as usize),
195 _ => None,
196 })
197 .unwrap_or(default)
198}
199
200fn default_arg_value(spec: &WindowFuncSpec, idx: usize) -> Value {
201 spec.args
202 .get(idx)
203 .and_then(|e| match e {
204 SqlExpr::Literal(v) => Some(v.clone()),
205 _ => None,
206 })
207 .unwrap_or(Value::Null)
208}
209
210pub(super) fn set_cell(rows: &mut [Vec<Value>], row_idx: usize, col_idx: usize, val: Value) {
213 if let Some(row) = rows.get_mut(row_idx)
214 && let Some(cell) = row.get_mut(col_idx)
215 {
216 *cell = val;
217 }
218}
219
220fn apply_v_row_number(rows: &mut [Vec<Value>], indices: &[usize], write_col: usize) {
223 for (rank, &i) in indices.iter().enumerate() {
224 set_cell(rows, i, write_col, Value::Integer((rank + 1) as i64));
225 }
226}
227
228fn apply_v_rank(
229 rows: &mut [Vec<Value>],
230 indices: &[usize],
231 column_index: &HashMap<String, usize>,
232 spec: &WindowFuncSpec,
233 write_col: usize,
234) {
235 if indices.is_empty() {
236 return;
237 }
238 let mut current_rank = 1usize;
239 set_cell(rows, indices[0], write_col, Value::Integer(1));
240 for pos in 1..indices.len() {
241 if !order_keys_equal_v(
242 rows,
243 indices[pos - 1],
244 indices[pos],
245 column_index,
246 &spec.order_by,
247 ) {
248 current_rank = pos + 1;
249 }
250 set_cell(
251 rows,
252 indices[pos],
253 write_col,
254 Value::Integer(current_rank as i64),
255 );
256 }
257}
258
259fn apply_v_dense_rank(
260 rows: &mut [Vec<Value>],
261 indices: &[usize],
262 column_index: &HashMap<String, usize>,
263 spec: &WindowFuncSpec,
264 write_col: usize,
265) {
266 if indices.is_empty() {
267 return;
268 }
269 let mut current_rank = 1usize;
270 set_cell(rows, indices[0], write_col, Value::Integer(1));
271 for pos in 1..indices.len() {
272 if !order_keys_equal_v(
273 rows,
274 indices[pos - 1],
275 indices[pos],
276 column_index,
277 &spec.order_by,
278 ) {
279 current_rank += 1;
280 }
281 set_cell(
282 rows,
283 indices[pos],
284 write_col,
285 Value::Integer(current_rank as i64),
286 );
287 }
288}
289
290fn apply_v_ntile(
291 rows: &mut [Vec<Value>],
292 indices: &[usize],
293 spec: &WindowFuncSpec,
294 write_col: usize,
295) -> Result<(), WindowError> {
296 let n = usize_arg(spec, 0, 1).max(1);
297 let total = indices.len();
298 if total == 0 {
299 return Ok(());
300 }
301 for (pos, &i) in indices.iter().enumerate() {
302 let bucket = (pos * n / total) + 1;
303 set_cell(rows, i, write_col, Value::Integer(bucket as i64));
304 }
305 Ok(())
306}
307
308fn apply_v_percent_rank(
309 rows: &mut [Vec<Value>],
310 indices: &[usize],
311 column_index: &HashMap<String, usize>,
312 spec: &WindowFuncSpec,
313 write_col: usize,
314) {
315 let total = indices.len();
316 if total == 0 {
317 return;
318 }
319 if total == 1 {
320 set_cell(rows, indices[0], write_col, Value::Float(0.0));
321 return;
322 }
323 let denom = (total - 1) as f64;
324 let mut current_rank = 1usize;
325 set_cell(rows, indices[0], write_col, Value::Float(0.0));
326 for pos in 1..total {
327 if !order_keys_equal_v(
328 rows,
329 indices[pos - 1],
330 indices[pos],
331 column_index,
332 &spec.order_by,
333 ) {
334 current_rank = pos + 1;
335 }
336 let pr = (current_rank - 1) as f64 / denom;
337 set_cell(rows, indices[pos], write_col, Value::Float(pr));
338 }
339}
340
341fn apply_v_cume_dist(
342 rows: &mut [Vec<Value>],
343 indices: &[usize],
344 column_index: &HashMap<String, usize>,
345 spec: &WindowFuncSpec,
346 write_col: usize,
347) {
348 let total = indices.len();
349 if total == 0 {
350 return;
351 }
352 let denom = total as f64;
353 let mut group_start = 0;
354 while group_start < total {
355 let mut group_end = group_start + 1;
356 while group_end < total
357 && order_keys_equal_v(
358 rows,
359 indices[group_start],
360 indices[group_end],
361 column_index,
362 &spec.order_by,
363 )
364 {
365 group_end += 1;
366 }
367 let cd = group_end as f64 / denom;
368 for &idx in &indices[group_start..group_end] {
369 set_cell(rows, idx, write_col, Value::Float(cd));
370 }
371 group_start = group_end;
372 }
373}
374
375fn collect_arg_values(
378 rows: &[Vec<Value>],
379 indices: &[usize],
380 column_index: &HashMap<String, usize>,
381 spec: &WindowFuncSpec,
382) -> Vec<Value> {
383 indices
384 .iter()
385 .map(|&i| {
386 rows.get(i)
387 .map(|row| {
388 spec.args
389 .first()
390 .map(|expr| eval_arg_for_row(expr, row, column_index))
391 .unwrap_or(Value::Null)
392 })
393 .unwrap_or(Value::Null)
394 })
395 .collect()
396}
397
398fn apply_v_lag(
399 rows: &mut [Vec<Value>],
400 indices: &[usize],
401 column_index: &HashMap<String, usize>,
402 spec: &WindowFuncSpec,
403 write_col: usize,
404) -> Result<(), WindowError> {
405 let offset = usize_arg(spec, 1, 1);
406 let default = default_arg_value(spec, 2);
407 let values = collect_arg_values(rows, indices, column_index, spec);
408 for (pos, &i) in indices.iter().enumerate() {
409 let val = if pos >= offset {
410 values[pos - offset].clone()
411 } else {
412 default.clone()
413 };
414 set_cell(rows, i, write_col, val);
415 }
416 Ok(())
417}
418
419fn apply_v_lead(
420 rows: &mut [Vec<Value>],
421 indices: &[usize],
422 column_index: &HashMap<String, usize>,
423 spec: &WindowFuncSpec,
424 write_col: usize,
425) -> Result<(), WindowError> {
426 let offset = usize_arg(spec, 1, 1);
427 let default = default_arg_value(spec, 2);
428 let values = collect_arg_values(rows, indices, column_index, spec);
429 for (pos, &i) in indices.iter().enumerate() {
430 let val = if pos + offset < indices.len() {
431 values[pos + offset].clone()
432 } else {
433 default.clone()
434 };
435 set_cell(rows, i, write_col, val);
436 }
437 Ok(())
438}
439
440fn apply_v_nth_value(
441 rows: &mut [Vec<Value>],
442 indices: &[usize],
443 column_index: &HashMap<String, usize>,
444 spec: &WindowFuncSpec,
445 write_col: usize,
446) -> Result<(), WindowError> {
447 let n = usize_arg(spec, 1, 1).max(1);
448 let values = collect_arg_values(rows, indices, column_index, spec);
449 for (pos, &i) in indices.iter().enumerate() {
450 let val = if pos + 1 >= n {
451 values[n - 1].clone()
452 } else {
453 Value::Null
454 };
455 set_cell(rows, i, write_col, val);
456 }
457 Ok(())
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463 use crate::expr::types::SqlExpr;
464 use crate::window::spec::WindowFrame;
465
466 fn ci(names: &[&str]) -> HashMap<String, usize> {
467 names
468 .iter()
469 .enumerate()
470 .map(|(i, n)| (n.to_string(), i))
471 .collect()
472 }
473
474 fn spec(
475 func: &str,
476 args: Vec<SqlExpr>,
477 partition_by: Vec<SqlExpr>,
478 order_by: Vec<(SqlExpr, bool)>,
479 ) -> WindowFuncSpec {
480 WindowFuncSpec {
481 alias: format!("w_{func}"),
482 func_name: func.to_string(),
483 args,
484 partition_by,
485 order_by,
486 frame: WindowFrame::default(),
487 }
488 }
489
490 fn col(name: &str) -> SqlExpr {
491 SqlExpr::Column(name.to_string())
492 }
493
494 fn rows_v(vals: &[i64]) -> Vec<Vec<Value>> {
496 vals.iter().map(|&v| vec![Value::Integer(v)]).collect()
497 }
498
499 fn out_int(rows: &[Vec<Value>], col_idx: usize) -> Vec<i64> {
500 rows.iter()
501 .map(|r| match &r[col_idx] {
502 Value::Integer(n) => *n,
503 other => panic!("expected integer, got {other:?}"),
504 })
505 .collect()
506 }
507
508 fn out_f64(rows: &[Vec<Value>], col_idx: usize) -> Vec<f64> {
509 rows.iter().map(|r| r[col_idx].as_f64().unwrap()).collect()
510 }
511
512 #[test]
513 fn row_number_sequential() {
514 let mut rows = rows_v(&[5, 5, 5]);
515 let cols = ci(&["v"]);
516 let s = spec("row_number", vec![], vec![], vec![]);
517 evaluate_window_functions_value(&mut rows, &cols, &[s]).unwrap();
518 assert_eq!(out_int(&rows, 1), vec![1, 2, 3]);
519 }
520
521 #[test]
522 fn rank_handles_ties() {
523 let mut rows = rows_v(&[10, 10, 20]);
524 let cols = ci(&["v"]);
525 let s = spec("rank", vec![], vec![], vec![(col("v"), true)]);
526 evaluate_window_functions_value(&mut rows, &cols, &[s]).unwrap();
527 assert_eq!(out_int(&rows, 1), vec![1, 1, 3]);
528 }
529
530 #[test]
531 fn dense_rank_handles_ties() {
532 let mut rows = rows_v(&[10, 10, 20]);
533 let cols = ci(&["v"]);
534 let s = spec("dense_rank", vec![], vec![], vec![(col("v"), true)]);
535 evaluate_window_functions_value(&mut rows, &cols, &[s]).unwrap();
536 assert_eq!(out_int(&rows, 1), vec![1, 1, 2]);
537 }
538
539 #[test]
540 fn ntile_buckets() {
541 let mut rows = rows_v(&[1, 2, 3]);
542 let cols = ci(&["v"]);
543 let s = spec(
544 "ntile",
545 vec![SqlExpr::Literal(Value::Integer(2))],
546 vec![],
547 vec![(col("v"), true)],
548 );
549 evaluate_window_functions_value(&mut rows, &cols, &[s]).unwrap();
550 assert_eq!(out_int(&rows, 1), vec![1, 1, 2]);
551 }
552
553 #[test]
554 fn lag_default_and_offset() {
555 let mut rows = rows_v(&[1, 2, 3]);
556 let cols = ci(&["v"]);
557 let s = spec("lag", vec![col("v")], vec![], vec![(col("v"), true)]);
558 evaluate_window_functions_value(&mut rows, &cols, &[s]).unwrap();
559 assert!(matches!(rows[0][1], Value::Null));
561 assert_eq!(rows[1][1].as_f64().unwrap() as i64, 1);
562 assert_eq!(rows[2][1].as_f64().unwrap() as i64, 2);
563 }
564
565 #[test]
566 fn lead_boundary() {
567 let mut rows = rows_v(&[1, 2, 3]);
568 let cols = ci(&["v"]);
569 let s = spec("lead", vec![col("v")], vec![], vec![(col("v"), true)]);
570 evaluate_window_functions_value(&mut rows, &cols, &[s]).unwrap();
571 assert_eq!(rows[0][1].as_f64().unwrap() as i64, 2);
572 assert_eq!(rows[1][1].as_f64().unwrap() as i64, 3);
573 assert!(matches!(rows[2][1], Value::Null));
575 }
576
577 #[test]
578 fn percent_rank_and_cume_dist() {
579 let cols = ci(&["v"]);
580
581 let mut rows = rows_v(&[10, 10, 20]);
582 let pr = spec("percent_rank", vec![], vec![], vec![(col("v"), true)]);
583 evaluate_window_functions_value(&mut rows, &cols, &[pr]).unwrap();
584 let pr_out = out_f64(&rows, 1);
585 assert!((pr_out[0] - 0.0).abs() < 1e-9);
586 assert!((pr_out[1] - 0.0).abs() < 1e-9);
587 assert!((pr_out[2] - 1.0).abs() < 1e-9);
588
589 let mut rows = rows_v(&[10, 10, 20]);
590 let cd = spec("cume_dist", vec![], vec![], vec![(col("v"), true)]);
591 evaluate_window_functions_value(&mut rows, &cols, &[cd]).unwrap();
592 let cd_out = out_f64(&rows, 1);
593 assert!((cd_out[0] - 2.0 / 3.0).abs() < 1e-9);
594 assert!((cd_out[1] - 2.0 / 3.0).abs() < 1e-9);
595 assert!((cd_out[2] - 1.0).abs() < 1e-9);
596 }
597
598 #[test]
599 fn partition_resets_row_number() {
600 let cols = ci(&["g", "v"]);
601 let mut rows = vec![
602 vec![Value::Integer(1), Value::Integer(100)],
603 vec![Value::Integer(1), Value::Integer(101)],
604 vec![Value::Integer(2), Value::Integer(102)],
605 ];
606 let s = spec("row_number", vec![], vec![col("g")], vec![]);
607 evaluate_window_functions_value(&mut rows, &cols, &[s]).unwrap();
608 assert_eq!(out_int(&rows, 2), vec![1, 2, 1]);
610 }
611
612 #[test]
613 fn unknown_function_errors() {
614 let mut rows = rows_v(&[1]);
615 let cols = ci(&["v"]);
616 let s = spec("nonexistent", vec![], vec![], vec![]);
617 let err = evaluate_window_functions_value(&mut rows, &cols, &[s]).unwrap_err();
618 assert!(matches!(err, WindowError::ArgEval { .. }));
619 }
620
621 #[test]
622 fn missing_partition_column_is_null_keyed() {
623 let cols = ci(&["v"]);
626 let mut rows = rows_v(&[1, 2]);
627 let s = spec("row_number", vec![], vec![col("missing")], vec![]);
628 evaluate_window_functions_value(&mut rows, &cols, &[s]).unwrap();
629 assert_eq!(out_int(&rows, 1), vec![1, 2]);
630 }
631}