use crate::core::Value;
use crate::functions::{
AggregateFunction, FunctionDataType, FunctionInfo, FunctionSignature, FunctionType,
};
use super::DistinctTracker;
#[derive(Clone)]
struct OrderedEntry {
value: Value,
sort_keys: Vec<Value>,
}
#[derive(Default)]
pub struct ArrayAggFunction {
values: Vec<Value>,
ordered_entries: Vec<OrderedEntry>,
order_directions: Vec<bool>,
has_order_by: bool,
distinct_tracker: Option<DistinctTracker>,
}
impl AggregateFunction for ArrayAggFunction {
fn name(&self) -> &str {
"ARRAY_AGG"
}
fn info(&self) -> FunctionInfo {
FunctionInfo::new(
"ARRAY_AGG",
FunctionType::Aggregate,
"Collects all values into a JSON array",
FunctionSignature::new(FunctionDataType::Json, vec![FunctionDataType::Any], 1, 1),
)
}
fn configure(&mut self, _options: &[Value]) {
}
fn set_order_by(&mut self, directions: Vec<bool>) {
self.order_directions = directions;
self.has_order_by = true;
}
fn supports_order_by(&self) -> bool {
true
}
fn accumulate(&mut self, value: &Value, distinct: bool) {
if value.is_null() {
return;
}
if distinct {
if self.distinct_tracker.is_none() {
self.distinct_tracker = Some(DistinctTracker::default());
}
if !self.distinct_tracker.as_mut().unwrap().check_and_add(value) {
return; }
}
self.values.push(value.clone());
}
fn accumulate_with_sort_key(&mut self, value: &Value, sort_keys: Vec<Value>, distinct: bool) {
if value.is_null() {
return;
}
if distinct {
if self.distinct_tracker.is_none() {
self.distinct_tracker = Some(DistinctTracker::default());
}
if !self.distinct_tracker.as_mut().unwrap().check_and_add(value) {
return; }
}
self.ordered_entries.push(OrderedEntry {
value: value.clone(),
sort_keys,
});
}
fn result(&self) -> Value {
let values_to_output: Vec<&Value> = if self.has_order_by && !self.ordered_entries.is_empty()
{
let mut entries: Vec<&OrderedEntry> = self.ordered_entries.iter().collect();
let directions = &self.order_directions;
entries.sort_by(|a, b| {
for (i, (key_a, key_b)) in a.sort_keys.iter().zip(b.sort_keys.iter()).enumerate() {
let is_asc = directions.get(i).copied().unwrap_or(true);
let cmp = compare_values(key_a, key_b);
if cmp != std::cmp::Ordering::Equal {
return if is_asc { cmp } else { cmp.reverse() };
}
}
std::cmp::Ordering::Equal
});
entries.iter().map(|e| &e.value).collect()
} else if !self.values.is_empty() {
self.values.iter().collect()
} else if !self.ordered_entries.is_empty() {
self.ordered_entries.iter().map(|e| &e.value).collect()
} else {
return Value::null_unknown();
};
if values_to_output.is_empty() {
return Value::null_unknown();
}
let json_elements: Vec<String> = values_to_output
.iter()
.map(|v| match v {
Value::Text(s) => {
format!("\"{}\"", s.replace('\\', "\\\\").replace('"', "\\\""))
}
Value::Integer(i) => i.to_string(),
Value::Float(f) => f.to_string(),
Value::Boolean(b) => b.to_string(),
Value::Null(_) => "null".to_string(),
Value::Timestamp(t) => format!("\"{}\"", t),
Value::Extension(data)
if data.first() == Some(&(crate::core::DataType::Json as u8)) =>
{
std::str::from_utf8(&data[1..]).unwrap_or("").to_string()
}
Value::Extension(_) => {
format!(
"\"{}\"",
v.to_string().replace('\\', "\\\\").replace('"', "\\\"")
)
}
})
.collect();
Value::text(format!("[{}]", json_elements.join(",")))
}
fn reset(&mut self) {
self.values.clear();
self.ordered_entries.clear();
self.distinct_tracker = None;
}
fn clone_box(&self) -> Box<dyn AggregateFunction> {
Box::new(ArrayAggFunction::default())
}
}
fn compare_values(a: &Value, b: &Value) -> std::cmp::Ordering {
match (a, b) {
(Value::Null(_), Value::Null(_)) => std::cmp::Ordering::Equal,
(Value::Null(_), _) => std::cmp::Ordering::Greater, (_, Value::Null(_)) => std::cmp::Ordering::Less,
(Value::Integer(a), Value::Integer(b)) => a.cmp(b),
(Value::Float(a), Value::Float(b)) => a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal),
(Value::Integer(a), Value::Float(b)) => (*a as f64)
.partial_cmp(b)
.unwrap_or(std::cmp::Ordering::Equal),
(Value::Float(a), Value::Integer(b)) => a
.partial_cmp(&(*b as f64))
.unwrap_or(std::cmp::Ordering::Equal),
(Value::Text(a), Value::Text(b)) => a.cmp(b),
(Value::Boolean(a), Value::Boolean(b)) => a.cmp(b),
(Value::Timestamp(a), Value::Timestamp(b)) => a.cmp(b),
_ => std::cmp::Ordering::Equal, }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_array_agg_basic() {
let mut agg = ArrayAggFunction::default();
agg.accumulate(&Value::Integer(1), false);
agg.accumulate(&Value::Integer(2), false);
agg.accumulate(&Value::Integer(3), false);
assert_eq!(agg.result(), Value::text("[1,2,3]"));
}
#[test]
fn test_array_agg_strings() {
let mut agg = ArrayAggFunction::default();
agg.accumulate(&Value::text("a"), false);
agg.accumulate(&Value::text("b"), false);
agg.accumulate(&Value::text("c"), false);
assert_eq!(agg.result(), Value::text("[\"a\",\"b\",\"c\"]"));
}
#[test]
fn test_array_agg_ignores_null() {
let mut agg = ArrayAggFunction::default();
agg.accumulate(&Value::Integer(1), false);
agg.accumulate(&Value::null_unknown(), false);
agg.accumulate(&Value::Integer(3), false);
assert_eq!(agg.result(), Value::text("[1,3]"));
}
#[test]
fn test_array_agg_distinct() {
let mut agg = ArrayAggFunction::default();
agg.accumulate(&Value::Integer(1), true);
agg.accumulate(&Value::Integer(2), true);
agg.accumulate(&Value::Integer(1), true); agg.accumulate(&Value::Integer(3), true);
assert_eq!(agg.result(), Value::text("[1,2,3]"));
}
#[test]
fn test_array_agg_empty() {
let agg = ArrayAggFunction::default();
assert!(agg.result().is_null());
}
#[test]
fn test_array_agg_mixed_types() {
let mut agg = ArrayAggFunction::default();
agg.accumulate(&Value::text("str"), false);
agg.accumulate(&Value::Integer(42), false);
agg.accumulate(&Value::Float(3.5), false);
agg.accumulate(&Value::Boolean(true), false);
assert_eq!(agg.result(), Value::text("[\"str\",42,3.5,true]"));
}
#[test]
fn test_array_agg_reset() {
let mut agg = ArrayAggFunction::default();
agg.accumulate(&Value::Integer(1), false);
agg.accumulate(&Value::Integer(2), false);
agg.reset();
assert!(agg.result().is_null());
}
}