Skip to main content

clicktype_query/
window.rs

1//! Type-safe window functions
2
3use std::marker::PhantomData;
4use clicktype_core::traits::{ClickTable, TypedColumn};
5
6/// Window function trait
7pub trait WindowFunction {
8    fn function_name() -> &'static str;
9}
10
11/// ROW_NUMBER() window function
12pub struct RowNumber;
13impl WindowFunction for RowNumber {
14    fn function_name() -> &'static str {
15        "ROW_NUMBER()"
16    }
17}
18
19/// RANK() window function
20pub struct Rank;
21impl WindowFunction for Rank {
22    fn function_name() -> &'static str {
23        "RANK()"
24    }
25}
26
27/// DENSE_RANK() window function
28pub struct DenseRank;
29impl WindowFunction for DenseRank {
30    fn function_name() -> &'static str {
31        "DENSE_RANK()"
32    }
33}
34
35/// LAG() window function
36pub struct Lag<T> {
37    _marker: PhantomData<T>,
38}
39
40impl<T> Lag<T> {
41    pub fn new() -> Self {
42        Self { _marker: PhantomData }
43    }
44}
45
46/// LEAD() window function
47pub struct Lead<T> {
48    _marker: PhantomData<T>,
49}
50
51impl<T> Lead<T> {
52    pub fn new() -> Self {
53        Self { _marker: PhantomData }
54    }
55}
56
57/// Window frame boundary
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum FrameBoundary {
60    UnboundedPreceding,
61    Preceding(u64),
62    CurrentRow,
63    Following(u64),
64    UnboundedFollowing,
65}
66
67impl FrameBoundary {
68    pub fn to_sql(&self) -> String {
69        match self {
70            FrameBoundary::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
71            FrameBoundary::Preceding(n) => format!("{} PRECEDING", n),
72            FrameBoundary::CurrentRow => "CURRENT ROW".to_string(),
73            FrameBoundary::Following(n) => format!("{} FOLLOWING", n),
74            FrameBoundary::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
75        }
76    }
77}
78
79/// Window frame mode
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum FrameMode {
82    Rows,
83    Range,
84}
85
86impl FrameMode {
87    pub fn to_sql(&self) -> &'static str {
88        match self {
89            FrameMode::Rows => "ROWS",
90            FrameMode::Range => "RANGE",
91        }
92    }
93}
94
95/// Window frame specification
96#[derive(Debug, Clone)]
97pub struct WindowFrame {
98    pub mode: FrameMode,
99    pub start: FrameBoundary,
100    pub end: Option<FrameBoundary>,
101}
102
103impl WindowFrame {
104    pub fn to_sql(&self) -> String {
105        let mut sql = format!("{} BETWEEN {}", self.mode.to_sql(), self.start.to_sql());
106        if let Some(end) = &self.end {
107            sql.push_str(" AND ");
108            sql.push_str(&end.to_sql());
109        } else {
110            sql.push_str(" AND CURRENT ROW");
111        }
112        sql
113    }
114}
115
116/// Window specification builder
117pub struct WindowSpec<T: ClickTable> {
118    _table: PhantomData<T>,
119    partition_by: Vec<String>,
120    order_by: Vec<String>,
121    frame: Option<WindowFrame>,
122}
123
124impl<T: ClickTable> WindowSpec<T> {
125    pub fn new() -> Self {
126        Self {
127            _table: PhantomData,
128            partition_by: Vec::new(),
129            order_by: Vec::new(),
130            frame: None,
131        }
132    }
133
134    pub fn partition_by<C: TypedColumn<Table = T>>(mut self, _col: C) -> Self {
135        self.partition_by.push(C::name().to_string());
136        self
137    }
138
139    pub fn order_by<C: TypedColumn<Table = T>>(mut self, _col: C, desc: bool) -> Self {
140        let order = if desc {
141            format!("{} DESC", C::name())
142        } else {
143            C::name().to_string()
144        };
145        self.order_by.push(order);
146        self
147    }
148
149    pub fn rows_between(mut self, start: FrameBoundary, end: FrameBoundary) -> Self {
150        self.frame = Some(WindowFrame {
151            mode: FrameMode::Rows,
152            start,
153            end: Some(end),
154        });
155        self
156    }
157
158    pub fn range_between(mut self, start: FrameBoundary, end: FrameBoundary) -> Self {
159        self.frame = Some(WindowFrame {
160            mode: FrameMode::Range,
161            start,
162            end: Some(end),
163        });
164        self
165    }
166
167    pub fn to_sql(&self) -> String {
168        let mut parts = Vec::new();
169
170        if !self.partition_by.is_empty() {
171            parts.push(format!("PARTITION BY {}", self.partition_by.join(", ")));
172        }
173
174        if !self.order_by.is_empty() {
175            parts.push(format!("ORDER BY {}", self.order_by.join(", ")));
176        }
177
178        if let Some(frame) = &self.frame {
179            parts.push(frame.to_sql());
180        }
181
182        parts.join(" ")
183    }
184}
185
186impl<T: ClickTable> Default for WindowSpec<T> {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192/// Window function expression
193pub struct WindowExpr<T: ClickTable, F> {
194    _table: PhantomData<T>,
195    _function: PhantomData<F>,
196    function_expr: String,
197    window_spec: WindowSpec<T>,
198}
199
200// Generic implementation for all WindowExpr types
201impl<T: ClickTable, F> WindowExpr<T, F> {
202    pub fn to_sql(&self) -> String {
203        format!("{} OVER ({})", self.function_expr, self.window_spec.to_sql())
204    }
205
206    pub fn alias(self, alias: &str) -> String {
207        format!("{} AS {}", self.to_sql(), alias)
208    }
209}
210
211// Implementation for standard window functions
212impl<T: ClickTable, F: WindowFunction> WindowExpr<T, F> {
213    pub fn new(window_spec: WindowSpec<T>) -> Self {
214        Self {
215            _table: PhantomData,
216            _function: PhantomData,
217            function_expr: F::function_name().to_string(),
218            window_spec,
219        }
220    }
221}
222
223/// LAG/LEAD window expressions with column argument
224impl<T: ClickTable, ColType> WindowExpr<T, Lag<ColType>> {
225    pub fn with_column<C: TypedColumn<Table = T, Type = ColType>>(
226        _col: C,
227        window_spec: WindowSpec<T>,
228    ) -> Self {
229        Self {
230            _table: PhantomData,
231            _function: PhantomData,
232            function_expr: format!("LAG({})", C::name()),
233            window_spec,
234        }
235    }
236
237    pub fn with_column_and_offset<C: TypedColumn<Table = T, Type = ColType>>(
238        _col: C,
239        offset: u64,
240        window_spec: WindowSpec<T>,
241    ) -> Self {
242        Self {
243            _table: PhantomData,
244            _function: PhantomData,
245            function_expr: format!("LAG({}, {})", C::name(), offset),
246            window_spec,
247        }
248    }
249}
250
251impl<T: ClickTable, ColType> WindowExpr<T, Lead<ColType>> {
252    pub fn with_column<C: TypedColumn<Table = T, Type = ColType>>(
253        _col: C,
254        window_spec: WindowSpec<T>,
255    ) -> Self {
256        Self {
257            _table: PhantomData,
258            _function: PhantomData,
259            function_expr: format!("LEAD({})", C::name()),
260            window_spec,
261        }
262    }
263
264    pub fn with_column_and_offset<C: TypedColumn<Table = T, Type = ColType>>(
265        _col: C,
266        offset: u64,
267        window_spec: WindowSpec<T>,
268    ) -> Self {
269        Self {
270            _table: PhantomData,
271            _function: PhantomData,
272            function_expr: format!("LEAD({}, {})", C::name(), offset),
273            window_spec,
274        }
275    }
276}
277
278/// Helper functions for creating window functions
279pub fn row_number<T: ClickTable>() -> WindowExpr<T, RowNumber> {
280    WindowExpr::new(WindowSpec::new())
281}
282
283pub fn rank<T: ClickTable>() -> WindowExpr<T, Rank> {
284    WindowExpr::new(WindowSpec::new())
285}
286
287pub fn dense_rank<T: ClickTable>() -> WindowExpr<T, DenseRank> {
288    WindowExpr::new(WindowSpec::new())
289}
290
291pub fn lag<T: ClickTable, ColType, C: TypedColumn<Table = T, Type = ColType>>(
292    col: C,
293) -> WindowExpr<T, Lag<ColType>> {
294    WindowExpr::<T, Lag<ColType>>::with_column(col, WindowSpec::new())
295}
296
297pub fn lead<T: ClickTable, ColType, C: TypedColumn<Table = T, Type = ColType>>(
298    col: C,
299) -> WindowExpr<T, Lead<ColType>> {
300    WindowExpr::<T, Lead<ColType>>::with_column(col, WindowSpec::new())
301}
302
303/// Aggregate window functions (SUM, AVG, etc. with OVER clause)
304pub struct AggregateWindow<T: ClickTable> {
305    _table: PhantomData<T>,
306    function_expr: String,
307}
308
309impl<T: ClickTable> AggregateWindow<T> {
310    pub fn sum<C: TypedColumn<Table = T>>(_col: C) -> Self {
311        Self {
312            _table: PhantomData,
313            function_expr: format!("SUM({})", C::name()),
314        }
315    }
316
317    pub fn avg<C: TypedColumn<Table = T>>(_col: C) -> Self {
318        Self {
319            _table: PhantomData,
320            function_expr: format!("AVG({})", C::name()),
321        }
322    }
323
324    pub fn count<C: TypedColumn<Table = T>>(_col: C) -> Self {
325        Self {
326            _table: PhantomData,
327            function_expr: format!("COUNT({})", C::name()),
328        }
329    }
330
331    pub fn min<C: TypedColumn<Table = T>>(_col: C) -> Self {
332        Self {
333            _table: PhantomData,
334            function_expr: format!("MIN({})", C::name()),
335        }
336    }
337
338    pub fn max<C: TypedColumn<Table = T>>(_col: C) -> Self {
339        Self {
340            _table: PhantomData,
341            function_expr: format!("MAX({})", C::name()),
342        }
343    }
344
345    pub fn over(self, window_spec: WindowSpec<T>) -> String {
346        format!("{} OVER ({})", self.function_expr, window_spec.to_sql())
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_frame_boundary_sql() {
356        assert_eq!(FrameBoundary::UnboundedPreceding.to_sql(), "UNBOUNDED PRECEDING");
357        assert_eq!(FrameBoundary::Preceding(5).to_sql(), "5 PRECEDING");
358        assert_eq!(FrameBoundary::CurrentRow.to_sql(), "CURRENT ROW");
359        assert_eq!(FrameBoundary::Following(3).to_sql(), "3 FOLLOWING");
360        assert_eq!(FrameBoundary::UnboundedFollowing.to_sql(), "UNBOUNDED FOLLOWING");
361    }
362
363    #[test]
364    fn test_frame_mode_sql() {
365        assert_eq!(FrameMode::Rows.to_sql(), "ROWS");
366        assert_eq!(FrameMode::Range.to_sql(), "RANGE");
367    }
368
369    #[test]
370    fn test_window_frame_sql() {
371        let frame = WindowFrame {
372            mode: FrameMode::Rows,
373            start: FrameBoundary::UnboundedPreceding,
374            end: Some(FrameBoundary::CurrentRow),
375        };
376        assert_eq!(
377            frame.to_sql(),
378            "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
379        );
380
381        let frame2 = WindowFrame {
382            mode: FrameMode::Range,
383            start: FrameBoundary::Preceding(5),
384            end: Some(FrameBoundary::Following(5)),
385        };
386        assert_eq!(
387            frame2.to_sql(),
388            "RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING"
389        );
390    }
391}