Skip to main content

prax_query/
window.rs

1//! Window functions support.
2//!
3//! This module provides types for building window functions (OVER clauses)
4//! across different database backends.
5//!
6//! # Supported Features
7//!
8//! | Feature         | PostgreSQL | MySQL | SQLite | MSSQL | MongoDB 5.0+ |
9//! |-----------------|------------|-------|--------|-------|--------------|
10//! | ROW_NUMBER      | ✅         | ✅    | ✅     | ✅    | ✅           |
11//! | RANK/DENSE_RANK | ✅         | ✅    | ✅     | ✅    | ✅           |
12//! | LAG/LEAD        | ✅         | ✅    | ✅     | ✅    | ✅           |
13//! | Frame clauses   | ✅         | ✅    | ✅     | ✅    | ✅           |
14//! | Named windows   | ✅         | ✅    | ✅     | ❌    | ❌           |
15//!
16//! # Example Usage
17//!
18//! ```rust,ignore
19//! use prax_query::window::{WindowFunction, WindowSpec, row_number, rank, sum};
20//!
21//! // ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC)
22//! let wf = row_number()
23//!     .over(WindowSpec::new()
24//!         .partition_by(["dept"])
25//!         .order_by("salary", SortOrder::Desc));
26//!
27//! // Running total
28//! let running = sum("amount")
29//!     .over(WindowSpec::new()
30//!         .order_by("date", SortOrder::Asc)
31//!         .rows_unbounded_preceding());
32//! ```
33
34use serde::{Deserialize, Serialize};
35
36use crate::sql::DatabaseType;
37use crate::types::SortOrder;
38
39/// A window function with its OVER clause.
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub struct WindowFunction {
42    /// The function being called.
43    pub function: WindowFn,
44    /// The OVER clause specification.
45    pub over: WindowSpec,
46    /// Optional alias for the result.
47    pub alias: Option<String>,
48}
49
50/// Available window functions.
51#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
52pub enum WindowFn {
53    // Ranking functions
54    /// ROW_NUMBER() - Sequential row number.
55    RowNumber,
56    /// RANK() - Rank with gaps.
57    Rank,
58    /// DENSE_RANK() - Rank without gaps.
59    DenseRank,
60    /// NTILE(n) - Distribute rows into n buckets.
61    Ntile(u32),
62    /// PERCENT_RANK() - Relative rank (0 to 1).
63    PercentRank,
64    /// CUME_DIST() - Cumulative distribution.
65    CumeDist,
66
67    // Value functions
68    /// LAG(expr, offset, default) - Value from previous row.
69    Lag {
70        expr: String,
71        offset: Option<u32>,
72        default: Option<String>,
73    },
74    /// LEAD(expr, offset, default) - Value from next row.
75    Lead {
76        expr: String,
77        offset: Option<u32>,
78        default: Option<String>,
79    },
80    /// FIRST_VALUE(expr) - First value in frame.
81    FirstValue(String),
82    /// LAST_VALUE(expr) - Last value in frame.
83    LastValue(String),
84    /// NTH_VALUE(expr, n) - Nth value in frame.
85    NthValue(String, u32),
86
87    // Aggregate functions as window functions
88    /// SUM(expr).
89    Sum(String),
90    /// AVG(expr).
91    Avg(String),
92    /// COUNT(expr).
93    Count(String),
94    /// MIN(expr).
95    Min(String),
96    /// MAX(expr).
97    Max(String),
98    /// Custom function.
99    Custom { name: String, args: Vec<String> },
100}
101
102impl WindowFn {
103    /// Generate the function SQL.
104    pub fn to_sql(&self) -> String {
105        match self {
106            Self::RowNumber => "ROW_NUMBER()".to_string(),
107            Self::Rank => "RANK()".to_string(),
108            Self::DenseRank => "DENSE_RANK()".to_string(),
109            Self::Ntile(n) => format!("NTILE({})", n),
110            Self::PercentRank => "PERCENT_RANK()".to_string(),
111            Self::CumeDist => "CUME_DIST()".to_string(),
112            Self::Lag {
113                expr,
114                offset,
115                default,
116            } => {
117                let mut sql = format!("LAG({})", expr);
118                if let Some(off) = offset {
119                    sql = format!("LAG({}, {})", expr, off);
120                    if let Some(def) = default {
121                        sql = format!("LAG({}, {}, {})", expr, off, def);
122                    }
123                }
124                sql
125            }
126            Self::Lead {
127                expr,
128                offset,
129                default,
130            } => {
131                let mut sql = format!("LEAD({})", expr);
132                if let Some(off) = offset {
133                    sql = format!("LEAD({}, {})", expr, off);
134                    if let Some(def) = default {
135                        sql = format!("LEAD({}, {}, {})", expr, off, def);
136                    }
137                }
138                sql
139            }
140            Self::FirstValue(expr) => format!("FIRST_VALUE({})", expr),
141            Self::LastValue(expr) => format!("LAST_VALUE({})", expr),
142            Self::NthValue(expr, n) => format!("NTH_VALUE({}, {})", expr, n),
143            Self::Sum(expr) => format!("SUM({})", expr),
144            Self::Avg(expr) => format!("AVG({})", expr),
145            Self::Count(expr) => format!("COUNT({})", expr),
146            Self::Min(expr) => format!("MIN({})", expr),
147            Self::Max(expr) => format!("MAX({})", expr),
148            Self::Custom { name, args } => {
149                format!("{}({})", name, args.join(", "))
150            }
151        }
152    }
153}
154
155/// Window specification (OVER clause).
156#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
157pub struct WindowSpec {
158    /// Reference to a named window.
159    pub window_name: Option<String>,
160    /// PARTITION BY columns.
161    pub partition_by: Vec<String>,
162    /// ORDER BY specifications.
163    pub order_by: Vec<OrderSpec>,
164    /// Frame clause.
165    pub frame: Option<FrameClause>,
166}
167
168/// Order specification for window functions.
169#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
170pub struct OrderSpec {
171    /// Column or expression to order by.
172    pub expr: String,
173    /// Sort direction.
174    pub direction: SortOrder,
175    /// NULLS FIRST/LAST.
176    pub nulls: Option<NullsPosition>,
177}
178
179/// Position of NULL values in ordering.
180#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
181pub enum NullsPosition {
182    /// NULL values first.
183    First,
184    /// NULL values last.
185    Last,
186}
187
188/// Frame clause for window functions.
189#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
190pub struct FrameClause {
191    /// Frame type (ROWS, RANGE, GROUPS).
192    pub frame_type: FrameType,
193    /// Frame start bound.
194    pub start: FrameBound,
195    /// Frame end bound (if BETWEEN).
196    pub end: Option<FrameBound>,
197    /// Exclude clause (PostgreSQL).
198    pub exclude: Option<FrameExclude>,
199}
200
201/// Frame type.
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
203pub enum FrameType {
204    /// Row-based frame.
205    Rows,
206    /// Value-based frame.
207    Range,
208    /// Group-based frame (PostgreSQL).
209    Groups,
210}
211
212/// Frame boundary.
213#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
214pub enum FrameBound {
215    /// UNBOUNDED PRECEDING.
216    UnboundedPreceding,
217    /// n PRECEDING.
218    Preceding(u32),
219    /// CURRENT ROW.
220    CurrentRow,
221    /// n FOLLOWING.
222    Following(u32),
223    /// UNBOUNDED FOLLOWING.
224    UnboundedFollowing,
225}
226
227/// Frame exclusion (PostgreSQL).
228#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
229pub enum FrameExclude {
230    /// EXCLUDE CURRENT ROW.
231    CurrentRow,
232    /// EXCLUDE GROUP.
233    Group,
234    /// EXCLUDE TIES.
235    Ties,
236    /// EXCLUDE NO OTHERS.
237    NoOthers,
238}
239
240impl WindowSpec {
241    /// Create a new empty window specification.
242    pub fn new() -> Self {
243        Self::default()
244    }
245
246    /// Reference a named window.
247    pub fn named(name: impl Into<String>) -> Self {
248        Self {
249            window_name: Some(name.into()),
250            ..Default::default()
251        }
252    }
253
254    /// Add PARTITION BY columns.
255    pub fn partition_by<I, S>(mut self, columns: I) -> Self
256    where
257        I: IntoIterator<Item = S>,
258        S: Into<String>,
259    {
260        self.partition_by = columns.into_iter().map(Into::into).collect();
261        self
262    }
263
264    /// Add an ORDER BY column (ascending).
265    pub fn order_by(mut self, column: impl Into<String>, direction: SortOrder) -> Self {
266        self.order_by.push(OrderSpec {
267            expr: column.into(),
268            direction,
269            nulls: None,
270        });
271        self
272    }
273
274    /// Add an ORDER BY column with NULLS position.
275    pub fn order_by_nulls(
276        mut self,
277        column: impl Into<String>,
278        direction: SortOrder,
279        nulls: NullsPosition,
280    ) -> Self {
281        self.order_by.push(OrderSpec {
282            expr: column.into(),
283            direction,
284            nulls: Some(nulls),
285        });
286        self
287    }
288
289    /// Set ROWS frame.
290    pub fn rows(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
291        self.frame = Some(FrameClause {
292            frame_type: FrameType::Rows,
293            start,
294            end,
295            exclude: None,
296        });
297        self
298    }
299
300    /// Set RANGE frame.
301    pub fn range(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
302        self.frame = Some(FrameClause {
303            frame_type: FrameType::Range,
304            start,
305            end,
306            exclude: None,
307        });
308        self
309    }
310
311    /// Set GROUPS frame (PostgreSQL).
312    pub fn groups(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
313        self.frame = Some(FrameClause {
314            frame_type: FrameType::Groups,
315            start,
316            end,
317            exclude: None,
318        });
319        self
320    }
321
322    /// Common frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW.
323    pub fn rows_unbounded_preceding(self) -> Self {
324        self.rows(FrameBound::UnboundedPreceding, Some(FrameBound::CurrentRow))
325    }
326
327    /// Common frame: ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING.
328    pub fn rows_unbounded_following(self) -> Self {
329        self.rows(FrameBound::CurrentRow, Some(FrameBound::UnboundedFollowing))
330    }
331
332    /// Common frame: ROWS BETWEEN n PRECEDING AND n FOLLOWING.
333    pub fn rows_around(self, n: u32) -> Self {
334        self.rows(FrameBound::Preceding(n), Some(FrameBound::Following(n)))
335    }
336
337    /// Common frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW.
338    pub fn range_unbounded_preceding(self) -> Self {
339        self.range(FrameBound::UnboundedPreceding, Some(FrameBound::CurrentRow))
340    }
341
342    /// Generate the OVER clause SQL.
343    pub fn to_sql(&self, db_type: DatabaseType) -> String {
344        if let Some(ref name) = self.window_name {
345            return format!("OVER {}", name);
346        }
347
348        let mut parts = Vec::new();
349
350        if !self.partition_by.is_empty() {
351            parts.push(format!("PARTITION BY {}", self.partition_by.join(", ")));
352        }
353
354        if !self.order_by.is_empty() {
355            let orders: Vec<String> = self
356                .order_by
357                .iter()
358                .map(|o| {
359                    let mut s = format!(
360                        "{} {}",
361                        o.expr,
362                        match o.direction {
363                            SortOrder::Asc => "ASC",
364                            SortOrder::Desc => "DESC",
365                        }
366                    );
367                    if let Some(nulls) = o.nulls {
368                        // MSSQL doesn't support NULLS FIRST/LAST directly
369                        if db_type != DatabaseType::MSSQL {
370                            s.push_str(match nulls {
371                                NullsPosition::First => " NULLS FIRST",
372                                NullsPosition::Last => " NULLS LAST",
373                            });
374                        }
375                    }
376                    s
377                })
378                .collect();
379            parts.push(format!("ORDER BY {}", orders.join(", ")));
380        }
381
382        if let Some(ref frame) = self.frame {
383            parts.push(frame.to_sql(db_type));
384        }
385
386        if parts.is_empty() {
387            "OVER ()".to_string()
388        } else {
389            format!("OVER ({})", parts.join(" "))
390        }
391    }
392}
393
394impl FrameClause {
395    /// Generate frame clause SQL.
396    pub fn to_sql(&self, db_type: DatabaseType) -> String {
397        let frame_type = match self.frame_type {
398            FrameType::Rows => "ROWS",
399            FrameType::Range => "RANGE",
400            FrameType::Groups => {
401                // GROUPS only supported in PostgreSQL and SQLite
402                match db_type {
403                    DatabaseType::PostgreSQL | DatabaseType::SQLite => "GROUPS",
404                    _ => "ROWS", // Fallback
405                }
406            }
407        };
408
409        let bounds = if let Some(ref end) = self.end {
410            format!("BETWEEN {} AND {}", self.start.to_sql(), end.to_sql())
411        } else {
412            self.start.to_sql()
413        };
414
415        let mut sql = format!("{} {}", frame_type, bounds);
416
417        // Exclude clause (PostgreSQL only)
418        if db_type == DatabaseType::PostgreSQL {
419            if let Some(exclude) = self.exclude {
420                sql.push_str(match exclude {
421                    FrameExclude::CurrentRow => " EXCLUDE CURRENT ROW",
422                    FrameExclude::Group => " EXCLUDE GROUP",
423                    FrameExclude::Ties => " EXCLUDE TIES",
424                    FrameExclude::NoOthers => " EXCLUDE NO OTHERS",
425                });
426            }
427        }
428
429        sql
430    }
431}
432
433impl FrameBound {
434    /// Generate bound SQL.
435    pub fn to_sql(&self) -> String {
436        match self {
437            Self::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
438            Self::Preceding(n) => format!("{} PRECEDING", n),
439            Self::CurrentRow => "CURRENT ROW".to_string(),
440            Self::Following(n) => format!("{} FOLLOWING", n),
441            Self::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
442        }
443    }
444}
445
446impl WindowFunction {
447    /// Create a new window function.
448    pub fn new(function: WindowFn) -> WindowFunctionBuilder {
449        WindowFunctionBuilder {
450            function,
451            over: None,
452            alias: None,
453        }
454    }
455
456    /// Set the OVER clause.
457    pub fn over(mut self, spec: WindowSpec) -> Self {
458        self.over = spec;
459        self
460    }
461
462    /// Set an alias for the result.
463    pub fn alias(mut self, name: impl Into<String>) -> Self {
464        self.alias = Some(name.into());
465        self
466    }
467
468    /// Generate the full SQL expression.
469    pub fn to_sql(&self, db_type: DatabaseType) -> String {
470        let mut sql = format!("{} {}", self.function.to_sql(), self.over.to_sql(db_type));
471        if let Some(ref alias) = self.alias {
472            sql.push_str(" AS ");
473            sql.push_str(alias);
474        }
475        sql
476    }
477}
478
479/// Builder for window functions.
480#[derive(Debug, Clone)]
481pub struct WindowFunctionBuilder {
482    function: WindowFn,
483    over: Option<WindowSpec>,
484    alias: Option<String>,
485}
486
487impl WindowFunctionBuilder {
488    /// Set the OVER clause.
489    pub fn over(mut self, spec: WindowSpec) -> Self {
490        self.over = Some(spec);
491        self
492    }
493
494    /// Set an alias.
495    pub fn alias(mut self, name: impl Into<String>) -> Self {
496        self.alias = Some(name.into());
497        self
498    }
499
500    /// Build the window function.
501    pub fn build(self) -> WindowFunction {
502        WindowFunction {
503            function: self.function,
504            over: self.over.unwrap_or_default(),
505            alias: self.alias,
506        }
507    }
508}
509
510/// A named window definition.
511#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
512pub struct NamedWindow {
513    /// Window name.
514    pub name: String,
515    /// Window specification.
516    pub spec: WindowSpec,
517}
518
519impl NamedWindow {
520    /// Create a new named window.
521    pub fn new(name: impl Into<String>, spec: WindowSpec) -> Self {
522        Self {
523            name: name.into(),
524            spec,
525        }
526    }
527
528    /// Generate the WINDOW clause definition.
529    pub fn to_sql(&self, db_type: DatabaseType) -> String {
530        // Named windows generate just the spec content (without OVER)
531        let spec_parts = {
532            let mut parts = Vec::new();
533            if !self.spec.partition_by.is_empty() {
534                parts.push(format!(
535                    "PARTITION BY {}",
536                    self.spec.partition_by.join(", ")
537                ));
538            }
539            if !self.spec.order_by.is_empty() {
540                let orders: Vec<String> = self
541                    .spec
542                    .order_by
543                    .iter()
544                    .map(|o| {
545                        format!(
546                            "{} {}",
547                            o.expr,
548                            match o.direction {
549                                SortOrder::Asc => "ASC",
550                                SortOrder::Desc => "DESC",
551                            }
552                        )
553                    })
554                    .collect();
555                parts.push(format!("ORDER BY {}", orders.join(", ")));
556            }
557            if let Some(ref frame) = self.spec.frame {
558                parts.push(frame.to_sql(db_type));
559            }
560            parts.join(" ")
561        };
562
563        format!("{} AS ({})", self.name, spec_parts)
564    }
565}
566
567// ============================================================================
568// Helper functions for creating window functions
569// ============================================================================
570
571/// Create ROW_NUMBER() window function.
572pub fn row_number() -> WindowFunctionBuilder {
573    WindowFunction::new(WindowFn::RowNumber)
574}
575
576/// Create RANK() window function.
577pub fn rank() -> WindowFunctionBuilder {
578    WindowFunction::new(WindowFn::Rank)
579}
580
581/// Create DENSE_RANK() window function.
582pub fn dense_rank() -> WindowFunctionBuilder {
583    WindowFunction::new(WindowFn::DenseRank)
584}
585
586/// Create NTILE(n) window function.
587pub fn ntile(n: u32) -> WindowFunctionBuilder {
588    WindowFunction::new(WindowFn::Ntile(n))
589}
590
591/// Create PERCENT_RANK() window function.
592pub fn percent_rank() -> WindowFunctionBuilder {
593    WindowFunction::new(WindowFn::PercentRank)
594}
595
596/// Create CUME_DIST() window function.
597pub fn cume_dist() -> WindowFunctionBuilder {
598    WindowFunction::new(WindowFn::CumeDist)
599}
600
601/// Create LAG() window function.
602pub fn lag(expr: impl Into<String>) -> WindowFunctionBuilder {
603    WindowFunction::new(WindowFn::Lag {
604        expr: expr.into(),
605        offset: None,
606        default: None,
607    })
608}
609
610/// Create LAG() with offset.
611pub fn lag_offset(expr: impl Into<String>, offset: u32) -> WindowFunctionBuilder {
612    WindowFunction::new(WindowFn::Lag {
613        expr: expr.into(),
614        offset: Some(offset),
615        default: None,
616    })
617}
618
619/// Create LAG() with offset and default.
620pub fn lag_full(
621    expr: impl Into<String>,
622    offset: u32,
623    default: impl Into<String>,
624) -> WindowFunctionBuilder {
625    WindowFunction::new(WindowFn::Lag {
626        expr: expr.into(),
627        offset: Some(offset),
628        default: Some(default.into()),
629    })
630}
631
632/// Create LEAD() window function.
633pub fn lead(expr: impl Into<String>) -> WindowFunctionBuilder {
634    WindowFunction::new(WindowFn::Lead {
635        expr: expr.into(),
636        offset: None,
637        default: None,
638    })
639}
640
641/// Create LEAD() with offset.
642pub fn lead_offset(expr: impl Into<String>, offset: u32) -> WindowFunctionBuilder {
643    WindowFunction::new(WindowFn::Lead {
644        expr: expr.into(),
645        offset: Some(offset),
646        default: None,
647    })
648}
649
650/// Create LEAD() with offset and default.
651pub fn lead_full(
652    expr: impl Into<String>,
653    offset: u32,
654    default: impl Into<String>,
655) -> WindowFunctionBuilder {
656    WindowFunction::new(WindowFn::Lead {
657        expr: expr.into(),
658        offset: Some(offset),
659        default: Some(default.into()),
660    })
661}
662
663/// Create FIRST_VALUE() window function.
664pub fn first_value(expr: impl Into<String>) -> WindowFunctionBuilder {
665    WindowFunction::new(WindowFn::FirstValue(expr.into()))
666}
667
668/// Create LAST_VALUE() window function.
669pub fn last_value(expr: impl Into<String>) -> WindowFunctionBuilder {
670    WindowFunction::new(WindowFn::LastValue(expr.into()))
671}
672
673/// Create NTH_VALUE() window function.
674pub fn nth_value(expr: impl Into<String>, n: u32) -> WindowFunctionBuilder {
675    WindowFunction::new(WindowFn::NthValue(expr.into(), n))
676}
677
678/// Create SUM() window function.
679pub fn sum(expr: impl Into<String>) -> WindowFunctionBuilder {
680    WindowFunction::new(WindowFn::Sum(expr.into()))
681}
682
683/// Create AVG() window function.
684pub fn avg(expr: impl Into<String>) -> WindowFunctionBuilder {
685    WindowFunction::new(WindowFn::Avg(expr.into()))
686}
687
688/// Create COUNT() window function.
689pub fn count(expr: impl Into<String>) -> WindowFunctionBuilder {
690    WindowFunction::new(WindowFn::Count(expr.into()))
691}
692
693/// Create MIN() window function.
694pub fn min(expr: impl Into<String>) -> WindowFunctionBuilder {
695    WindowFunction::new(WindowFn::Min(expr.into()))
696}
697
698/// Create MAX() window function.
699pub fn max(expr: impl Into<String>) -> WindowFunctionBuilder {
700    WindowFunction::new(WindowFn::Max(expr.into()))
701}
702
703/// Create a custom window function.
704pub fn custom<I, S>(name: impl Into<String>, args: I) -> WindowFunctionBuilder
705where
706    I: IntoIterator<Item = S>,
707    S: Into<String>,
708{
709    WindowFunction::new(WindowFn::Custom {
710        name: name.into(),
711        args: args.into_iter().map(Into::into).collect(),
712    })
713}
714
715/// MongoDB $setWindowFields support.
716pub mod mongodb {
717    use serde::{Deserialize, Serialize};
718    use serde_json::Value as JsonValue;
719
720    /// A $setWindowFields stage for MongoDB aggregation pipelines.
721    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
722    pub struct SetWindowFields {
723        /// PARTITION BY equivalent.
724        pub partition_by: Option<JsonValue>,
725        /// SORT BY specification.
726        pub sort_by: Option<JsonValue>,
727        /// Output fields with window functions.
728        pub output: serde_json::Map<String, JsonValue>,
729    }
730
731    impl SetWindowFields {
732        /// Create a new $setWindowFields stage.
733        pub fn new() -> SetWindowFieldsBuilder {
734            SetWindowFieldsBuilder::default()
735        }
736
737        /// Convert to BSON document.
738        pub fn to_bson(&self) -> JsonValue {
739            let mut stage = serde_json::Map::new();
740
741            if let Some(ref partition) = self.partition_by {
742                stage.insert("partitionBy".to_string(), partition.clone());
743            }
744
745            if let Some(ref sort) = self.sort_by {
746                stage.insert("sortBy".to_string(), sort.clone());
747            }
748
749            stage.insert("output".to_string(), JsonValue::Object(self.output.clone()));
750
751            serde_json::json!({ "$setWindowFields": stage })
752        }
753    }
754
755    impl Default for SetWindowFields {
756        fn default() -> Self {
757            Self {
758                partition_by: None,
759                sort_by: None,
760                output: serde_json::Map::new(),
761            }
762        }
763    }
764
765    /// Builder for $setWindowFields.
766    #[derive(Debug, Clone, Default)]
767    pub struct SetWindowFieldsBuilder {
768        partition_by: Option<JsonValue>,
769        sort_by: Option<JsonValue>,
770        output: serde_json::Map<String, JsonValue>,
771    }
772
773    impl SetWindowFieldsBuilder {
774        /// Set PARTITION BY.
775        pub fn partition_by(mut self, expr: impl Into<String>) -> Self {
776            self.partition_by = Some(JsonValue::String(format!("${}", expr.into())));
777            self
778        }
779
780        /// Set PARTITION BY with object expression.
781        pub fn partition_by_expr(mut self, expr: JsonValue) -> Self {
782            self.partition_by = Some(expr);
783            self
784        }
785
786        /// Set SORT BY (single field ascending).
787        pub fn sort_by(mut self, field: impl Into<String>) -> Self {
788            let mut sort = serde_json::Map::new();
789            sort.insert(field.into(), JsonValue::Number(1.into()));
790            self.sort_by = Some(JsonValue::Object(sort));
791            self
792        }
793
794        /// Set SORT BY with direction.
795        pub fn sort_by_desc(mut self, field: impl Into<String>) -> Self {
796            let mut sort = serde_json::Map::new();
797            sort.insert(field.into(), JsonValue::Number((-1).into()));
798            self.sort_by = Some(JsonValue::Object(sort));
799            self
800        }
801
802        /// Set SORT BY with multiple fields.
803        pub fn sort_by_fields(mut self, fields: Vec<(&str, i32)>) -> Self {
804            let mut sort = serde_json::Map::new();
805            for (field, dir) in fields {
806                sort.insert(field.to_string(), JsonValue::Number(dir.into()));
807            }
808            self.sort_by = Some(JsonValue::Object(sort));
809            self
810        }
811
812        /// Add $rowNumber output field.
813        pub fn row_number(mut self, output_field: impl Into<String>) -> Self {
814            self.output
815                .insert(output_field.into(), serde_json::json!({ "$rowNumber": {} }));
816            self
817        }
818
819        /// Add $rank output field.
820        pub fn rank(mut self, output_field: impl Into<String>) -> Self {
821            self.output
822                .insert(output_field.into(), serde_json::json!({ "$rank": {} }));
823            self
824        }
825
826        /// Add $denseRank output field.
827        pub fn dense_rank(mut self, output_field: impl Into<String>) -> Self {
828            self.output
829                .insert(output_field.into(), serde_json::json!({ "$denseRank": {} }));
830            self
831        }
832
833        /// Add $sum with window output field.
834        pub fn sum(
835            mut self,
836            output_field: impl Into<String>,
837            input: impl Into<String>,
838            window: Option<MongoWindow>,
839        ) -> Self {
840            let mut spec = serde_json::Map::new();
841            spec.insert(
842                "$sum".to_string(),
843                JsonValue::String(format!("${}", input.into())),
844            );
845            if let Some(w) = window {
846                spec.insert("window".to_string(), w.to_bson());
847            }
848            self.output
849                .insert(output_field.into(), JsonValue::Object(spec));
850            self
851        }
852
853        /// Add $avg with window output field.
854        pub fn avg(
855            mut self,
856            output_field: impl Into<String>,
857            input: impl Into<String>,
858            window: Option<MongoWindow>,
859        ) -> Self {
860            let mut spec = serde_json::Map::new();
861            spec.insert(
862                "$avg".to_string(),
863                JsonValue::String(format!("${}", input.into())),
864            );
865            if let Some(w) = window {
866                spec.insert("window".to_string(), w.to_bson());
867            }
868            self.output
869                .insert(output_field.into(), JsonValue::Object(spec));
870            self
871        }
872
873        /// Add $first output field.
874        pub fn first(mut self, output_field: impl Into<String>, input: impl Into<String>) -> Self {
875            self.output.insert(
876                output_field.into(),
877                serde_json::json!({ "$first": format!("${}", input.into()) }),
878            );
879            self
880        }
881
882        /// Add $last output field.
883        pub fn last(mut self, output_field: impl Into<String>, input: impl Into<String>) -> Self {
884            self.output.insert(
885                output_field.into(),
886                serde_json::json!({ "$last": format!("${}", input.into()) }),
887            );
888            self
889        }
890
891        /// Add $shift (LAG/LEAD equivalent) output field.
892        pub fn shift(
893            mut self,
894            output_field: impl Into<String>,
895            output: impl Into<String>,
896            by: i32,
897            default: Option<JsonValue>,
898        ) -> Self {
899            let mut spec = serde_json::Map::new();
900            spec.insert(
901                "output".to_string(),
902                JsonValue::String(format!("${}", output.into())),
903            );
904            spec.insert("by".to_string(), JsonValue::Number(by.into()));
905            if let Some(def) = default {
906                spec.insert("default".to_string(), def);
907            }
908            self.output
909                .insert(output_field.into(), serde_json::json!({ "$shift": spec }));
910            self
911        }
912
913        /// Add custom window function.
914        pub fn output(mut self, field: impl Into<String>, spec: JsonValue) -> Self {
915            self.output.insert(field.into(), spec);
916            self
917        }
918
919        /// Build the stage.
920        pub fn build(self) -> SetWindowFields {
921            SetWindowFields {
922                partition_by: self.partition_by,
923                sort_by: self.sort_by,
924                output: self.output,
925            }
926        }
927    }
928
929    /// MongoDB window specification.
930    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
931    pub struct MongoWindow {
932        /// Documents array [start, end].
933        pub documents: Option<[WindowBound; 2]>,
934        /// Range array [start, end].
935        pub range: Option<[WindowBound; 2]>,
936        /// Unit for range (day, week, month, etc.).
937        pub unit: Option<String>,
938    }
939
940    /// Window boundary value.
941    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
942    #[serde(untagged)]
943    pub enum WindowBound {
944        /// Numeric offset or "unbounded"/"current".
945        Number(i64),
946        /// String keyword.
947        Keyword(String),
948    }
949
950    impl MongoWindow {
951        /// Documents window (like SQL ROWS).
952        pub fn documents(start: i64, end: i64) -> Self {
953            Self {
954                documents: Some([WindowBound::Number(start), WindowBound::Number(end)]),
955                range: None,
956                unit: None,
957            }
958        }
959
960        /// Unbounded documents window.
961        pub fn documents_unbounded() -> Self {
962            Self {
963                documents: Some([
964                    WindowBound::Keyword("unbounded".to_string()),
965                    WindowBound::Keyword("unbounded".to_string()),
966                ]),
967                range: None,
968                unit: None,
969            }
970        }
971
972        /// Documents from unbounded preceding to current.
973        pub fn documents_to_current() -> Self {
974            Self {
975                documents: Some([
976                    WindowBound::Keyword("unbounded".to_string()),
977                    WindowBound::Keyword("current".to_string()),
978                ]),
979                range: None,
980                unit: None,
981            }
982        }
983
984        /// Range window with unit.
985        pub fn range_with_unit(start: i64, end: i64, unit: impl Into<String>) -> Self {
986            Self {
987                documents: None,
988                range: Some([WindowBound::Number(start), WindowBound::Number(end)]),
989                unit: Some(unit.into()),
990            }
991        }
992
993        /// Convert to BSON.
994        pub fn to_bson(&self) -> JsonValue {
995            let mut window = serde_json::Map::new();
996
997            if let Some(ref docs) = self.documents {
998                let arr: Vec<JsonValue> = docs
999                    .iter()
1000                    .map(|b| match b {
1001                        WindowBound::Number(n) => JsonValue::Number((*n).into()),
1002                        WindowBound::Keyword(s) => JsonValue::String(s.clone()),
1003                    })
1004                    .collect();
1005                window.insert("documents".to_string(), JsonValue::Array(arr));
1006            }
1007
1008            if let Some(ref range) = self.range {
1009                let arr: Vec<JsonValue> = range
1010                    .iter()
1011                    .map(|b| match b {
1012                        WindowBound::Number(n) => JsonValue::Number((*n).into()),
1013                        WindowBound::Keyword(s) => JsonValue::String(s.clone()),
1014                    })
1015                    .collect();
1016                window.insert("range".to_string(), JsonValue::Array(arr));
1017            }
1018
1019            if let Some(ref unit) = self.unit {
1020                window.insert("unit".to_string(), JsonValue::String(unit.clone()));
1021            }
1022
1023            JsonValue::Object(window)
1024        }
1025    }
1026
1027    /// Helper to create a $setWindowFields stage.
1028    pub fn set_window_fields() -> SetWindowFieldsBuilder {
1029        SetWindowFields::new()
1030    }
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035    use super::*;
1036
1037    #[test]
1038    fn test_row_number() {
1039        let wf = row_number()
1040            .over(
1041                WindowSpec::new()
1042                    .partition_by(["dept"])
1043                    .order_by("salary", SortOrder::Desc),
1044            )
1045            .build();
1046
1047        let sql = wf.to_sql(DatabaseType::PostgreSQL);
1048        assert!(sql.contains("ROW_NUMBER()"));
1049        assert!(sql.contains("PARTITION BY dept"));
1050        assert!(sql.contains("ORDER BY salary DESC"));
1051    }
1052
1053    #[test]
1054    fn test_rank_functions() {
1055        let r = rank()
1056            .over(WindowSpec::new().order_by("score", SortOrder::Desc))
1057            .build();
1058        assert!(r.to_sql(DatabaseType::PostgreSQL).contains("RANK()"));
1059
1060        let dr = dense_rank()
1061            .over(WindowSpec::new().order_by("score", SortOrder::Desc))
1062            .build();
1063        assert!(dr.to_sql(DatabaseType::PostgreSQL).contains("DENSE_RANK()"));
1064    }
1065
1066    #[test]
1067    fn test_ntile() {
1068        let wf = ntile(4)
1069            .over(WindowSpec::new().order_by("value", SortOrder::Asc))
1070            .build();
1071
1072        assert!(wf.to_sql(DatabaseType::MySQL).contains("NTILE(4)"));
1073    }
1074
1075    #[test]
1076    fn test_lag_lead() {
1077        let l = lag("price")
1078            .over(WindowSpec::new().order_by("date", SortOrder::Asc))
1079            .build();
1080        assert!(l.to_sql(DatabaseType::PostgreSQL).contains("LAG(price)"));
1081
1082        let l2 = lag_offset("price", 2)
1083            .over(WindowSpec::new().order_by("date", SortOrder::Asc))
1084            .build();
1085        assert!(
1086            l2.to_sql(DatabaseType::PostgreSQL)
1087                .contains("LAG(price, 2)")
1088        );
1089
1090        let l3 = lag_full("price", 1, "0")
1091            .over(WindowSpec::new().order_by("date", SortOrder::Asc))
1092            .build();
1093        assert!(
1094            l3.to_sql(DatabaseType::PostgreSQL)
1095                .contains("LAG(price, 1, 0)")
1096        );
1097
1098        let ld = lead("price")
1099            .over(WindowSpec::new().order_by("date", SortOrder::Asc))
1100            .build();
1101        assert!(ld.to_sql(DatabaseType::PostgreSQL).contains("LEAD(price)"));
1102    }
1103
1104    #[test]
1105    fn test_aggregate_window() {
1106        let s = sum("amount")
1107            .over(
1108                WindowSpec::new()
1109                    .partition_by(["account_id"])
1110                    .order_by("date", SortOrder::Asc)
1111                    .rows_unbounded_preceding(),
1112            )
1113            .alias("running_total")
1114            .build();
1115
1116        let sql = s.to_sql(DatabaseType::PostgreSQL);
1117        assert!(sql.contains("SUM(amount)"));
1118        assert!(sql.contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"));
1119        assert!(sql.contains("AS running_total"));
1120    }
1121
1122    #[test]
1123    fn test_frame_clauses() {
1124        let spec = WindowSpec::new()
1125            .order_by("id", SortOrder::Asc)
1126            .rows(FrameBound::Preceding(3), Some(FrameBound::Following(3)));
1127
1128        let sql = spec.to_sql(DatabaseType::PostgreSQL);
1129        assert!(sql.contains("ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING"));
1130    }
1131
1132    #[test]
1133    fn test_named_window() {
1134        let nw = NamedWindow::new(
1135            "w",
1136            WindowSpec::new()
1137                .partition_by(["dept"])
1138                .order_by("salary", SortOrder::Desc),
1139        );
1140
1141        let sql = nw.to_sql(DatabaseType::PostgreSQL);
1142        assert!(sql.contains("w AS ("));
1143        assert!(sql.contains("PARTITION BY dept"));
1144    }
1145
1146    #[test]
1147    fn test_window_reference() {
1148        let spec = WindowSpec::named("w");
1149        assert_eq!(spec.to_sql(DatabaseType::PostgreSQL), "OVER w");
1150    }
1151
1152    #[test]
1153    fn test_nulls_position() {
1154        let spec = WindowSpec::new().order_by_nulls("value", SortOrder::Desc, NullsPosition::Last);
1155
1156        let pg_sql = spec.to_sql(DatabaseType::PostgreSQL);
1157        assert!(pg_sql.contains("NULLS LAST"));
1158
1159        // MSSQL doesn't support NULLS FIRST/LAST
1160        let mssql_sql = spec.to_sql(DatabaseType::MSSQL);
1161        assert!(!mssql_sql.contains("NULLS"));
1162    }
1163
1164    #[test]
1165    fn test_first_last_value() {
1166        let fv = first_value("salary")
1167            .over(
1168                WindowSpec::new()
1169                    .partition_by(["dept"])
1170                    .order_by("hire_date", SortOrder::Asc),
1171            )
1172            .build();
1173
1174        assert!(
1175            fv.to_sql(DatabaseType::PostgreSQL)
1176                .contains("FIRST_VALUE(salary)")
1177        );
1178
1179        let lv = last_value("salary")
1180            .over(
1181                WindowSpec::new()
1182                    .partition_by(["dept"])
1183                    .order_by("hire_date", SortOrder::Asc)
1184                    .rows(
1185                        FrameBound::UnboundedPreceding,
1186                        Some(FrameBound::UnboundedFollowing),
1187                    ),
1188            )
1189            .build();
1190
1191        assert!(
1192            lv.to_sql(DatabaseType::PostgreSQL)
1193                .contains("LAST_VALUE(salary)")
1194        );
1195    }
1196
1197    mod mongodb_tests {
1198        use super::super::mongodb::*;
1199
1200        #[test]
1201        fn test_row_number() {
1202            let stage = set_window_fields()
1203                .partition_by("state")
1204                .sort_by_desc("quantity")
1205                .row_number("rowNumber")
1206                .build();
1207
1208            let bson = stage.to_bson();
1209            assert!(bson["$setWindowFields"]["output"]["rowNumber"]["$rowNumber"].is_object());
1210        }
1211
1212        #[test]
1213        fn test_rank() {
1214            let stage = set_window_fields()
1215                .sort_by("score")
1216                .rank("ranking")
1217                .dense_rank("denseRanking")
1218                .build();
1219
1220            let bson = stage.to_bson();
1221            assert!(bson["$setWindowFields"]["output"]["ranking"]["$rank"].is_object());
1222            assert!(bson["$setWindowFields"]["output"]["denseRanking"]["$denseRank"].is_object());
1223        }
1224
1225        #[test]
1226        fn test_running_total() {
1227            let stage = set_window_fields()
1228                .partition_by("account")
1229                .sort_by("date")
1230                .sum(
1231                    "runningTotal",
1232                    "amount",
1233                    Some(MongoWindow::documents_to_current()),
1234                )
1235                .build();
1236
1237            let bson = stage.to_bson();
1238            let output = &bson["$setWindowFields"]["output"]["runningTotal"];
1239            assert!(output["$sum"].is_string());
1240            assert!(output["window"]["documents"].is_array());
1241        }
1242
1243        #[test]
1244        fn test_shift_lag() {
1245            let stage = set_window_fields()
1246                .sort_by("date")
1247                .shift("prevPrice", "price", -1, Some(serde_json::json!(0)))
1248                .shift("nextPrice", "price", 1, None)
1249                .build();
1250
1251            let bson = stage.to_bson();
1252            assert!(bson["$setWindowFields"]["output"]["prevPrice"]["$shift"]["by"] == -1);
1253            assert!(bson["$setWindowFields"]["output"]["nextPrice"]["$shift"]["by"] == 1);
1254        }
1255
1256        #[test]
1257        fn test_window_bounds() {
1258            let w = MongoWindow::documents(-3, 3);
1259            let bson = w.to_bson();
1260            assert_eq!(bson["documents"][0], -3);
1261            assert_eq!(bson["documents"][1], 3);
1262
1263            let w2 = MongoWindow::range_with_unit(-7, 0, "day");
1264            let bson2 = w2.to_bson();
1265            assert!(bson2["range"].is_array());
1266            assert_eq!(bson2["unit"], "day");
1267        }
1268    }
1269}