use chrono::{DateTime, Utc};
use chrono::{NaiveDateTime, TimeZone};
use rust_decimal::Decimal;
use rust_decimal::prelude::FromPrimitive;
use serde::{Deserialize, Serialize};
use serde_json::{Number, Value, json};
use std::collections::{BTreeMap, HashSet};
use std::fmt;
use std::str::FromStr;
use crate::{GatewayRequestCondition, normalize_column_name};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct SortOptions {
pub column: String,
pub ascending: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TimeGranularity {
Day,
Hour,
Minute,
}
impl FromStr for TimeGranularity {
type Err = String;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.to_ascii_lowercase().as_str() {
"day" => Ok(Self::Day),
"hour" => Ok(Self::Hour),
"minute" => Ok(Self::Minute),
other => Err(format!("Unsupported time_granularity '{other}'")),
}
}
}
impl TimeGranularity {
pub fn format_label(&self, datetime: DateTime<Utc>) -> String {
match self {
Self::Day => datetime.format("%Y-%m-%d").to_string(),
Self::Hour => datetime.format("%Y-%m-%d %H:00").to_string(),
Self::Minute => datetime.format("%Y-%m-%d %H:%M").to_string(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AggregationStrategy {
CumulativeSum,
}
impl FromStr for AggregationStrategy {
type Err = String;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.to_ascii_lowercase().as_str() {
"cumulative_sum" => Ok(Self::CumulativeSum),
other => Err(format!("Unsupported aggregation_strategy '{other}'")),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PostProcessingConfig {
pub group_by: Option<String>,
pub time_granularity: Option<TimeGranularity>,
pub aggregation_column: Option<String>,
pub aggregation_strategy: Option<AggregationStrategy>,
pub dedup_aggregation: bool,
}
impl PostProcessingConfig {
pub fn from_body(body: Option<&Value>, force_snake: bool) -> Self {
let normalize = |value: &str| normalize_column_name(value, force_snake);
let group_by = body
.and_then(|b| b.get("group_by"))
.and_then(Value::as_str)
.map(normalize);
let aggregation_column = body
.and_then(|b| b.get("aggregation_column"))
.and_then(Value::as_str)
.map(normalize);
let time_granularity = body
.and_then(|b| b.get("time_granularity"))
.and_then(Value::as_str)
.and_then(|s| s.parse::<TimeGranularity>().ok());
let aggregation_strategy = body
.and_then(|b| b.get("aggregation_strategy"))
.and_then(Value::as_str)
.and_then(|s| s.parse::<AggregationStrategy>().ok());
let dedup_aggregation = body
.and_then(|b| b.get("aggregation_dedup"))
.and_then(Value::as_bool)
.unwrap_or(false);
Self {
group_by,
time_granularity,
aggregation_column,
aggregation_strategy,
dedup_aggregation,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GatewayFetchConditionError {
message: String,
}
impl GatewayFetchConditionError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
pub fn message(&self) -> &str {
&self.message
}
}
impl fmt::Display for GatewayFetchConditionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for GatewayFetchConditionError {}
pub fn apply_post_processing(
rows: &[Value],
config: &PostProcessingConfig,
) -> Result<Option<Value>, String> {
let Some(group_by_column) = &config.group_by else {
return Ok(None);
};
if config.aggregation_strategy.is_some() && config.aggregation_column.is_none() {
return Err("aggregation_strategy requires aggregation_column".to_string());
}
let mut buckets: BTreeMap<String, Vec<Value>> = BTreeMap::new();
for row in rows {
let label = extract_group_label(row, group_by_column, config.time_granularity.as_ref())?;
buckets.entry(label).or_default().push(row.clone());
}
let mut grouped = Vec::new();
let mut running_total = Decimal::ZERO;
for (label, bucket_rows) in buckets {
let base_sum = if let Some(column) = &config.aggregation_column {
Some(sum_bucket(&bucket_rows, column, config.dedup_aggregation)?)
} else {
None
};
let aggregation_value = match (config.aggregation_strategy.as_ref(), base_sum) {
(Some(AggregationStrategy::CumulativeSum), Some(sum)) => {
running_total += sum;
Some(running_total)
}
(_, Some(sum)) => Some(sum),
_ => None,
};
let aggregation_payload = aggregation_value.map(|value| Value::String(value.to_string()));
grouped.push(json!({
"label": label,
"rows": bucket_rows,
"aggregation": aggregation_payload
}));
}
Ok(Some(json!({ "grouped": grouped })))
}
pub fn parse_sort_options_from_body(
body: Option<&Value>,
force_snake: bool,
) -> Option<SortOptions> {
let obj = body?
.get("sort_by")
.or_else(|| body?.get("sortBy"))
.and_then(Value::as_object)?;
let field = obj
.get("field")
.or_else(|| obj.get("column"))
.and_then(Value::as_str)?;
let normalized = normalize_column_name(field, force_snake);
if normalized.is_empty() {
return None;
}
let ascending = obj
.get("direction")
.and_then(Value::as_str)
.map(|s| matches!(s.to_ascii_lowercase().as_str(), "asc" | "ascending"))
.unwrap_or(true);
Some(SortOptions {
column: normalized,
ascending,
})
}
fn extract_group_label(
row: &Value,
column: &str,
granularity: Option<&TimeGranularity>,
) -> Result<String, String> {
let value = row
.get(column)
.ok_or_else(|| format!("Missing group_by column '{column}'"))?;
if let Some(granularity) = granularity {
let datetime = parse_datetime_value(value)?;
return Ok(granularity.format_label(datetime));
}
if let Some(text) = value.as_str() {
return Ok(text.to_string());
}
Ok(value.to_string())
}
#[allow(deprecated)]
fn parse_datetime_value(value: &Value) -> Result<DateTime<Utc>, String> {
match value {
Value::String(text) => DateTime::parse_from_rfc3339(text)
.map(|dt| dt.with_timezone(&Utc))
.or_else(|_| {
NaiveDateTime::parse_from_str(text, "%Y-%m-%d %H:%M:%S")
.map(|naive| DateTime::<Utc>::from_utc(naive, Utc))
})
.map_err(|_| format!("Failed to parse datetime string: '{text}'")),
Value::Number(number) => {
let timestamp = if let Some(i) = number.as_i64() {
i as f64
} else if let Some(f) = number.as_f64() {
f
} else {
return Err("Numeric timestamp must be integer or float".to_string());
};
let secs = timestamp.trunc() as i64;
let nanos = ((timestamp.fract()) * 1_000_000_000f64) as u32;
Utc.timestamp_opt(secs, nanos)
.single()
.ok_or_else(|| format!("Invalid unix timestamp: {timestamp}"))
}
_ => {
Err("Value must be a string (RFC3339/datetime) or number (unix timestamp)".to_string())
}
}
}
pub fn build_fetch_hashed_cache_key(
table_name: &str,
conditions: &[GatewayRequestCondition],
columns_vec: &[String],
limit: i64,
strip_nulls: bool,
client_name: &str,
sort_options: Option<&SortOptions>,
) -> String {
let input = GatewayFetchCacheKeyInput {
table_name,
conditions,
columns_vec,
limit,
strip_nulls,
client_name,
sort_options,
};
build_fetch_hashed_cache_key_with_hash_len(&input, 16)
}
pub fn build_fetch_hashed_cache_key_legacy8(
table_name: &str,
conditions: &[GatewayRequestCondition],
columns_vec: &[String],
limit: i64,
strip_nulls: bool,
client_name: &str,
sort_options: Option<&SortOptions>,
) -> String {
let input = GatewayFetchCacheKeyInput {
table_name,
conditions,
columns_vec,
limit,
strip_nulls,
client_name,
sort_options,
};
build_fetch_hashed_cache_key_with_hash_len(&input, 8)
}
struct GatewayFetchCacheKeyInput<'a> {
table_name: &'a str,
conditions: &'a [GatewayRequestCondition],
columns_vec: &'a [String],
limit: i64,
strip_nulls: bool,
client_name: &'a str,
sort_options: Option<&'a SortOptions>,
}
fn build_fetch_hashed_cache_key_with_hash_len(
input: &GatewayFetchCacheKeyInput<'_>,
hash_len: usize,
) -> String {
let mut normalized_conditions: Vec<(String, Value, String)> = input
.conditions
.iter()
.map(|condition| {
let serialized_value = serde_json::to_string(&condition.eq_value).unwrap_or_default();
(
condition.eq_column.clone(),
condition.eq_value.clone(),
serialized_value,
)
})
.collect();
normalized_conditions.sort_by(|a, b| a.0.cmp(&b.0).then(a.2.cmp(&b.2)));
let first_eq_column = normalized_conditions
.first()
.map_or("_", |(column, _, _)| column.as_str());
let hash_input = json!({
"columns": input.columns_vec,
"conditions": normalized_conditions.iter().map(|(eq_column, eq_value, _)| json!({
"eq_column": eq_column,
"eq_value": eq_value.clone()
})).collect::<Vec<_>>(),
"limit": input.limit,
"strip_nulls": input.strip_nulls,
"client": input.client_name,
"sort": input.sort_options.map(|s| json!({"column": s.column, "ascending": s.ascending})),
});
let hash_str = sha256::digest(serde_json::to_string(&hash_input).unwrap_or_default());
let short_hash = &hash_str[..hash_len.min(hash_str.len())];
format!(
"{}:{first_eq_column}:{}:{}:{}:{short_hash}",
input.table_name,
input.columns_vec.join(","),
input.limit,
input.strip_nulls
)
}
fn sum_bucket(rows: &[Value], column: &str, dedup: bool) -> Result<Decimal, String> {
let mut total = Decimal::ZERO;
let mut seen = HashSet::new();
for row in rows {
let value = row
.get(column)
.ok_or_else(|| format!("Missing aggregation column '{column}'"))?;
let decimal = value_to_decimal(value)
.ok_or_else(|| format!("Aggregation column '{column}' contains non-numeric data"))?;
if dedup {
if seen.contains(&decimal) {
continue;
}
seen.insert(decimal);
}
total += decimal;
}
Ok(total)
}
fn value_to_decimal(value: &Value) -> Option<Decimal> {
match value {
Value::Number(num) => {
if let Some(i) = num.as_i64() {
Some(Decimal::from(i))
} else if let Some(u) = num.as_u64() {
Some(Decimal::from(u))
} else if let Some(f) = num.as_f64() {
Decimal::from_f64(f)
} else {
None
}
}
Value::String(text) => Decimal::from_str(text).ok(),
_ => None,
}
}
pub fn parse_room_id_value(value: &Value) -> Result<i64, GatewayFetchConditionError> {
match value {
Value::Number(num) => num
.as_i64()
.ok_or_else(|| GatewayFetchConditionError::new("room_id must be an integer")),
Value::String(text) => {
let trimmed = text.trim();
if trimmed == "*" {
return Err(GatewayFetchConditionError::new(
"room_id wildcard '*' is not allowed",
));
}
if trimmed.is_empty() {
return Err(GatewayFetchConditionError::new("room_id must not be empty"));
}
trimmed
.parse::<i64>()
.map_err(|_| GatewayFetchConditionError::new("room_id must be numeric"))
}
_ => Err(GatewayFetchConditionError::new("room_id must be numeric")),
}
}
pub fn coerce_room_id_eq_value(eq_value_raw: &Value) -> Result<Value, GatewayFetchConditionError> {
parse_room_id_value(eq_value_raw).map(|id| Value::Number(Number::from(id)))
}
pub fn parse_gateway_fetch_conditions(
json_body: &Value,
force_camel_case_to_snake_case: bool,
) -> Result<Vec<GatewayRequestCondition>, GatewayFetchConditionError> {
let mut conditions = Vec::new();
let Some(additional_conditions) = json_body.get("conditions").and_then(Value::as_array) else {
return Ok(conditions);
};
for condition in additional_conditions {
if let Some(eq_column) = condition.get("eq_column").and_then(Value::as_str) {
let eq_column_str = eq_column.to_string();
let normalized_for_validation =
normalize_column_name(eq_column, force_camel_case_to_snake_case);
let eq_value_raw = match condition.get("eq_value") {
Some(value) => value.clone(),
None => {
if normalized_for_validation == "room_id" || eq_column_str == "roomId" {
return Err(GatewayFetchConditionError::new(
"room_id is required and must be numeric",
));
}
continue;
}
};
let eq_value = if normalized_for_validation == "room_id" || eq_column_str == "roomId" {
coerce_room_id_eq_value(&eq_value_raw)?
} else {
eq_value_raw
};
conditions.push(GatewayRequestCondition::new(eq_column_str, eq_value));
}
}
Ok(conditions)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn hashed_cache_key_stable_for_same_inputs() {
let conditions = vec![GatewayRequestCondition::new(
"workspace_id".into(),
json!("abc"),
)];
let k1 = build_fetch_hashed_cache_key(
"users",
&conditions,
&["id".into(), "email".into()],
10,
false,
"supabase",
None,
);
let k2 = build_fetch_hashed_cache_key(
"users",
&conditions,
&["id".into(), "email".into()],
10,
false,
"supabase",
None,
);
assert_eq!(k1, k2);
}
#[test]
fn hashed_cache_key_stable_for_reordered_conditions() {
let conditions_a = vec![
GatewayRequestCondition::new("workspace_id".into(), json!("abc")),
GatewayRequestCondition::new("room_id".into(), json!(123)),
];
let conditions_b = vec![
GatewayRequestCondition::new("room_id".into(), json!(123)),
GatewayRequestCondition::new("workspace_id".into(), json!("abc")),
];
let k1 = build_fetch_hashed_cache_key(
"users",
&conditions_a,
&["id".into(), "email".into()],
10,
false,
"supabase",
None,
);
let k2 = build_fetch_hashed_cache_key(
"users",
&conditions_b,
&["id".into(), "email".into()],
10,
false,
"supabase",
None,
);
assert_eq!(k1, k2);
}
#[test]
fn parse_sort_options_normalizes_requested_column() {
let sort = parse_sort_options_from_body(
Some(&json!({
"sortBy": {
"field": "createdAt",
"direction": "desc"
}
})),
true,
)
.expect("sort should parse");
assert_eq!(
sort,
SortOptions {
column: "created_at".to_string(),
ascending: false,
}
);
}
#[test]
fn parse_gateway_fetch_conditions_coerces_room_id() {
let conditions = parse_gateway_fetch_conditions(
&json!({
"conditions": [
{
"eq_column": "roomId",
"eq_value": "42"
}
]
}),
true,
)
.expect("conditions should parse");
assert_eq!(
conditions,
vec![GatewayRequestCondition::new("roomId".into(), json!(42))]
);
}
#[test]
fn parse_gateway_fetch_conditions_rejects_missing_room_id_value() {
let err = parse_gateway_fetch_conditions(
&json!({
"conditions": [
{
"eq_column": "room_id"
}
]
}),
false,
)
.expect_err("room_id without value should fail");
assert_eq!(err.message(), "room_id is required and must be numeric");
}
#[test]
fn post_processing_config_normalizes_columns() {
let config = PostProcessingConfig::from_body(
Some(&json!({
"group_by": "createdAt",
"aggregation_column": "requestCount",
"time_granularity": "hour",
"aggregation_strategy": "cumulative_sum",
"aggregation_dedup": true
})),
true,
);
assert_eq!(config.group_by.as_deref(), Some("created_at"));
assert_eq!(config.aggregation_column.as_deref(), Some("request_count"));
assert_eq!(config.time_granularity, Some(TimeGranularity::Hour));
assert_eq!(
config.aggregation_strategy,
Some(AggregationStrategy::CumulativeSum)
);
assert!(config.dedup_aggregation);
}
#[test]
fn apply_post_processing_requires_aggregation_column_when_strategy_set() {
let rows = vec![json!({"created_at": "2024-01-01T00:00:00Z", "value": 1})];
let config = PostProcessingConfig {
group_by: Some("created_at".to_string()),
time_granularity: None,
aggregation_column: None,
aggregation_strategy: Some(AggregationStrategy::CumulativeSum),
dedup_aggregation: false,
};
let result = apply_post_processing(&rows, &config);
assert!(result.is_err());
}
#[test]
fn apply_post_processing_groups_and_accumulates_rows() {
let rows = vec![
json!({"created_at": "2024-01-01T10:05:00Z", "value": 2}),
json!({"created_at": "2024-01-01T10:15:00Z", "value": 3}),
json!({"created_at": "2024-01-01T11:00:00Z", "value": 5}),
];
let config = PostProcessingConfig {
group_by: Some("created_at".to_string()),
time_granularity: Some(TimeGranularity::Hour),
aggregation_column: Some("value".to_string()),
aggregation_strategy: Some(AggregationStrategy::CumulativeSum),
dedup_aggregation: false,
};
let result = apply_post_processing(&rows, &config)
.expect("post-processing should succeed")
.expect("grouping should be present");
assert_eq!(
result,
json!({
"grouped": [
{
"label": "2024-01-01 10:00",
"rows": [
{"created_at": "2024-01-01T10:05:00Z", "value": 2},
{"created_at": "2024-01-01T10:15:00Z", "value": 3}
],
"aggregation": "5"
},
{
"label": "2024-01-01 11:00",
"rows": [
{"created_at": "2024-01-01T11:00:00Z", "value": 5}
],
"aggregation": "10"
}
]
})
);
}
}