1use crate::expr::SqlExpr;
7
8#[derive(
10 Debug,
11 Clone,
12 serde::Serialize,
13 serde::Deserialize,
14 zerompk::ToMessagePack,
15 zerompk::FromMessagePack,
16)]
17pub struct WindowFuncSpec {
18 pub alias: String,
20 pub func_name: String,
22 pub args: Vec<SqlExpr>,
24 pub partition_by: Vec<String>,
26 pub order_by: Vec<(String, bool)>,
28 pub frame: WindowFrame,
30}
31
32#[derive(
34 Debug,
35 Clone,
36 serde::Serialize,
37 serde::Deserialize,
38 zerompk::ToMessagePack,
39 zerompk::FromMessagePack,
40)]
41pub struct WindowFrame {
42 pub mode: String,
44 pub start: FrameBound,
46 pub end: FrameBound,
48}
49
50impl Default for WindowFrame {
51 fn default() -> Self {
52 Self {
53 mode: "range".into(),
54 start: FrameBound::UnboundedPreceding,
55 end: FrameBound::CurrentRow,
56 }
57 }
58}
59
60#[derive(
62 Debug,
63 Clone,
64 serde::Serialize,
65 serde::Deserialize,
66 zerompk::ToMessagePack,
67 zerompk::FromMessagePack,
68)]
69pub enum FrameBound {
70 UnboundedPreceding,
71 Preceding(u64),
72 CurrentRow,
73 Following(u64),
74 UnboundedFollowing,
75}
76
77pub fn evaluate_window_functions(
82 rows: &mut [(String, serde_json::Value)],
83 specs: &[WindowFuncSpec],
84) {
85 for spec in specs {
86 let partitions = build_partitions(rows, &spec.partition_by);
87
88 for partition_indices in &partitions {
89 match spec.func_name.as_str() {
90 "row_number" => apply_row_number(rows, partition_indices, &spec.alias),
91 "rank" => apply_rank(rows, partition_indices, &spec.alias, &spec.order_by),
92 "dense_rank" => {
93 apply_dense_rank(rows, partition_indices, &spec.alias, &spec.order_by)
94 }
95 "lag" => apply_lag(rows, partition_indices, spec),
96 "lead" => apply_lead(rows, partition_indices, spec),
97 "ntile" => apply_ntile(rows, partition_indices, spec),
98 "sum" | "count" | "avg" | "min" | "max" | "first_value" | "last_value" => {
99 apply_aggregate_window(rows, partition_indices, spec)
100 }
101 _ => {}
102 }
103 }
104 }
105}
106
107fn build_partitions(
108 rows: &[(String, serde_json::Value)],
109 partition_by: &[String],
110) -> Vec<Vec<usize>> {
111 if partition_by.is_empty() {
112 return vec![(0..rows.len()).collect()];
113 }
114
115 let mut groups: std::collections::HashMap<String, Vec<usize>> =
116 std::collections::HashMap::new();
117 let mut order = Vec::new();
118
119 for (i, (_id, doc)) in rows.iter().enumerate() {
120 let key: String = partition_by
121 .iter()
122 .map(|col| {
123 doc.get(col)
124 .map(|v| v.to_string())
125 .unwrap_or_else(|| "null".to_string())
126 })
127 .collect::<Vec<_>>()
128 .join("\x00");
129 let entry = groups.entry(key.clone()).or_default();
130 if entry.is_empty() {
131 order.push(key);
132 }
133 entry.push(i);
134 }
135
136 order.iter().filter_map(|k| groups.remove(k)).collect()
137}
138
139fn set_window_col(row: &mut serde_json::Value, alias: &str, val: serde_json::Value) {
140 if let serde_json::Value::Object(map) = row {
141 map.insert(alias.to_string(), val);
142 }
143}
144
145fn get_field(doc: &serde_json::Value, field: &str) -> serde_json::Value {
146 doc.get(field).cloned().unwrap_or(serde_json::Value::Null)
147}
148
149fn as_f64(v: &serde_json::Value) -> Option<f64> {
150 match v {
151 serde_json::Value::Number(n) => n.as_f64(),
152 serde_json::Value::String(s) => s.parse().ok(),
153 _ => None,
154 }
155}
156
157fn apply_row_number(rows: &mut [(String, serde_json::Value)], indices: &[usize], alias: &str) {
158 for (rank, &i) in indices.iter().enumerate() {
159 set_window_col(&mut rows[i].1, alias, serde_json::json!(rank + 1));
160 }
161}
162
163fn apply_rank(
164 rows: &mut [(String, serde_json::Value)],
165 indices: &[usize],
166 alias: &str,
167 order_by: &[(String, bool)],
168) {
169 if indices.is_empty() {
170 return;
171 }
172 let mut current_rank = 1;
173 set_window_col(&mut rows[indices[0]].1, alias, serde_json::json!(1));
174
175 for pos in 1..indices.len() {
176 let prev = &rows[indices[pos - 1]].1;
177 let curr = &rows[indices[pos]].1;
178 let same = order_by
179 .iter()
180 .all(|(col, _)| get_field(prev, col) == get_field(curr, col));
181 if !same {
182 current_rank = pos + 1;
183 }
184 set_window_col(
185 &mut rows[indices[pos]].1,
186 alias,
187 serde_json::json!(current_rank),
188 );
189 }
190}
191
192fn apply_dense_rank(
193 rows: &mut [(String, serde_json::Value)],
194 indices: &[usize],
195 alias: &str,
196 order_by: &[(String, bool)],
197) {
198 if indices.is_empty() {
199 return;
200 }
201 let mut current_rank = 1;
202 set_window_col(&mut rows[indices[0]].1, alias, serde_json::json!(1));
203
204 for pos in 1..indices.len() {
205 let prev = &rows[indices[pos - 1]].1;
206 let curr = &rows[indices[pos]].1;
207 let same = order_by
208 .iter()
209 .all(|(col, _)| get_field(prev, col) == get_field(curr, col));
210 if !same {
211 current_rank += 1;
212 }
213 set_window_col(
214 &mut rows[indices[pos]].1,
215 alias,
216 serde_json::json!(current_rank),
217 );
218 }
219}
220
221fn apply_ntile(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
222 let n = spec
223 .args
224 .first()
225 .and_then(|e| {
226 if let SqlExpr::Literal(v) = e {
227 v.as_f64().map(|x| x as usize)
228 } else {
229 None
230 }
231 })
232 .unwrap_or(1)
233 .max(1);
234 let total = indices.len();
235 for (pos, &i) in indices.iter().enumerate() {
236 let bucket = (pos * n / total) + 1;
237 set_window_col(&mut rows[i].1, &spec.alias, serde_json::json!(bucket));
238 }
239}
240
241fn apply_lag(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
242 let field = spec
243 .args
244 .first()
245 .and_then(|e| {
246 if let SqlExpr::Column(c) = e {
247 Some(c.as_str())
248 } else {
249 None
250 }
251 })
252 .unwrap_or("*");
253 let offset = spec
254 .args
255 .get(1)
256 .and_then(|e| {
257 if let SqlExpr::Literal(v) = e {
258 v.as_f64().map(|n| n as usize)
259 } else {
260 None
261 }
262 })
263 .unwrap_or(1);
264 let default = spec
265 .args
266 .get(2)
267 .and_then(|e| {
268 if let SqlExpr::Literal(v) = e {
269 Some(serde_json::Value::from(v.clone()))
270 } else {
271 None
272 }
273 })
274 .unwrap_or(serde_json::Value::Null);
275
276 for (pos, &i) in indices.iter().enumerate() {
277 let val = if pos >= offset {
278 get_field(&rows[indices[pos - offset]].1, field)
279 } else {
280 default.clone()
281 };
282 set_window_col(&mut rows[i].1, &spec.alias, val);
283 }
284}
285
286fn apply_lead(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
287 let field = spec
288 .args
289 .first()
290 .and_then(|e| {
291 if let SqlExpr::Column(c) = e {
292 Some(c.as_str())
293 } else {
294 None
295 }
296 })
297 .unwrap_or("*");
298 let offset = spec
299 .args
300 .get(1)
301 .and_then(|e| {
302 if let SqlExpr::Literal(v) = e {
303 v.as_f64().map(|n| n as usize)
304 } else {
305 None
306 }
307 })
308 .unwrap_or(1);
309 let default = spec
310 .args
311 .get(2)
312 .and_then(|e| {
313 if let SqlExpr::Literal(v) = e {
314 Some(serde_json::Value::from(v.clone()))
315 } else {
316 None
317 }
318 })
319 .unwrap_or(serde_json::Value::Null);
320
321 for (pos, &i) in indices.iter().enumerate() {
322 let val = if pos + offset < indices.len() {
323 get_field(&rows[indices[pos + offset]].1, field)
324 } else {
325 default.clone()
326 };
327 set_window_col(&mut rows[i].1, &spec.alias, val);
328 }
329}
330
331fn apply_aggregate_window(
332 rows: &mut [(String, serde_json::Value)],
333 indices: &[usize],
334 spec: &WindowFuncSpec,
335) {
336 let field = spec
337 .args
338 .first()
339 .and_then(|e| {
340 if let SqlExpr::Column(c) = e {
341 Some(c.as_str())
342 } else {
343 None
344 }
345 })
346 .unwrap_or("*");
347
348 let use_running = spec.frame.mode == "range"
349 && matches!(spec.frame.start, FrameBound::UnboundedPreceding)
350 && matches!(spec.frame.end, FrameBound::CurrentRow);
351
352 if use_running {
353 let mut running_sum = 0.0f64;
354 let mut running_count = 0u64;
355 let mut running_min: Option<f64> = None;
356 let mut running_max: Option<f64> = None;
357
358 for (pos, &i) in indices.iter().enumerate() {
359 let val = get_field(&rows[i].1, field);
360 if let Some(n) = as_f64(&val) {
361 running_sum += n;
362 running_count += 1;
363 running_min = Some(running_min.map_or(n, |m: f64| m.min(n)));
364 running_max = Some(running_max.map_or(n, |m: f64| m.max(n)));
365 } else if spec.func_name == "count" {
366 running_count += 1;
367 }
368
369 let result = match spec.func_name.as_str() {
370 "sum" => serde_json::json!(running_sum),
371 "count" => serde_json::json!(running_count),
372 "avg" => {
373 if running_count > 0 {
374 serde_json::json!(running_sum / running_count as f64)
375 } else {
376 serde_json::Value::Null
377 }
378 }
379 "min" => running_min
380 .map(|m| serde_json::json!(m))
381 .unwrap_or(serde_json::Value::Null),
382 "max" => running_max
383 .map(|m| serde_json::json!(m))
384 .unwrap_or(serde_json::Value::Null),
385 "first_value" => get_field(&rows[indices[0]].1, field),
386 "last_value" => get_field(&rows[indices[pos]].1, field),
387 _ => serde_json::Value::Null,
388 };
389 set_window_col(&mut rows[i].1, &spec.alias, result);
390 }
391 } else {
392 let values: Vec<f64> = indices
393 .iter()
394 .filter_map(|&i| as_f64(&get_field(&rows[i].1, field)))
395 .collect();
396
397 let rt = crate::simd_agg::ts_runtime();
399 let result = match spec.func_name.as_str() {
400 "sum" => serde_json::json!((rt.sum_f64)(&values)),
401 "count" => serde_json::json!(indices.len()),
402 "avg" => {
403 if values.is_empty() {
404 serde_json::Value::Null
405 } else {
406 serde_json::json!((rt.sum_f64)(&values) / values.len() as f64)
407 }
408 }
409 "min" => {
410 if values.is_empty() {
411 serde_json::Value::Null
412 } else {
413 serde_json::json!((rt.min_f64)(&values))
414 }
415 }
416 "max" => {
417 if values.is_empty() {
418 serde_json::Value::Null
419 } else {
420 serde_json::json!((rt.max_f64)(&values))
421 }
422 }
423 "first_value" => indices
424 .first()
425 .map(|&i| get_field(&rows[i].1, field))
426 .unwrap_or(serde_json::Value::Null),
427 "last_value" => indices
428 .last()
429 .map(|&i| get_field(&rows[i].1, field))
430 .unwrap_or(serde_json::Value::Null),
431 _ => serde_json::Value::Null,
432 };
433
434 for &i in indices {
435 set_window_col(&mut rows[i].1, &spec.alias, result.clone());
436 }
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use serde_json::json;
444
445 fn make_rows() -> Vec<(String, serde_json::Value)> {
446 vec![
447 (
448 "1".into(),
449 json!({"dept": "eng", "salary": 100, "name": "Alice"}),
450 ),
451 (
452 "2".into(),
453 json!({"dept": "eng", "salary": 120, "name": "Bob"}),
454 ),
455 (
456 "3".into(),
457 json!({"dept": "eng", "salary": 90, "name": "Carol"}),
458 ),
459 (
460 "4".into(),
461 json!({"dept": "sales", "salary": 80, "name": "Dave"}),
462 ),
463 (
464 "5".into(),
465 json!({"dept": "sales", "salary": 110, "name": "Eve"}),
466 ),
467 ]
468 }
469
470 #[test]
471 fn row_number_single_partition() {
472 let mut rows = make_rows();
473 let spec = WindowFuncSpec {
474 alias: "rn".into(),
475 func_name: "row_number".into(),
476 args: vec![],
477 partition_by: vec![],
478 order_by: vec![],
479 frame: WindowFrame::default(),
480 };
481 evaluate_window_functions(&mut rows, &[spec]);
482 assert_eq!(rows[0].1["rn"], json!(1));
483 assert_eq!(rows[4].1["rn"], json!(5));
484 }
485
486 #[test]
487 fn row_number_partitioned() {
488 let mut rows = make_rows();
489 let spec = WindowFuncSpec {
490 alias: "rn".into(),
491 func_name: "row_number".into(),
492 args: vec![],
493 partition_by: vec!["dept".into()],
494 order_by: vec![],
495 frame: WindowFrame::default(),
496 };
497 evaluate_window_functions(&mut rows, &[spec]);
498 assert_eq!(rows[0].1["rn"], json!(1));
499 assert_eq!(rows[2].1["rn"], json!(3));
500 assert_eq!(rows[3].1["rn"], json!(1));
501 assert_eq!(rows[4].1["rn"], json!(2));
502 }
503
504 #[test]
505 fn running_sum() {
506 let mut rows = make_rows();
507 let spec = WindowFuncSpec {
508 alias: "running_total".into(),
509 func_name: "sum".into(),
510 args: vec![SqlExpr::Column("salary".into())],
511 partition_by: vec!["dept".into()],
512 order_by: vec![("salary".into(), true)],
513 frame: WindowFrame::default(),
514 };
515 evaluate_window_functions(&mut rows, &[spec]);
516 assert_eq!(rows[0].1["running_total"], json!(100.0));
517 assert_eq!(rows[1].1["running_total"], json!(220.0));
518 assert_eq!(rows[2].1["running_total"], json!(310.0));
519 assert_eq!(rows[3].1["running_total"], json!(80.0));
520 assert_eq!(rows[4].1["running_total"], json!(190.0));
521 }
522}