use crate::error::FusekiResult;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct EnhancedAggregationProcessor {
functions: HashMap<String, AggregationFunction>,
optimization_cache: HashMap<String, String>,
}
impl Default for EnhancedAggregationProcessor {
fn default() -> Self {
Self::new()
}
}
impl EnhancedAggregationProcessor {
pub fn new() -> Self {
let mut processor = Self {
functions: HashMap::new(),
optimization_cache: HashMap::new(),
};
processor.register_builtin_functions();
processor
}
fn register_builtin_functions(&mut self) {
self.functions.insert(
"GROUP_CONCAT".to_string(),
AggregationFunction {
name: "GROUP_CONCAT".to_string(),
return_type: "literal".to_string(),
supports_distinct: true,
supports_separator: true,
parallel_safe: true,
},
);
self.functions.insert(
"SAMPLE".to_string(),
AggregationFunction {
name: "SAMPLE".to_string(),
return_type: "any".to_string(),
supports_distinct: false,
supports_separator: false,
parallel_safe: true,
},
);
self.functions.insert(
"MEDIAN".to_string(),
AggregationFunction {
name: "MEDIAN".to_string(),
return_type: "numeric".to_string(),
supports_distinct: true,
supports_separator: false,
parallel_safe: false,
},
);
self.functions.insert(
"MODE".to_string(),
AggregationFunction {
name: "MODE".to_string(),
return_type: "any".to_string(),
supports_distinct: false,
supports_separator: false,
parallel_safe: false,
},
);
}
pub fn process_aggregations(&mut self, query: &str) -> FusekiResult<String> {
let mut processed = query.to_string();
let functions: Vec<(String, AggregationFunction)> = self
.functions
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
for (name, function) in functions {
processed = self.process_function(&processed, &name, &function)?;
}
Ok(processed)
}
fn process_function(
&mut self,
query: &str,
function_name: &str,
function: &AggregationFunction,
) -> FusekiResult<String> {
let pattern = format!("{function_name}(");
let mut result = query.to_string();
while let Some(pos) = result.find(&pattern) {
if let Some(func_call) = self.extract_function_call(&result[pos..]) {
let optimized = self.optimize_function_call(&func_call, function)?;
result = result.replace(&func_call, &optimized);
} else {
break;
}
}
Ok(result)
}
fn extract_function_call(&self, text: &str) -> Option<String> {
let mut paren_count = 0;
let mut in_string = false;
let mut escape_next = false;
for (i, ch) in text.char_indices() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' => escape_next = true,
'"' | '\'' => in_string = !in_string,
'(' if !in_string => paren_count += 1,
')' if !in_string => {
paren_count -= 1;
if paren_count == 0 {
return Some(text[..=i].to_string());
}
}
_ => {}
}
}
None
}
fn optimize_function_call(
&mut self,
func_call: &str,
function: &AggregationFunction,
) -> FusekiResult<String> {
if let Some(cached) = self.optimization_cache.get(func_call) {
return Ok(cached.clone());
}
let optimized = match function.name.as_str() {
"GROUP_CONCAT" => self.optimize_group_concat(func_call)?,
"SAMPLE" => self.optimize_sample(func_call)?,
"MEDIAN" => self.optimize_median(func_call)?,
"MODE" => self.optimize_mode(func_call)?,
_ => func_call.to_string(),
};
self.optimization_cache
.insert(func_call.to_string(), optimized.clone());
Ok(optimized)
}
fn optimize_group_concat(&self, func_call: &str) -> FusekiResult<String> {
let args = self.parse_function_args(func_call)?;
if args.is_empty() {
return Ok(func_call.to_string());
}
if args.len() == 1 && !func_call.contains("SEPARATOR") {
let expr = &args[0];
Ok(format!("GROUP_CONCAT({expr} ; SEPARATOR=',')"))
} else if func_call.contains("DISTINCT") {
Ok(format!("OPTIMIZED_{func_call}"))
} else {
Ok(func_call.to_string())
}
}
fn optimize_sample(&self, func_call: &str) -> FusekiResult<String> {
if func_call.contains("DISTINCT") {
Ok(format!(
"DETERMINISTIC_SAMPLE({})",
&func_call[7..func_call.len() - 1]
)) } else {
Ok(func_call.to_string())
}
}
fn optimize_median(&self, _func_call: &str) -> FusekiResult<String> {
Ok(format!("SORTED_{_func_call}"))
}
fn optimize_mode(&self, _func_call: &str) -> FusekiResult<String> {
Ok(format!("GROUPED_{_func_call}"))
}
fn parse_function_args(&self, func_call: &str) -> FusekiResult<Vec<String>> {
let open_paren = func_call
.find('(')
.ok_or_else(|| crate::error::FusekiError::query_parsing("Invalid function call"))?;
let args_str = &func_call[open_paren + 1..func_call.len() - 1];
if args_str.trim().is_empty() {
return Ok(Vec::new());
}
let args: Vec<String> = args_str
.split(',')
.map(|arg| arg.trim().to_string())
.collect();
Ok(args)
}
pub fn is_supported(&self, function_name: &str) -> bool {
self.functions.contains_key(function_name)
}
pub fn get_function(&self, function_name: &str) -> Option<&AggregationFunction> {
self.functions.get(function_name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregationFunction {
pub name: String,
pub return_type: String,
pub supports_distinct: bool,
pub supports_separator: bool,
pub parallel_safe: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_group_concat_optimization() {
let processor = EnhancedAggregationProcessor::new();
assert!(processor.is_supported("GROUP_CONCAT"));
let simple_call = "GROUP_CONCAT(?name)";
let optimized = processor.optimize_group_concat(simple_call).unwrap();
assert!(optimized.contains("SEPARATOR"));
}
#[test]
fn test_sample_optimization() {
let processor = EnhancedAggregationProcessor::new();
assert!(processor.is_supported("SAMPLE"));
let simple_call = "SAMPLE(?value)";
let optimized = processor.optimize_sample(simple_call).unwrap();
assert_eq!(optimized, simple_call); }
#[test]
fn test_function_detection() {
let processor = EnhancedAggregationProcessor::new();
assert!(processor.is_supported("GROUP_CONCAT"));
assert!(processor.is_supported("SAMPLE"));
assert!(processor.is_supported("MEDIAN"));
assert!(processor.is_supported("MODE"));
assert!(!processor.is_supported("UNKNOWN_FUNCTION"));
}
#[test]
fn test_argument_parsing() {
let processor = EnhancedAggregationProcessor::new();
let args = processor.parse_function_args("COUNT(?x)").unwrap();
assert_eq!(args, vec!["?x"]);
let args = processor
.parse_function_args("GROUP_CONCAT(?name, ',')")
.unwrap();
assert_eq!(args, vec!["?name", "'", "'"]);
let args = processor.parse_function_args("SUM()").unwrap();
assert!(args.is_empty());
}
}