Skip to main content

reinhardt_db/orm/
aggregation.rs

1//! Aggregation functions for database queries
2//!
3//! This module provides Django-inspired aggregation functionality.
4
5use reinhardt_query::prelude::{Alias, Iden};
6use serde::{Deserialize, Serialize};
7use std::fmt;
8
9/// Aggregate function types
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub enum AggregateFunc {
12	/// COUNT aggregation
13	Count,
14	/// COUNT DISTINCT aggregation
15	CountDistinct,
16	/// SUM aggregation
17	Sum,
18	/// AVG aggregation
19	Avg,
20	/// MAX aggregation
21	Max,
22	/// MIN aggregation
23	Min,
24}
25
26impl fmt::Display for AggregateFunc {
27	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28		match self {
29			AggregateFunc::Count => write!(f, "COUNT"),
30			AggregateFunc::CountDistinct => write!(f, "COUNT"),
31			AggregateFunc::Sum => write!(f, "SUM"),
32			AggregateFunc::Avg => write!(f, "AVG"),
33			AggregateFunc::Max => write!(f, "MAX"),
34			AggregateFunc::Min => write!(f, "MIN"),
35		}
36	}
37}
38
39/// Aggregate expression
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Aggregate {
42	/// The aggregate function
43	pub func: AggregateFunc,
44	/// The field to aggregate (None for COUNT(*))
45	pub field: Option<String>,
46	/// Alias for the result
47	pub alias: Option<String>,
48	/// Whether this is a DISTINCT aggregation
49	pub distinct: bool,
50}
51
52/// Validates an SQL identifier (column name, alias, etc.)
53///
54/// This function checks that the identifier follows safe SQL naming conventions:
55/// - Non-empty
56/// - Contains only alphanumeric characters and underscores
57/// - Does not start with a number
58///
59/// # Arguments
60/// * `name` - The identifier to validate
61///
62/// # Returns
63/// * `Ok(())` if the identifier is valid
64/// * `Err(String)` with error message if invalid
65///
66/// # Examples
67/// ```
68/// # use reinhardt_db::orm::aggregation::validate_identifier;
69/// assert!(validate_identifier("user_id").is_ok());
70/// assert!(validate_identifier("name123").is_ok());
71/// assert!(validate_identifier("123invalid").is_err()); // Starts with number
72/// assert!(validate_identifier("user-id").is_err());     // Contains hyphen
73/// assert!(validate_identifier("user; DROP TABLE").is_err()); // SQL injection attempt
74/// ```
75pub fn validate_identifier(name: &str) -> Result<(), String> {
76	// Check for empty string
77	if name.is_empty() {
78		return Err("Identifier cannot be empty".to_string());
79	}
80
81	// Check for wildcard (special case - allowed)
82	if name == "*" {
83		return Ok(());
84	}
85
86	// Check that all characters are alphanumeric or underscore
87	if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
88		return Err(format!(
89			"Identifier '{}' contains invalid characters. Only alphanumeric characters and underscores are allowed",
90			name
91		));
92	}
93
94	// Check that it doesn't start with a number
95	if let Some(first_char) = name.chars().next()
96		&& first_char.is_numeric()
97	{
98		return Err(format!("Identifier '{}' cannot start with a number", name));
99	}
100
101	Ok(())
102}
103
104impl Aggregate {
105	/// Create a COUNT aggregate
106	///
107	/// # Panics
108	/// Panics if the field name contains invalid characters
109	pub fn count(field: Option<&str>) -> Self {
110		if let Some(f) = field {
111			validate_identifier(f).expect("Invalid field name for COUNT aggregate");
112		}
113		Self {
114			func: AggregateFunc::Count,
115			field: field.map(|s| s.to_string()),
116			alias: None,
117			distinct: false,
118		}
119	}
120
121	/// Create a COUNT(*) aggregate
122	pub fn count_all() -> Self {
123		Self {
124			func: AggregateFunc::Count,
125			field: None,
126			alias: None,
127			distinct: false,
128		}
129	}
130
131	/// Create a COUNT DISTINCT aggregate
132	///
133	/// # Panics
134	/// Panics if the field name contains invalid characters
135	pub fn count_distinct(field: &str) -> Self {
136		validate_identifier(field).expect("Invalid field name for COUNT DISTINCT aggregate");
137		Self {
138			func: AggregateFunc::CountDistinct,
139			field: Some(field.to_string()),
140			alias: None,
141			distinct: true,
142		}
143	}
144
145	/// Create a SUM aggregate
146	///
147	/// # Panics
148	/// Panics if the field name contains invalid characters
149	pub fn sum(field: &str) -> Self {
150		validate_identifier(field).expect("Invalid field name for SUM aggregate");
151		Self {
152			func: AggregateFunc::Sum,
153			field: Some(field.to_string()),
154			alias: None,
155			distinct: false,
156		}
157	}
158
159	/// Create an AVG aggregate
160	///
161	/// # Panics
162	/// Panics if the field name contains invalid characters
163	pub fn avg(field: &str) -> Self {
164		validate_identifier(field).expect("Invalid field name for AVG aggregate");
165		Self {
166			func: AggregateFunc::Avg,
167			field: Some(field.to_string()),
168			alias: None,
169			distinct: false,
170		}
171	}
172
173	/// Create a MAX aggregate
174	///
175	/// # Panics
176	/// Panics if the field name contains invalid characters
177	pub fn max(field: &str) -> Self {
178		validate_identifier(field).expect("Invalid field name for MAX aggregate");
179		Self {
180			func: AggregateFunc::Max,
181			field: Some(field.to_string()),
182			alias: None,
183			distinct: false,
184		}
185	}
186
187	/// Create a MIN aggregate
188	///
189	/// # Panics
190	/// Panics if the field name contains invalid characters
191	pub fn min(field: &str) -> Self {
192		validate_identifier(field).expect("Invalid field name for MIN aggregate");
193		Self {
194			func: AggregateFunc::Min,
195			field: Some(field.to_string()),
196			alias: None,
197			distinct: false,
198		}
199	}
200
201	/// Set an alias for this aggregate
202	///
203	/// # Panics
204	/// Panics if the alias contains invalid characters
205	pub fn with_alias(mut self, alias: &str) -> Self {
206		validate_identifier(alias).expect("Invalid alias name");
207		self.alias = Some(alias.to_string());
208		self
209	}
210
211	/// Convert to SQL string using reinhardt-query for safe identifier escaping
212	pub fn to_sql(&self) -> String {
213		let mut parts = Vec::new();
214
215		// Build aggregate expression
216		parts.push(self.func.to_string());
217		parts.push("(".to_string());
218
219		if self.distinct && self.field.is_some() {
220			parts.push("DISTINCT ".to_string());
221		}
222
223		match &self.field {
224			Some(field) => {
225				// Use reinhardt-query's Alias to safely escape the identifier
226				let iden = Alias::new(field);
227				parts.push(iden.to_string());
228			}
229			None => parts.push("*".to_string()),
230		}
231
232		parts.push(")".to_string());
233
234		if let Some(alias) = &self.alias {
235			parts.push(" AS ".to_string());
236			// Safely escape the alias identifier
237			let alias_iden = Alias::new(alias);
238			parts.push(alias_iden.to_string());
239		}
240
241		parts.join("")
242	}
243
244	/// Convert to SQL string without alias (for use in SELECT expressions with expr_as)
245	/// Uses reinhardt-query for safe identifier escaping
246	pub fn to_sql_expr(&self) -> String {
247		let mut parts = Vec::new();
248
249		parts.push(self.func.to_string());
250		parts.push("(".to_string());
251
252		if self.distinct && self.field.is_some() {
253			parts.push("DISTINCT ".to_string());
254		}
255
256		match &self.field {
257			Some(field) => {
258				// Use reinhardt-query's Alias to safely escape the identifier
259				let iden = Alias::new(field);
260				parts.push(iden.to_string());
261			}
262			None => parts.push("*".to_string()),
263		}
264
265		parts.push(")".to_string());
266
267		parts.join("")
268	}
269}
270
271/// Result of an aggregation
272#[derive(Debug, Clone)]
273pub enum AggregateValue {
274	/// Integer value
275	Int(i64),
276	/// Float value
277	Float(f64),
278	/// Null value
279	Null,
280}
281
282/// Result container for aggregation queries
283#[derive(Debug, Clone)]
284pub struct AggregateResult {
285	/// Map of alias to value
286	pub values: std::collections::HashMap<String, AggregateValue>,
287}
288
289impl AggregateResult {
290	/// Create a new empty result
291	pub fn new() -> Self {
292		Self {
293			values: std::collections::HashMap::new(),
294		}
295	}
296
297	/// Get a value by alias
298	pub fn get(&self, alias: &str) -> Option<&AggregateValue> {
299		self.values.get(alias)
300	}
301
302	/// Insert a value
303	pub fn insert(&mut self, alias: String, value: AggregateValue) {
304		self.values.insert(alias, value);
305	}
306}
307
308impl Default for AggregateResult {
309	fn default() -> Self {
310		Self::new()
311	}
312}
313
314#[cfg(test)]
315mod tests {
316	use super::*;
317
318	#[test]
319	fn test_validate_identifier_valid() {
320		assert!(validate_identifier("user_id").is_ok());
321		assert!(validate_identifier("name123").is_ok());
322		assert!(validate_identifier("_internal").is_ok());
323		assert!(validate_identifier("CamelCase").is_ok());
324		assert!(validate_identifier("*").is_ok()); // Wildcard is allowed
325	}
326
327	#[test]
328	fn test_validate_identifier_invalid() {
329		// Starts with number
330		assert!(validate_identifier("123invalid").is_err());
331
332		// Contains invalid characters
333		assert!(validate_identifier("user-id").is_err());
334		assert!(validate_identifier("user.name").is_err());
335		assert!(validate_identifier("user name").is_err());
336
337		// SQL injection attempts
338		assert!(validate_identifier("user; DROP TABLE").is_err());
339		assert!(validate_identifier("id' OR '1'='1").is_err());
340		assert!(validate_identifier("id); DELETE FROM users; --").is_err());
341
342		// Empty string
343		assert!(validate_identifier("").is_err());
344	}
345
346	#[test]
347	#[should_panic(expected = "Invalid field name")]
348	fn test_aggregate_rejects_invalid_field() {
349		// Should panic when trying to create aggregate with SQL injection attempt
350		Aggregate::sum("amount; DROP TABLE users");
351	}
352
353	#[test]
354	#[should_panic(expected = "Invalid alias")]
355	fn test_aggregate_rejects_invalid_alias() {
356		// Should panic when trying to use invalid alias
357		Aggregate::sum("amount").with_alias("total; DROP TABLE");
358	}
359
360	#[test]
361	fn test_aggregate_escapes_identifiers() {
362		// Test that identifiers are properly escaped using reinhardt-query
363		let agg = Aggregate::sum("user_id");
364		let sql = agg.to_sql();
365
366		// The identifier should be in the output
367		assert!(sql.contains("user_id"));
368		// But it should be properly formatted
369		assert_eq!(sql, "SUM(user_id)");
370	}
371
372	#[test]
373	fn test_count_aggregate() {
374		let agg = Aggregate::count(Some("id"));
375		assert_eq!(agg.to_sql(), "COUNT(id)");
376	}
377
378	#[test]
379	fn test_count_all_aggregate() {
380		let agg = Aggregate::count_all();
381		assert_eq!(agg.to_sql(), "COUNT(*)");
382	}
383
384	#[test]
385	fn test_count_distinct_aggregate() {
386		let agg = Aggregate::count_distinct("user_id");
387		assert_eq!(agg.to_sql(), "COUNT(DISTINCT user_id)");
388	}
389
390	#[test]
391	fn test_sum_aggregate() {
392		let agg = Aggregate::sum("amount");
393		assert_eq!(agg.to_sql(), "SUM(amount)");
394	}
395
396	#[test]
397	fn test_avg_aggregate() {
398		let agg = Aggregate::avg("score");
399		assert_eq!(agg.to_sql(), "AVG(score)");
400	}
401
402	#[test]
403	fn test_max_aggregate() {
404		let agg = Aggregate::max("price");
405		assert_eq!(agg.to_sql(), "MAX(price)");
406	}
407
408	#[test]
409	fn test_min_aggregate() {
410		let agg = Aggregate::min("age");
411		assert_eq!(agg.to_sql(), "MIN(age)");
412	}
413
414	#[test]
415	fn test_aggregate_with_alias() {
416		let agg = Aggregate::sum("amount").with_alias("total_amount");
417		assert_eq!(agg.to_sql(), "SUM(amount) AS total_amount");
418	}
419}