Skip to main content

athena_gateway/
fetch.rs

1use chrono::{DateTime, Utc};
2use chrono::{NaiveDateTime, TimeZone};
3use rust_decimal::Decimal;
4use rust_decimal::prelude::FromPrimitive;
5use serde::{Deserialize, Serialize};
6use serde_json::{Number, Value, json};
7use std::collections::{BTreeMap, HashSet};
8use std::fmt;
9use std::str::FromStr;
10
11use crate::{GatewayRequestCondition, normalize_column_name};
12
13/// Sort options for gateway fetch requests.
14#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
15pub struct SortOptions {
16    /// Column name, optionally normalized to `snake_case`.
17    pub column: String,
18    /// When true, sort ascending; otherwise descending.
19    pub ascending: bool,
20}
21
22/// Granularity options used for grouping timestamp rows during post-processing.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum TimeGranularity {
26    Day,
27    Hour,
28    Minute,
29}
30
31impl FromStr for TimeGranularity {
32    type Err = String;
33
34    fn from_str(value: &str) -> Result<Self, Self::Err> {
35        match value.to_ascii_lowercase().as_str() {
36            "day" => Ok(Self::Day),
37            "hour" => Ok(Self::Hour),
38            "minute" => Ok(Self::Minute),
39            other => Err(format!("Unsupported time_granularity '{other}'")),
40        }
41    }
42}
43
44impl TimeGranularity {
45    /// Formats a UTC timestamp using the configured granularity.
46    pub fn format_label(&self, datetime: DateTime<Utc>) -> String {
47        match self {
48            Self::Day => datetime.format("%Y-%m-%d").to_string(),
49            Self::Hour => datetime.format("%Y-%m-%d %H:00").to_string(),
50            Self::Minute => datetime.format("%Y-%m-%d %H:%M").to_string(),
51        }
52    }
53}
54
55/// Supported aggregation strategies for grouped fetch results.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum AggregationStrategy {
59    CumulativeSum,
60}
61
62impl FromStr for AggregationStrategy {
63    type Err = String;
64
65    fn from_str(value: &str) -> Result<Self, Self::Err> {
66        match value.to_ascii_lowercase().as_str() {
67            "cumulative_sum" => Ok(Self::CumulativeSum),
68            other => Err(format!("Unsupported aggregation_strategy '{other}'")),
69        }
70    }
71}
72
73/// Gateway fetch post-processing configuration parsed from request bodies.
74#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub struct PostProcessingConfig {
76    /// Optional column used for grouping rows.
77    pub group_by: Option<String>,
78    /// Optional timestamp label formatter for grouped rows.
79    pub time_granularity: Option<TimeGranularity>,
80    /// Optional column used to compute aggregation values.
81    pub aggregation_column: Option<String>,
82    /// Optional accumulation mode for aggregation values.
83    pub aggregation_strategy: Option<AggregationStrategy>,
84    /// When true, duplicate aggregation values are ignored per bucket.
85    pub dedup_aggregation: bool,
86}
87
88impl PostProcessingConfig {
89    /// Builds post-processing preferences from a gateway fetch payload.
90    pub fn from_body(body: Option<&Value>, force_snake: bool) -> Self {
91        let normalize = |value: &str| normalize_column_name(value, force_snake);
92        let group_by = body
93            .and_then(|b| b.get("group_by"))
94            .and_then(Value::as_str)
95            .map(normalize);
96        let aggregation_column = body
97            .and_then(|b| b.get("aggregation_column"))
98            .and_then(Value::as_str)
99            .map(normalize);
100        let time_granularity = body
101            .and_then(|b| b.get("time_granularity"))
102            .and_then(Value::as_str)
103            .and_then(|s| s.parse::<TimeGranularity>().ok());
104        let aggregation_strategy = body
105            .and_then(|b| b.get("aggregation_strategy"))
106            .and_then(Value::as_str)
107            .and_then(|s| s.parse::<AggregationStrategy>().ok());
108        let dedup_aggregation = body
109            .and_then(|b| b.get("aggregation_dedup"))
110            .and_then(Value::as_bool)
111            .unwrap_or(false);
112
113        Self {
114            group_by,
115            time_granularity,
116            aggregation_column,
117            aggregation_strategy,
118            dedup_aggregation,
119        }
120    }
121}
122
123/// Structured validation error returned while parsing fetch request conditions.
124#[derive(Debug, Clone, PartialEq, Eq)]
125pub struct GatewayFetchConditionError {
126    message: String,
127}
128
129impl GatewayFetchConditionError {
130    /// Builds a new condition parsing error with a stable message.
131    pub fn new(message: impl Into<String>) -> Self {
132        Self {
133            message: message.into(),
134        }
135    }
136
137    /// Returns the stable validation message.
138    pub fn message(&self) -> &str {
139        &self.message
140    }
141}
142
143impl fmt::Display for GatewayFetchConditionError {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        f.write_str(&self.message)
146    }
147}
148
149impl std::error::Error for GatewayFetchConditionError {}
150
151/// Applies grouping and aggregation rules to fetched gateway rows.
152pub fn apply_post_processing(
153    rows: &[Value],
154    config: &PostProcessingConfig,
155) -> Result<Option<Value>, String> {
156    let Some(group_by_column) = &config.group_by else {
157        return Ok(None);
158    };
159
160    if config.aggregation_strategy.is_some() && config.aggregation_column.is_none() {
161        return Err("aggregation_strategy requires aggregation_column".to_string());
162    }
163
164    let mut buckets: BTreeMap<String, Vec<Value>> = BTreeMap::new();
165    for row in rows {
166        let label = extract_group_label(row, group_by_column, config.time_granularity.as_ref())?;
167        buckets.entry(label).or_default().push(row.clone());
168    }
169
170    let mut grouped = Vec::new();
171    let mut running_total = Decimal::ZERO;
172    for (label, bucket_rows) in buckets {
173        let base_sum = if let Some(column) = &config.aggregation_column {
174            Some(sum_bucket(&bucket_rows, column, config.dedup_aggregation)?)
175        } else {
176            None
177        };
178        let aggregation_value = match (config.aggregation_strategy.as_ref(), base_sum) {
179            (Some(AggregationStrategy::CumulativeSum), Some(sum)) => {
180                running_total += sum;
181                Some(running_total)
182            }
183            (_, Some(sum)) => Some(sum),
184            _ => None,
185        };
186        let aggregation_payload = aggregation_value.map(|value| Value::String(value.to_string()));
187        grouped.push(json!({
188            "label": label,
189            "rows": bucket_rows,
190            "aggregation": aggregation_payload
191        }));
192    }
193
194    Ok(Some(json!({ "grouped": grouped })))
195}
196
197/// Parses optional `sortBy` or `sort_by` from a fetch request body.
198pub fn parse_sort_options_from_body(
199    body: Option<&Value>,
200    force_snake: bool,
201) -> Option<SortOptions> {
202    let obj = body?
203        .get("sort_by")
204        .or_else(|| body?.get("sortBy"))
205        .and_then(Value::as_object)?;
206    let field = obj
207        .get("field")
208        .or_else(|| obj.get("column"))
209        .and_then(Value::as_str)?;
210    let normalized = normalize_column_name(field, force_snake);
211    if normalized.is_empty() {
212        return None;
213    }
214    let ascending = obj
215        .get("direction")
216        .and_then(Value::as_str)
217        .map(|s| matches!(s.to_ascii_lowercase().as_str(), "asc" | "ascending"))
218        .unwrap_or(true);
219    Some(SortOptions {
220        column: normalized,
221        ascending,
222    })
223}
224
225fn extract_group_label(
226    row: &Value,
227    column: &str,
228    granularity: Option<&TimeGranularity>,
229) -> Result<String, String> {
230    let value = row
231        .get(column)
232        .ok_or_else(|| format!("Missing group_by column '{column}'"))?;
233
234    if let Some(granularity) = granularity {
235        let datetime = parse_datetime_value(value)?;
236        return Ok(granularity.format_label(datetime));
237    }
238
239    if let Some(text) = value.as_str() {
240        return Ok(text.to_string());
241    }
242
243    Ok(value.to_string())
244}
245
246#[allow(deprecated)]
247fn parse_datetime_value(value: &Value) -> Result<DateTime<Utc>, String> {
248    match value {
249        Value::String(text) => DateTime::parse_from_rfc3339(text)
250            .map(|dt| dt.with_timezone(&Utc))
251            .or_else(|_| {
252                NaiveDateTime::parse_from_str(text, "%Y-%m-%d %H:%M:%S")
253                    .map(|naive| DateTime::<Utc>::from_utc(naive, Utc))
254            })
255            .map_err(|_| format!("Failed to parse datetime string: '{text}'")),
256        Value::Number(number) => {
257            let timestamp = if let Some(i) = number.as_i64() {
258                i as f64
259            } else if let Some(f) = number.as_f64() {
260                f
261            } else {
262                return Err("Numeric timestamp must be integer or float".to_string());
263            };
264
265            let secs = timestamp.trunc() as i64;
266            let nanos = ((timestamp.fract()) * 1_000_000_000f64) as u32;
267            Utc.timestamp_opt(secs, nanos)
268                .single()
269                .ok_or_else(|| format!("Invalid unix timestamp: {timestamp}"))
270        }
271        _ => {
272            Err("Value must be a string (RFC3339/datetime) or number (unix timestamp)".to_string())
273        }
274    }
275}
276
277/// Builds the hashed cache key used for POST gateway fetch requests.
278pub fn build_fetch_hashed_cache_key(
279    table_name: &str,
280    conditions: &[GatewayRequestCondition],
281    columns_vec: &[String],
282    limit: i64,
283    strip_nulls: bool,
284    client_name: &str,
285    sort_options: Option<&SortOptions>,
286) -> String {
287    let input = GatewayFetchCacheKeyInput {
288        table_name,
289        conditions,
290        columns_vec,
291        limit,
292        strip_nulls,
293        client_name,
294        sort_options,
295    };
296    build_fetch_hashed_cache_key_with_hash_len(&input, 16)
297}
298
299/// Builds the legacy 8-character hashed cache key used during fetch cache cutovers.
300pub fn build_fetch_hashed_cache_key_legacy8(
301    table_name: &str,
302    conditions: &[GatewayRequestCondition],
303    columns_vec: &[String],
304    limit: i64,
305    strip_nulls: bool,
306    client_name: &str,
307    sort_options: Option<&SortOptions>,
308) -> String {
309    let input = GatewayFetchCacheKeyInput {
310        table_name,
311        conditions,
312        columns_vec,
313        limit,
314        strip_nulls,
315        client_name,
316        sort_options,
317    };
318    build_fetch_hashed_cache_key_with_hash_len(&input, 8)
319}
320
321struct GatewayFetchCacheKeyInput<'a> {
322    table_name: &'a str,
323    conditions: &'a [GatewayRequestCondition],
324    columns_vec: &'a [String],
325    limit: i64,
326    strip_nulls: bool,
327    client_name: &'a str,
328    sort_options: Option<&'a SortOptions>,
329}
330
331fn build_fetch_hashed_cache_key_with_hash_len(
332    input: &GatewayFetchCacheKeyInput<'_>,
333    hash_len: usize,
334) -> String {
335    let mut normalized_conditions: Vec<(String, Value, String)> = input
336        .conditions
337        .iter()
338        .map(|condition| {
339            let serialized_value = serde_json::to_string(&condition.eq_value).unwrap_or_default();
340            (
341                condition.eq_column.clone(),
342                condition.eq_value.clone(),
343                serialized_value,
344            )
345        })
346        .collect();
347    normalized_conditions.sort_by(|a, b| a.0.cmp(&b.0).then(a.2.cmp(&b.2)));
348
349    let first_eq_column = normalized_conditions
350        .first()
351        .map_or("_", |(column, _, _)| column.as_str());
352    let hash_input = json!({
353        "columns": input.columns_vec,
354        "conditions": normalized_conditions.iter().map(|(eq_column, eq_value, _)| json!({
355            "eq_column": eq_column,
356            "eq_value": eq_value.clone()
357        })).collect::<Vec<_>>(),
358        "limit": input.limit,
359        "strip_nulls": input.strip_nulls,
360        "client": input.client_name,
361        "sort": input.sort_options.map(|s| json!({"column": s.column, "ascending": s.ascending})),
362    });
363    let hash_str = sha256::digest(serde_json::to_string(&hash_input).unwrap_or_default());
364    let short_hash = &hash_str[..hash_len.min(hash_str.len())];
365
366    format!(
367        "{}:{first_eq_column}:{}:{}:{}:{short_hash}",
368        input.table_name,
369        input.columns_vec.join(","),
370        input.limit,
371        input.strip_nulls
372    )
373}
374
375fn sum_bucket(rows: &[Value], column: &str, dedup: bool) -> Result<Decimal, String> {
376    let mut total = Decimal::ZERO;
377    let mut seen = HashSet::new();
378
379    for row in rows {
380        let value = row
381            .get(column)
382            .ok_or_else(|| format!("Missing aggregation column '{column}'"))?;
383        let decimal = value_to_decimal(value)
384            .ok_or_else(|| format!("Aggregation column '{column}' contains non-numeric data"))?;
385
386        if dedup {
387            if seen.contains(&decimal) {
388                continue;
389            }
390            seen.insert(decimal);
391        }
392
393        total += decimal;
394    }
395
396    Ok(total)
397}
398
399fn value_to_decimal(value: &Value) -> Option<Decimal> {
400    match value {
401        Value::Number(num) => {
402            if let Some(i) = num.as_i64() {
403                Some(Decimal::from(i))
404            } else if let Some(u) = num.as_u64() {
405                Some(Decimal::from(u))
406            } else if let Some(f) = num.as_f64() {
407                Decimal::from_f64(f)
408            } else {
409                None
410            }
411        }
412        Value::String(text) => Decimal::from_str(text).ok(),
413        _ => None,
414    }
415}
416
417/// Parses `eq_value` for a `room_id` condition.
418pub fn parse_room_id_value(value: &Value) -> Result<i64, GatewayFetchConditionError> {
419    match value {
420        Value::Number(num) => num
421            .as_i64()
422            .ok_or_else(|| GatewayFetchConditionError::new("room_id must be an integer")),
423        Value::String(text) => {
424            let trimmed = text.trim();
425            if trimmed == "*" {
426                return Err(GatewayFetchConditionError::new(
427                    "room_id wildcard '*' is not allowed",
428                ));
429            }
430            if trimmed.is_empty() {
431                return Err(GatewayFetchConditionError::new("room_id must not be empty"));
432            }
433            trimmed
434                .parse::<i64>()
435                .map_err(|_| GatewayFetchConditionError::new("room_id must be numeric"))
436        }
437        _ => Err(GatewayFetchConditionError::new("room_id must be numeric")),
438    }
439}
440
441/// Coerces `eq_value` to a JSON number when the column is `room_id` / `roomId`.
442pub fn coerce_room_id_eq_value(eq_value_raw: &Value) -> Result<Value, GatewayFetchConditionError> {
443    parse_room_id_value(eq_value_raw).map(|id| Value::Number(Number::from(id)))
444}
445
446/// Parses `conditions` from a gateway fetch payload.
447pub fn parse_gateway_fetch_conditions(
448    json_body: &Value,
449    force_camel_case_to_snake_case: bool,
450) -> Result<Vec<GatewayRequestCondition>, GatewayFetchConditionError> {
451    let mut conditions = Vec::new();
452    let Some(additional_conditions) = json_body.get("conditions").and_then(Value::as_array) else {
453        return Ok(conditions);
454    };
455
456    for condition in additional_conditions {
457        if let Some(eq_column) = condition.get("eq_column").and_then(Value::as_str) {
458            let eq_column_str = eq_column.to_string();
459            let normalized_for_validation =
460                normalize_column_name(eq_column, force_camel_case_to_snake_case);
461
462            let eq_value_raw = match condition.get("eq_value") {
463                Some(value) => value.clone(),
464                None => {
465                    if normalized_for_validation == "room_id" || eq_column_str == "roomId" {
466                        return Err(GatewayFetchConditionError::new(
467                            "room_id is required and must be numeric",
468                        ));
469                    }
470                    continue;
471                }
472            };
473
474            let eq_value = if normalized_for_validation == "room_id" || eq_column_str == "roomId" {
475                coerce_room_id_eq_value(&eq_value_raw)?
476            } else {
477                eq_value_raw
478            };
479            conditions.push(GatewayRequestCondition::new(eq_column_str, eq_value));
480        }
481    }
482
483    Ok(conditions)
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use serde_json::json;
490
491    #[test]
492    fn hashed_cache_key_stable_for_same_inputs() {
493        let conditions = vec![GatewayRequestCondition::new(
494            "workspace_id".into(),
495            json!("abc"),
496        )];
497        let k1 = build_fetch_hashed_cache_key(
498            "users",
499            &conditions,
500            &["id".into(), "email".into()],
501            10,
502            false,
503            "supabase",
504            None,
505        );
506        let k2 = build_fetch_hashed_cache_key(
507            "users",
508            &conditions,
509            &["id".into(), "email".into()],
510            10,
511            false,
512            "supabase",
513            None,
514        );
515        assert_eq!(k1, k2);
516    }
517
518    #[test]
519    fn hashed_cache_key_stable_for_reordered_conditions() {
520        let conditions_a = vec![
521            GatewayRequestCondition::new("workspace_id".into(), json!("abc")),
522            GatewayRequestCondition::new("room_id".into(), json!(123)),
523        ];
524        let conditions_b = vec![
525            GatewayRequestCondition::new("room_id".into(), json!(123)),
526            GatewayRequestCondition::new("workspace_id".into(), json!("abc")),
527        ];
528
529        let k1 = build_fetch_hashed_cache_key(
530            "users",
531            &conditions_a,
532            &["id".into(), "email".into()],
533            10,
534            false,
535            "supabase",
536            None,
537        );
538        let k2 = build_fetch_hashed_cache_key(
539            "users",
540            &conditions_b,
541            &["id".into(), "email".into()],
542            10,
543            false,
544            "supabase",
545            None,
546        );
547
548        assert_eq!(k1, k2);
549    }
550
551    #[test]
552    fn parse_sort_options_normalizes_requested_column() {
553        let sort = parse_sort_options_from_body(
554            Some(&json!({
555                "sortBy": {
556                    "field": "createdAt",
557                    "direction": "desc"
558                }
559            })),
560            true,
561        )
562        .expect("sort should parse");
563
564        assert_eq!(
565            sort,
566            SortOptions {
567                column: "created_at".to_string(),
568                ascending: false,
569            }
570        );
571    }
572
573    #[test]
574    fn parse_gateway_fetch_conditions_coerces_room_id() {
575        let conditions = parse_gateway_fetch_conditions(
576            &json!({
577                "conditions": [
578                    {
579                        "eq_column": "roomId",
580                        "eq_value": "42"
581                    }
582                ]
583            }),
584            true,
585        )
586        .expect("conditions should parse");
587
588        assert_eq!(
589            conditions,
590            vec![GatewayRequestCondition::new("roomId".into(), json!(42))]
591        );
592    }
593
594    #[test]
595    fn parse_gateway_fetch_conditions_rejects_missing_room_id_value() {
596        let err = parse_gateway_fetch_conditions(
597            &json!({
598                "conditions": [
599                    {
600                        "eq_column": "room_id"
601                    }
602                ]
603            }),
604            false,
605        )
606        .expect_err("room_id without value should fail");
607
608        assert_eq!(err.message(), "room_id is required and must be numeric");
609    }
610
611    #[test]
612    fn post_processing_config_normalizes_columns() {
613        let config = PostProcessingConfig::from_body(
614            Some(&json!({
615                "group_by": "createdAt",
616                "aggregation_column": "requestCount",
617                "time_granularity": "hour",
618                "aggregation_strategy": "cumulative_sum",
619                "aggregation_dedup": true
620            })),
621            true,
622        );
623
624        assert_eq!(config.group_by.as_deref(), Some("created_at"));
625        assert_eq!(config.aggregation_column.as_deref(), Some("request_count"));
626        assert_eq!(config.time_granularity, Some(TimeGranularity::Hour));
627        assert_eq!(
628            config.aggregation_strategy,
629            Some(AggregationStrategy::CumulativeSum)
630        );
631        assert!(config.dedup_aggregation);
632    }
633
634    #[test]
635    fn apply_post_processing_requires_aggregation_column_when_strategy_set() {
636        let rows = vec![json!({"created_at": "2024-01-01T00:00:00Z", "value": 1})];
637        let config = PostProcessingConfig {
638            group_by: Some("created_at".to_string()),
639            time_granularity: None,
640            aggregation_column: None,
641            aggregation_strategy: Some(AggregationStrategy::CumulativeSum),
642            dedup_aggregation: false,
643        };
644
645        let result = apply_post_processing(&rows, &config);
646        assert!(result.is_err());
647    }
648
649    #[test]
650    fn apply_post_processing_groups_and_accumulates_rows() {
651        let rows = vec![
652            json!({"created_at": "2024-01-01T10:05:00Z", "value": 2}),
653            json!({"created_at": "2024-01-01T10:15:00Z", "value": 3}),
654            json!({"created_at": "2024-01-01T11:00:00Z", "value": 5}),
655        ];
656        let config = PostProcessingConfig {
657            group_by: Some("created_at".to_string()),
658            time_granularity: Some(TimeGranularity::Hour),
659            aggregation_column: Some("value".to_string()),
660            aggregation_strategy: Some(AggregationStrategy::CumulativeSum),
661            dedup_aggregation: false,
662        };
663
664        let result = apply_post_processing(&rows, &config)
665            .expect("post-processing should succeed")
666            .expect("grouping should be present");
667
668        assert_eq!(
669            result,
670            json!({
671                "grouped": [
672                    {
673                        "label": "2024-01-01 10:00",
674                        "rows": [
675                            {"created_at": "2024-01-01T10:05:00Z", "value": 2},
676                            {"created_at": "2024-01-01T10:15:00Z", "value": 3}
677                        ],
678                        "aggregation": "5"
679                    },
680                    {
681                        "label": "2024-01-01 11:00",
682                        "rows": [
683                            {"created_at": "2024-01-01T11:00:00Z", "value": 5}
684                        ],
685                        "aggregation": "10"
686                    }
687                ]
688            })
689        );
690    }
691}