1use std::collections::HashMap;
10
11use ff_core::partition::{budget_partition, PartitionConfig};
12use ff_core::types::BudgetId;
13
14#[derive(Clone, Debug, Default)]
16pub struct BudgetStatus {
17 pub hard_breached: bool,
19 pub soft_breached: bool,
21}
22
23pub struct BudgetChecker {
28 client: ferriskey::Client,
29 partition_config: PartitionConfig,
30 cycle_cache: HashMap<BudgetId, BudgetStatus>,
31}
32
33impl BudgetChecker {
34 pub fn new(client: ferriskey::Client, partition_config: PartitionConfig) -> Self {
35 Self {
36 client,
37 partition_config,
38 cycle_cache: HashMap::new(),
39 }
40 }
41
42 pub async fn check_budget(
44 &mut self,
45 budget_id: &BudgetId,
46 ) -> Result<BudgetStatus, ferriskey::Error> {
47 if let Some(status) = self.cycle_cache.get(budget_id) {
49 return Ok(status.clone());
50 }
51
52 let partition = budget_partition(budget_id, &self.partition_config);
54 let tag = partition.hash_tag();
55
56 let usage_key = format!("ff:budget:{}:{}:usage", tag, budget_id);
57 let limits_key = format!("ff:budget:{}:{}:limits", tag, budget_id);
58
59 let usage_raw: Vec<String> = self
61 .client
62 .cmd("HGETALL")
63 .arg(&usage_key)
64 .execute()
65 .await
66 .unwrap_or_default();
67
68 let limits_raw: Vec<String> = self
69 .client
70 .cmd("HGETALL")
71 .arg(&limits_key)
72 .execute()
73 .await
74 .unwrap_or_default();
75
76 let status = evaluate_budget_status(&usage_raw, &limits_raw);
77
78 self.cycle_cache.insert(budget_id.clone(), status.clone());
79 Ok(status)
80 }
81
82 pub fn clear_cache(&mut self) {
84 self.cycle_cache.clear();
85 }
86
87 pub fn cache_size(&self) -> usize {
89 self.cycle_cache.len()
90 }
91}
92
93fn evaluate_budget_status(usage_raw: &[String], limits_raw: &[String]) -> BudgetStatus {
95 if limits_raw.is_empty() {
96 return BudgetStatus::default(); }
98
99 let mut hard_breached = false;
100 let mut soft_breached = false;
101
102 let mut i = 0;
104 while i + 1 < limits_raw.len() {
105 let field = &limits_raw[i];
106 let limit_val: i64 = limits_raw[i + 1].parse().unwrap_or(i64::MAX);
107 i += 2;
108
109 if let Some(dim) = field.strip_prefix("hard:") {
111 let current = find_usage(usage_raw, dim);
112 if current >= limit_val {
113 hard_breached = true;
114 }
115 } else if let Some(dim) = field.strip_prefix("soft:") {
116 let current = find_usage(usage_raw, dim);
117 if current >= limit_val {
118 soft_breached = true;
119 }
120 }
121 }
122
123 for j in (0..usage_raw.len()).step_by(2) {
125 if usage_raw[j] == "breached_at" {
126 hard_breached = true;
127 break;
128 }
129 }
130
131 BudgetStatus {
132 hard_breached,
133 soft_breached,
134 }
135}
136
137fn find_usage(usage_raw: &[String], dimension: &str) -> i64 {
139 let mut i = 0;
140 while i + 1 < usage_raw.len() {
141 if usage_raw[i] == dimension {
142 return usage_raw[i + 1].parse().unwrap_or(0);
143 }
144 i += 2;
145 }
146 0
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn no_limits_means_not_breached() {
155 let status = evaluate_budget_status(&[], &[]);
156 assert!(!status.hard_breached);
157 assert!(!status.soft_breached);
158 }
159
160 #[test]
161 fn hard_breach_detected() {
162 let usage = vec!["tokens".into(), "1000".into()];
163 let limits = vec!["hard:tokens".into(), "500".into()];
164 let status = evaluate_budget_status(&usage, &limits);
165 assert!(status.hard_breached);
166 }
167
168 #[test]
169 fn under_limit_not_breached() {
170 let usage = vec!["tokens".into(), "100".into()];
171 let limits = vec!["hard:tokens".into(), "500".into()];
172 let status = evaluate_budget_status(&usage, &limits);
173 assert!(!status.hard_breached);
174 }
175
176 #[test]
177 fn soft_breach_detected() {
178 let usage = vec!["cost_cents".into(), "8000".into()];
179 let limits = vec![
180 "hard:cost_cents".into(), "10000".into(),
181 "soft:cost_cents".into(), "7500".into(),
182 ];
183 let status = evaluate_budget_status(&usage, &limits);
184 assert!(!status.hard_breached);
185 assert!(status.soft_breached);
186 }
187}