1use reinhardt_query::prelude::{Alias, Iden};
6use serde::{Deserialize, Serialize};
7use std::fmt;
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub enum AggregateFunc {
12 Count,
14 CountDistinct,
16 Sum,
18 Avg,
20 Max,
22 Min,
24}
25
26impl fmt::Display for AggregateFunc {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 match self {
29 AggregateFunc::Count => write!(f, "COUNT"),
30 AggregateFunc::CountDistinct => write!(f, "COUNT"),
31 AggregateFunc::Sum => write!(f, "SUM"),
32 AggregateFunc::Avg => write!(f, "AVG"),
33 AggregateFunc::Max => write!(f, "MAX"),
34 AggregateFunc::Min => write!(f, "MIN"),
35 }
36 }
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Aggregate {
42 pub func: AggregateFunc,
44 pub field: Option<String>,
46 pub alias: Option<String>,
48 pub distinct: bool,
50}
51
52pub fn validate_identifier(name: &str) -> Result<(), String> {
76 if name.is_empty() {
78 return Err("Identifier cannot be empty".to_string());
79 }
80
81 if name == "*" {
83 return Ok(());
84 }
85
86 if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
88 return Err(format!(
89 "Identifier '{}' contains invalid characters. Only alphanumeric characters and underscores are allowed",
90 name
91 ));
92 }
93
94 if let Some(first_char) = name.chars().next()
96 && first_char.is_numeric()
97 {
98 return Err(format!("Identifier '{}' cannot start with a number", name));
99 }
100
101 Ok(())
102}
103
104impl Aggregate {
105 pub fn count(field: Option<&str>) -> Self {
110 if let Some(f) = field {
111 validate_identifier(f).expect("Invalid field name for COUNT aggregate");
112 }
113 Self {
114 func: AggregateFunc::Count,
115 field: field.map(|s| s.to_string()),
116 alias: None,
117 distinct: false,
118 }
119 }
120
121 pub fn count_all() -> Self {
123 Self {
124 func: AggregateFunc::Count,
125 field: None,
126 alias: None,
127 distinct: false,
128 }
129 }
130
131 pub fn count_distinct(field: &str) -> Self {
136 validate_identifier(field).expect("Invalid field name for COUNT DISTINCT aggregate");
137 Self {
138 func: AggregateFunc::CountDistinct,
139 field: Some(field.to_string()),
140 alias: None,
141 distinct: true,
142 }
143 }
144
145 pub fn sum(field: &str) -> Self {
150 validate_identifier(field).expect("Invalid field name for SUM aggregate");
151 Self {
152 func: AggregateFunc::Sum,
153 field: Some(field.to_string()),
154 alias: None,
155 distinct: false,
156 }
157 }
158
159 pub fn avg(field: &str) -> Self {
164 validate_identifier(field).expect("Invalid field name for AVG aggregate");
165 Self {
166 func: AggregateFunc::Avg,
167 field: Some(field.to_string()),
168 alias: None,
169 distinct: false,
170 }
171 }
172
173 pub fn max(field: &str) -> Self {
178 validate_identifier(field).expect("Invalid field name for MAX aggregate");
179 Self {
180 func: AggregateFunc::Max,
181 field: Some(field.to_string()),
182 alias: None,
183 distinct: false,
184 }
185 }
186
187 pub fn min(field: &str) -> Self {
192 validate_identifier(field).expect("Invalid field name for MIN aggregate");
193 Self {
194 func: AggregateFunc::Min,
195 field: Some(field.to_string()),
196 alias: None,
197 distinct: false,
198 }
199 }
200
201 pub fn with_alias(mut self, alias: &str) -> Self {
206 validate_identifier(alias).expect("Invalid alias name");
207 self.alias = Some(alias.to_string());
208 self
209 }
210
211 pub fn to_sql(&self) -> String {
213 let mut parts = Vec::new();
214
215 parts.push(self.func.to_string());
217 parts.push("(".to_string());
218
219 if self.distinct && self.field.is_some() {
220 parts.push("DISTINCT ".to_string());
221 }
222
223 match &self.field {
224 Some(field) => {
225 let iden = Alias::new(field);
227 parts.push(iden.to_string());
228 }
229 None => parts.push("*".to_string()),
230 }
231
232 parts.push(")".to_string());
233
234 if let Some(alias) = &self.alias {
235 parts.push(" AS ".to_string());
236 let alias_iden = Alias::new(alias);
238 parts.push(alias_iden.to_string());
239 }
240
241 parts.join("")
242 }
243
244 pub fn to_sql_expr(&self) -> String {
247 let mut parts = Vec::new();
248
249 parts.push(self.func.to_string());
250 parts.push("(".to_string());
251
252 if self.distinct && self.field.is_some() {
253 parts.push("DISTINCT ".to_string());
254 }
255
256 match &self.field {
257 Some(field) => {
258 let iden = Alias::new(field);
260 parts.push(iden.to_string());
261 }
262 None => parts.push("*".to_string()),
263 }
264
265 parts.push(")".to_string());
266
267 parts.join("")
268 }
269}
270
271#[derive(Debug, Clone)]
273pub enum AggregateValue {
274 Int(i64),
276 Float(f64),
278 Null,
280}
281
282#[derive(Debug, Clone)]
284pub struct AggregateResult {
285 pub values: std::collections::HashMap<String, AggregateValue>,
287}
288
289impl AggregateResult {
290 pub fn new() -> Self {
292 Self {
293 values: std::collections::HashMap::new(),
294 }
295 }
296
297 pub fn get(&self, alias: &str) -> Option<&AggregateValue> {
299 self.values.get(alias)
300 }
301
302 pub fn insert(&mut self, alias: String, value: AggregateValue) {
304 self.values.insert(alias, value);
305 }
306}
307
308impl Default for AggregateResult {
309 fn default() -> Self {
310 Self::new()
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_validate_identifier_valid() {
320 assert!(validate_identifier("user_id").is_ok());
321 assert!(validate_identifier("name123").is_ok());
322 assert!(validate_identifier("_internal").is_ok());
323 assert!(validate_identifier("CamelCase").is_ok());
324 assert!(validate_identifier("*").is_ok()); }
326
327 #[test]
328 fn test_validate_identifier_invalid() {
329 assert!(validate_identifier("123invalid").is_err());
331
332 assert!(validate_identifier("user-id").is_err());
334 assert!(validate_identifier("user.name").is_err());
335 assert!(validate_identifier("user name").is_err());
336
337 assert!(validate_identifier("user; DROP TABLE").is_err());
339 assert!(validate_identifier("id' OR '1'='1").is_err());
340 assert!(validate_identifier("id); DELETE FROM users; --").is_err());
341
342 assert!(validate_identifier("").is_err());
344 }
345
346 #[test]
347 #[should_panic(expected = "Invalid field name")]
348 fn test_aggregate_rejects_invalid_field() {
349 Aggregate::sum("amount; DROP TABLE users");
351 }
352
353 #[test]
354 #[should_panic(expected = "Invalid alias")]
355 fn test_aggregate_rejects_invalid_alias() {
356 Aggregate::sum("amount").with_alias("total; DROP TABLE");
358 }
359
360 #[test]
361 fn test_aggregate_escapes_identifiers() {
362 let agg = Aggregate::sum("user_id");
364 let sql = agg.to_sql();
365
366 assert!(sql.contains("user_id"));
368 assert_eq!(sql, "SUM(user_id)");
370 }
371
372 #[test]
373 fn test_count_aggregate() {
374 let agg = Aggregate::count(Some("id"));
375 assert_eq!(agg.to_sql(), "COUNT(id)");
376 }
377
378 #[test]
379 fn test_count_all_aggregate() {
380 let agg = Aggregate::count_all();
381 assert_eq!(agg.to_sql(), "COUNT(*)");
382 }
383
384 #[test]
385 fn test_count_distinct_aggregate() {
386 let agg = Aggregate::count_distinct("user_id");
387 assert_eq!(agg.to_sql(), "COUNT(DISTINCT user_id)");
388 }
389
390 #[test]
391 fn test_sum_aggregate() {
392 let agg = Aggregate::sum("amount");
393 assert_eq!(agg.to_sql(), "SUM(amount)");
394 }
395
396 #[test]
397 fn test_avg_aggregate() {
398 let agg = Aggregate::avg("score");
399 assert_eq!(agg.to_sql(), "AVG(score)");
400 }
401
402 #[test]
403 fn test_max_aggregate() {
404 let agg = Aggregate::max("price");
405 assert_eq!(agg.to_sql(), "MAX(price)");
406 }
407
408 #[test]
409 fn test_min_aggregate() {
410 let agg = Aggregate::min("age");
411 assert_eq!(agg.to_sql(), "MIN(age)");
412 }
413
414 #[test]
415 fn test_aggregate_with_alias() {
416 let agg = Aggregate::sum("amount").with_alias("total_amount");
417 assert_eq!(agg.to_sql(), "SUM(amount) AS total_amount");
418 }
419}