1use serde::{Deserialize, Serialize};
35
36use crate::sql::DatabaseType;
37use crate::types::SortOrder;
38
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub struct WindowFunction {
42 pub function: WindowFn,
44 pub over: WindowSpec,
46 pub alias: Option<String>,
48}
49
50#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
52pub enum WindowFn {
53 RowNumber,
56 Rank,
58 DenseRank,
60 Ntile(u32),
62 PercentRank,
64 CumeDist,
66
67 Lag {
70 expr: String,
71 offset: Option<u32>,
72 default: Option<String>,
73 },
74 Lead {
76 expr: String,
77 offset: Option<u32>,
78 default: Option<String>,
79 },
80 FirstValue(String),
82 LastValue(String),
84 NthValue(String, u32),
86
87 Sum(String),
90 Avg(String),
92 Count(String),
94 Min(String),
96 Max(String),
98 Custom { name: String, args: Vec<String> },
100}
101
102impl WindowFn {
103 pub fn to_sql(&self) -> String {
105 match self {
106 Self::RowNumber => "ROW_NUMBER()".to_string(),
107 Self::Rank => "RANK()".to_string(),
108 Self::DenseRank => "DENSE_RANK()".to_string(),
109 Self::Ntile(n) => format!("NTILE({})", n),
110 Self::PercentRank => "PERCENT_RANK()".to_string(),
111 Self::CumeDist => "CUME_DIST()".to_string(),
112 Self::Lag {
113 expr,
114 offset,
115 default,
116 } => {
117 let mut sql = format!("LAG({})", expr);
118 if let Some(off) = offset {
119 sql = format!("LAG({}, {})", expr, off);
120 if let Some(def) = default {
121 sql = format!("LAG({}, {}, {})", expr, off, def);
122 }
123 }
124 sql
125 }
126 Self::Lead {
127 expr,
128 offset,
129 default,
130 } => {
131 let mut sql = format!("LEAD({})", expr);
132 if let Some(off) = offset {
133 sql = format!("LEAD({}, {})", expr, off);
134 if let Some(def) = default {
135 sql = format!("LEAD({}, {}, {})", expr, off, def);
136 }
137 }
138 sql
139 }
140 Self::FirstValue(expr) => format!("FIRST_VALUE({})", expr),
141 Self::LastValue(expr) => format!("LAST_VALUE({})", expr),
142 Self::NthValue(expr, n) => format!("NTH_VALUE({}, {})", expr, n),
143 Self::Sum(expr) => format!("SUM({})", expr),
144 Self::Avg(expr) => format!("AVG({})", expr),
145 Self::Count(expr) => format!("COUNT({})", expr),
146 Self::Min(expr) => format!("MIN({})", expr),
147 Self::Max(expr) => format!("MAX({})", expr),
148 Self::Custom { name, args } => {
149 format!("{}({})", name, args.join(", "))
150 }
151 }
152 }
153}
154
155#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
157pub struct WindowSpec {
158 pub window_name: Option<String>,
160 pub partition_by: Vec<String>,
162 pub order_by: Vec<OrderSpec>,
164 pub frame: Option<FrameClause>,
166}
167
168#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
170pub struct OrderSpec {
171 pub expr: String,
173 pub direction: SortOrder,
175 pub nulls: Option<NullsPosition>,
177}
178
179#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
181pub enum NullsPosition {
182 First,
184 Last,
186}
187
188#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
190pub struct FrameClause {
191 pub frame_type: FrameType,
193 pub start: FrameBound,
195 pub end: Option<FrameBound>,
197 pub exclude: Option<FrameExclude>,
199}
200
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
203pub enum FrameType {
204 Rows,
206 Range,
208 Groups,
210}
211
212#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
214pub enum FrameBound {
215 UnboundedPreceding,
217 Preceding(u32),
219 CurrentRow,
221 Following(u32),
223 UnboundedFollowing,
225}
226
227#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
229pub enum FrameExclude {
230 CurrentRow,
232 Group,
234 Ties,
236 NoOthers,
238}
239
240impl WindowSpec {
241 pub fn new() -> Self {
243 Self::default()
244 }
245
246 pub fn named(name: impl Into<String>) -> Self {
248 Self {
249 window_name: Some(name.into()),
250 ..Default::default()
251 }
252 }
253
254 pub fn partition_by<I, S>(mut self, columns: I) -> Self
256 where
257 I: IntoIterator<Item = S>,
258 S: Into<String>,
259 {
260 self.partition_by = columns.into_iter().map(Into::into).collect();
261 self
262 }
263
264 pub fn order_by(mut self, column: impl Into<String>, direction: SortOrder) -> Self {
266 self.order_by.push(OrderSpec {
267 expr: column.into(),
268 direction,
269 nulls: None,
270 });
271 self
272 }
273
274 pub fn order_by_nulls(
276 mut self,
277 column: impl Into<String>,
278 direction: SortOrder,
279 nulls: NullsPosition,
280 ) -> Self {
281 self.order_by.push(OrderSpec {
282 expr: column.into(),
283 direction,
284 nulls: Some(nulls),
285 });
286 self
287 }
288
289 pub fn rows(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
291 self.frame = Some(FrameClause {
292 frame_type: FrameType::Rows,
293 start,
294 end,
295 exclude: None,
296 });
297 self
298 }
299
300 pub fn range(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
302 self.frame = Some(FrameClause {
303 frame_type: FrameType::Range,
304 start,
305 end,
306 exclude: None,
307 });
308 self
309 }
310
311 pub fn groups(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
313 self.frame = Some(FrameClause {
314 frame_type: FrameType::Groups,
315 start,
316 end,
317 exclude: None,
318 });
319 self
320 }
321
322 pub fn rows_unbounded_preceding(self) -> Self {
324 self.rows(FrameBound::UnboundedPreceding, Some(FrameBound::CurrentRow))
325 }
326
327 pub fn rows_unbounded_following(self) -> Self {
329 self.rows(FrameBound::CurrentRow, Some(FrameBound::UnboundedFollowing))
330 }
331
332 pub fn rows_around(self, n: u32) -> Self {
334 self.rows(FrameBound::Preceding(n), Some(FrameBound::Following(n)))
335 }
336
337 pub fn range_unbounded_preceding(self) -> Self {
339 self.range(FrameBound::UnboundedPreceding, Some(FrameBound::CurrentRow))
340 }
341
342 pub fn to_sql(&self, db_type: DatabaseType) -> String {
344 if let Some(ref name) = self.window_name {
345 return format!("OVER {}", name);
346 }
347
348 let mut parts = Vec::new();
349
350 if !self.partition_by.is_empty() {
351 parts.push(format!("PARTITION BY {}", self.partition_by.join(", ")));
352 }
353
354 if !self.order_by.is_empty() {
355 let orders: Vec<String> = self
356 .order_by
357 .iter()
358 .map(|o| {
359 let mut s = format!(
360 "{} {}",
361 o.expr,
362 match o.direction {
363 SortOrder::Asc => "ASC",
364 SortOrder::Desc => "DESC",
365 }
366 );
367 if let Some(nulls) = o.nulls {
368 if db_type != DatabaseType::MSSQL {
370 s.push_str(match nulls {
371 NullsPosition::First => " NULLS FIRST",
372 NullsPosition::Last => " NULLS LAST",
373 });
374 }
375 }
376 s
377 })
378 .collect();
379 parts.push(format!("ORDER BY {}", orders.join(", ")));
380 }
381
382 if let Some(ref frame) = self.frame {
383 parts.push(frame.to_sql(db_type));
384 }
385
386 if parts.is_empty() {
387 "OVER ()".to_string()
388 } else {
389 format!("OVER ({})", parts.join(" "))
390 }
391 }
392}
393
394impl FrameClause {
395 pub fn to_sql(&self, db_type: DatabaseType) -> String {
397 let frame_type = match self.frame_type {
398 FrameType::Rows => "ROWS",
399 FrameType::Range => "RANGE",
400 FrameType::Groups => {
401 match db_type {
403 DatabaseType::PostgreSQL | DatabaseType::SQLite => "GROUPS",
404 _ => "ROWS", }
406 }
407 };
408
409 let bounds = if let Some(ref end) = self.end {
410 format!("BETWEEN {} AND {}", self.start.to_sql(), end.to_sql())
411 } else {
412 self.start.to_sql()
413 };
414
415 let mut sql = format!("{} {}", frame_type, bounds);
416
417 if db_type == DatabaseType::PostgreSQL {
419 if let Some(exclude) = self.exclude {
420 sql.push_str(match exclude {
421 FrameExclude::CurrentRow => " EXCLUDE CURRENT ROW",
422 FrameExclude::Group => " EXCLUDE GROUP",
423 FrameExclude::Ties => " EXCLUDE TIES",
424 FrameExclude::NoOthers => " EXCLUDE NO OTHERS",
425 });
426 }
427 }
428
429 sql
430 }
431}
432
433impl FrameBound {
434 pub fn to_sql(&self) -> String {
436 match self {
437 Self::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
438 Self::Preceding(n) => format!("{} PRECEDING", n),
439 Self::CurrentRow => "CURRENT ROW".to_string(),
440 Self::Following(n) => format!("{} FOLLOWING", n),
441 Self::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
442 }
443 }
444}
445
446impl WindowFunction {
447 pub fn new(function: WindowFn) -> WindowFunctionBuilder {
449 WindowFunctionBuilder {
450 function,
451 over: None,
452 alias: None,
453 }
454 }
455
456 pub fn over(mut self, spec: WindowSpec) -> Self {
458 self.over = spec;
459 self
460 }
461
462 pub fn alias(mut self, name: impl Into<String>) -> Self {
464 self.alias = Some(name.into());
465 self
466 }
467
468 pub fn to_sql(&self, db_type: DatabaseType) -> String {
470 let mut sql = format!("{} {}", self.function.to_sql(), self.over.to_sql(db_type));
471 if let Some(ref alias) = self.alias {
472 sql.push_str(" AS ");
473 sql.push_str(alias);
474 }
475 sql
476 }
477}
478
479#[derive(Debug, Clone)]
481pub struct WindowFunctionBuilder {
482 function: WindowFn,
483 over: Option<WindowSpec>,
484 alias: Option<String>,
485}
486
487impl WindowFunctionBuilder {
488 pub fn over(mut self, spec: WindowSpec) -> Self {
490 self.over = Some(spec);
491 self
492 }
493
494 pub fn alias(mut self, name: impl Into<String>) -> Self {
496 self.alias = Some(name.into());
497 self
498 }
499
500 pub fn build(self) -> WindowFunction {
502 WindowFunction {
503 function: self.function,
504 over: self.over.unwrap_or_default(),
505 alias: self.alias,
506 }
507 }
508}
509
510#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
512pub struct NamedWindow {
513 pub name: String,
515 pub spec: WindowSpec,
517}
518
519impl NamedWindow {
520 pub fn new(name: impl Into<String>, spec: WindowSpec) -> Self {
522 Self {
523 name: name.into(),
524 spec,
525 }
526 }
527
528 pub fn to_sql(&self, db_type: DatabaseType) -> String {
530 let spec_parts = {
532 let mut parts = Vec::new();
533 if !self.spec.partition_by.is_empty() {
534 parts.push(format!(
535 "PARTITION BY {}",
536 self.spec.partition_by.join(", ")
537 ));
538 }
539 if !self.spec.order_by.is_empty() {
540 let orders: Vec<String> = self
541 .spec
542 .order_by
543 .iter()
544 .map(|o| {
545 format!(
546 "{} {}",
547 o.expr,
548 match o.direction {
549 SortOrder::Asc => "ASC",
550 SortOrder::Desc => "DESC",
551 }
552 )
553 })
554 .collect();
555 parts.push(format!("ORDER BY {}", orders.join(", ")));
556 }
557 if let Some(ref frame) = self.spec.frame {
558 parts.push(frame.to_sql(db_type));
559 }
560 parts.join(" ")
561 };
562
563 format!("{} AS ({})", self.name, spec_parts)
564 }
565}
566
567pub fn row_number() -> WindowFunctionBuilder {
573 WindowFunction::new(WindowFn::RowNumber)
574}
575
576pub fn rank() -> WindowFunctionBuilder {
578 WindowFunction::new(WindowFn::Rank)
579}
580
581pub fn dense_rank() -> WindowFunctionBuilder {
583 WindowFunction::new(WindowFn::DenseRank)
584}
585
586pub fn ntile(n: u32) -> WindowFunctionBuilder {
588 WindowFunction::new(WindowFn::Ntile(n))
589}
590
591pub fn percent_rank() -> WindowFunctionBuilder {
593 WindowFunction::new(WindowFn::PercentRank)
594}
595
596pub fn cume_dist() -> WindowFunctionBuilder {
598 WindowFunction::new(WindowFn::CumeDist)
599}
600
601pub fn lag(expr: impl Into<String>) -> WindowFunctionBuilder {
603 WindowFunction::new(WindowFn::Lag {
604 expr: expr.into(),
605 offset: None,
606 default: None,
607 })
608}
609
610pub fn lag_offset(expr: impl Into<String>, offset: u32) -> WindowFunctionBuilder {
612 WindowFunction::new(WindowFn::Lag {
613 expr: expr.into(),
614 offset: Some(offset),
615 default: None,
616 })
617}
618
619pub fn lag_full(
621 expr: impl Into<String>,
622 offset: u32,
623 default: impl Into<String>,
624) -> WindowFunctionBuilder {
625 WindowFunction::new(WindowFn::Lag {
626 expr: expr.into(),
627 offset: Some(offset),
628 default: Some(default.into()),
629 })
630}
631
632pub fn lead(expr: impl Into<String>) -> WindowFunctionBuilder {
634 WindowFunction::new(WindowFn::Lead {
635 expr: expr.into(),
636 offset: None,
637 default: None,
638 })
639}
640
641pub fn lead_offset(expr: impl Into<String>, offset: u32) -> WindowFunctionBuilder {
643 WindowFunction::new(WindowFn::Lead {
644 expr: expr.into(),
645 offset: Some(offset),
646 default: None,
647 })
648}
649
650pub fn lead_full(
652 expr: impl Into<String>,
653 offset: u32,
654 default: impl Into<String>,
655) -> WindowFunctionBuilder {
656 WindowFunction::new(WindowFn::Lead {
657 expr: expr.into(),
658 offset: Some(offset),
659 default: Some(default.into()),
660 })
661}
662
663pub fn first_value(expr: impl Into<String>) -> WindowFunctionBuilder {
665 WindowFunction::new(WindowFn::FirstValue(expr.into()))
666}
667
668pub fn last_value(expr: impl Into<String>) -> WindowFunctionBuilder {
670 WindowFunction::new(WindowFn::LastValue(expr.into()))
671}
672
673pub fn nth_value(expr: impl Into<String>, n: u32) -> WindowFunctionBuilder {
675 WindowFunction::new(WindowFn::NthValue(expr.into(), n))
676}
677
678pub fn sum(expr: impl Into<String>) -> WindowFunctionBuilder {
680 WindowFunction::new(WindowFn::Sum(expr.into()))
681}
682
683pub fn avg(expr: impl Into<String>) -> WindowFunctionBuilder {
685 WindowFunction::new(WindowFn::Avg(expr.into()))
686}
687
688pub fn count(expr: impl Into<String>) -> WindowFunctionBuilder {
690 WindowFunction::new(WindowFn::Count(expr.into()))
691}
692
693pub fn min(expr: impl Into<String>) -> WindowFunctionBuilder {
695 WindowFunction::new(WindowFn::Min(expr.into()))
696}
697
698pub fn max(expr: impl Into<String>) -> WindowFunctionBuilder {
700 WindowFunction::new(WindowFn::Max(expr.into()))
701}
702
703pub fn custom<I, S>(name: impl Into<String>, args: I) -> WindowFunctionBuilder
705where
706 I: IntoIterator<Item = S>,
707 S: Into<String>,
708{
709 WindowFunction::new(WindowFn::Custom {
710 name: name.into(),
711 args: args.into_iter().map(Into::into).collect(),
712 })
713}
714
715pub mod mongodb {
717 use serde::{Deserialize, Serialize};
718 use serde_json::Value as JsonValue;
719
720 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
722 pub struct SetWindowFields {
723 pub partition_by: Option<JsonValue>,
725 pub sort_by: Option<JsonValue>,
727 pub output: serde_json::Map<String, JsonValue>,
729 }
730
731 impl SetWindowFields {
732 pub fn new() -> SetWindowFieldsBuilder {
734 SetWindowFieldsBuilder::default()
735 }
736
737 pub fn to_bson(&self) -> JsonValue {
739 let mut stage = serde_json::Map::new();
740
741 if let Some(ref partition) = self.partition_by {
742 stage.insert("partitionBy".to_string(), partition.clone());
743 }
744
745 if let Some(ref sort) = self.sort_by {
746 stage.insert("sortBy".to_string(), sort.clone());
747 }
748
749 stage.insert("output".to_string(), JsonValue::Object(self.output.clone()));
750
751 serde_json::json!({ "$setWindowFields": stage })
752 }
753 }
754
755 impl Default for SetWindowFields {
756 fn default() -> Self {
757 Self {
758 partition_by: None,
759 sort_by: None,
760 output: serde_json::Map::new(),
761 }
762 }
763 }
764
765 #[derive(Debug, Clone, Default)]
767 pub struct SetWindowFieldsBuilder {
768 partition_by: Option<JsonValue>,
769 sort_by: Option<JsonValue>,
770 output: serde_json::Map<String, JsonValue>,
771 }
772
773 impl SetWindowFieldsBuilder {
774 pub fn partition_by(mut self, expr: impl Into<String>) -> Self {
776 self.partition_by = Some(JsonValue::String(format!("${}", expr.into())));
777 self
778 }
779
780 pub fn partition_by_expr(mut self, expr: JsonValue) -> Self {
782 self.partition_by = Some(expr);
783 self
784 }
785
786 pub fn sort_by(mut self, field: impl Into<String>) -> Self {
788 let mut sort = serde_json::Map::new();
789 sort.insert(field.into(), JsonValue::Number(1.into()));
790 self.sort_by = Some(JsonValue::Object(sort));
791 self
792 }
793
794 pub fn sort_by_desc(mut self, field: impl Into<String>) -> Self {
796 let mut sort = serde_json::Map::new();
797 sort.insert(field.into(), JsonValue::Number((-1).into()));
798 self.sort_by = Some(JsonValue::Object(sort));
799 self
800 }
801
802 pub fn sort_by_fields(mut self, fields: Vec<(&str, i32)>) -> Self {
804 let mut sort = serde_json::Map::new();
805 for (field, dir) in fields {
806 sort.insert(field.to_string(), JsonValue::Number(dir.into()));
807 }
808 self.sort_by = Some(JsonValue::Object(sort));
809 self
810 }
811
812 pub fn row_number(mut self, output_field: impl Into<String>) -> Self {
814 self.output
815 .insert(output_field.into(), serde_json::json!({ "$rowNumber": {} }));
816 self
817 }
818
819 pub fn rank(mut self, output_field: impl Into<String>) -> Self {
821 self.output
822 .insert(output_field.into(), serde_json::json!({ "$rank": {} }));
823 self
824 }
825
826 pub fn dense_rank(mut self, output_field: impl Into<String>) -> Self {
828 self.output
829 .insert(output_field.into(), serde_json::json!({ "$denseRank": {} }));
830 self
831 }
832
833 pub fn sum(
835 mut self,
836 output_field: impl Into<String>,
837 input: impl Into<String>,
838 window: Option<MongoWindow>,
839 ) -> Self {
840 let mut spec = serde_json::Map::new();
841 spec.insert(
842 "$sum".to_string(),
843 JsonValue::String(format!("${}", input.into())),
844 );
845 if let Some(w) = window {
846 spec.insert("window".to_string(), w.to_bson());
847 }
848 self.output
849 .insert(output_field.into(), JsonValue::Object(spec));
850 self
851 }
852
853 pub fn avg(
855 mut self,
856 output_field: impl Into<String>,
857 input: impl Into<String>,
858 window: Option<MongoWindow>,
859 ) -> Self {
860 let mut spec = serde_json::Map::new();
861 spec.insert(
862 "$avg".to_string(),
863 JsonValue::String(format!("${}", input.into())),
864 );
865 if let Some(w) = window {
866 spec.insert("window".to_string(), w.to_bson());
867 }
868 self.output
869 .insert(output_field.into(), JsonValue::Object(spec));
870 self
871 }
872
873 pub fn first(mut self, output_field: impl Into<String>, input: impl Into<String>) -> Self {
875 self.output.insert(
876 output_field.into(),
877 serde_json::json!({ "$first": format!("${}", input.into()) }),
878 );
879 self
880 }
881
882 pub fn last(mut self, output_field: impl Into<String>, input: impl Into<String>) -> Self {
884 self.output.insert(
885 output_field.into(),
886 serde_json::json!({ "$last": format!("${}", input.into()) }),
887 );
888 self
889 }
890
891 pub fn shift(
893 mut self,
894 output_field: impl Into<String>,
895 output: impl Into<String>,
896 by: i32,
897 default: Option<JsonValue>,
898 ) -> Self {
899 let mut spec = serde_json::Map::new();
900 spec.insert(
901 "output".to_string(),
902 JsonValue::String(format!("${}", output.into())),
903 );
904 spec.insert("by".to_string(), JsonValue::Number(by.into()));
905 if let Some(def) = default {
906 spec.insert("default".to_string(), def);
907 }
908 self.output
909 .insert(output_field.into(), serde_json::json!({ "$shift": spec }));
910 self
911 }
912
913 pub fn output(mut self, field: impl Into<String>, spec: JsonValue) -> Self {
915 self.output.insert(field.into(), spec);
916 self
917 }
918
919 pub fn build(self) -> SetWindowFields {
921 SetWindowFields {
922 partition_by: self.partition_by,
923 sort_by: self.sort_by,
924 output: self.output,
925 }
926 }
927 }
928
929 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
931 pub struct MongoWindow {
932 pub documents: Option<[WindowBound; 2]>,
934 pub range: Option<[WindowBound; 2]>,
936 pub unit: Option<String>,
938 }
939
940 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
942 #[serde(untagged)]
943 pub enum WindowBound {
944 Number(i64),
946 Keyword(String),
948 }
949
950 impl MongoWindow {
951 pub fn documents(start: i64, end: i64) -> Self {
953 Self {
954 documents: Some([WindowBound::Number(start), WindowBound::Number(end)]),
955 range: None,
956 unit: None,
957 }
958 }
959
960 pub fn documents_unbounded() -> Self {
962 Self {
963 documents: Some([
964 WindowBound::Keyword("unbounded".to_string()),
965 WindowBound::Keyword("unbounded".to_string()),
966 ]),
967 range: None,
968 unit: None,
969 }
970 }
971
972 pub fn documents_to_current() -> Self {
974 Self {
975 documents: Some([
976 WindowBound::Keyword("unbounded".to_string()),
977 WindowBound::Keyword("current".to_string()),
978 ]),
979 range: None,
980 unit: None,
981 }
982 }
983
984 pub fn range_with_unit(start: i64, end: i64, unit: impl Into<String>) -> Self {
986 Self {
987 documents: None,
988 range: Some([WindowBound::Number(start), WindowBound::Number(end)]),
989 unit: Some(unit.into()),
990 }
991 }
992
993 pub fn to_bson(&self) -> JsonValue {
995 let mut window = serde_json::Map::new();
996
997 if let Some(ref docs) = self.documents {
998 let arr: Vec<JsonValue> = docs
999 .iter()
1000 .map(|b| match b {
1001 WindowBound::Number(n) => JsonValue::Number((*n).into()),
1002 WindowBound::Keyword(s) => JsonValue::String(s.clone()),
1003 })
1004 .collect();
1005 window.insert("documents".to_string(), JsonValue::Array(arr));
1006 }
1007
1008 if let Some(ref range) = self.range {
1009 let arr: Vec<JsonValue> = range
1010 .iter()
1011 .map(|b| match b {
1012 WindowBound::Number(n) => JsonValue::Number((*n).into()),
1013 WindowBound::Keyword(s) => JsonValue::String(s.clone()),
1014 })
1015 .collect();
1016 window.insert("range".to_string(), JsonValue::Array(arr));
1017 }
1018
1019 if let Some(ref unit) = self.unit {
1020 window.insert("unit".to_string(), JsonValue::String(unit.clone()));
1021 }
1022
1023 JsonValue::Object(window)
1024 }
1025 }
1026
1027 pub fn set_window_fields() -> SetWindowFieldsBuilder {
1029 SetWindowFields::new()
1030 }
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035 use super::*;
1036
1037 #[test]
1038 fn test_row_number() {
1039 let wf = row_number()
1040 .over(
1041 WindowSpec::new()
1042 .partition_by(["dept"])
1043 .order_by("salary", SortOrder::Desc),
1044 )
1045 .build();
1046
1047 let sql = wf.to_sql(DatabaseType::PostgreSQL);
1048 assert!(sql.contains("ROW_NUMBER()"));
1049 assert!(sql.contains("PARTITION BY dept"));
1050 assert!(sql.contains("ORDER BY salary DESC"));
1051 }
1052
1053 #[test]
1054 fn test_rank_functions() {
1055 let r = rank()
1056 .over(WindowSpec::new().order_by("score", SortOrder::Desc))
1057 .build();
1058 assert!(r.to_sql(DatabaseType::PostgreSQL).contains("RANK()"));
1059
1060 let dr = dense_rank()
1061 .over(WindowSpec::new().order_by("score", SortOrder::Desc))
1062 .build();
1063 assert!(dr.to_sql(DatabaseType::PostgreSQL).contains("DENSE_RANK()"));
1064 }
1065
1066 #[test]
1067 fn test_ntile() {
1068 let wf = ntile(4)
1069 .over(WindowSpec::new().order_by("value", SortOrder::Asc))
1070 .build();
1071
1072 assert!(wf.to_sql(DatabaseType::MySQL).contains("NTILE(4)"));
1073 }
1074
1075 #[test]
1076 fn test_lag_lead() {
1077 let l = lag("price")
1078 .over(WindowSpec::new().order_by("date", SortOrder::Asc))
1079 .build();
1080 assert!(l.to_sql(DatabaseType::PostgreSQL).contains("LAG(price)"));
1081
1082 let l2 = lag_offset("price", 2)
1083 .over(WindowSpec::new().order_by("date", SortOrder::Asc))
1084 .build();
1085 assert!(
1086 l2.to_sql(DatabaseType::PostgreSQL)
1087 .contains("LAG(price, 2)")
1088 );
1089
1090 let l3 = lag_full("price", 1, "0")
1091 .over(WindowSpec::new().order_by("date", SortOrder::Asc))
1092 .build();
1093 assert!(
1094 l3.to_sql(DatabaseType::PostgreSQL)
1095 .contains("LAG(price, 1, 0)")
1096 );
1097
1098 let ld = lead("price")
1099 .over(WindowSpec::new().order_by("date", SortOrder::Asc))
1100 .build();
1101 assert!(ld.to_sql(DatabaseType::PostgreSQL).contains("LEAD(price)"));
1102 }
1103
1104 #[test]
1105 fn test_aggregate_window() {
1106 let s = sum("amount")
1107 .over(
1108 WindowSpec::new()
1109 .partition_by(["account_id"])
1110 .order_by("date", SortOrder::Asc)
1111 .rows_unbounded_preceding(),
1112 )
1113 .alias("running_total")
1114 .build();
1115
1116 let sql = s.to_sql(DatabaseType::PostgreSQL);
1117 assert!(sql.contains("SUM(amount)"));
1118 assert!(sql.contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"));
1119 assert!(sql.contains("AS running_total"));
1120 }
1121
1122 #[test]
1123 fn test_frame_clauses() {
1124 let spec = WindowSpec::new()
1125 .order_by("id", SortOrder::Asc)
1126 .rows(FrameBound::Preceding(3), Some(FrameBound::Following(3)));
1127
1128 let sql = spec.to_sql(DatabaseType::PostgreSQL);
1129 assert!(sql.contains("ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING"));
1130 }
1131
1132 #[test]
1133 fn test_named_window() {
1134 let nw = NamedWindow::new(
1135 "w",
1136 WindowSpec::new()
1137 .partition_by(["dept"])
1138 .order_by("salary", SortOrder::Desc),
1139 );
1140
1141 let sql = nw.to_sql(DatabaseType::PostgreSQL);
1142 assert!(sql.contains("w AS ("));
1143 assert!(sql.contains("PARTITION BY dept"));
1144 }
1145
1146 #[test]
1147 fn test_window_reference() {
1148 let spec = WindowSpec::named("w");
1149 assert_eq!(spec.to_sql(DatabaseType::PostgreSQL), "OVER w");
1150 }
1151
1152 #[test]
1153 fn test_nulls_position() {
1154 let spec = WindowSpec::new().order_by_nulls("value", SortOrder::Desc, NullsPosition::Last);
1155
1156 let pg_sql = spec.to_sql(DatabaseType::PostgreSQL);
1157 assert!(pg_sql.contains("NULLS LAST"));
1158
1159 let mssql_sql = spec.to_sql(DatabaseType::MSSQL);
1161 assert!(!mssql_sql.contains("NULLS"));
1162 }
1163
1164 #[test]
1165 fn test_first_last_value() {
1166 let fv = first_value("salary")
1167 .over(
1168 WindowSpec::new()
1169 .partition_by(["dept"])
1170 .order_by("hire_date", SortOrder::Asc),
1171 )
1172 .build();
1173
1174 assert!(
1175 fv.to_sql(DatabaseType::PostgreSQL)
1176 .contains("FIRST_VALUE(salary)")
1177 );
1178
1179 let lv = last_value("salary")
1180 .over(
1181 WindowSpec::new()
1182 .partition_by(["dept"])
1183 .order_by("hire_date", SortOrder::Asc)
1184 .rows(
1185 FrameBound::UnboundedPreceding,
1186 Some(FrameBound::UnboundedFollowing),
1187 ),
1188 )
1189 .build();
1190
1191 assert!(
1192 lv.to_sql(DatabaseType::PostgreSQL)
1193 .contains("LAST_VALUE(salary)")
1194 );
1195 }
1196
1197 mod mongodb_tests {
1198 use super::super::mongodb::*;
1199
1200 #[test]
1201 fn test_row_number() {
1202 let stage = set_window_fields()
1203 .partition_by("state")
1204 .sort_by_desc("quantity")
1205 .row_number("rowNumber")
1206 .build();
1207
1208 let bson = stage.to_bson();
1209 assert!(bson["$setWindowFields"]["output"]["rowNumber"]["$rowNumber"].is_object());
1210 }
1211
1212 #[test]
1213 fn test_rank() {
1214 let stage = set_window_fields()
1215 .sort_by("score")
1216 .rank("ranking")
1217 .dense_rank("denseRanking")
1218 .build();
1219
1220 let bson = stage.to_bson();
1221 assert!(bson["$setWindowFields"]["output"]["ranking"]["$rank"].is_object());
1222 assert!(bson["$setWindowFields"]["output"]["denseRanking"]["$denseRank"].is_object());
1223 }
1224
1225 #[test]
1226 fn test_running_total() {
1227 let stage = set_window_fields()
1228 .partition_by("account")
1229 .sort_by("date")
1230 .sum(
1231 "runningTotal",
1232 "amount",
1233 Some(MongoWindow::documents_to_current()),
1234 )
1235 .build();
1236
1237 let bson = stage.to_bson();
1238 let output = &bson["$setWindowFields"]["output"]["runningTotal"];
1239 assert!(output["$sum"].is_string());
1240 assert!(output["window"]["documents"].is_array());
1241 }
1242
1243 #[test]
1244 fn test_shift_lag() {
1245 let stage = set_window_fields()
1246 .sort_by("date")
1247 .shift("prevPrice", "price", -1, Some(serde_json::json!(0)))
1248 .shift("nextPrice", "price", 1, None)
1249 .build();
1250
1251 let bson = stage.to_bson();
1252 assert!(bson["$setWindowFields"]["output"]["prevPrice"]["$shift"]["by"] == -1);
1253 assert!(bson["$setWindowFields"]["output"]["nextPrice"]["$shift"]["by"] == 1);
1254 }
1255
1256 #[test]
1257 fn test_window_bounds() {
1258 let w = MongoWindow::documents(-3, 3);
1259 let bson = w.to_bson();
1260 assert_eq!(bson["documents"][0], -3);
1261 assert_eq!(bson["documents"][1], 3);
1262
1263 let w2 = MongoWindow::range_with_unit(-7, 0, "day");
1264 let bson2 = w2.to_bson();
1265 assert!(bson2["range"].is_array());
1266 assert_eq!(bson2["unit"], "day");
1267 }
1268 }
1269}