1use std::marker::PhantomData;
4use clicktype_core::traits::{ClickTable, TypedColumn};
5
6pub trait WindowFunction {
8 fn function_name() -> &'static str;
9}
10
11pub struct RowNumber;
13impl WindowFunction for RowNumber {
14 fn function_name() -> &'static str {
15 "ROW_NUMBER()"
16 }
17}
18
19pub struct Rank;
21impl WindowFunction for Rank {
22 fn function_name() -> &'static str {
23 "RANK()"
24 }
25}
26
27pub struct DenseRank;
29impl WindowFunction for DenseRank {
30 fn function_name() -> &'static str {
31 "DENSE_RANK()"
32 }
33}
34
35pub struct Lag<T> {
37 _marker: PhantomData<T>,
38}
39
40impl<T> Lag<T> {
41 pub fn new() -> Self {
42 Self { _marker: PhantomData }
43 }
44}
45
46pub struct Lead<T> {
48 _marker: PhantomData<T>,
49}
50
51impl<T> Lead<T> {
52 pub fn new() -> Self {
53 Self { _marker: PhantomData }
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum FrameBoundary {
60 UnboundedPreceding,
61 Preceding(u64),
62 CurrentRow,
63 Following(u64),
64 UnboundedFollowing,
65}
66
67impl FrameBoundary {
68 pub fn to_sql(&self) -> String {
69 match self {
70 FrameBoundary::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
71 FrameBoundary::Preceding(n) => format!("{} PRECEDING", n),
72 FrameBoundary::CurrentRow => "CURRENT ROW".to_string(),
73 FrameBoundary::Following(n) => format!("{} FOLLOWING", n),
74 FrameBoundary::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum FrameMode {
82 Rows,
83 Range,
84}
85
86impl FrameMode {
87 pub fn to_sql(&self) -> &'static str {
88 match self {
89 FrameMode::Rows => "ROWS",
90 FrameMode::Range => "RANGE",
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct WindowFrame {
98 pub mode: FrameMode,
99 pub start: FrameBoundary,
100 pub end: Option<FrameBoundary>,
101}
102
103impl WindowFrame {
104 pub fn to_sql(&self) -> String {
105 let mut sql = format!("{} BETWEEN {}", self.mode.to_sql(), self.start.to_sql());
106 if let Some(end) = &self.end {
107 sql.push_str(" AND ");
108 sql.push_str(&end.to_sql());
109 } else {
110 sql.push_str(" AND CURRENT ROW");
111 }
112 sql
113 }
114}
115
116pub struct WindowSpec<T: ClickTable> {
118 _table: PhantomData<T>,
119 partition_by: Vec<String>,
120 order_by: Vec<String>,
121 frame: Option<WindowFrame>,
122}
123
124impl<T: ClickTable> WindowSpec<T> {
125 pub fn new() -> Self {
126 Self {
127 _table: PhantomData,
128 partition_by: Vec::new(),
129 order_by: Vec::new(),
130 frame: None,
131 }
132 }
133
134 pub fn partition_by<C: TypedColumn<Table = T>>(mut self, _col: C) -> Self {
135 self.partition_by.push(C::name().to_string());
136 self
137 }
138
139 pub fn order_by<C: TypedColumn<Table = T>>(mut self, _col: C, desc: bool) -> Self {
140 let order = if desc {
141 format!("{} DESC", C::name())
142 } else {
143 C::name().to_string()
144 };
145 self.order_by.push(order);
146 self
147 }
148
149 pub fn rows_between(mut self, start: FrameBoundary, end: FrameBoundary) -> Self {
150 self.frame = Some(WindowFrame {
151 mode: FrameMode::Rows,
152 start,
153 end: Some(end),
154 });
155 self
156 }
157
158 pub fn range_between(mut self, start: FrameBoundary, end: FrameBoundary) -> Self {
159 self.frame = Some(WindowFrame {
160 mode: FrameMode::Range,
161 start,
162 end: Some(end),
163 });
164 self
165 }
166
167 pub fn to_sql(&self) -> String {
168 let mut parts = Vec::new();
169
170 if !self.partition_by.is_empty() {
171 parts.push(format!("PARTITION BY {}", self.partition_by.join(", ")));
172 }
173
174 if !self.order_by.is_empty() {
175 parts.push(format!("ORDER BY {}", self.order_by.join(", ")));
176 }
177
178 if let Some(frame) = &self.frame {
179 parts.push(frame.to_sql());
180 }
181
182 parts.join(" ")
183 }
184}
185
186impl<T: ClickTable> Default for WindowSpec<T> {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192pub struct WindowExpr<T: ClickTable, F> {
194 _table: PhantomData<T>,
195 _function: PhantomData<F>,
196 function_expr: String,
197 window_spec: WindowSpec<T>,
198}
199
200impl<T: ClickTable, F> WindowExpr<T, F> {
202 pub fn to_sql(&self) -> String {
203 format!("{} OVER ({})", self.function_expr, self.window_spec.to_sql())
204 }
205
206 pub fn alias(self, alias: &str) -> String {
207 format!("{} AS {}", self.to_sql(), alias)
208 }
209}
210
211impl<T: ClickTable, F: WindowFunction> WindowExpr<T, F> {
213 pub fn new(window_spec: WindowSpec<T>) -> Self {
214 Self {
215 _table: PhantomData,
216 _function: PhantomData,
217 function_expr: F::function_name().to_string(),
218 window_spec,
219 }
220 }
221}
222
223impl<T: ClickTable, ColType> WindowExpr<T, Lag<ColType>> {
225 pub fn with_column<C: TypedColumn<Table = T, Type = ColType>>(
226 _col: C,
227 window_spec: WindowSpec<T>,
228 ) -> Self {
229 Self {
230 _table: PhantomData,
231 _function: PhantomData,
232 function_expr: format!("LAG({})", C::name()),
233 window_spec,
234 }
235 }
236
237 pub fn with_column_and_offset<C: TypedColumn<Table = T, Type = ColType>>(
238 _col: C,
239 offset: u64,
240 window_spec: WindowSpec<T>,
241 ) -> Self {
242 Self {
243 _table: PhantomData,
244 _function: PhantomData,
245 function_expr: format!("LAG({}, {})", C::name(), offset),
246 window_spec,
247 }
248 }
249}
250
251impl<T: ClickTable, ColType> WindowExpr<T, Lead<ColType>> {
252 pub fn with_column<C: TypedColumn<Table = T, Type = ColType>>(
253 _col: C,
254 window_spec: WindowSpec<T>,
255 ) -> Self {
256 Self {
257 _table: PhantomData,
258 _function: PhantomData,
259 function_expr: format!("LEAD({})", C::name()),
260 window_spec,
261 }
262 }
263
264 pub fn with_column_and_offset<C: TypedColumn<Table = T, Type = ColType>>(
265 _col: C,
266 offset: u64,
267 window_spec: WindowSpec<T>,
268 ) -> Self {
269 Self {
270 _table: PhantomData,
271 _function: PhantomData,
272 function_expr: format!("LEAD({}, {})", C::name(), offset),
273 window_spec,
274 }
275 }
276}
277
278pub fn row_number<T: ClickTable>() -> WindowExpr<T, RowNumber> {
280 WindowExpr::new(WindowSpec::new())
281}
282
283pub fn rank<T: ClickTable>() -> WindowExpr<T, Rank> {
284 WindowExpr::new(WindowSpec::new())
285}
286
287pub fn dense_rank<T: ClickTable>() -> WindowExpr<T, DenseRank> {
288 WindowExpr::new(WindowSpec::new())
289}
290
291pub fn lag<T: ClickTable, ColType, C: TypedColumn<Table = T, Type = ColType>>(
292 col: C,
293) -> WindowExpr<T, Lag<ColType>> {
294 WindowExpr::<T, Lag<ColType>>::with_column(col, WindowSpec::new())
295}
296
297pub fn lead<T: ClickTable, ColType, C: TypedColumn<Table = T, Type = ColType>>(
298 col: C,
299) -> WindowExpr<T, Lead<ColType>> {
300 WindowExpr::<T, Lead<ColType>>::with_column(col, WindowSpec::new())
301}
302
303pub struct AggregateWindow<T: ClickTable> {
305 _table: PhantomData<T>,
306 function_expr: String,
307}
308
309impl<T: ClickTable> AggregateWindow<T> {
310 pub fn sum<C: TypedColumn<Table = T>>(_col: C) -> Self {
311 Self {
312 _table: PhantomData,
313 function_expr: format!("SUM({})", C::name()),
314 }
315 }
316
317 pub fn avg<C: TypedColumn<Table = T>>(_col: C) -> Self {
318 Self {
319 _table: PhantomData,
320 function_expr: format!("AVG({})", C::name()),
321 }
322 }
323
324 pub fn count<C: TypedColumn<Table = T>>(_col: C) -> Self {
325 Self {
326 _table: PhantomData,
327 function_expr: format!("COUNT({})", C::name()),
328 }
329 }
330
331 pub fn min<C: TypedColumn<Table = T>>(_col: C) -> Self {
332 Self {
333 _table: PhantomData,
334 function_expr: format!("MIN({})", C::name()),
335 }
336 }
337
338 pub fn max<C: TypedColumn<Table = T>>(_col: C) -> Self {
339 Self {
340 _table: PhantomData,
341 function_expr: format!("MAX({})", C::name()),
342 }
343 }
344
345 pub fn over(self, window_spec: WindowSpec<T>) -> String {
346 format!("{} OVER ({})", self.function_expr, window_spec.to_sql())
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_frame_boundary_sql() {
356 assert_eq!(FrameBoundary::UnboundedPreceding.to_sql(), "UNBOUNDED PRECEDING");
357 assert_eq!(FrameBoundary::Preceding(5).to_sql(), "5 PRECEDING");
358 assert_eq!(FrameBoundary::CurrentRow.to_sql(), "CURRENT ROW");
359 assert_eq!(FrameBoundary::Following(3).to_sql(), "3 FOLLOWING");
360 assert_eq!(FrameBoundary::UnboundedFollowing.to_sql(), "UNBOUNDED FOLLOWING");
361 }
362
363 #[test]
364 fn test_frame_mode_sql() {
365 assert_eq!(FrameMode::Rows.to_sql(), "ROWS");
366 assert_eq!(FrameMode::Range.to_sql(), "RANGE");
367 }
368
369 #[test]
370 fn test_window_frame_sql() {
371 let frame = WindowFrame {
372 mode: FrameMode::Rows,
373 start: FrameBoundary::UnboundedPreceding,
374 end: Some(FrameBoundary::CurrentRow),
375 };
376 assert_eq!(
377 frame.to_sql(),
378 "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
379 );
380
381 let frame2 = WindowFrame {
382 mode: FrameMode::Range,
383 start: FrameBoundary::Preceding(5),
384 end: Some(FrameBoundary::Following(5)),
385 };
386 assert_eq!(
387 frame2.to_sql(),
388 "RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING"
389 );
390 }
391}