use chrono::{DateTime, NaiveDateTime, TimeZone, Utc};
use rust_decimal::Decimal;
use rust_decimal::prelude::FromPrimitive;
use serde_json::{Value, json};
use std::collections::{BTreeMap, HashSet};
use std::str::FromStr;
use super::types::{AggregationStrategy, PostProcessingConfig, TimeGranularity};
#[doc(hidden)]
pub fn apply_post_processing(
rows: &[Value],
config: &PostProcessingConfig,
) -> Result<Option<Value>, String> {
let group_by_column: &String = match &config.group_by {
Some(column) => column,
None => return Ok(None),
};
if config.aggregation_strategy.is_some() && config.aggregation_column.is_none() {
return Err("aggregation_strategy requires aggregation_column".to_string());
}
let mut buckets: BTreeMap<String, Vec<Value>> = BTreeMap::new();
for row in rows {
let label: String =
extract_group_label(row, group_by_column, config.time_granularity.as_ref())?;
buckets.entry(label).or_default().push(row.clone());
}
let mut grouped: Vec<Value> = Vec::new();
let mut running_total: Decimal = Decimal::ZERO;
for (label, bucket_rows) in buckets {
let base_sum: Option<Decimal> = if let Some(column) = &config.aggregation_column {
Some(sum_bucket(&bucket_rows, column, config.dedup_aggregation)?)
} else {
None
};
let aggregation_value: Option<Decimal> =
match (config.aggregation_strategy.as_ref(), base_sum) {
(Some(AggregationStrategy::CumulativeSum), Some(sum)) => {
running_total += sum;
Some(running_total)
}
(_, Some(sum)) => Some(sum),
_ => None,
};
let aggregation_payload: Option<Value> =
aggregation_value.map(|value| Value::String(value.to_string()));
grouped.push(json!({
"label": label,
"rows": bucket_rows,
"aggregation": aggregation_payload
}));
}
Ok(Some(json!({ "grouped": grouped })))
}
fn extract_group_label(
row: &Value,
column: &str,
granularity: Option<&TimeGranularity>,
) -> Result<String, String> {
let value: &Value = row
.get(column)
.ok_or_else(|| format!("Missing group_by column '{}'", column))?;
if let Some(gran) = granularity {
let datetime: DateTime<Utc> = parse_datetime_value(value)?;
return Ok(gran.format_label(datetime));
}
if let Some(text) = value.as_str() {
return Ok(text.to_string());
}
Ok(value.to_string())
}
#[allow(deprecated)]
fn parse_datetime_value(value: &Value) -> Result<DateTime<Utc>, String> {
match value {
Value::String(text) => DateTime::parse_from_rfc3339(text)
.map(|dt| dt.with_timezone(&Utc))
.or_else(|_| {
NaiveDateTime::parse_from_str(text, "%Y-%m-%d %H:%M:%S")
.map(|naive| DateTime::<Utc>::from_utc(naive, Utc))
})
.map_err(|_| format!("Failed to parse datetime string: '{}'", text)),
Value::Number(number) => {
let timestamp = if let Some(i) = number.as_i64() {
i as f64
} else if let Some(f) = number.as_f64() {
f
} else {
return Err("Numeric timestamp must be integer or float".to_string());
};
let secs: i64 = timestamp.trunc() as i64;
let nanos: u32 = ((timestamp.fract()) * 1_000_000_000f64) as u32;
Utc.timestamp_opt(secs, nanos)
.single()
.ok_or_else(|| format!("Invalid unix timestamp: {}", timestamp))
}
_ => {
Err("Value must be a string (RFC3339/datetime) or number (unix timestamp)".to_string())
}
}
}
fn sum_bucket(rows: &[Value], column: &str, dedup: bool) -> Result<Decimal, String> {
let mut total: Decimal = Decimal::ZERO;
let mut seen: HashSet<Decimal> = HashSet::new();
for row in rows {
let value: &Value = row
.get(column)
.ok_or_else(|| format!("Missing aggregation column '{}'", column))?;
let decimal: Decimal = value_to_decimal(value)
.ok_or_else(|| format!("Aggregation column '{}' contains non-numeric data", column))?;
if dedup {
if seen.contains(&decimal) {
continue;
}
seen.insert(decimal);
}
total += decimal;
}
Ok(total)
}
fn value_to_decimal(value: &Value) -> Option<Decimal> {
match value {
Value::Number(num) => {
if let Some(i) = num.as_i64() {
Some(Decimal::from(i))
} else if let Some(u) = num.as_u64() {
Some(Decimal::from(u))
} else if let Some(f) = num.as_f64() {
Decimal::from_f64(f)
} else {
None
}
}
Value::String(text) => Decimal::from_str(text).ok(),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn apply_post_processing_requires_aggregation_column_when_strategy_set() {
let rows = vec![json!({"created_at": "2024-01-01T00:00:00Z", "value": 1})];
let config = PostProcessingConfig {
group_by: Some("created_at".to_string()),
time_granularity: None,
aggregation_column: None,
aggregation_strategy: Some(AggregationStrategy::CumulativeSum),
dedup_aggregation: false,
};
let result = apply_post_processing(&rows, &config);
assert!(result.is_err());
}
}