use crate::error::OxirsError;
use crate::model::{Literal, Term};
use crate::rdf_store::VariableBinding;
use crate::sparql::modifiers::compare_terms;
use crate::Result;
use ahash::{AHashMap, AHashSet};
use std::collections::hash_map::Entry;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum AggregateFunction {
Count,
Sum,
Avg,
Min,
Max,
GroupConcat {
separator: String,
},
Sample,
Median,
Variance,
StdDev,
Percentile {
percentile: u8,
}, }
#[derive(Debug, Clone)]
pub struct AggregateExpression {
pub function: AggregateFunction,
pub variable: Option<String>, pub alias: String,
pub distinct: bool, }
#[derive(Debug, Clone)]
pub struct GroupBySpec {
pub variables: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct GroupKey(Vec<TermHash>);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum TermHash {
NamedNode(String),
BlankNode(String),
Literal {
value: String,
datatype: Option<String>,
language: Option<String>,
},
Unbound,
}
impl From<&Term> for TermHash {
fn from(term: &Term) -> Self {
match term {
Term::NamedNode(n) => TermHash::NamedNode(n.as_str().to_string()),
Term::BlankNode(b) => TermHash::BlankNode(b.as_str().to_string()),
Term::Literal(l) => TermHash::Literal {
value: l.value().to_string(),
datatype: Some(l.datatype().as_str().to_string()),
language: l.language().map(|lang| lang.to_string()),
},
Term::Variable(v) => TermHash::NamedNode(format!("?{}", v.as_str())),
Term::QuotedTriple(qt) => TermHash::NamedNode(format!("<<{}>>", qt)),
}
}
}
#[derive(Debug, Clone)]
struct AggregateAccumulator {
function: AggregateFunction,
count: usize,
sum: f64,
values: Vec<Term>,
seen_values: AHashSet<TermHash>, min_value: Option<Term>,
max_value: Option<Term>,
concat_values: Vec<String>, sample_value: Option<Term>, distinct: bool,
}
impl AggregateAccumulator {
fn new(function: AggregateFunction, distinct: bool) -> Self {
Self {
function,
count: 0,
sum: 0.0,
values: Vec::new(),
seen_values: AHashSet::new(),
min_value: None,
max_value: None,
concat_values: Vec::new(),
sample_value: None,
distinct,
}
}
fn add_value(&mut self, term: Option<&Term>) {
let Some(term) = term else {
return;
};
if self.distinct {
let term_hash = TermHash::from(term);
if !self.seen_values.insert(term_hash) {
return; }
}
self.count += 1;
match &self.function {
AggregateFunction::Count => {
}
AggregateFunction::Sum | AggregateFunction::Avg => {
if let Term::Literal(lit) = term {
if let Ok(val) = lit.value().parse::<f64>() {
self.sum += val;
if matches!(self.function, AggregateFunction::Avg) {
self.values.push(term.clone());
}
}
}
}
AggregateFunction::Min => {
if let Some(ref current_min) = self.min_value {
if compare_terms(term, current_min).is_lt() {
self.min_value = Some(term.clone());
}
} else {
self.min_value = Some(term.clone());
}
}
AggregateFunction::Max => {
if let Some(ref current_max) = self.max_value {
if compare_terms(term, current_max).is_gt() {
self.max_value = Some(term.clone());
}
} else {
self.max_value = Some(term.clone());
}
}
AggregateFunction::GroupConcat { .. } => {
if let Term::Literal(lit) = term {
self.concat_values.push(lit.value().to_string());
} else {
self.concat_values.push(term.to_string());
}
}
AggregateFunction::Sample => {
if self.sample_value.is_none() {
self.sample_value = Some(term.clone());
}
}
AggregateFunction::Median
| AggregateFunction::Variance
| AggregateFunction::StdDev
| AggregateFunction::Percentile { .. } => {
if let Term::Literal(lit) = term {
if lit.value().parse::<f64>().is_ok() {
self.values.push(term.clone());
}
}
}
}
}
fn finalize(&self) -> Term {
match &self.function {
AggregateFunction::Count => Term::from(Literal::new(self.count.to_string())),
AggregateFunction::Sum => Term::from(Literal::new(self.sum.to_string())),
AggregateFunction::Avg => {
let avg = if self.count > 0 {
self.sum / self.count as f64
} else {
0.0
};
Term::from(Literal::new(avg.to_string()))
}
AggregateFunction::Min => self
.min_value
.clone()
.unwrap_or_else(|| Term::from(Literal::new(""))),
AggregateFunction::Max => self
.max_value
.clone()
.unwrap_or_else(|| Term::from(Literal::new(""))),
AggregateFunction::GroupConcat { separator } => {
let concatenated = self.concat_values.join(separator);
Term::from(Literal::new(concatenated))
}
AggregateFunction::Sample => self
.sample_value
.clone()
.unwrap_or_else(|| Term::from(Literal::new(""))),
AggregateFunction::Median => {
let result = compute_median(&self.values);
Term::from(Literal::new(result.to_string()))
}
AggregateFunction::Variance => {
let result = compute_variance(&self.values);
Term::from(Literal::new(result.to_string()))
}
AggregateFunction::StdDev => {
let variance = compute_variance(&self.values);
let stddev = variance.sqrt();
Term::from(Literal::new(stddev.to_string()))
}
AggregateFunction::Percentile { percentile } => {
let result = compute_percentile(&self.values, *percentile);
Term::from(Literal::new(result.to_string()))
}
}
}
}
pub fn extract_aggregates(sparql: &str) -> Result<Vec<AggregateExpression>> {
let mut aggregates = Vec::new();
if let Some(select_start) = sparql.to_uppercase().find("SELECT") {
if let Some(where_start) = sparql.to_uppercase().find("WHERE") {
let select_clause = &sparql[select_start + 6..where_start];
let mut pos = 0;
while pos < select_clause.len() {
if let Some(paren_start) = select_clause[pos..].find('(') {
let abs_pos = pos + paren_start;
if let Some(paren_end) = find_matching_paren(&select_clause[abs_pos..]) {
let expr = &select_clause[abs_pos..abs_pos + paren_end + 1];
let expr_upper = expr.to_uppercase();
let function = if expr_upper.starts_with("(COUNT") {
Some(AggregateFunction::Count)
} else if expr_upper.starts_with("(SUM") {
Some(AggregateFunction::Sum)
} else if expr_upper.starts_with("(AVG") {
Some(AggregateFunction::Avg)
} else if expr_upper.starts_with("(MIN") {
Some(AggregateFunction::Min)
} else if expr_upper.starts_with("(MAX") {
Some(AggregateFunction::Max)
} else {
None
};
if let Some(func) = function {
let inner = &expr[1..expr.len() - 1];
let func_name_end = if let Some(inner_paren) = inner.find('(') {
inner_paren
} else {
continue;
};
let after_func = &inner[func_name_end..];
let after_func_upper = after_func.to_uppercase();
let (var_part, alias_part) =
if let Some(as_pos) = after_func_upper.find(" AS ") {
(&after_func[1..as_pos], &after_func[as_pos + 4..])
} else {
(&after_func[1..], "")
};
let args_trimmed = var_part.trim_end_matches(')').trim();
let variable = if args_trimmed == "*" {
None
} else if let Some(var_name) = args_trimmed.strip_prefix('?') {
Some(var_name.to_string())
} else {
Some(args_trimmed.to_string())
};
let mut alias = String::from("aggregate");
if !alias_part.is_empty() {
for token in alias_part.split_whitespace() {
if let Some(var_name) = token.strip_prefix('?') {
alias = var_name.trim_end_matches(')').to_string();
break;
}
}
}
let distinct = expr_upper.contains("DISTINCT");
aggregates.push(AggregateExpression {
function: func,
variable,
alias,
distinct,
});
}
pos = abs_pos + paren_end + 1;
} else {
break;
}
} else {
break;
}
}
}
}
Ok(aggregates)
}
pub fn find_matching_paren(text: &str) -> Option<usize> {
let mut paren_count = 1;
let chars: Vec<char> = text.chars().collect();
for (i, &ch) in chars.iter().enumerate().skip(1) {
if ch == '(' {
paren_count += 1;
} else if ch == ')' {
paren_count -= 1;
if paren_count == 0 {
return Some(i);
}
}
}
None
}
pub fn apply_aggregates(
results: Vec<VariableBinding>,
aggregates: &[AggregateExpression],
) -> Result<(Vec<VariableBinding>, Vec<String>)> {
if aggregates.is_empty() {
return Err(OxirsError::Query("No aggregates to apply".to_string()));
}
apply_aggregates_no_grouping(results, aggregates)
}
pub fn apply_aggregates_with_grouping(
results: Vec<VariableBinding>,
aggregates: &[AggregateExpression],
group_by: &GroupBySpec,
) -> Result<(Vec<VariableBinding>, Vec<String>)> {
if aggregates.is_empty() {
return Err(OxirsError::Query("No aggregates to apply".to_string()));
}
let mut groups: AHashMap<GroupKey, Vec<VariableBinding>> = AHashMap::new();
for binding in results {
let key = extract_group_key(&binding, &group_by.variables);
match groups.entry(key) {
Entry::Occupied(mut entry) => {
entry.get_mut().push(binding);
}
Entry::Vacant(entry) => {
entry.insert(vec![binding]);
}
}
}
#[cfg(feature = "parallel")]
let group_results: Vec<_> = {
let groups_vec: Vec<_> = groups.into_iter().collect();
if groups_vec.len() > 10 {
groups_vec
.into_par_iter()
.map(|(key, group_bindings)| {
process_group(key, group_bindings, aggregates, &group_by.variables)
})
.collect::<Result<Vec<_>>>()?
} else {
groups_vec
.into_iter()
.map(|(key, group_bindings)| {
process_group(key, group_bindings, aggregates, &group_by.variables)
})
.collect::<Result<Vec<_>>>()?
}
};
#[cfg(not(feature = "parallel"))]
let group_results: Vec<_> = groups
.into_iter()
.map(|(key, group_bindings)| {
process_group(key, group_bindings, aggregates, &group_by.variables)
})
.collect::<Result<Vec<_>>>()?;
let mut result_variables = group_by.variables.clone();
for agg_expr in aggregates {
result_variables.push(agg_expr.alias.clone());
}
Ok((group_results, result_variables))
}
fn apply_aggregates_no_grouping(
results: Vec<VariableBinding>,
aggregates: &[AggregateExpression],
) -> Result<(Vec<VariableBinding>, Vec<String>)> {
let mut result_variables = Vec::new();
let mut aggregate_binding = VariableBinding::new();
let mut accumulators: Vec<AggregateAccumulator> = aggregates
.iter()
.map(|agg| AggregateAccumulator::new(agg.function.clone(), agg.distinct))
.collect();
for binding in &results {
for (acc, agg_expr) in accumulators.iter_mut().zip(aggregates.iter()) {
let value = if let Some(var) = &agg_expr.variable {
binding.get(var)
} else {
Some(&Term::from(Literal::new("1")))
};
acc.add_value(value);
}
}
for (acc, agg_expr) in accumulators.iter().zip(aggregates.iter()) {
let value = acc.finalize();
aggregate_binding.bind(agg_expr.alias.clone(), value);
result_variables.push(agg_expr.alias.clone());
}
Ok((vec![aggregate_binding], result_variables))
}
fn extract_group_key(binding: &VariableBinding, group_vars: &[String]) -> GroupKey {
let key_terms: Vec<TermHash> = group_vars
.iter()
.map(|var| {
binding
.get(var)
.map(TermHash::from)
.unwrap_or(TermHash::Unbound)
})
.collect();
GroupKey(key_terms)
}
fn process_group(
_key: GroupKey,
group_bindings: Vec<VariableBinding>,
aggregates: &[AggregateExpression],
group_vars: &[String],
) -> Result<VariableBinding> {
let mut result_binding = VariableBinding::new();
if let Some(first_binding) = group_bindings.first() {
for var in group_vars {
if let Some(value) = first_binding.get(var) {
result_binding.bind(var.clone(), value.clone());
}
}
}
let mut accumulators: Vec<AggregateAccumulator> = aggregates
.iter()
.map(|agg| AggregateAccumulator::new(agg.function.clone(), agg.distinct))
.collect();
for binding in &group_bindings {
for (acc, agg_expr) in accumulators.iter_mut().zip(aggregates.iter()) {
let value = if let Some(var) = &agg_expr.variable {
binding.get(var)
} else {
Some(&Term::from(Literal::new("1")))
};
acc.add_value(value);
}
}
for (acc, agg_expr) in accumulators.iter().zip(aggregates.iter()) {
let value = acc.finalize();
result_binding.bind(agg_expr.alias.clone(), value);
}
Ok(result_binding)
}
fn compute_median(values: &[Term]) -> f64 {
if values.is_empty() {
return 0.0;
}
let mut nums: Vec<f64> = values
.iter()
.filter_map(|term| {
if let Term::Literal(lit) = term {
lit.value().parse::<f64>().ok()
} else {
None
}
})
.collect();
if nums.is_empty() {
return 0.0;
}
nums.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let len = nums.len();
if len % 2 == 0 {
(nums[len / 2 - 1] + nums[len / 2]) / 2.0
} else {
nums[len / 2]
}
}
fn compute_variance(values: &[Term]) -> f64 {
if values.len() < 2 {
return 0.0;
}
let nums: Vec<f64> = values
.iter()
.filter_map(|term| {
if let Term::Literal(lit) = term {
lit.value().parse::<f64>().ok()
} else {
None
}
})
.collect();
if nums.len() < 2 {
return 0.0;
}
let mean = nums.iter().sum::<f64>() / nums.len() as f64;
let squared_diffs: f64 = nums.iter().map(|x| (x - mean).powi(2)).sum();
squared_diffs / (nums.len() - 1) as f64
}
fn compute_percentile(values: &[Term], percentile: u8) -> f64 {
if values.is_empty() || percentile > 100 {
return 0.0;
}
let mut nums: Vec<f64> = values
.iter()
.filter_map(|term| {
if let Term::Literal(lit) = term {
lit.value().parse::<f64>().ok()
} else {
None
}
})
.collect();
if nums.is_empty() {
return 0.0;
}
nums.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
if percentile == 0 {
return nums[0];
}
if percentile == 100 {
return nums[nums.len() - 1];
}
let rank = (percentile as f64 / 100.0) * (nums.len() - 1) as f64;
let lower_index = rank.floor() as usize;
let upper_index = rank.ceil() as usize;
if lower_index == upper_index {
nums[lower_index]
} else {
let lower_value = nums[lower_index];
let upper_value = nums[upper_index];
let fraction = rank - lower_index as f64;
lower_value + fraction * (upper_value - lower_value)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_binding(values: Vec<(&str, f64)>) -> VariableBinding {
let mut binding = VariableBinding::new();
for (var, val) in values {
binding.bind(var.to_string(), Term::from(Literal::new(val.to_string())));
}
binding
}
#[test]
fn test_count_aggregate() {
let results = vec![
create_test_binding(vec![("x", 1.0)]),
create_test_binding(vec![("x", 2.0)]),
create_test_binding(vec![("x", 3.0)]),
];
let agg = AggregateExpression {
function: AggregateFunction::Count,
variable: Some("x".to_string()),
alias: "count".to_string(),
distinct: false,
};
let (result, vars) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
assert_eq!(result.len(), 1);
assert_eq!(vars, vec!["count"]);
if let Term::Literal(lit) = result[0].get("count").expect("binding should exist") {
assert_eq!(lit.value(), "3");
} else {
panic!("Expected literal");
}
}
#[test]
fn test_sum_aggregate() {
let results = vec![
create_test_binding(vec![("x", 10.0)]),
create_test_binding(vec![("x", 20.0)]),
create_test_binding(vec![("x", 30.0)]),
];
let agg = AggregateExpression {
function: AggregateFunction::Sum,
variable: Some("x".to_string()),
alias: "sum".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("sum").expect("binding should exist") {
let sum: f64 = lit.value().parse().expect("parse should succeed");
assert!((sum - 60.0).abs() < 0.0001);
} else {
panic!("Expected literal");
}
}
#[test]
fn test_avg_aggregate() {
let results = vec![
create_test_binding(vec![("x", 10.0)]),
create_test_binding(vec![("x", 20.0)]),
create_test_binding(vec![("x", 30.0)]),
];
let agg = AggregateExpression {
function: AggregateFunction::Avg,
variable: Some("x".to_string()),
alias: "avg".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("avg").expect("binding should exist") {
let avg: f64 = lit.value().parse().expect("parse should succeed");
assert!((avg - 20.0).abs() < 0.0001);
} else {
panic!("Expected literal");
}
}
#[test]
fn test_count_distinct() {
let results = vec![
create_test_binding(vec![("x", 1.0)]),
create_test_binding(vec![("x", 2.0)]),
create_test_binding(vec![("x", 1.0)]), create_test_binding(vec![("x", 3.0)]),
];
let agg = AggregateExpression {
function: AggregateFunction::Count,
variable: Some("x".to_string()),
alias: "count".to_string(),
distinct: true, };
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("count").expect("binding should exist") {
assert_eq!(lit.value(), "3"); } else {
panic!("Expected literal");
}
}
#[test]
fn test_group_concat() {
let mut binding1 = VariableBinding::new();
binding1.bind("x".to_string(), Term::from(Literal::new("apple")));
let mut binding2 = VariableBinding::new();
binding2.bind("x".to_string(), Term::from(Literal::new("banana")));
let mut binding3 = VariableBinding::new();
binding3.bind("x".to_string(), Term::from(Literal::new("cherry")));
let results = vec![binding1, binding2, binding3];
let agg = AggregateExpression {
function: AggregateFunction::GroupConcat {
separator: ", ".to_string(),
},
variable: Some("x".to_string()),
alias: "concat".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("concat").expect("binding should exist") {
assert_eq!(lit.value(), "apple, banana, cherry");
} else {
panic!("Expected literal");
}
}
#[test]
fn test_sample_aggregate() {
let results = vec![
create_test_binding(vec![("x", 10.0)]),
create_test_binding(vec![("x", 20.0)]),
create_test_binding(vec![("x", 30.0)]),
];
let agg = AggregateExpression {
function: AggregateFunction::Sample,
variable: Some("x".to_string()),
alias: "sample".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
assert!(result[0].get("sample").is_some());
}
#[test]
fn test_group_by_hash_based() {
let mut binding1 = VariableBinding::new();
binding1.bind("category".to_string(), Term::from(Literal::new("A")));
binding1.bind("value".to_string(), Term::from(Literal::new("10")));
let mut binding2 = VariableBinding::new();
binding2.bind("category".to_string(), Term::from(Literal::new("A")));
binding2.bind("value".to_string(), Term::from(Literal::new("20")));
let mut binding3 = VariableBinding::new();
binding3.bind("category".to_string(), Term::from(Literal::new("B")));
binding3.bind("value".to_string(), Term::from(Literal::new("30")));
let results = vec![binding1, binding2, binding3];
let agg = AggregateExpression {
function: AggregateFunction::Sum,
variable: Some("value".to_string()),
alias: "total".to_string(),
distinct: false,
};
let group_by = GroupBySpec {
variables: vec!["category".to_string()],
};
let (result, vars) = apply_aggregates_with_grouping(results, &[agg], &group_by)
.expect("aggregate operation should succeed");
assert_eq!(result.len(), 2);
assert_eq!(vars, vec!["category", "total"]);
for binding in &result {
if let Term::Literal(cat) = binding.get("category").expect("binding should exist") {
if let Term::Literal(total) = binding.get("total").expect("binding should exist") {
let total_val: f64 = total.value().parse().expect("parse should succeed");
if cat.value() == "A" {
assert!((total_val - 30.0).abs() < 0.0001); } else if cat.value() == "B" {
assert!((total_val - 30.0).abs() < 0.0001);
}
}
}
}
}
#[test]
fn test_multiple_aggregates() {
let results = vec![
create_test_binding(vec![("x", 10.0)]),
create_test_binding(vec![("x", 20.0)]),
create_test_binding(vec![("x", 30.0)]),
];
let aggregates = vec![
AggregateExpression {
function: AggregateFunction::Count,
variable: Some("x".to_string()),
alias: "count".to_string(),
distinct: false,
},
AggregateExpression {
function: AggregateFunction::Sum,
variable: Some("x".to_string()),
alias: "sum".to_string(),
distinct: false,
},
AggregateExpression {
function: AggregateFunction::Avg,
variable: Some("x".to_string()),
alias: "avg".to_string(),
distinct: false,
},
];
let (result, vars) =
apply_aggregates(results, &aggregates).expect("aggregate operation should succeed");
assert_eq!(result.len(), 1);
assert_eq!(vars, vec!["count", "sum", "avg"]);
assert!(result[0].get("count").is_some());
assert!(result[0].get("sum").is_some());
assert!(result[0].get("avg").is_some());
}
#[test]
fn test_median_aggregate() {
let results = vec![
create_test_binding(vec![("x", 1.0)]),
create_test_binding(vec![("x", 3.0)]),
create_test_binding(vec![("x", 5.0)]),
create_test_binding(vec![("x", 7.0)]),
create_test_binding(vec![("x", 9.0)]),
];
let agg = AggregateExpression {
function: AggregateFunction::Median,
variable: Some("x".to_string()),
alias: "median".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("median").expect("binding should exist") {
let median: f64 = lit.value().parse().expect("parse should succeed");
assert!((median - 5.0).abs() < 0.001);
}
let results = vec![
create_test_binding(vec![("x", 2.0)]),
create_test_binding(vec![("x", 4.0)]),
create_test_binding(vec![("x", 6.0)]),
create_test_binding(vec![("x", 8.0)]),
];
let agg = AggregateExpression {
function: AggregateFunction::Median,
variable: Some("x".to_string()),
alias: "median".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("median").expect("binding should exist") {
let median: f64 = lit.value().parse().expect("parse should succeed");
assert!((median - 5.0).abs() < 0.001); }
}
#[test]
fn test_variance_aggregate() {
let results = vec![
create_test_binding(vec![("x", 2.0)]),
create_test_binding(vec![("x", 4.0)]),
create_test_binding(vec![("x", 6.0)]),
create_test_binding(vec![("x", 8.0)]),
];
let agg = AggregateExpression {
function: AggregateFunction::Variance,
variable: Some("x".to_string()),
alias: "variance".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("variance").expect("binding should exist") {
let variance: f64 = lit.value().parse().expect("parse should succeed");
assert!((variance - 6.666666666666667).abs() < 0.001);
}
}
#[test]
fn test_stddev_aggregate() {
let results = vec![
create_test_binding(vec![("x", 2.0)]),
create_test_binding(vec![("x", 4.0)]),
create_test_binding(vec![("x", 6.0)]),
create_test_binding(vec![("x", 8.0)]),
];
let agg = AggregateExpression {
function: AggregateFunction::StdDev,
variable: Some("x".to_string()),
alias: "stddev".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("stddev").expect("binding should exist") {
let stddev: f64 = lit.value().parse().expect("parse should succeed");
assert!((stddev - 2.581988897471611).abs() < 0.001);
}
}
#[test]
fn test_percentile_aggregate() {
let results = vec![
create_test_binding(vec![("x", 1.0)]),
create_test_binding(vec![("x", 2.0)]),
create_test_binding(vec![("x", 3.0)]),
create_test_binding(vec![("x", 4.0)]),
create_test_binding(vec![("x", 5.0)]),
create_test_binding(vec![("x", 6.0)]),
create_test_binding(vec![("x", 7.0)]),
create_test_binding(vec![("x", 8.0)]),
create_test_binding(vec![("x", 9.0)]),
create_test_binding(vec![("x", 10.0)]),
];
let agg = AggregateExpression {
function: AggregateFunction::Percentile { percentile: 50 },
variable: Some("x".to_string()),
alias: "p50".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results.clone(), &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("p50").expect("binding should exist") {
let p50: f64 = lit.value().parse().expect("parse should succeed");
assert!((p50 - 5.5).abs() < 0.001);
}
let agg = AggregateExpression {
function: AggregateFunction::Percentile { percentile: 95 },
variable: Some("x".to_string()),
alias: "p95".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results.clone(), &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("p95").expect("binding should exist") {
let p95: f64 = lit.value().parse().expect("parse should succeed");
assert!((p95 - 9.55).abs() < 0.01);
}
let agg = AggregateExpression {
function: AggregateFunction::Percentile { percentile: 25 },
variable: Some("x".to_string()),
alias: "p25".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("p25").expect("binding should exist") {
let p25: f64 = lit.value().parse().expect("parse should succeed");
assert!((p25 - 3.25).abs() < 0.01);
}
}
#[test]
fn test_statistical_aggregates_with_grouping() {
let mut binding1 = VariableBinding::new();
binding1.bind("category".to_string(), Term::from(Literal::new("A")));
binding1.bind("value".to_string(), Term::from(Literal::new("10")));
let mut binding2 = VariableBinding::new();
binding2.bind("category".to_string(), Term::from(Literal::new("A")));
binding2.bind("value".to_string(), Term::from(Literal::new("20")));
let mut binding3 = VariableBinding::new();
binding3.bind("category".to_string(), Term::from(Literal::new("A")));
binding3.bind("value".to_string(), Term::from(Literal::new("30")));
let mut binding4 = VariableBinding::new();
binding4.bind("category".to_string(), Term::from(Literal::new("B")));
binding4.bind("value".to_string(), Term::from(Literal::new("5")));
let mut binding5 = VariableBinding::new();
binding5.bind("category".to_string(), Term::from(Literal::new("B")));
binding5.bind("value".to_string(), Term::from(Literal::new("15")));
let results = vec![binding1, binding2, binding3, binding4, binding5];
let agg = AggregateExpression {
function: AggregateFunction::Median,
variable: Some("value".to_string()),
alias: "median".to_string(),
distinct: false,
};
let group_by = GroupBySpec {
variables: vec!["category".to_string()],
};
let (result, _) = apply_aggregates_with_grouping(results, &[agg], &group_by)
.expect("aggregate operation should succeed");
assert_eq!(result.len(), 2);
for binding in &result {
if let Term::Literal(cat) = binding.get("category").expect("binding should exist") {
if let Term::Literal(median) = binding.get("median").expect("binding should exist")
{
let median_val: f64 = median.value().parse().expect("parse should succeed");
if cat.value() == "A" {
assert!((median_val - 20.0).abs() < 0.001);
} else if cat.value() == "B" {
assert!((median_val - 10.0).abs() < 0.001);
}
}
}
}
}
#[test]
fn test_statistical_aggregate_edge_cases() {
let results: Vec<VariableBinding> = vec![];
let agg = AggregateExpression {
function: AggregateFunction::Median,
variable: Some("x".to_string()),
alias: "median".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("median").expect("binding should exist") {
let median: f64 = lit.value().parse().expect("parse should succeed");
assert_eq!(median, 0.0);
}
let results = vec![create_test_binding(vec![("x", 5.0)])];
let agg = AggregateExpression {
function: AggregateFunction::Variance,
variable: Some("x".to_string()),
alias: "variance".to_string(),
distinct: false,
};
let (result, _) =
apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
if let Term::Literal(lit) = result[0].get("variance").expect("binding should exist") {
let variance: f64 = lit.value().parse().expect("parse should succeed");
assert_eq!(variance, 0.0);
}
}
}