athena_rs 2.9.1

Database gateway API
Documentation
//! Grouping and aggregation applied to fetched gateway rows.

use chrono::{DateTime, NaiveDateTime, TimeZone, Utc};
use rust_decimal::Decimal;
use rust_decimal::prelude::FromPrimitive;
use serde_json::{Value, json};
use std::collections::{BTreeMap, HashSet};
use std::str::FromStr;

use super::types::{AggregationStrategy, PostProcessingConfig, TimeGranularity};

/// Applies grouping and aggregation rules from `PostProcessingConfig`.
#[doc(hidden)]
pub fn apply_post_processing(
    rows: &[Value],
    config: &PostProcessingConfig,
) -> Result<Option<Value>, String> {
    let group_by_column: &String = match &config.group_by {
        Some(column) => column,
        None => return Ok(None),
    };

    if config.aggregation_strategy.is_some() && config.aggregation_column.is_none() {
        return Err("aggregation_strategy requires aggregation_column".to_string());
    }

    let mut buckets: BTreeMap<String, Vec<Value>> = BTreeMap::new();
    for row in rows {
        let label: String =
            extract_group_label(row, group_by_column, config.time_granularity.as_ref())?;
        buckets.entry(label).or_default().push(row.clone());
    }

    let mut grouped: Vec<Value> = Vec::new();
    let mut running_total: Decimal = Decimal::ZERO;
    for (label, bucket_rows) in buckets {
        let base_sum: Option<Decimal> = if let Some(column) = &config.aggregation_column {
            Some(sum_bucket(&bucket_rows, column, config.dedup_aggregation)?)
        } else {
            None
        };
        let aggregation_value: Option<Decimal> =
            match (config.aggregation_strategy.as_ref(), base_sum) {
                (Some(AggregationStrategy::CumulativeSum), Some(sum)) => {
                    running_total += sum;
                    Some(running_total)
                }
                (_, Some(sum)) => Some(sum),
                _ => None,
            };
        let aggregation_payload: Option<Value> =
            aggregation_value.map(|value| Value::String(value.to_string()));
        grouped.push(json!({
            "label": label,
            "rows": bucket_rows,
            "aggregation": aggregation_payload
        }));
    }

    Ok(Some(json!({ "grouped": grouped })))
}

fn extract_group_label(
    row: &Value,
    column: &str,
    granularity: Option<&TimeGranularity>,
) -> Result<String, String> {
    let value: &Value = row
        .get(column)
        .ok_or_else(|| format!("Missing group_by column '{}'", column))?;

    if let Some(gran) = granularity {
        let datetime: DateTime<Utc> = parse_datetime_value(value)?;
        return Ok(gran.format_label(datetime));
    }

    if let Some(text) = value.as_str() {
        return Ok(text.to_string());
    }

    Ok(value.to_string())
}

#[allow(deprecated)]
fn parse_datetime_value(value: &Value) -> Result<DateTime<Utc>, String> {
    match value {
        Value::String(text) => DateTime::parse_from_rfc3339(text)
            .map(|dt| dt.with_timezone(&Utc))
            .or_else(|_| {
                NaiveDateTime::parse_from_str(text, "%Y-%m-%d %H:%M:%S")
                    .map(|naive| DateTime::<Utc>::from_utc(naive, Utc))
            })
            .map_err(|_| format!("Failed to parse datetime string: '{}'", text)),
        Value::Number(number) => {
            let timestamp = if let Some(i) = number.as_i64() {
                i as f64
            } else if let Some(f) = number.as_f64() {
                f
            } else {
                return Err("Numeric timestamp must be integer or float".to_string());
            };

            let secs: i64 = timestamp.trunc() as i64;
            let nanos: u32 = ((timestamp.fract()) * 1_000_000_000f64) as u32;
            Utc.timestamp_opt(secs, nanos)
                .single()
                .ok_or_else(|| format!("Invalid unix timestamp: {}", timestamp))
        }
        _ => {
            Err("Value must be a string (RFC3339/datetime) or number (unix timestamp)".to_string())
        }
    }
}

fn sum_bucket(rows: &[Value], column: &str, dedup: bool) -> Result<Decimal, String> {
    let mut total: Decimal = Decimal::ZERO;
    let mut seen: HashSet<Decimal> = HashSet::new();

    for row in rows {
        let value: &Value = row
            .get(column)
            .ok_or_else(|| format!("Missing aggregation column '{}'", column))?;
        let decimal: Decimal = value_to_decimal(value)
            .ok_or_else(|| format!("Aggregation column '{}' contains non-numeric data", column))?;

        if dedup {
            if seen.contains(&decimal) {
                continue;
            }
            seen.insert(decimal);
        }

        total += decimal;
    }

    Ok(total)
}

fn value_to_decimal(value: &Value) -> Option<Decimal> {
    match value {
        Value::Number(num) => {
            if let Some(i) = num.as_i64() {
                Some(Decimal::from(i))
            } else if let Some(u) = num.as_u64() {
                Some(Decimal::from(u))
            } else if let Some(f) = num.as_f64() {
                Decimal::from_f64(f)
            } else {
                None
            }
        }
        Value::String(text) => Decimal::from_str(text).ok(),
        _ => None,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde_json::json;

    #[test]
    fn apply_post_processing_requires_aggregation_column_when_strategy_set() {
        let rows = vec![json!({"created_at": "2024-01-01T00:00:00Z", "value": 1})];
        let config = PostProcessingConfig {
            group_by: Some("created_at".to_string()),
            time_granularity: None,
            aggregation_column: None,
            aggregation_strategy: Some(AggregationStrategy::CumulativeSum),
            dedup_aggregation: false,
        };
        let result = apply_post_processing(&rows, &config);
        assert!(result.is_err());
    }
}