1use chrono::{DateTime, Utc};
2use chrono::{NaiveDateTime, TimeZone};
3use rust_decimal::Decimal;
4use rust_decimal::prelude::FromPrimitive;
5use serde::{Deserialize, Serialize};
6use serde_json::{Number, Value, json};
7use std::collections::{BTreeMap, HashSet};
8use std::fmt;
9use std::str::FromStr;
10
11use crate::{GatewayRequestCondition, normalize_column_name};
12
13#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
15pub struct SortOptions {
16 pub column: String,
18 pub ascending: bool,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum TimeGranularity {
26 Day,
27 Hour,
28 Minute,
29}
30
31impl FromStr for TimeGranularity {
32 type Err = String;
33
34 fn from_str(value: &str) -> Result<Self, Self::Err> {
35 match value.to_ascii_lowercase().as_str() {
36 "day" => Ok(Self::Day),
37 "hour" => Ok(Self::Hour),
38 "minute" => Ok(Self::Minute),
39 other => Err(format!("Unsupported time_granularity '{other}'")),
40 }
41 }
42}
43
44impl TimeGranularity {
45 pub fn format_label(&self, datetime: DateTime<Utc>) -> String {
47 match self {
48 Self::Day => datetime.format("%Y-%m-%d").to_string(),
49 Self::Hour => datetime.format("%Y-%m-%d %H:00").to_string(),
50 Self::Minute => datetime.format("%Y-%m-%d %H:%M").to_string(),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum AggregationStrategy {
59 CumulativeSum,
60}
61
62impl FromStr for AggregationStrategy {
63 type Err = String;
64
65 fn from_str(value: &str) -> Result<Self, Self::Err> {
66 match value.to_ascii_lowercase().as_str() {
67 "cumulative_sum" => Ok(Self::CumulativeSum),
68 other => Err(format!("Unsupported aggregation_strategy '{other}'")),
69 }
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub struct PostProcessingConfig {
76 pub group_by: Option<String>,
78 pub time_granularity: Option<TimeGranularity>,
80 pub aggregation_column: Option<String>,
82 pub aggregation_strategy: Option<AggregationStrategy>,
84 pub dedup_aggregation: bool,
86}
87
88impl PostProcessingConfig {
89 pub fn from_body(body: Option<&Value>, force_snake: bool) -> Self {
91 let normalize = |value: &str| normalize_column_name(value, force_snake);
92 let group_by = body
93 .and_then(|b| b.get("group_by"))
94 .and_then(Value::as_str)
95 .map(normalize);
96 let aggregation_column = body
97 .and_then(|b| b.get("aggregation_column"))
98 .and_then(Value::as_str)
99 .map(normalize);
100 let time_granularity = body
101 .and_then(|b| b.get("time_granularity"))
102 .and_then(Value::as_str)
103 .and_then(|s| s.parse::<TimeGranularity>().ok());
104 let aggregation_strategy = body
105 .and_then(|b| b.get("aggregation_strategy"))
106 .and_then(Value::as_str)
107 .and_then(|s| s.parse::<AggregationStrategy>().ok());
108 let dedup_aggregation = body
109 .and_then(|b| b.get("aggregation_dedup"))
110 .and_then(Value::as_bool)
111 .unwrap_or(false);
112
113 Self {
114 group_by,
115 time_granularity,
116 aggregation_column,
117 aggregation_strategy,
118 dedup_aggregation,
119 }
120 }
121}
122
123#[derive(Debug, Clone, PartialEq, Eq)]
125pub struct GatewayFetchConditionError {
126 message: String,
127}
128
129impl GatewayFetchConditionError {
130 pub fn new(message: impl Into<String>) -> Self {
132 Self {
133 message: message.into(),
134 }
135 }
136
137 pub fn message(&self) -> &str {
139 &self.message
140 }
141}
142
143impl fmt::Display for GatewayFetchConditionError {
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 f.write_str(&self.message)
146 }
147}
148
149impl std::error::Error for GatewayFetchConditionError {}
150
151pub fn apply_post_processing(
153 rows: &[Value],
154 config: &PostProcessingConfig,
155) -> Result<Option<Value>, String> {
156 let Some(group_by_column) = &config.group_by else {
157 return Ok(None);
158 };
159
160 if config.aggregation_strategy.is_some() && config.aggregation_column.is_none() {
161 return Err("aggregation_strategy requires aggregation_column".to_string());
162 }
163
164 let mut buckets: BTreeMap<String, Vec<Value>> = BTreeMap::new();
165 for row in rows {
166 let label = extract_group_label(row, group_by_column, config.time_granularity.as_ref())?;
167 buckets.entry(label).or_default().push(row.clone());
168 }
169
170 let mut grouped = Vec::new();
171 let mut running_total = Decimal::ZERO;
172 for (label, bucket_rows) in buckets {
173 let base_sum = if let Some(column) = &config.aggregation_column {
174 Some(sum_bucket(&bucket_rows, column, config.dedup_aggregation)?)
175 } else {
176 None
177 };
178 let aggregation_value = match (config.aggregation_strategy.as_ref(), base_sum) {
179 (Some(AggregationStrategy::CumulativeSum), Some(sum)) => {
180 running_total += sum;
181 Some(running_total)
182 }
183 (_, Some(sum)) => Some(sum),
184 _ => None,
185 };
186 let aggregation_payload = aggregation_value.map(|value| Value::String(value.to_string()));
187 grouped.push(json!({
188 "label": label,
189 "rows": bucket_rows,
190 "aggregation": aggregation_payload
191 }));
192 }
193
194 Ok(Some(json!({ "grouped": grouped })))
195}
196
197pub fn parse_sort_options_from_body(
199 body: Option<&Value>,
200 force_snake: bool,
201) -> Option<SortOptions> {
202 let obj = body?
203 .get("sort_by")
204 .or_else(|| body?.get("sortBy"))
205 .and_then(Value::as_object)?;
206 let field = obj
207 .get("field")
208 .or_else(|| obj.get("column"))
209 .and_then(Value::as_str)?;
210 let normalized = normalize_column_name(field, force_snake);
211 if normalized.is_empty() {
212 return None;
213 }
214 let ascending = obj
215 .get("direction")
216 .and_then(Value::as_str)
217 .map(|s| matches!(s.to_ascii_lowercase().as_str(), "asc" | "ascending"))
218 .unwrap_or(true);
219 Some(SortOptions {
220 column: normalized,
221 ascending,
222 })
223}
224
225fn extract_group_label(
226 row: &Value,
227 column: &str,
228 granularity: Option<&TimeGranularity>,
229) -> Result<String, String> {
230 let value = row
231 .get(column)
232 .ok_or_else(|| format!("Missing group_by column '{column}'"))?;
233
234 if let Some(granularity) = granularity {
235 let datetime = parse_datetime_value(value)?;
236 return Ok(granularity.format_label(datetime));
237 }
238
239 if let Some(text) = value.as_str() {
240 return Ok(text.to_string());
241 }
242
243 Ok(value.to_string())
244}
245
246#[allow(deprecated)]
247fn parse_datetime_value(value: &Value) -> Result<DateTime<Utc>, String> {
248 match value {
249 Value::String(text) => DateTime::parse_from_rfc3339(text)
250 .map(|dt| dt.with_timezone(&Utc))
251 .or_else(|_| {
252 NaiveDateTime::parse_from_str(text, "%Y-%m-%d %H:%M:%S")
253 .map(|naive| DateTime::<Utc>::from_utc(naive, Utc))
254 })
255 .map_err(|_| format!("Failed to parse datetime string: '{text}'")),
256 Value::Number(number) => {
257 let timestamp = if let Some(i) = number.as_i64() {
258 i as f64
259 } else if let Some(f) = number.as_f64() {
260 f
261 } else {
262 return Err("Numeric timestamp must be integer or float".to_string());
263 };
264
265 let secs = timestamp.trunc() as i64;
266 let nanos = ((timestamp.fract()) * 1_000_000_000f64) as u32;
267 Utc.timestamp_opt(secs, nanos)
268 .single()
269 .ok_or_else(|| format!("Invalid unix timestamp: {timestamp}"))
270 }
271 _ => {
272 Err("Value must be a string (RFC3339/datetime) or number (unix timestamp)".to_string())
273 }
274 }
275}
276
277pub fn build_fetch_hashed_cache_key(
279 table_name: &str,
280 conditions: &[GatewayRequestCondition],
281 columns_vec: &[String],
282 limit: i64,
283 strip_nulls: bool,
284 client_name: &str,
285 sort_options: Option<&SortOptions>,
286) -> String {
287 let input = GatewayFetchCacheKeyInput {
288 table_name,
289 conditions,
290 columns_vec,
291 limit,
292 strip_nulls,
293 client_name,
294 sort_options,
295 };
296 build_fetch_hashed_cache_key_with_hash_len(&input, 16)
297}
298
299pub fn build_fetch_hashed_cache_key_legacy8(
301 table_name: &str,
302 conditions: &[GatewayRequestCondition],
303 columns_vec: &[String],
304 limit: i64,
305 strip_nulls: bool,
306 client_name: &str,
307 sort_options: Option<&SortOptions>,
308) -> String {
309 let input = GatewayFetchCacheKeyInput {
310 table_name,
311 conditions,
312 columns_vec,
313 limit,
314 strip_nulls,
315 client_name,
316 sort_options,
317 };
318 build_fetch_hashed_cache_key_with_hash_len(&input, 8)
319}
320
321struct GatewayFetchCacheKeyInput<'a> {
322 table_name: &'a str,
323 conditions: &'a [GatewayRequestCondition],
324 columns_vec: &'a [String],
325 limit: i64,
326 strip_nulls: bool,
327 client_name: &'a str,
328 sort_options: Option<&'a SortOptions>,
329}
330
331fn build_fetch_hashed_cache_key_with_hash_len(
332 input: &GatewayFetchCacheKeyInput<'_>,
333 hash_len: usize,
334) -> String {
335 let mut normalized_conditions: Vec<(String, Value, String)> = input
336 .conditions
337 .iter()
338 .map(|condition| {
339 let serialized_value = serde_json::to_string(&condition.eq_value).unwrap_or_default();
340 (
341 condition.eq_column.clone(),
342 condition.eq_value.clone(),
343 serialized_value,
344 )
345 })
346 .collect();
347 normalized_conditions.sort_by(|a, b| a.0.cmp(&b.0).then(a.2.cmp(&b.2)));
348
349 let first_eq_column = normalized_conditions
350 .first()
351 .map_or("_", |(column, _, _)| column.as_str());
352 let hash_input = json!({
353 "columns": input.columns_vec,
354 "conditions": normalized_conditions.iter().map(|(eq_column, eq_value, _)| json!({
355 "eq_column": eq_column,
356 "eq_value": eq_value.clone()
357 })).collect::<Vec<_>>(),
358 "limit": input.limit,
359 "strip_nulls": input.strip_nulls,
360 "client": input.client_name,
361 "sort": input.sort_options.map(|s| json!({"column": s.column, "ascending": s.ascending})),
362 });
363 let hash_str = sha256::digest(serde_json::to_string(&hash_input).unwrap_or_default());
364 let short_hash = &hash_str[..hash_len.min(hash_str.len())];
365
366 format!(
367 "{}:{first_eq_column}:{}:{}:{}:{short_hash}",
368 input.table_name,
369 input.columns_vec.join(","),
370 input.limit,
371 input.strip_nulls
372 )
373}
374
375fn sum_bucket(rows: &[Value], column: &str, dedup: bool) -> Result<Decimal, String> {
376 let mut total = Decimal::ZERO;
377 let mut seen = HashSet::new();
378
379 for row in rows {
380 let value = row
381 .get(column)
382 .ok_or_else(|| format!("Missing aggregation column '{column}'"))?;
383 let decimal = value_to_decimal(value)
384 .ok_or_else(|| format!("Aggregation column '{column}' contains non-numeric data"))?;
385
386 if dedup {
387 if seen.contains(&decimal) {
388 continue;
389 }
390 seen.insert(decimal);
391 }
392
393 total += decimal;
394 }
395
396 Ok(total)
397}
398
399fn value_to_decimal(value: &Value) -> Option<Decimal> {
400 match value {
401 Value::Number(num) => {
402 if let Some(i) = num.as_i64() {
403 Some(Decimal::from(i))
404 } else if let Some(u) = num.as_u64() {
405 Some(Decimal::from(u))
406 } else if let Some(f) = num.as_f64() {
407 Decimal::from_f64(f)
408 } else {
409 None
410 }
411 }
412 Value::String(text) => Decimal::from_str(text).ok(),
413 _ => None,
414 }
415}
416
417pub fn parse_room_id_value(value: &Value) -> Result<i64, GatewayFetchConditionError> {
419 match value {
420 Value::Number(num) => num
421 .as_i64()
422 .ok_or_else(|| GatewayFetchConditionError::new("room_id must be an integer")),
423 Value::String(text) => {
424 let trimmed = text.trim();
425 if trimmed == "*" {
426 return Err(GatewayFetchConditionError::new(
427 "room_id wildcard '*' is not allowed",
428 ));
429 }
430 if trimmed.is_empty() {
431 return Err(GatewayFetchConditionError::new("room_id must not be empty"));
432 }
433 trimmed
434 .parse::<i64>()
435 .map_err(|_| GatewayFetchConditionError::new("room_id must be numeric"))
436 }
437 _ => Err(GatewayFetchConditionError::new("room_id must be numeric")),
438 }
439}
440
441pub fn coerce_room_id_eq_value(eq_value_raw: &Value) -> Result<Value, GatewayFetchConditionError> {
443 parse_room_id_value(eq_value_raw).map(|id| Value::Number(Number::from(id)))
444}
445
446pub fn parse_gateway_fetch_conditions(
448 json_body: &Value,
449 force_camel_case_to_snake_case: bool,
450) -> Result<Vec<GatewayRequestCondition>, GatewayFetchConditionError> {
451 let mut conditions = Vec::new();
452 let Some(additional_conditions) = json_body.get("conditions").and_then(Value::as_array) else {
453 return Ok(conditions);
454 };
455
456 for condition in additional_conditions {
457 if let Some(eq_column) = condition.get("eq_column").and_then(Value::as_str) {
458 let eq_column_str = eq_column.to_string();
459 let normalized_for_validation =
460 normalize_column_name(eq_column, force_camel_case_to_snake_case);
461
462 let eq_value_raw = match condition.get("eq_value") {
463 Some(value) => value.clone(),
464 None => {
465 if normalized_for_validation == "room_id" || eq_column_str == "roomId" {
466 return Err(GatewayFetchConditionError::new(
467 "room_id is required and must be numeric",
468 ));
469 }
470 continue;
471 }
472 };
473
474 let eq_value = if normalized_for_validation == "room_id" || eq_column_str == "roomId" {
475 coerce_room_id_eq_value(&eq_value_raw)?
476 } else {
477 eq_value_raw
478 };
479 conditions.push(GatewayRequestCondition::new(eq_column_str, eq_value));
480 }
481 }
482
483 Ok(conditions)
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use serde_json::json;
490
491 #[test]
492 fn hashed_cache_key_stable_for_same_inputs() {
493 let conditions = vec![GatewayRequestCondition::new(
494 "workspace_id".into(),
495 json!("abc"),
496 )];
497 let k1 = build_fetch_hashed_cache_key(
498 "users",
499 &conditions,
500 &["id".into(), "email".into()],
501 10,
502 false,
503 "supabase",
504 None,
505 );
506 let k2 = build_fetch_hashed_cache_key(
507 "users",
508 &conditions,
509 &["id".into(), "email".into()],
510 10,
511 false,
512 "supabase",
513 None,
514 );
515 assert_eq!(k1, k2);
516 }
517
518 #[test]
519 fn hashed_cache_key_stable_for_reordered_conditions() {
520 let conditions_a = vec![
521 GatewayRequestCondition::new("workspace_id".into(), json!("abc")),
522 GatewayRequestCondition::new("room_id".into(), json!(123)),
523 ];
524 let conditions_b = vec![
525 GatewayRequestCondition::new("room_id".into(), json!(123)),
526 GatewayRequestCondition::new("workspace_id".into(), json!("abc")),
527 ];
528
529 let k1 = build_fetch_hashed_cache_key(
530 "users",
531 &conditions_a,
532 &["id".into(), "email".into()],
533 10,
534 false,
535 "supabase",
536 None,
537 );
538 let k2 = build_fetch_hashed_cache_key(
539 "users",
540 &conditions_b,
541 &["id".into(), "email".into()],
542 10,
543 false,
544 "supabase",
545 None,
546 );
547
548 assert_eq!(k1, k2);
549 }
550
551 #[test]
552 fn parse_sort_options_normalizes_requested_column() {
553 let sort = parse_sort_options_from_body(
554 Some(&json!({
555 "sortBy": {
556 "field": "createdAt",
557 "direction": "desc"
558 }
559 })),
560 true,
561 )
562 .expect("sort should parse");
563
564 assert_eq!(
565 sort,
566 SortOptions {
567 column: "created_at".to_string(),
568 ascending: false,
569 }
570 );
571 }
572
573 #[test]
574 fn parse_gateway_fetch_conditions_coerces_room_id() {
575 let conditions = parse_gateway_fetch_conditions(
576 &json!({
577 "conditions": [
578 {
579 "eq_column": "roomId",
580 "eq_value": "42"
581 }
582 ]
583 }),
584 true,
585 )
586 .expect("conditions should parse");
587
588 assert_eq!(
589 conditions,
590 vec![GatewayRequestCondition::new("roomId".into(), json!(42))]
591 );
592 }
593
594 #[test]
595 fn parse_gateway_fetch_conditions_rejects_missing_room_id_value() {
596 let err = parse_gateway_fetch_conditions(
597 &json!({
598 "conditions": [
599 {
600 "eq_column": "room_id"
601 }
602 ]
603 }),
604 false,
605 )
606 .expect_err("room_id without value should fail");
607
608 assert_eq!(err.message(), "room_id is required and must be numeric");
609 }
610
611 #[test]
612 fn post_processing_config_normalizes_columns() {
613 let config = PostProcessingConfig::from_body(
614 Some(&json!({
615 "group_by": "createdAt",
616 "aggregation_column": "requestCount",
617 "time_granularity": "hour",
618 "aggregation_strategy": "cumulative_sum",
619 "aggregation_dedup": true
620 })),
621 true,
622 );
623
624 assert_eq!(config.group_by.as_deref(), Some("created_at"));
625 assert_eq!(config.aggregation_column.as_deref(), Some("request_count"));
626 assert_eq!(config.time_granularity, Some(TimeGranularity::Hour));
627 assert_eq!(
628 config.aggregation_strategy,
629 Some(AggregationStrategy::CumulativeSum)
630 );
631 assert!(config.dedup_aggregation);
632 }
633
634 #[test]
635 fn apply_post_processing_requires_aggregation_column_when_strategy_set() {
636 let rows = vec![json!({"created_at": "2024-01-01T00:00:00Z", "value": 1})];
637 let config = PostProcessingConfig {
638 group_by: Some("created_at".to_string()),
639 time_granularity: None,
640 aggregation_column: None,
641 aggregation_strategy: Some(AggregationStrategy::CumulativeSum),
642 dedup_aggregation: false,
643 };
644
645 let result = apply_post_processing(&rows, &config);
646 assert!(result.is_err());
647 }
648
649 #[test]
650 fn apply_post_processing_groups_and_accumulates_rows() {
651 let rows = vec![
652 json!({"created_at": "2024-01-01T10:05:00Z", "value": 2}),
653 json!({"created_at": "2024-01-01T10:15:00Z", "value": 3}),
654 json!({"created_at": "2024-01-01T11:00:00Z", "value": 5}),
655 ];
656 let config = PostProcessingConfig {
657 group_by: Some("created_at".to_string()),
658 time_granularity: Some(TimeGranularity::Hour),
659 aggregation_column: Some("value".to_string()),
660 aggregation_strategy: Some(AggregationStrategy::CumulativeSum),
661 dedup_aggregation: false,
662 };
663
664 let result = apply_post_processing(&rows, &config)
665 .expect("post-processing should succeed")
666 .expect("grouping should be present");
667
668 assert_eq!(
669 result,
670 json!({
671 "grouped": [
672 {
673 "label": "2024-01-01 10:00",
674 "rows": [
675 {"created_at": "2024-01-01T10:05:00Z", "value": 2},
676 {"created_at": "2024-01-01T10:15:00Z", "value": 3}
677 ],
678 "aggregation": "5"
679 },
680 {
681 "label": "2024-01-01 11:00",
682 "rows": [
683 {"created_at": "2024-01-01T11:00:00Z", "value": 5}
684 ],
685 "aggregation": "10"
686 }
687 ]
688 })
689 );
690 }
691}