1use crate::billing::{BudgetTracker, CostEstimate};
2use crate::protocol::{ErrorCode, ErrorDetails};
3
4#[derive(Debug, Clone)]
6pub struct ExecutionContext {
7 pub non_interactive: bool,
9 pub trace_id: Option<String>,
11 pub max_cost_credits: Option<u32>,
13 pub budget_daily_credits: Option<u32>,
15 pub dry_run: bool,
17}
18
19impl ExecutionContext {
20 pub fn new(
22 non_interactive: bool,
23 trace_id: Option<String>,
24 max_cost_credits: Option<u32>,
25 budget_daily_credits: Option<u32>,
26 dry_run: bool,
27 ) -> Self {
28 Self {
29 non_interactive,
30 trace_id,
31 max_cost_credits,
32 budget_daily_credits,
33 dry_run,
34 }
35 }
36}
37
38#[derive(Debug, Clone, Default)]
40pub struct ExecutionPolicy;
41
42impl ExecutionPolicy {
43 pub fn new() -> Self {
45 Self
46 }
47
48 pub fn check_interaction_required(
69 &self,
70 ctx: &ExecutionContext,
71 message: impl Into<String>,
72 next_steps: Vec<String>,
73 ) -> Option<ErrorDetails> {
74 if ctx.non_interactive {
75 Some(ErrorDetails::auth_required(message, next_steps))
76 } else {
77 None
78 }
79 }
80
81 pub fn check_max_cost(
83 &self,
84 ctx: &ExecutionContext,
85 cost: &CostEstimate,
86 ) -> Option<ErrorDetails> {
87 if let Some(max) = ctx.max_cost_credits {
88 if cost.credits > max {
89 let mut details = std::collections::HashMap::new();
90 details.insert("cost".to_string(), serde_json::json!(cost.credits));
91 details.insert("limit".to_string(), serde_json::json!(max));
92 return Some(ErrorDetails::with_details(
93 ErrorCode::CostLimitExceeded,
94 format!(
95 "Operation cost {} credits exceeds maximum {} credits",
96 cost.credits, max
97 ),
98 details,
99 ));
100 }
101 }
102 None
103 }
104
105 pub fn check_daily_budget(
107 &self,
108 ctx: &ExecutionContext,
109 cost: &CostEstimate,
110 tracker: &BudgetTracker,
111 ) -> Option<ErrorDetails> {
112 if tracker.check_budget(cost.credits).is_err() {
113 let mut details = std::collections::HashMap::new();
114 details.insert("cost".to_string(), serde_json::json!(cost.credits));
115 details.insert(
116 "todayUsage".to_string(),
117 serde_json::json!(tracker.today_usage()),
118 );
119 if let Some(limit) = ctx.budget_daily_credits {
120 details.insert("dailyLimit".to_string(), serde_json::json!(limit));
121 }
122 return Some(ErrorDetails::with_details(
123 ErrorCode::DailyBudgetExceeded,
124 "Daily budget exceeded".to_string(),
125 details,
126 ));
127 }
128 None
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::protocol::ErrorCode;
136
137 #[test]
138 fn test_context_creation() {
139 let ctx = ExecutionContext::new(true, Some("trace-123".to_string()), None, None, false);
140 assert!(ctx.non_interactive);
141 assert_eq!(ctx.trace_id, Some("trace-123".to_string()));
142 }
143
144 #[test]
145 fn test_check_interaction_required_non_interactive() {
146 let ctx = ExecutionContext::new(true, None, None, None, false);
147 let policy = ExecutionPolicy::new();
148 let error = policy.check_interaction_required(
149 &ctx,
150 "Auth required",
151 vec!["Run login command".to_string()],
152 );
153 assert!(error.is_some());
154 let err = error.unwrap();
155 assert_eq!(err.code, ErrorCode::AuthRequired);
156 assert!(!err.is_retryable);
157 assert!(err.details.is_some());
158 }
159
160 #[test]
161 fn test_check_interaction_required_interactive() {
162 let ctx = ExecutionContext::new(false, None, None, None, false);
163 let policy = ExecutionPolicy::new();
164 let error = policy.check_interaction_required(
165 &ctx,
166 "Auth required",
167 vec!["Run login command".to_string()],
168 );
169 assert!(error.is_none());
170 }
171
172 #[test]
173 fn test_check_max_cost_within_limit() {
174 let ctx = ExecutionContext::new(false, None, Some(100), None, false);
175 let policy = ExecutionPolicy::new();
176 let cost = CostEstimate::new(50, 0.05);
177 let error = policy.check_max_cost(&ctx, &cost);
178 assert!(error.is_none());
179 }
180
181 #[test]
182 fn test_check_max_cost_exceeds_limit() {
183 let ctx = ExecutionContext::new(false, None, Some(100), None, false);
184 let policy = ExecutionPolicy::new();
185 let cost = CostEstimate::new(101, 0.101);
186 let error = policy.check_max_cost(&ctx, &cost);
187 assert!(error.is_some());
188 let err = error.unwrap();
189 assert_eq!(err.code, ErrorCode::CostLimitExceeded);
190 }
191
192 #[test]
193 fn test_check_daily_budget_within_limit() {
194 let ctx = ExecutionContext::new(false, None, None, Some(100), false);
195 let policy = ExecutionPolicy::new();
196 let tracker = BudgetTracker::new(Some(100));
197 let cost = CostEstimate::new(50, 0.05);
198 let error = policy.check_daily_budget(&ctx, &cost, &tracker);
199 assert!(error.is_none());
200 }
201
202 #[test]
203 fn test_check_daily_budget_exceeds_limit() {
204 let ctx = ExecutionContext::new(false, None, None, Some(100), false);
205 let policy = ExecutionPolicy::new();
206 let tracker = BudgetTracker::new(Some(100));
207 let cost = CostEstimate::new(101, 0.101);
208 let error = policy.check_daily_budget(&ctx, &cost, &tracker);
209 assert!(error.is_some());
210 let err = error.unwrap();
211 assert_eq!(err.code, ErrorCode::DailyBudgetExceeded);
212 }
213}