1use crate::expr::SqlExpr;
7
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10pub struct WindowFuncSpec {
11 pub alias: String,
13 pub func_name: String,
15 pub args: Vec<SqlExpr>,
17 pub partition_by: Vec<String>,
19 pub order_by: Vec<(String, bool)>,
21 pub frame: WindowFrame,
23}
24
25#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
27pub struct WindowFrame {
28 pub mode: String,
30 pub start: FrameBound,
32 pub end: FrameBound,
34}
35
36impl Default for WindowFrame {
37 fn default() -> Self {
38 Self {
39 mode: "range".into(),
40 start: FrameBound::UnboundedPreceding,
41 end: FrameBound::CurrentRow,
42 }
43 }
44}
45
46#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
48pub enum FrameBound {
49 UnboundedPreceding,
50 Preceding(u64),
51 CurrentRow,
52 Following(u64),
53 UnboundedFollowing,
54}
55
56pub fn evaluate_window_functions(
61 rows: &mut [(String, serde_json::Value)],
62 specs: &[WindowFuncSpec],
63) {
64 for spec in specs {
65 let partitions = build_partitions(rows, &spec.partition_by);
66
67 for partition_indices in &partitions {
68 match spec.func_name.as_str() {
69 "row_number" => apply_row_number(rows, partition_indices, &spec.alias),
70 "rank" => apply_rank(rows, partition_indices, &spec.alias, &spec.order_by),
71 "dense_rank" => {
72 apply_dense_rank(rows, partition_indices, &spec.alias, &spec.order_by)
73 }
74 "lag" => apply_lag(rows, partition_indices, spec),
75 "lead" => apply_lead(rows, partition_indices, spec),
76 "ntile" => apply_ntile(rows, partition_indices, spec),
77 "sum" | "count" | "avg" | "min" | "max" | "first_value" | "last_value" => {
78 apply_aggregate_window(rows, partition_indices, spec)
79 }
80 _ => {}
81 }
82 }
83 }
84}
85
86fn build_partitions(
87 rows: &[(String, serde_json::Value)],
88 partition_by: &[String],
89) -> Vec<Vec<usize>> {
90 if partition_by.is_empty() {
91 return vec![(0..rows.len()).collect()];
92 }
93
94 let mut groups: std::collections::HashMap<String, Vec<usize>> =
95 std::collections::HashMap::new();
96 let mut order = Vec::new();
97
98 for (i, (_id, doc)) in rows.iter().enumerate() {
99 let key: String = partition_by
100 .iter()
101 .map(|col| {
102 doc.get(col)
103 .map(|v| v.to_string())
104 .unwrap_or_else(|| "null".to_string())
105 })
106 .collect::<Vec<_>>()
107 .join("\x00");
108 let entry = groups.entry(key.clone()).or_default();
109 if entry.is_empty() {
110 order.push(key);
111 }
112 entry.push(i);
113 }
114
115 order.iter().filter_map(|k| groups.remove(k)).collect()
116}
117
118fn set_window_col(row: &mut serde_json::Value, alias: &str, val: serde_json::Value) {
119 if let serde_json::Value::Object(map) = row {
120 map.insert(alias.to_string(), val);
121 }
122}
123
124fn get_field(doc: &serde_json::Value, field: &str) -> serde_json::Value {
125 doc.get(field).cloned().unwrap_or(serde_json::Value::Null)
126}
127
128fn as_f64(v: &serde_json::Value) -> Option<f64> {
129 match v {
130 serde_json::Value::Number(n) => n.as_f64(),
131 serde_json::Value::String(s) => s.parse().ok(),
132 _ => None,
133 }
134}
135
136fn apply_row_number(rows: &mut [(String, serde_json::Value)], indices: &[usize], alias: &str) {
137 for (rank, &i) in indices.iter().enumerate() {
138 set_window_col(&mut rows[i].1, alias, serde_json::json!(rank + 1));
139 }
140}
141
142fn apply_rank(
143 rows: &mut [(String, serde_json::Value)],
144 indices: &[usize],
145 alias: &str,
146 order_by: &[(String, bool)],
147) {
148 if indices.is_empty() {
149 return;
150 }
151 let mut current_rank = 1;
152 set_window_col(&mut rows[indices[0]].1, alias, serde_json::json!(1));
153
154 for pos in 1..indices.len() {
155 let prev = &rows[indices[pos - 1]].1;
156 let curr = &rows[indices[pos]].1;
157 let same = order_by
158 .iter()
159 .all(|(col, _)| get_field(prev, col) == get_field(curr, col));
160 if !same {
161 current_rank = pos + 1;
162 }
163 set_window_col(
164 &mut rows[indices[pos]].1,
165 alias,
166 serde_json::json!(current_rank),
167 );
168 }
169}
170
171fn apply_dense_rank(
172 rows: &mut [(String, serde_json::Value)],
173 indices: &[usize],
174 alias: &str,
175 order_by: &[(String, bool)],
176) {
177 if indices.is_empty() {
178 return;
179 }
180 let mut current_rank = 1;
181 set_window_col(&mut rows[indices[0]].1, alias, serde_json::json!(1));
182
183 for pos in 1..indices.len() {
184 let prev = &rows[indices[pos - 1]].1;
185 let curr = &rows[indices[pos]].1;
186 let same = order_by
187 .iter()
188 .all(|(col, _)| get_field(prev, col) == get_field(curr, col));
189 if !same {
190 current_rank += 1;
191 }
192 set_window_col(
193 &mut rows[indices[pos]].1,
194 alias,
195 serde_json::json!(current_rank),
196 );
197 }
198}
199
200fn apply_ntile(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
201 let n = spec
202 .args
203 .first()
204 .and_then(|e| {
205 if let SqlExpr::Literal(v) = e {
206 as_f64(v).map(|x| x as usize)
207 } else {
208 None
209 }
210 })
211 .unwrap_or(1)
212 .max(1);
213 let total = indices.len();
214 for (pos, &i) in indices.iter().enumerate() {
215 let bucket = (pos * n / total) + 1;
216 set_window_col(&mut rows[i].1, &spec.alias, serde_json::json!(bucket));
217 }
218}
219
220fn apply_lag(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
221 let field = spec
222 .args
223 .first()
224 .and_then(|e| {
225 if let SqlExpr::Column(c) = e {
226 Some(c.as_str())
227 } else {
228 None
229 }
230 })
231 .unwrap_or("*");
232 let offset = spec
233 .args
234 .get(1)
235 .and_then(|e| {
236 if let SqlExpr::Literal(v) = e {
237 as_f64(v).map(|n| n as usize)
238 } else {
239 None
240 }
241 })
242 .unwrap_or(1);
243 let default = spec
244 .args
245 .get(2)
246 .and_then(|e| {
247 if let SqlExpr::Literal(v) = e {
248 Some(v.clone())
249 } else {
250 None
251 }
252 })
253 .unwrap_or(serde_json::Value::Null);
254
255 for (pos, &i) in indices.iter().enumerate() {
256 let val = if pos >= offset {
257 get_field(&rows[indices[pos - offset]].1, field)
258 } else {
259 default.clone()
260 };
261 set_window_col(&mut rows[i].1, &spec.alias, val);
262 }
263}
264
265fn apply_lead(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
266 let field = spec
267 .args
268 .first()
269 .and_then(|e| {
270 if let SqlExpr::Column(c) = e {
271 Some(c.as_str())
272 } else {
273 None
274 }
275 })
276 .unwrap_or("*");
277 let offset = spec
278 .args
279 .get(1)
280 .and_then(|e| {
281 if let SqlExpr::Literal(v) = e {
282 as_f64(v).map(|n| n as usize)
283 } else {
284 None
285 }
286 })
287 .unwrap_or(1);
288 let default = spec
289 .args
290 .get(2)
291 .and_then(|e| {
292 if let SqlExpr::Literal(v) = e {
293 Some(v.clone())
294 } else {
295 None
296 }
297 })
298 .unwrap_or(serde_json::Value::Null);
299
300 for (pos, &i) in indices.iter().enumerate() {
301 let val = if pos + offset < indices.len() {
302 get_field(&rows[indices[pos + offset]].1, field)
303 } else {
304 default.clone()
305 };
306 set_window_col(&mut rows[i].1, &spec.alias, val);
307 }
308}
309
310fn apply_aggregate_window(
311 rows: &mut [(String, serde_json::Value)],
312 indices: &[usize],
313 spec: &WindowFuncSpec,
314) {
315 let field = spec
316 .args
317 .first()
318 .and_then(|e| {
319 if let SqlExpr::Column(c) = e {
320 Some(c.as_str())
321 } else {
322 None
323 }
324 })
325 .unwrap_or("*");
326
327 let use_running = spec.frame.mode == "range"
328 && matches!(spec.frame.start, FrameBound::UnboundedPreceding)
329 && matches!(spec.frame.end, FrameBound::CurrentRow);
330
331 if use_running {
332 let mut running_sum = 0.0f64;
333 let mut running_count = 0u64;
334 let mut running_min: Option<f64> = None;
335 let mut running_max: Option<f64> = None;
336
337 for (pos, &i) in indices.iter().enumerate() {
338 let val = get_field(&rows[i].1, field);
339 if let Some(n) = as_f64(&val) {
340 running_sum += n;
341 running_count += 1;
342 running_min = Some(running_min.map_or(n, |m: f64| m.min(n)));
343 running_max = Some(running_max.map_or(n, |m: f64| m.max(n)));
344 } else if spec.func_name == "count" {
345 running_count += 1;
346 }
347
348 let result = match spec.func_name.as_str() {
349 "sum" => serde_json::json!(running_sum),
350 "count" => serde_json::json!(running_count),
351 "avg" => {
352 if running_count > 0 {
353 serde_json::json!(running_sum / running_count as f64)
354 } else {
355 serde_json::Value::Null
356 }
357 }
358 "min" => running_min
359 .map(|m| serde_json::json!(m))
360 .unwrap_or(serde_json::Value::Null),
361 "max" => running_max
362 .map(|m| serde_json::json!(m))
363 .unwrap_or(serde_json::Value::Null),
364 "first_value" => get_field(&rows[indices[0]].1, field),
365 "last_value" => get_field(&rows[indices[pos]].1, field),
366 _ => serde_json::Value::Null,
367 };
368 set_window_col(&mut rows[i].1, &spec.alias, result);
369 }
370 } else {
371 let values: Vec<f64> = indices
372 .iter()
373 .filter_map(|&i| as_f64(&get_field(&rows[i].1, field)))
374 .collect();
375
376 let result = match spec.func_name.as_str() {
377 "sum" => serde_json::json!(values.iter().sum::<f64>()),
378 "count" => serde_json::json!(indices.len()),
379 "avg" => {
380 if values.is_empty() {
381 serde_json::Value::Null
382 } else {
383 serde_json::json!(values.iter().sum::<f64>() / values.len() as f64)
384 }
385 }
386 "min" => values
387 .iter()
388 .copied()
389 .reduce(f64::min)
390 .map(|m| serde_json::json!(m))
391 .unwrap_or(serde_json::Value::Null),
392 "max" => values
393 .iter()
394 .copied()
395 .reduce(f64::max)
396 .map(|m| serde_json::json!(m))
397 .unwrap_or(serde_json::Value::Null),
398 "first_value" => get_field(&rows[indices[0]].1, field),
399 "last_value" => get_field(&rows[*indices.last().unwrap()].1, field),
400 _ => serde_json::Value::Null,
401 };
402
403 for &i in indices {
404 set_window_col(&mut rows[i].1, &spec.alias, result.clone());
405 }
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use serde_json::json;
413
414 fn make_rows() -> Vec<(String, serde_json::Value)> {
415 vec![
416 (
417 "1".into(),
418 json!({"dept": "eng", "salary": 100, "name": "Alice"}),
419 ),
420 (
421 "2".into(),
422 json!({"dept": "eng", "salary": 120, "name": "Bob"}),
423 ),
424 (
425 "3".into(),
426 json!({"dept": "eng", "salary": 90, "name": "Carol"}),
427 ),
428 (
429 "4".into(),
430 json!({"dept": "sales", "salary": 80, "name": "Dave"}),
431 ),
432 (
433 "5".into(),
434 json!({"dept": "sales", "salary": 110, "name": "Eve"}),
435 ),
436 ]
437 }
438
439 #[test]
440 fn row_number_single_partition() {
441 let mut rows = make_rows();
442 let spec = WindowFuncSpec {
443 alias: "rn".into(),
444 func_name: "row_number".into(),
445 args: vec![],
446 partition_by: vec![],
447 order_by: vec![],
448 frame: WindowFrame::default(),
449 };
450 evaluate_window_functions(&mut rows, &[spec]);
451 assert_eq!(rows[0].1["rn"], json!(1));
452 assert_eq!(rows[4].1["rn"], json!(5));
453 }
454
455 #[test]
456 fn row_number_partitioned() {
457 let mut rows = make_rows();
458 let spec = WindowFuncSpec {
459 alias: "rn".into(),
460 func_name: "row_number".into(),
461 args: vec![],
462 partition_by: vec!["dept".into()],
463 order_by: vec![],
464 frame: WindowFrame::default(),
465 };
466 evaluate_window_functions(&mut rows, &[spec]);
467 assert_eq!(rows[0].1["rn"], json!(1));
468 assert_eq!(rows[2].1["rn"], json!(3));
469 assert_eq!(rows[3].1["rn"], json!(1));
470 assert_eq!(rows[4].1["rn"], json!(2));
471 }
472
473 #[test]
474 fn running_sum() {
475 let mut rows = make_rows();
476 let spec = WindowFuncSpec {
477 alias: "running_total".into(),
478 func_name: "sum".into(),
479 args: vec![SqlExpr::Column("salary".into())],
480 partition_by: vec!["dept".into()],
481 order_by: vec![("salary".into(), true)],
482 frame: WindowFrame::default(),
483 };
484 evaluate_window_functions(&mut rows, &[spec]);
485 assert_eq!(rows[0].1["running_total"], json!(100.0));
486 assert_eq!(rows[1].1["running_total"], json!(220.0));
487 assert_eq!(rows[2].1["running_total"], json!(310.0));
488 assert_eq!(rows[3].1["running_total"], json!(80.0));
489 assert_eq!(rows[4].1["running_total"], json!(190.0));
490 }
491}