use reinhardt_query::prelude::{Alias, Iden};
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum AggregateFunc {
Count,
CountDistinct,
Sum,
Avg,
Max,
Min,
}
impl fmt::Display for AggregateFunc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AggregateFunc::Count => write!(f, "COUNT"),
AggregateFunc::CountDistinct => write!(f, "COUNT"),
AggregateFunc::Sum => write!(f, "SUM"),
AggregateFunc::Avg => write!(f, "AVG"),
AggregateFunc::Max => write!(f, "MAX"),
AggregateFunc::Min => write!(f, "MIN"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Aggregate {
pub func: AggregateFunc,
pub field: Option<String>,
pub alias: Option<String>,
pub distinct: bool,
}
pub fn validate_identifier(name: &str) -> Result<(), String> {
if name.is_empty() {
return Err("Identifier cannot be empty".to_string());
}
if name == "*" {
return Ok(());
}
if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(format!(
"Identifier '{}' contains invalid characters. Only alphanumeric characters and underscores are allowed",
name
));
}
if let Some(first_char) = name.chars().next()
&& first_char.is_numeric()
{
return Err(format!("Identifier '{}' cannot start with a number", name));
}
Ok(())
}
impl Aggregate {
pub fn count(field: Option<&str>) -> Self {
if let Some(f) = field {
validate_identifier(f).expect("Invalid field name for COUNT aggregate");
}
Self {
func: AggregateFunc::Count,
field: field.map(|s| s.to_string()),
alias: None,
distinct: false,
}
}
pub fn count_all() -> Self {
Self {
func: AggregateFunc::Count,
field: None,
alias: None,
distinct: false,
}
}
pub fn count_distinct(field: &str) -> Self {
validate_identifier(field).expect("Invalid field name for COUNT DISTINCT aggregate");
Self {
func: AggregateFunc::CountDistinct,
field: Some(field.to_string()),
alias: None,
distinct: true,
}
}
pub fn sum(field: &str) -> Self {
validate_identifier(field).expect("Invalid field name for SUM aggregate");
Self {
func: AggregateFunc::Sum,
field: Some(field.to_string()),
alias: None,
distinct: false,
}
}
pub fn avg(field: &str) -> Self {
validate_identifier(field).expect("Invalid field name for AVG aggregate");
Self {
func: AggregateFunc::Avg,
field: Some(field.to_string()),
alias: None,
distinct: false,
}
}
pub fn max(field: &str) -> Self {
validate_identifier(field).expect("Invalid field name for MAX aggregate");
Self {
func: AggregateFunc::Max,
field: Some(field.to_string()),
alias: None,
distinct: false,
}
}
pub fn min(field: &str) -> Self {
validate_identifier(field).expect("Invalid field name for MIN aggregate");
Self {
func: AggregateFunc::Min,
field: Some(field.to_string()),
alias: None,
distinct: false,
}
}
pub fn with_alias(mut self, alias: &str) -> Self {
validate_identifier(alias).expect("Invalid alias name");
self.alias = Some(alias.to_string());
self
}
pub fn to_sql(&self) -> String {
let mut parts = Vec::new();
parts.push(self.func.to_string());
parts.push("(".to_string());
if self.distinct && self.field.is_some() {
parts.push("DISTINCT ".to_string());
}
match &self.field {
Some(field) => {
let iden = Alias::new(field);
parts.push(iden.to_string());
}
None => parts.push("*".to_string()),
}
parts.push(")".to_string());
if let Some(alias) = &self.alias {
parts.push(" AS ".to_string());
let alias_iden = Alias::new(alias);
parts.push(alias_iden.to_string());
}
parts.join("")
}
pub fn to_sql_expr(&self) -> String {
let mut parts = Vec::new();
parts.push(self.func.to_string());
parts.push("(".to_string());
if self.distinct && self.field.is_some() {
parts.push("DISTINCT ".to_string());
}
match &self.field {
Some(field) => {
let iden = Alias::new(field);
parts.push(iden.to_string());
}
None => parts.push("*".to_string()),
}
parts.push(")".to_string());
parts.join("")
}
}
#[derive(Debug, Clone)]
pub enum AggregateValue {
Int(i64),
Float(f64),
Null,
}
#[derive(Debug, Clone)]
pub struct AggregateResult {
pub values: std::collections::HashMap<String, AggregateValue>,
}
impl AggregateResult {
pub fn new() -> Self {
Self {
values: std::collections::HashMap::new(),
}
}
pub fn get(&self, alias: &str) -> Option<&AggregateValue> {
self.values.get(alias)
}
pub fn insert(&mut self, alias: String, value: AggregateValue) {
self.values.insert(alias, value);
}
}
impl Default for AggregateResult {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_identifier_valid() {
assert!(validate_identifier("user_id").is_ok());
assert!(validate_identifier("name123").is_ok());
assert!(validate_identifier("_internal").is_ok());
assert!(validate_identifier("CamelCase").is_ok());
assert!(validate_identifier("*").is_ok()); }
#[test]
fn test_validate_identifier_invalid() {
assert!(validate_identifier("123invalid").is_err());
assert!(validate_identifier("user-id").is_err());
assert!(validate_identifier("user.name").is_err());
assert!(validate_identifier("user name").is_err());
assert!(validate_identifier("user; DROP TABLE").is_err());
assert!(validate_identifier("id' OR '1'='1").is_err());
assert!(validate_identifier("id); DELETE FROM users; --").is_err());
assert!(validate_identifier("").is_err());
}
#[test]
#[should_panic(expected = "Invalid field name")]
fn test_aggregate_rejects_invalid_field() {
Aggregate::sum("amount; DROP TABLE users");
}
#[test]
#[should_panic(expected = "Invalid alias")]
fn test_aggregate_rejects_invalid_alias() {
Aggregate::sum("amount").with_alias("total; DROP TABLE");
}
#[test]
fn test_aggregate_escapes_identifiers() {
let agg = Aggregate::sum("user_id");
let sql = agg.to_sql();
assert!(sql.contains("user_id"));
assert_eq!(sql, "SUM(user_id)");
}
#[test]
fn test_count_aggregate() {
let agg = Aggregate::count(Some("id"));
assert_eq!(agg.to_sql(), "COUNT(id)");
}
#[test]
fn test_count_all_aggregate() {
let agg = Aggregate::count_all();
assert_eq!(agg.to_sql(), "COUNT(*)");
}
#[test]
fn test_count_distinct_aggregate() {
let agg = Aggregate::count_distinct("user_id");
assert_eq!(agg.to_sql(), "COUNT(DISTINCT user_id)");
}
#[test]
fn test_sum_aggregate() {
let agg = Aggregate::sum("amount");
assert_eq!(agg.to_sql(), "SUM(amount)");
}
#[test]
fn test_avg_aggregate() {
let agg = Aggregate::avg("score");
assert_eq!(agg.to_sql(), "AVG(score)");
}
#[test]
fn test_max_aggregate() {
let agg = Aggregate::max("price");
assert_eq!(agg.to_sql(), "MAX(price)");
}
#[test]
fn test_min_aggregate() {
let agg = Aggregate::min("age");
assert_eq!(agg.to_sql(), "MIN(age)");
}
#[test]
fn test_aggregate_with_alias() {
let agg = Aggregate::sum("amount").with_alias("total_amount");
assert_eq!(agg.to_sql(), "SUM(amount) AS total_amount");
}
}