1use std::collections::HashSet;
10
11use crate::kernel::action::{Action, ActionError, ActionErrorKind};
12use crate::kernel::identity::RunId;
13use crate::kernel::KernelError;
14
15#[derive(Clone, Debug, Default)]
17pub struct PolicyCtx {
18 pub user_id: Option<String>,
19 pub metadata: std::collections::HashMap<String, String>,
20}
21
22#[derive(Clone, Debug)]
24pub enum RetryDecision {
25 Retry,
26 RetryAfterMs(u64),
27 Fail,
28}
29
30#[derive(Clone, Debug, Default)]
32pub struct BudgetRules {
33 pub max_tool_calls: Option<u64>,
34 pub max_llm_tokens: Option<u64>,
35}
36
37pub trait Policy: Send + Sync {
39 fn authorize(
41 &self,
42 run_id: &RunId,
43 action: &Action,
44 ctx: &PolicyCtx,
45 ) -> Result<(), KernelError>;
46
47 fn retry_strategy(&self, err: &dyn std::fmt::Display, _action: &Action) -> RetryDecision {
49 let _ = err;
50 RetryDecision::Fail
51 }
52
53 fn retry_strategy_attempt(
58 &self,
59 err: &ActionError,
60 action: &Action,
61 attempt: u32,
62 ) -> RetryDecision {
63 let _ = (action, attempt);
64 match &err.kind {
65 ActionErrorKind::Permanent => RetryDecision::Fail,
66 ActionErrorKind::Transient | ActionErrorKind::RateLimited => {
67 if let ActionErrorKind::RateLimited = &err.kind {
69 if let Some(ms) = err.retry_after_ms {
70 return RetryDecision::RetryAfterMs(ms);
71 }
72 }
73 RetryDecision::Fail
74 }
75 }
76 }
77
78 fn budget(&self) -> BudgetRules {
80 BudgetRules::default()
81 }
82}
83
84pub struct AllowListPolicy {
88 pub allowed_tools: HashSet<String>,
89 pub allowed_providers: HashSet<String>,
90}
91
92impl AllowListPolicy {
93 pub fn new(allowed_tools: HashSet<String>, allowed_providers: HashSet<String>) -> Self {
94 Self {
95 allowed_tools,
96 allowed_providers,
97 }
98 }
99
100 pub fn tools_only(tools: impl IntoIterator<Item = String>) -> Self {
101 Self {
102 allowed_tools: tools.into_iter().collect(),
103 allowed_providers: HashSet::new(),
104 }
105 }
106
107 pub fn providers_only(providers: impl IntoIterator<Item = String>) -> Self {
108 Self {
109 allowed_tools: HashSet::new(),
110 allowed_providers: providers.into_iter().collect(),
111 }
112 }
113}
114
115impl Policy for AllowListPolicy {
116 fn authorize(
117 &self,
118 _run_id: &RunId,
119 action: &Action,
120 _ctx: &PolicyCtx,
121 ) -> Result<(), KernelError> {
122 match action {
123 Action::CallTool { tool, .. } => {
124 if self.allowed_tools.contains(tool) {
125 Ok(())
126 } else {
127 Err(KernelError::Policy(format!("tool not allowed: {}", tool)))
128 }
129 }
130 Action::CallLLM { provider, .. } => {
131 if self.allowed_providers.contains(provider) {
132 Ok(())
133 } else {
134 Err(KernelError::Policy(format!(
135 "provider not allowed: {}",
136 provider
137 )))
138 }
139 }
140 Action::Sleep { .. } | Action::WaitSignal { .. } => Ok(()),
141 }
142 }
143}
144
145pub struct RetryWithBackoffPolicy<P> {
148 pub inner: P,
149 pub max_retries: u32,
150 pub backoff_base_ms: u64,
152 pub backoff_cap_ms: Option<u64>,
154 pub jitter_ratio: f64,
156}
157
158impl<P: Policy> RetryWithBackoffPolicy<P> {
159 pub fn new(inner: P, max_retries: u32, backoff_ms: u64) -> Self {
161 Self {
162 inner,
163 max_retries,
164 backoff_base_ms: backoff_ms,
165 backoff_cap_ms: None,
166 jitter_ratio: 0.0,
167 }
168 }
169
170 pub fn with_exponential_backoff(
172 inner: P,
173 max_retries: u32,
174 backoff_base_ms: u64,
175 backoff_cap_ms: Option<u64>,
176 jitter_ratio: f64,
177 ) -> Self {
178 Self {
179 inner,
180 max_retries,
181 backoff_base_ms,
182 backoff_cap_ms,
183 jitter_ratio,
184 }
185 }
186
187 fn delay_ms(&self, err: &ActionError, attempt: u32) -> u64 {
188 if matches!(err.kind, ActionErrorKind::RateLimited) && err.retry_after_ms.is_some() {
189 return err.retry_after_ms.unwrap();
190 }
191 let exp = self
192 .backoff_base_ms
193 .saturating_mul(2_u64.saturating_pow(attempt));
194 let capped = match self.backoff_cap_ms {
195 Some(cap) => std::cmp::min(exp, cap),
196 None => exp,
197 };
198 if self.jitter_ratio <= 0.0 {
199 return capped;
200 }
201 let jitter_factor = ((attempt.wrapping_mul(31)) % 100) as f64 / 100.0;
203 let jitter = capped as f64 * self.jitter_ratio * jitter_factor;
204 (capped as f64 + jitter) as u64
205 }
206}
207
208impl<P: Policy> Policy for RetryWithBackoffPolicy<P> {
209 fn authorize(
210 &self,
211 run_id: &RunId,
212 action: &Action,
213 ctx: &PolicyCtx,
214 ) -> Result<(), KernelError> {
215 self.inner.authorize(run_id, action, ctx)
216 }
217
218 fn retry_strategy_attempt(
219 &self,
220 err: &ActionError,
221 action: &Action,
222 attempt: u32,
223 ) -> RetryDecision {
224 if matches!(err.kind, ActionErrorKind::Permanent) {
225 return RetryDecision::Fail;
226 }
227 if attempt < self.max_retries {
228 RetryDecision::RetryAfterMs(self.delay_ms(err, attempt))
229 } else {
230 let _ = action;
231 RetryDecision::Fail
232 }
233 }
234
235 fn budget(&self) -> BudgetRules {
236 self.inner.budget()
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::kernel::action::ActionError;
244
245 #[test]
246 fn permanent_error_returns_fail() {
247 let err = ActionError::permanent("bad request");
248 let policy = AllowListPolicy::tools_only(std::iter::once("t1".to_string()));
249 let action = Action::CallTool {
250 tool: "t1".into(),
251 input: serde_json::json!(null),
252 };
253 let d = policy.retry_strategy_attempt(&err, &action, 0);
254 assert!(matches!(d, RetryDecision::Fail));
255 }
256
257 #[test]
258 fn retry_with_backoff_transient_retries_then_fails() {
259 let inner = AllowListPolicy::tools_only(std::iter::once("t1".to_string()));
260 let policy = RetryWithBackoffPolicy::new(inner, 2, 10);
261 let err = ActionError::transient("timeout");
262 let action = Action::CallTool {
263 tool: "t1".into(),
264 input: serde_json::json!(null),
265 };
266 assert!(matches!(
267 policy.retry_strategy_attempt(&err, &action, 0),
268 RetryDecision::RetryAfterMs(10)
269 ));
270 assert!(matches!(
271 policy.retry_strategy_attempt(&err, &action, 1),
272 RetryDecision::RetryAfterMs(_)
273 ));
274 assert!(matches!(
275 policy.retry_strategy_attempt(&err, &action, 2),
276 RetryDecision::Fail
277 ));
278 }
279
280 #[test]
281 fn retry_with_backoff_rate_limited_uses_retry_after_ms() {
282 let inner = AllowListPolicy::tools_only(std::iter::empty());
283 let policy = RetryWithBackoffPolicy::new(inner, 3, 100);
284 let err = ActionError::rate_limited("429", 2500);
285 let action = Action::CallLLM {
286 provider: "p1".into(),
287 input: serde_json::json!(null),
288 };
289 let d = policy.retry_strategy_attempt(&err, &action, 0);
290 assert!(matches!(d, RetryDecision::RetryAfterMs(2500)));
291 }
292
293 #[test]
294 fn retry_with_backoff_exponential_increases() {
295 let inner = AllowListPolicy::tools_only(std::iter::once("t1".to_string()));
296 let policy = RetryWithBackoffPolicy::with_exponential_backoff(inner, 5, 50, Some(500), 0.0);
297 let err = ActionError::transient("timeout");
298 let action = Action::CallTool {
299 tool: "t1".into(),
300 input: serde_json::json!(null),
301 };
302 let d0 = policy.retry_strategy_attempt(&err, &action, 0);
303 let d1 = policy.retry_strategy_attempt(&err, &action, 1);
304 let d2 = policy.retry_strategy_attempt(&err, &action, 2);
305 let ms0 = match &d0 {
306 RetryDecision::RetryAfterMs(m) => *m,
307 _ => 0,
308 };
309 let ms1 = match &d1 {
310 RetryDecision::RetryAfterMs(m) => *m,
311 _ => 0,
312 };
313 let ms2 = match &d2 {
314 RetryDecision::RetryAfterMs(m) => *m,
315 _ => 0,
316 };
317 assert!(ms0 == 50);
318 assert!(ms1 == 100);
319 assert!(ms2 == 200);
320 }
321}