prax_query/
window.rs

1//! Window functions support.
2//!
3//! This module provides types for building window functions (OVER clauses)
4//! across different database backends.
5//!
6//! # Supported Features
7//!
8//! | Feature         | PostgreSQL | MySQL | SQLite | MSSQL | MongoDB 5.0+ |
9//! |-----------------|------------|-------|--------|-------|--------------|
10//! | ROW_NUMBER      | ✅         | ✅    | ✅     | ✅    | ✅           |
11//! | RANK/DENSE_RANK | ✅         | ✅    | ✅     | ✅    | ✅           |
12//! | LAG/LEAD        | ✅         | ✅    | ✅     | ✅    | ✅           |
13//! | Frame clauses   | ✅         | ✅    | ✅     | ✅    | ✅           |
14//! | Named windows   | ✅         | ✅    | ✅     | ❌    | ❌           |
15//!
16//! # Example Usage
17//!
18//! ```rust,ignore
19//! use prax_query::window::{WindowFunction, WindowSpec, row_number, rank, sum};
20//!
21//! // ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC)
22//! let wf = row_number()
23//!     .over(WindowSpec::new()
24//!         .partition_by(["dept"])
25//!         .order_by("salary", SortOrder::Desc));
26//!
27//! // Running total
28//! let running = sum("amount")
29//!     .over(WindowSpec::new()
30//!         .order_by("date", SortOrder::Asc)
31//!         .rows_unbounded_preceding());
32//! ```
33
34use serde::{Deserialize, Serialize};
35
36use crate::sql::DatabaseType;
37use crate::types::SortOrder;
38
39/// A window function with its OVER clause.
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub struct WindowFunction {
42    /// The function being called.
43    pub function: WindowFn,
44    /// The OVER clause specification.
45    pub over: WindowSpec,
46    /// Optional alias for the result.
47    pub alias: Option<String>,
48}
49
50/// Available window functions.
51#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
52pub enum WindowFn {
53    // Ranking functions
54    /// ROW_NUMBER() - Sequential row number.
55    RowNumber,
56    /// RANK() - Rank with gaps.
57    Rank,
58    /// DENSE_RANK() - Rank without gaps.
59    DenseRank,
60    /// NTILE(n) - Distribute rows into n buckets.
61    Ntile(u32),
62    /// PERCENT_RANK() - Relative rank (0 to 1).
63    PercentRank,
64    /// CUME_DIST() - Cumulative distribution.
65    CumeDist,
66
67    // Value functions
68    /// LAG(expr, offset, default) - Value from previous row.
69    Lag {
70        expr: String,
71        offset: Option<u32>,
72        default: Option<String>,
73    },
74    /// LEAD(expr, offset, default) - Value from next row.
75    Lead {
76        expr: String,
77        offset: Option<u32>,
78        default: Option<String>,
79    },
80    /// FIRST_VALUE(expr) - First value in frame.
81    FirstValue(String),
82    /// LAST_VALUE(expr) - Last value in frame.
83    LastValue(String),
84    /// NTH_VALUE(expr, n) - Nth value in frame.
85    NthValue(String, u32),
86
87    // Aggregate functions as window functions
88    /// SUM(expr).
89    Sum(String),
90    /// AVG(expr).
91    Avg(String),
92    /// COUNT(expr).
93    Count(String),
94    /// MIN(expr).
95    Min(String),
96    /// MAX(expr).
97    Max(String),
98    /// Custom function.
99    Custom { name: String, args: Vec<String> },
100}
101
102impl WindowFn {
103    /// Generate the function SQL.
104    pub fn to_sql(&self) -> String {
105        match self {
106            Self::RowNumber => "ROW_NUMBER()".to_string(),
107            Self::Rank => "RANK()".to_string(),
108            Self::DenseRank => "DENSE_RANK()".to_string(),
109            Self::Ntile(n) => format!("NTILE({})", n),
110            Self::PercentRank => "PERCENT_RANK()".to_string(),
111            Self::CumeDist => "CUME_DIST()".to_string(),
112            Self::Lag { expr, offset, default } => {
113                let mut sql = format!("LAG({})", expr);
114                if let Some(off) = offset {
115                    sql = format!("LAG({}, {})", expr, off);
116                    if let Some(def) = default {
117                        sql = format!("LAG({}, {}, {})", expr, off, def);
118                    }
119                }
120                sql
121            }
122            Self::Lead { expr, offset, default } => {
123                let mut sql = format!("LEAD({})", expr);
124                if let Some(off) = offset {
125                    sql = format!("LEAD({}, {})", expr, off);
126                    if let Some(def) = default {
127                        sql = format!("LEAD({}, {}, {})", expr, off, def);
128                    }
129                }
130                sql
131            }
132            Self::FirstValue(expr) => format!("FIRST_VALUE({})", expr),
133            Self::LastValue(expr) => format!("LAST_VALUE({})", expr),
134            Self::NthValue(expr, n) => format!("NTH_VALUE({}, {})", expr, n),
135            Self::Sum(expr) => format!("SUM({})", expr),
136            Self::Avg(expr) => format!("AVG({})", expr),
137            Self::Count(expr) => format!("COUNT({})", expr),
138            Self::Min(expr) => format!("MIN({})", expr),
139            Self::Max(expr) => format!("MAX({})", expr),
140            Self::Custom { name, args } => {
141                format!("{}({})", name, args.join(", "))
142            }
143        }
144    }
145}
146
147/// Window specification (OVER clause).
148#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
149pub struct WindowSpec {
150    /// Reference to a named window.
151    pub window_name: Option<String>,
152    /// PARTITION BY columns.
153    pub partition_by: Vec<String>,
154    /// ORDER BY specifications.
155    pub order_by: Vec<OrderSpec>,
156    /// Frame clause.
157    pub frame: Option<FrameClause>,
158}
159
160/// Order specification for window functions.
161#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
162pub struct OrderSpec {
163    /// Column or expression to order by.
164    pub expr: String,
165    /// Sort direction.
166    pub direction: SortOrder,
167    /// NULLS FIRST/LAST.
168    pub nulls: Option<NullsPosition>,
169}
170
171/// Position of NULL values in ordering.
172#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
173pub enum NullsPosition {
174    /// NULL values first.
175    First,
176    /// NULL values last.
177    Last,
178}
179
180/// Frame clause for window functions.
181#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
182pub struct FrameClause {
183    /// Frame type (ROWS, RANGE, GROUPS).
184    pub frame_type: FrameType,
185    /// Frame start bound.
186    pub start: FrameBound,
187    /// Frame end bound (if BETWEEN).
188    pub end: Option<FrameBound>,
189    /// Exclude clause (PostgreSQL).
190    pub exclude: Option<FrameExclude>,
191}
192
193/// Frame type.
194#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
195pub enum FrameType {
196    /// Row-based frame.
197    Rows,
198    /// Value-based frame.
199    Range,
200    /// Group-based frame (PostgreSQL).
201    Groups,
202}
203
204/// Frame boundary.
205#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
206pub enum FrameBound {
207    /// UNBOUNDED PRECEDING.
208    UnboundedPreceding,
209    /// n PRECEDING.
210    Preceding(u32),
211    /// CURRENT ROW.
212    CurrentRow,
213    /// n FOLLOWING.
214    Following(u32),
215    /// UNBOUNDED FOLLOWING.
216    UnboundedFollowing,
217}
218
219/// Frame exclusion (PostgreSQL).
220#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
221pub enum FrameExclude {
222    /// EXCLUDE CURRENT ROW.
223    CurrentRow,
224    /// EXCLUDE GROUP.
225    Group,
226    /// EXCLUDE TIES.
227    Ties,
228    /// EXCLUDE NO OTHERS.
229    NoOthers,
230}
231
232impl WindowSpec {
233    /// Create a new empty window specification.
234    pub fn new() -> Self {
235        Self::default()
236    }
237
238    /// Reference a named window.
239    pub fn named(name: impl Into<String>) -> Self {
240        Self {
241            window_name: Some(name.into()),
242            ..Default::default()
243        }
244    }
245
246    /// Add PARTITION BY columns.
247    pub fn partition_by<I, S>(mut self, columns: I) -> Self
248    where
249        I: IntoIterator<Item = S>,
250        S: Into<String>,
251    {
252        self.partition_by = columns.into_iter().map(Into::into).collect();
253        self
254    }
255
256    /// Add an ORDER BY column (ascending).
257    pub fn order_by(mut self, column: impl Into<String>, direction: SortOrder) -> Self {
258        self.order_by.push(OrderSpec {
259            expr: column.into(),
260            direction,
261            nulls: None,
262        });
263        self
264    }
265
266    /// Add an ORDER BY column with NULLS position.
267    pub fn order_by_nulls(
268        mut self,
269        column: impl Into<String>,
270        direction: SortOrder,
271        nulls: NullsPosition,
272    ) -> Self {
273        self.order_by.push(OrderSpec {
274            expr: column.into(),
275            direction,
276            nulls: Some(nulls),
277        });
278        self
279    }
280
281    /// Set ROWS frame.
282    pub fn rows(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
283        self.frame = Some(FrameClause {
284            frame_type: FrameType::Rows,
285            start,
286            end,
287            exclude: None,
288        });
289        self
290    }
291
292    /// Set RANGE frame.
293    pub fn range(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
294        self.frame = Some(FrameClause {
295            frame_type: FrameType::Range,
296            start,
297            end,
298            exclude: None,
299        });
300        self
301    }
302
303    /// Set GROUPS frame (PostgreSQL).
304    pub fn groups(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
305        self.frame = Some(FrameClause {
306            frame_type: FrameType::Groups,
307            start,
308            end,
309            exclude: None,
310        });
311        self
312    }
313
314    /// Common frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW.
315    pub fn rows_unbounded_preceding(self) -> Self {
316        self.rows(FrameBound::UnboundedPreceding, Some(FrameBound::CurrentRow))
317    }
318
319    /// Common frame: ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING.
320    pub fn rows_unbounded_following(self) -> Self {
321        self.rows(FrameBound::CurrentRow, Some(FrameBound::UnboundedFollowing))
322    }
323
324    /// Common frame: ROWS BETWEEN n PRECEDING AND n FOLLOWING.
325    pub fn rows_around(self, n: u32) -> Self {
326        self.rows(FrameBound::Preceding(n), Some(FrameBound::Following(n)))
327    }
328
329    /// Common frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW.
330    pub fn range_unbounded_preceding(self) -> Self {
331        self.range(FrameBound::UnboundedPreceding, Some(FrameBound::CurrentRow))
332    }
333
334    /// Generate the OVER clause SQL.
335    pub fn to_sql(&self, db_type: DatabaseType) -> String {
336        if let Some(ref name) = self.window_name {
337            return format!("OVER {}", name);
338        }
339
340        let mut parts = Vec::new();
341
342        if !self.partition_by.is_empty() {
343            parts.push(format!("PARTITION BY {}", self.partition_by.join(", ")));
344        }
345
346        if !self.order_by.is_empty() {
347            let orders: Vec<String> = self
348                .order_by
349                .iter()
350                .map(|o| {
351                    let mut s = format!(
352                        "{} {}",
353                        o.expr,
354                        match o.direction {
355                            SortOrder::Asc => "ASC",
356                            SortOrder::Desc => "DESC",
357                        }
358                    );
359                    if let Some(nulls) = o.nulls {
360                        // MSSQL doesn't support NULLS FIRST/LAST directly
361                        if db_type != DatabaseType::MSSQL {
362                            s.push_str(match nulls {
363                                NullsPosition::First => " NULLS FIRST",
364                                NullsPosition::Last => " NULLS LAST",
365                            });
366                        }
367                    }
368                    s
369                })
370                .collect();
371            parts.push(format!("ORDER BY {}", orders.join(", ")));
372        }
373
374        if let Some(ref frame) = self.frame {
375            parts.push(frame.to_sql(db_type));
376        }
377
378        if parts.is_empty() {
379            "OVER ()".to_string()
380        } else {
381            format!("OVER ({})", parts.join(" "))
382        }
383    }
384}
385
386impl FrameClause {
387    /// Generate frame clause SQL.
388    pub fn to_sql(&self, db_type: DatabaseType) -> String {
389        let frame_type = match self.frame_type {
390            FrameType::Rows => "ROWS",
391            FrameType::Range => "RANGE",
392            FrameType::Groups => {
393                // GROUPS only supported in PostgreSQL and SQLite
394                match db_type {
395                    DatabaseType::PostgreSQL | DatabaseType::SQLite => "GROUPS",
396                    _ => "ROWS", // Fallback
397                }
398            }
399        };
400
401        let bounds = if let Some(ref end) = self.end {
402            format!(
403                "BETWEEN {} AND {}",
404                self.start.to_sql(),
405                end.to_sql()
406            )
407        } else {
408            self.start.to_sql()
409        };
410
411        let mut sql = format!("{} {}", frame_type, bounds);
412
413        // Exclude clause (PostgreSQL only)
414        if db_type == DatabaseType::PostgreSQL {
415            if let Some(exclude) = self.exclude {
416                sql.push_str(match exclude {
417                    FrameExclude::CurrentRow => " EXCLUDE CURRENT ROW",
418                    FrameExclude::Group => " EXCLUDE GROUP",
419                    FrameExclude::Ties => " EXCLUDE TIES",
420                    FrameExclude::NoOthers => " EXCLUDE NO OTHERS",
421                });
422            }
423        }
424
425        sql
426    }
427}
428
429impl FrameBound {
430    /// Generate bound SQL.
431    pub fn to_sql(&self) -> String {
432        match self {
433            Self::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
434            Self::Preceding(n) => format!("{} PRECEDING", n),
435            Self::CurrentRow => "CURRENT ROW".to_string(),
436            Self::Following(n) => format!("{} FOLLOWING", n),
437            Self::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
438        }
439    }
440}
441
442impl WindowFunction {
443    /// Create a new window function.
444    pub fn new(function: WindowFn) -> WindowFunctionBuilder {
445        WindowFunctionBuilder {
446            function,
447            over: None,
448            alias: None,
449        }
450    }
451
452    /// Set the OVER clause.
453    pub fn over(mut self, spec: WindowSpec) -> Self {
454        self.over = spec;
455        self
456    }
457
458    /// Set an alias for the result.
459    pub fn alias(mut self, name: impl Into<String>) -> Self {
460        self.alias = Some(name.into());
461        self
462    }
463
464    /// Generate the full SQL expression.
465    pub fn to_sql(&self, db_type: DatabaseType) -> String {
466        let mut sql = format!("{} {}", self.function.to_sql(), self.over.to_sql(db_type));
467        if let Some(ref alias) = self.alias {
468            sql.push_str(" AS ");
469            sql.push_str(alias);
470        }
471        sql
472    }
473}
474
475/// Builder for window functions.
476#[derive(Debug, Clone)]
477pub struct WindowFunctionBuilder {
478    function: WindowFn,
479    over: Option<WindowSpec>,
480    alias: Option<String>,
481}
482
483impl WindowFunctionBuilder {
484    /// Set the OVER clause.
485    pub fn over(mut self, spec: WindowSpec) -> Self {
486        self.over = Some(spec);
487        self
488    }
489
490    /// Set an alias.
491    pub fn alias(mut self, name: impl Into<String>) -> Self {
492        self.alias = Some(name.into());
493        self
494    }
495
496    /// Build the window function.
497    pub fn build(self) -> WindowFunction {
498        WindowFunction {
499            function: self.function,
500            over: self.over.unwrap_or_default(),
501            alias: self.alias,
502        }
503    }
504}
505
506/// A named window definition.
507#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
508pub struct NamedWindow {
509    /// Window name.
510    pub name: String,
511    /// Window specification.
512    pub spec: WindowSpec,
513}
514
515impl NamedWindow {
516    /// Create a new named window.
517    pub fn new(name: impl Into<String>, spec: WindowSpec) -> Self {
518        Self {
519            name: name.into(),
520            spec,
521        }
522    }
523
524    /// Generate the WINDOW clause definition.
525    pub fn to_sql(&self, db_type: DatabaseType) -> String {
526        // Named windows generate just the spec content (without OVER)
527        let spec_parts = {
528            let mut parts = Vec::new();
529            if !self.spec.partition_by.is_empty() {
530                parts.push(format!("PARTITION BY {}", self.spec.partition_by.join(", ")));
531            }
532            if !self.spec.order_by.is_empty() {
533                let orders: Vec<String> = self
534                    .spec
535                    .order_by
536                    .iter()
537                    .map(|o| {
538                        format!(
539                            "{} {}",
540                            o.expr,
541                            match o.direction {
542                                SortOrder::Asc => "ASC",
543                                SortOrder::Desc => "DESC",
544                            }
545                        )
546                    })
547                    .collect();
548                parts.push(format!("ORDER BY {}", orders.join(", ")));
549            }
550            if let Some(ref frame) = self.spec.frame {
551                parts.push(frame.to_sql(db_type));
552            }
553            parts.join(" ")
554        };
555
556        format!("{} AS ({})", self.name, spec_parts)
557    }
558}
559
560// ============================================================================
561// Helper functions for creating window functions
562// ============================================================================
563
564/// Create ROW_NUMBER() window function.
565pub fn row_number() -> WindowFunctionBuilder {
566    WindowFunction::new(WindowFn::RowNumber)
567}
568
569/// Create RANK() window function.
570pub fn rank() -> WindowFunctionBuilder {
571    WindowFunction::new(WindowFn::Rank)
572}
573
574/// Create DENSE_RANK() window function.
575pub fn dense_rank() -> WindowFunctionBuilder {
576    WindowFunction::new(WindowFn::DenseRank)
577}
578
579/// Create NTILE(n) window function.
580pub fn ntile(n: u32) -> WindowFunctionBuilder {
581    WindowFunction::new(WindowFn::Ntile(n))
582}
583
584/// Create PERCENT_RANK() window function.
585pub fn percent_rank() -> WindowFunctionBuilder {
586    WindowFunction::new(WindowFn::PercentRank)
587}
588
589/// Create CUME_DIST() window function.
590pub fn cume_dist() -> WindowFunctionBuilder {
591    WindowFunction::new(WindowFn::CumeDist)
592}
593
594/// Create LAG() window function.
595pub fn lag(expr: impl Into<String>) -> WindowFunctionBuilder {
596    WindowFunction::new(WindowFn::Lag {
597        expr: expr.into(),
598        offset: None,
599        default: None,
600    })
601}
602
603/// Create LAG() with offset.
604pub fn lag_offset(expr: impl Into<String>, offset: u32) -> WindowFunctionBuilder {
605    WindowFunction::new(WindowFn::Lag {
606        expr: expr.into(),
607        offset: Some(offset),
608        default: None,
609    })
610}
611
612/// Create LAG() with offset and default.
613pub fn lag_full(expr: impl Into<String>, offset: u32, default: impl Into<String>) -> WindowFunctionBuilder {
614    WindowFunction::new(WindowFn::Lag {
615        expr: expr.into(),
616        offset: Some(offset),
617        default: Some(default.into()),
618    })
619}
620
621/// Create LEAD() window function.
622pub fn lead(expr: impl Into<String>) -> WindowFunctionBuilder {
623    WindowFunction::new(WindowFn::Lead {
624        expr: expr.into(),
625        offset: None,
626        default: None,
627    })
628}
629
630/// Create LEAD() with offset.
631pub fn lead_offset(expr: impl Into<String>, offset: u32) -> WindowFunctionBuilder {
632    WindowFunction::new(WindowFn::Lead {
633        expr: expr.into(),
634        offset: Some(offset),
635        default: None,
636    })
637}
638
639/// Create LEAD() with offset and default.
640pub fn lead_full(expr: impl Into<String>, offset: u32, default: impl Into<String>) -> WindowFunctionBuilder {
641    WindowFunction::new(WindowFn::Lead {
642        expr: expr.into(),
643        offset: Some(offset),
644        default: Some(default.into()),
645    })
646}
647
648/// Create FIRST_VALUE() window function.
649pub fn first_value(expr: impl Into<String>) -> WindowFunctionBuilder {
650    WindowFunction::new(WindowFn::FirstValue(expr.into()))
651}
652
653/// Create LAST_VALUE() window function.
654pub fn last_value(expr: impl Into<String>) -> WindowFunctionBuilder {
655    WindowFunction::new(WindowFn::LastValue(expr.into()))
656}
657
658/// Create NTH_VALUE() window function.
659pub fn nth_value(expr: impl Into<String>, n: u32) -> WindowFunctionBuilder {
660    WindowFunction::new(WindowFn::NthValue(expr.into(), n))
661}
662
663/// Create SUM() window function.
664pub fn sum(expr: impl Into<String>) -> WindowFunctionBuilder {
665    WindowFunction::new(WindowFn::Sum(expr.into()))
666}
667
668/// Create AVG() window function.
669pub fn avg(expr: impl Into<String>) -> WindowFunctionBuilder {
670    WindowFunction::new(WindowFn::Avg(expr.into()))
671}
672
673/// Create COUNT() window function.
674pub fn count(expr: impl Into<String>) -> WindowFunctionBuilder {
675    WindowFunction::new(WindowFn::Count(expr.into()))
676}
677
678/// Create MIN() window function.
679pub fn min(expr: impl Into<String>) -> WindowFunctionBuilder {
680    WindowFunction::new(WindowFn::Min(expr.into()))
681}
682
683/// Create MAX() window function.
684pub fn max(expr: impl Into<String>) -> WindowFunctionBuilder {
685    WindowFunction::new(WindowFn::Max(expr.into()))
686}
687
688/// Create a custom window function.
689pub fn custom<I, S>(name: impl Into<String>, args: I) -> WindowFunctionBuilder
690where
691    I: IntoIterator<Item = S>,
692    S: Into<String>,
693{
694    WindowFunction::new(WindowFn::Custom {
695        name: name.into(),
696        args: args.into_iter().map(Into::into).collect(),
697    })
698}
699
700/// MongoDB $setWindowFields support.
701pub mod mongodb {
702    use serde::{Deserialize, Serialize};
703    use serde_json::Value as JsonValue;
704
705    /// A $setWindowFields stage for MongoDB aggregation pipelines.
706    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
707    pub struct SetWindowFields {
708        /// PARTITION BY equivalent.
709        pub partition_by: Option<JsonValue>,
710        /// SORT BY specification.
711        pub sort_by: Option<JsonValue>,
712        /// Output fields with window functions.
713        pub output: serde_json::Map<String, JsonValue>,
714    }
715
716    impl SetWindowFields {
717        /// Create a new $setWindowFields stage.
718        pub fn new() -> SetWindowFieldsBuilder {
719            SetWindowFieldsBuilder::default()
720        }
721
722        /// Convert to BSON document.
723        pub fn to_bson(&self) -> JsonValue {
724            let mut stage = serde_json::Map::new();
725
726            if let Some(ref partition) = self.partition_by {
727                stage.insert("partitionBy".to_string(), partition.clone());
728            }
729
730            if let Some(ref sort) = self.sort_by {
731                stage.insert("sortBy".to_string(), sort.clone());
732            }
733
734            stage.insert("output".to_string(), JsonValue::Object(self.output.clone()));
735
736            serde_json::json!({ "$setWindowFields": stage })
737        }
738    }
739
740    impl Default for SetWindowFields {
741        fn default() -> Self {
742            Self {
743                partition_by: None,
744                sort_by: None,
745                output: serde_json::Map::new(),
746            }
747        }
748    }
749
750    /// Builder for $setWindowFields.
751    #[derive(Debug, Clone, Default)]
752    pub struct SetWindowFieldsBuilder {
753        partition_by: Option<JsonValue>,
754        sort_by: Option<JsonValue>,
755        output: serde_json::Map<String, JsonValue>,
756    }
757
758    impl SetWindowFieldsBuilder {
759        /// Set PARTITION BY.
760        pub fn partition_by(mut self, expr: impl Into<String>) -> Self {
761            self.partition_by = Some(JsonValue::String(format!("${}", expr.into())));
762            self
763        }
764
765        /// Set PARTITION BY with object expression.
766        pub fn partition_by_expr(mut self, expr: JsonValue) -> Self {
767            self.partition_by = Some(expr);
768            self
769        }
770
771        /// Set SORT BY (single field ascending).
772        pub fn sort_by(mut self, field: impl Into<String>) -> Self {
773            let mut sort = serde_json::Map::new();
774            sort.insert(field.into(), JsonValue::Number(1.into()));
775            self.sort_by = Some(JsonValue::Object(sort));
776            self
777        }
778
779        /// Set SORT BY with direction.
780        pub fn sort_by_desc(mut self, field: impl Into<String>) -> Self {
781            let mut sort = serde_json::Map::new();
782            sort.insert(field.into(), JsonValue::Number((-1).into()));
783            self.sort_by = Some(JsonValue::Object(sort));
784            self
785        }
786
787        /// Set SORT BY with multiple fields.
788        pub fn sort_by_fields(mut self, fields: Vec<(&str, i32)>) -> Self {
789            let mut sort = serde_json::Map::new();
790            for (field, dir) in fields {
791                sort.insert(field.to_string(), JsonValue::Number(dir.into()));
792            }
793            self.sort_by = Some(JsonValue::Object(sort));
794            self
795        }
796
797        /// Add $rowNumber output field.
798        pub fn row_number(mut self, output_field: impl Into<String>) -> Self {
799            self.output.insert(
800                output_field.into(),
801                serde_json::json!({ "$rowNumber": {} }),
802            );
803            self
804        }
805
806        /// Add $rank output field.
807        pub fn rank(mut self, output_field: impl Into<String>) -> Self {
808            self.output.insert(
809                output_field.into(),
810                serde_json::json!({ "$rank": {} }),
811            );
812            self
813        }
814
815        /// Add $denseRank output field.
816        pub fn dense_rank(mut self, output_field: impl Into<String>) -> Self {
817            self.output.insert(
818                output_field.into(),
819                serde_json::json!({ "$denseRank": {} }),
820            );
821            self
822        }
823
824        /// Add $sum with window output field.
825        pub fn sum(
826            mut self,
827            output_field: impl Into<String>,
828            input: impl Into<String>,
829            window: Option<MongoWindow>,
830        ) -> Self {
831            let mut spec = serde_json::Map::new();
832            spec.insert("$sum".to_string(), JsonValue::String(format!("${}", input.into())));
833            if let Some(w) = window {
834                spec.insert("window".to_string(), w.to_bson());
835            }
836            self.output.insert(output_field.into(), JsonValue::Object(spec));
837            self
838        }
839
840        /// Add $avg with window output field.
841        pub fn avg(
842            mut self,
843            output_field: impl Into<String>,
844            input: impl Into<String>,
845            window: Option<MongoWindow>,
846        ) -> Self {
847            let mut spec = serde_json::Map::new();
848            spec.insert("$avg".to_string(), JsonValue::String(format!("${}", input.into())));
849            if let Some(w) = window {
850                spec.insert("window".to_string(), w.to_bson());
851            }
852            self.output.insert(output_field.into(), JsonValue::Object(spec));
853            self
854        }
855
856        /// Add $first output field.
857        pub fn first(mut self, output_field: impl Into<String>, input: impl Into<String>) -> Self {
858            self.output.insert(
859                output_field.into(),
860                serde_json::json!({ "$first": format!("${}", input.into()) }),
861            );
862            self
863        }
864
865        /// Add $last output field.
866        pub fn last(mut self, output_field: impl Into<String>, input: impl Into<String>) -> Self {
867            self.output.insert(
868                output_field.into(),
869                serde_json::json!({ "$last": format!("${}", input.into()) }),
870            );
871            self
872        }
873
874        /// Add $shift (LAG/LEAD equivalent) output field.
875        pub fn shift(
876            mut self,
877            output_field: impl Into<String>,
878            output: impl Into<String>,
879            by: i32,
880            default: Option<JsonValue>,
881        ) -> Self {
882            let mut spec = serde_json::Map::new();
883            spec.insert("output".to_string(), JsonValue::String(format!("${}", output.into())));
884            spec.insert("by".to_string(), JsonValue::Number(by.into()));
885            if let Some(def) = default {
886                spec.insert("default".to_string(), def);
887            }
888            self.output.insert(
889                output_field.into(),
890                serde_json::json!({ "$shift": spec }),
891            );
892            self
893        }
894
895        /// Add custom window function.
896        pub fn output(mut self, field: impl Into<String>, spec: JsonValue) -> Self {
897            self.output.insert(field.into(), spec);
898            self
899        }
900
901        /// Build the stage.
902        pub fn build(self) -> SetWindowFields {
903            SetWindowFields {
904                partition_by: self.partition_by,
905                sort_by: self.sort_by,
906                output: self.output,
907            }
908        }
909    }
910
911    /// MongoDB window specification.
912    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
913    pub struct MongoWindow {
914        /// Documents array [start, end].
915        pub documents: Option<[WindowBound; 2]>,
916        /// Range array [start, end].
917        pub range: Option<[WindowBound; 2]>,
918        /// Unit for range (day, week, month, etc.).
919        pub unit: Option<String>,
920    }
921
922    /// Window boundary value.
923    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
924    #[serde(untagged)]
925    pub enum WindowBound {
926        /// Numeric offset or "unbounded"/"current".
927        Number(i64),
928        /// String keyword.
929        Keyword(String),
930    }
931
932    impl MongoWindow {
933        /// Documents window (like SQL ROWS).
934        pub fn documents(start: i64, end: i64) -> Self {
935            Self {
936                documents: Some([WindowBound::Number(start), WindowBound::Number(end)]),
937                range: None,
938                unit: None,
939            }
940        }
941
942        /// Unbounded documents window.
943        pub fn documents_unbounded() -> Self {
944            Self {
945                documents: Some([
946                    WindowBound::Keyword("unbounded".to_string()),
947                    WindowBound::Keyword("unbounded".to_string()),
948                ]),
949                range: None,
950                unit: None,
951            }
952        }
953
954        /// Documents from unbounded preceding to current.
955        pub fn documents_to_current() -> Self {
956            Self {
957                documents: Some([
958                    WindowBound::Keyword("unbounded".to_string()),
959                    WindowBound::Keyword("current".to_string()),
960                ]),
961                range: None,
962                unit: None,
963            }
964        }
965
966        /// Range window with unit.
967        pub fn range_with_unit(start: i64, end: i64, unit: impl Into<String>) -> Self {
968            Self {
969                documents: None,
970                range: Some([WindowBound::Number(start), WindowBound::Number(end)]),
971                unit: Some(unit.into()),
972            }
973        }
974
975        /// Convert to BSON.
976        pub fn to_bson(&self) -> JsonValue {
977            let mut window = serde_json::Map::new();
978
979            if let Some(ref docs) = self.documents {
980                let arr: Vec<JsonValue> = docs
981                    .iter()
982                    .map(|b| match b {
983                        WindowBound::Number(n) => JsonValue::Number((*n).into()),
984                        WindowBound::Keyword(s) => JsonValue::String(s.clone()),
985                    })
986                    .collect();
987                window.insert("documents".to_string(), JsonValue::Array(arr));
988            }
989
990            if let Some(ref range) = self.range {
991                let arr: Vec<JsonValue> = range
992                    .iter()
993                    .map(|b| match b {
994                        WindowBound::Number(n) => JsonValue::Number((*n).into()),
995                        WindowBound::Keyword(s) => JsonValue::String(s.clone()),
996                    })
997                    .collect();
998                window.insert("range".to_string(), JsonValue::Array(arr));
999            }
1000
1001            if let Some(ref unit) = self.unit {
1002                window.insert("unit".to_string(), JsonValue::String(unit.clone()));
1003            }
1004
1005            JsonValue::Object(window)
1006        }
1007    }
1008
1009    /// Helper to create a $setWindowFields stage.
1010    pub fn set_window_fields() -> SetWindowFieldsBuilder {
1011        SetWindowFields::new()
1012    }
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017    use super::*;
1018
1019    #[test]
1020    fn test_row_number() {
1021        let wf = row_number()
1022            .over(WindowSpec::new()
1023                .partition_by(["dept"])
1024                .order_by("salary", SortOrder::Desc))
1025            .build();
1026
1027        let sql = wf.to_sql(DatabaseType::PostgreSQL);
1028        assert!(sql.contains("ROW_NUMBER()"));
1029        assert!(sql.contains("PARTITION BY dept"));
1030        assert!(sql.contains("ORDER BY salary DESC"));
1031    }
1032
1033    #[test]
1034    fn test_rank_functions() {
1035        let r = rank().over(WindowSpec::new().order_by("score", SortOrder::Desc)).build();
1036        assert!(r.to_sql(DatabaseType::PostgreSQL).contains("RANK()"));
1037
1038        let dr = dense_rank().over(WindowSpec::new().order_by("score", SortOrder::Desc)).build();
1039        assert!(dr.to_sql(DatabaseType::PostgreSQL).contains("DENSE_RANK()"));
1040    }
1041
1042    #[test]
1043    fn test_ntile() {
1044        let wf = ntile(4)
1045            .over(WindowSpec::new().order_by("value", SortOrder::Asc))
1046            .build();
1047
1048        assert!(wf.to_sql(DatabaseType::MySQL).contains("NTILE(4)"));
1049    }
1050
1051    #[test]
1052    fn test_lag_lead() {
1053        let l = lag("price").over(WindowSpec::new().order_by("date", SortOrder::Asc)).build();
1054        assert!(l.to_sql(DatabaseType::PostgreSQL).contains("LAG(price)"));
1055
1056        let l2 = lag_offset("price", 2).over(WindowSpec::new().order_by("date", SortOrder::Asc)).build();
1057        assert!(l2.to_sql(DatabaseType::PostgreSQL).contains("LAG(price, 2)"));
1058
1059        let l3 = lag_full("price", 1, "0").over(WindowSpec::new().order_by("date", SortOrder::Asc)).build();
1060        assert!(l3.to_sql(DatabaseType::PostgreSQL).contains("LAG(price, 1, 0)"));
1061
1062        let ld = lead("price").over(WindowSpec::new().order_by("date", SortOrder::Asc)).build();
1063        assert!(ld.to_sql(DatabaseType::PostgreSQL).contains("LEAD(price)"));
1064    }
1065
1066    #[test]
1067    fn test_aggregate_window() {
1068        let s = sum("amount")
1069            .over(WindowSpec::new()
1070                .partition_by(["account_id"])
1071                .order_by("date", SortOrder::Asc)
1072                .rows_unbounded_preceding())
1073            .alias("running_total")
1074            .build();
1075
1076        let sql = s.to_sql(DatabaseType::PostgreSQL);
1077        assert!(sql.contains("SUM(amount)"));
1078        assert!(sql.contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"));
1079        assert!(sql.contains("AS running_total"));
1080    }
1081
1082    #[test]
1083    fn test_frame_clauses() {
1084        let spec = WindowSpec::new()
1085            .order_by("id", SortOrder::Asc)
1086            .rows(FrameBound::Preceding(3), Some(FrameBound::Following(3)));
1087
1088        let sql = spec.to_sql(DatabaseType::PostgreSQL);
1089        assert!(sql.contains("ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING"));
1090    }
1091
1092    #[test]
1093    fn test_named_window() {
1094        let nw = NamedWindow::new(
1095            "w",
1096            WindowSpec::new()
1097                .partition_by(["dept"])
1098                .order_by("salary", SortOrder::Desc),
1099        );
1100
1101        let sql = nw.to_sql(DatabaseType::PostgreSQL);
1102        assert!(sql.contains("w AS ("));
1103        assert!(sql.contains("PARTITION BY dept"));
1104    }
1105
1106    #[test]
1107    fn test_window_reference() {
1108        let spec = WindowSpec::named("w");
1109        assert_eq!(spec.to_sql(DatabaseType::PostgreSQL), "OVER w");
1110    }
1111
1112    #[test]
1113    fn test_nulls_position() {
1114        let spec = WindowSpec::new()
1115            .order_by_nulls("value", SortOrder::Desc, NullsPosition::Last);
1116
1117        let pg_sql = spec.to_sql(DatabaseType::PostgreSQL);
1118        assert!(pg_sql.contains("NULLS LAST"));
1119
1120        // MSSQL doesn't support NULLS FIRST/LAST
1121        let mssql_sql = spec.to_sql(DatabaseType::MSSQL);
1122        assert!(!mssql_sql.contains("NULLS"));
1123    }
1124
1125    #[test]
1126    fn test_first_last_value() {
1127        let fv = first_value("salary")
1128            .over(WindowSpec::new()
1129                .partition_by(["dept"])
1130                .order_by("hire_date", SortOrder::Asc))
1131            .build();
1132
1133        assert!(fv.to_sql(DatabaseType::PostgreSQL).contains("FIRST_VALUE(salary)"));
1134
1135        let lv = last_value("salary")
1136            .over(WindowSpec::new()
1137                .partition_by(["dept"])
1138                .order_by("hire_date", SortOrder::Asc)
1139                .rows(FrameBound::UnboundedPreceding, Some(FrameBound::UnboundedFollowing)))
1140            .build();
1141
1142        assert!(lv.to_sql(DatabaseType::PostgreSQL).contains("LAST_VALUE(salary)"));
1143    }
1144
1145    mod mongodb_tests {
1146        use super::super::mongodb::*;
1147
1148        #[test]
1149        fn test_row_number() {
1150            let stage = set_window_fields()
1151                .partition_by("state")
1152                .sort_by_desc("quantity")
1153                .row_number("rowNumber")
1154                .build();
1155
1156            let bson = stage.to_bson();
1157            assert!(bson["$setWindowFields"]["output"]["rowNumber"]["$rowNumber"].is_object());
1158        }
1159
1160        #[test]
1161        fn test_rank() {
1162            let stage = set_window_fields()
1163                .sort_by("score")
1164                .rank("ranking")
1165                .dense_rank("denseRanking")
1166                .build();
1167
1168            let bson = stage.to_bson();
1169            assert!(bson["$setWindowFields"]["output"]["ranking"]["$rank"].is_object());
1170            assert!(bson["$setWindowFields"]["output"]["denseRanking"]["$denseRank"].is_object());
1171        }
1172
1173        #[test]
1174        fn test_running_total() {
1175            let stage = set_window_fields()
1176                .partition_by("account")
1177                .sort_by("date")
1178                .sum("runningTotal", "amount", Some(MongoWindow::documents_to_current()))
1179                .build();
1180
1181            let bson = stage.to_bson();
1182            let output = &bson["$setWindowFields"]["output"]["runningTotal"];
1183            assert!(output["$sum"].is_string());
1184            assert!(output["window"]["documents"].is_array());
1185        }
1186
1187        #[test]
1188        fn test_shift_lag() {
1189            let stage = set_window_fields()
1190                .sort_by("date")
1191                .shift("prevPrice", "price", -1, Some(serde_json::json!(0)))
1192                .shift("nextPrice", "price", 1, None)
1193                .build();
1194
1195            let bson = stage.to_bson();
1196            assert!(bson["$setWindowFields"]["output"]["prevPrice"]["$shift"]["by"] == -1);
1197            assert!(bson["$setWindowFields"]["output"]["nextPrice"]["$shift"]["by"] == 1);
1198        }
1199
1200        #[test]
1201        fn test_window_bounds() {
1202            let w = MongoWindow::documents(-3, 3);
1203            let bson = w.to_bson();
1204            assert_eq!(bson["documents"][0], -3);
1205            assert_eq!(bson["documents"][1], 3);
1206
1207            let w2 = MongoWindow::range_with_unit(-7, 0, "day");
1208            let bson2 = w2.to_bson();
1209            assert!(bson2["range"].is_array());
1210            assert_eq!(bson2["unit"], "day");
1211        }
1212    }
1213}
1214
1215
1216
1217