Skip to main content

oris_kernel/kernel/
policy.rs

1//! Policy: governance layer (authorize, retry, budget).
2//!
3//! Must exist even as a minimal implementation so Oris is not a "run any tool" demo.
4//!
5//! **Retry loop:** The driver calls `retry_strategy_attempt` on executor `Err` and only stops when
6//! the policy returns `Fail`. Implementations must eventually return `Fail` or the loop would not
7//! terminate; `RetryWithBackoffPolicy` does so after `max_retries` attempts.
8
9use std::collections::HashSet;
10
11use crate::kernel::action::{Action, ActionError, ActionErrorKind};
12use crate::kernel::identity::RunId;
13use crate::kernel::KernelError;
14
15/// Context passed to policy (e.g. caller identity, run metadata).
16#[derive(Clone, Debug, Default)]
17pub struct PolicyCtx {
18    pub user_id: Option<String>,
19    pub metadata: std::collections::HashMap<String, String>,
20}
21
22/// Decision after an action failure (retry, backoff, or fail).
23#[derive(Clone, Debug)]
24pub enum RetryDecision {
25    Retry,
26    RetryAfterMs(u64),
27    Fail,
28}
29
30/// Optional budget rules (cost, token limits, etc.).
31#[derive(Clone, Debug, Default)]
32pub struct BudgetRules {
33    pub max_tool_calls: Option<u64>,
34    pub max_llm_tokens: Option<u64>,
35}
36
37/// Policy: authorize actions, decide retries, optional budget.
38pub trait Policy: Send + Sync {
39    /// Whether the action is allowed for this run and context.
40    fn authorize(
41        &self,
42        run_id: &RunId,
43        action: &Action,
44        ctx: &PolicyCtx,
45    ) -> Result<(), KernelError>;
46
47    /// Whether to retry after an error (and optionally after a delay).
48    fn retry_strategy(&self, err: &dyn std::fmt::Display, _action: &Action) -> RetryDecision {
49        let _ = err;
50        RetryDecision::Fail
51    }
52
53    /// Retry strategy with attempt count and structured error. Default uses kind: Permanent => Fail,
54    /// others may be retried by implementations. Applies only to executor `Err`; `ActionResult::Failure` is not retried.
55    ///
56    /// `attempt` is the 0-based count of failures so far. Return `Fail` when no more retries are desired.
57    fn retry_strategy_attempt(
58        &self,
59        err: &ActionError,
60        action: &Action,
61        attempt: u32,
62    ) -> RetryDecision {
63        let _ = (action, attempt);
64        match &err.kind {
65            ActionErrorKind::Permanent => RetryDecision::Fail,
66            ActionErrorKind::Transient | ActionErrorKind::RateLimited => {
67                // Default: no retry unless overridden
68                if let ActionErrorKind::RateLimited = &err.kind {
69                    if let Some(ms) = err.retry_after_ms {
70                        return RetryDecision::RetryAfterMs(ms);
71                    }
72                }
73                RetryDecision::Fail
74            }
75        }
76    }
77
78    /// Optional budget; default is no limits.
79    fn budget(&self) -> BudgetRules {
80        BudgetRules::default()
81    }
82}
83
84/// Policy that allows only actions whose tool/provider is in the given sets.
85/// **Empty set = no tools or providers allowed** for that category. Sleep and WaitSignal are
86/// always allowed. To allow all tools/providers use `AllowAllPolicy`, or populate the sets explicitly.
87pub struct AllowListPolicy {
88    pub allowed_tools: HashSet<String>,
89    pub allowed_providers: HashSet<String>,
90}
91
92impl AllowListPolicy {
93    pub fn new(allowed_tools: HashSet<String>, allowed_providers: HashSet<String>) -> Self {
94        Self {
95            allowed_tools,
96            allowed_providers,
97        }
98    }
99
100    pub fn tools_only(tools: impl IntoIterator<Item = String>) -> Self {
101        Self {
102            allowed_tools: tools.into_iter().collect(),
103            allowed_providers: HashSet::new(),
104        }
105    }
106
107    pub fn providers_only(providers: impl IntoIterator<Item = String>) -> Self {
108        Self {
109            allowed_tools: HashSet::new(),
110            allowed_providers: providers.into_iter().collect(),
111        }
112    }
113}
114
115impl Policy for AllowListPolicy {
116    fn authorize(
117        &self,
118        _run_id: &RunId,
119        action: &Action,
120        _ctx: &PolicyCtx,
121    ) -> Result<(), KernelError> {
122        match action {
123            Action::CallTool { tool, .. } => {
124                if self.allowed_tools.contains(tool) {
125                    Ok(())
126                } else {
127                    Err(KernelError::Policy(format!("tool not allowed: {}", tool)))
128                }
129            }
130            Action::CallLLM { provider, .. } => {
131                if self.allowed_providers.contains(provider) {
132                    Ok(())
133                } else {
134                    Err(KernelError::Policy(format!(
135                        "provider not allowed: {}",
136                        provider
137                    )))
138                }
139            }
140            Action::Sleep { .. } | Action::WaitSignal { .. } => Ok(()),
141        }
142    }
143}
144
145/// Policy that returns RetryAfterMs with exponential backoff (and optional jitter) for the first
146/// max_retries attempts, then Fail. For RateLimited errors with retry_after_ms, uses that value when set.
147pub struct RetryWithBackoffPolicy<P> {
148    pub inner: P,
149    pub max_retries: u32,
150    /// Base delay in ms; actual delay = min(cap, backoff_base_ms * 2^attempt) + jitter.
151    pub backoff_base_ms: u64,
152    /// Max delay cap in ms (optional).
153    pub backoff_cap_ms: Option<u64>,
154    /// Jitter ratio in [0.0, 1.0]; added randomness to avoid thundering herd.
155    pub jitter_ratio: f64,
156}
157
158impl<P: Policy> RetryWithBackoffPolicy<P> {
159    /// New with fixed backoff (no exponent, no jitter). Preserves legacy behavior when backoff_base_ms is the only delay.
160    pub fn new(inner: P, max_retries: u32, backoff_ms: u64) -> Self {
161        Self {
162            inner,
163            max_retries,
164            backoff_base_ms: backoff_ms,
165            backoff_cap_ms: None,
166            jitter_ratio: 0.0,
167        }
168    }
169
170    /// Exponential backoff: base * 2^attempt, capped, plus jitter.
171    pub fn with_exponential_backoff(
172        inner: P,
173        max_retries: u32,
174        backoff_base_ms: u64,
175        backoff_cap_ms: Option<u64>,
176        jitter_ratio: f64,
177    ) -> Self {
178        Self {
179            inner,
180            max_retries,
181            backoff_base_ms,
182            backoff_cap_ms,
183            jitter_ratio,
184        }
185    }
186
187    fn delay_ms(&self, err: &ActionError, attempt: u32) -> u64 {
188        if matches!(err.kind, ActionErrorKind::RateLimited) && err.retry_after_ms.is_some() {
189            return err.retry_after_ms.unwrap();
190        }
191        let exp = self
192            .backoff_base_ms
193            .saturating_mul(2_u64.saturating_pow(attempt));
194        let capped = match self.backoff_cap_ms {
195            Some(cap) => std::cmp::min(exp, cap),
196            None => exp,
197        };
198        if self.jitter_ratio <= 0.0 {
199            return capped;
200        }
201        // Deterministic jitter from attempt to avoid thundering herd without adding rand dep.
202        let jitter_factor = ((attempt.wrapping_mul(31)) % 100) as f64 / 100.0;
203        let jitter = capped as f64 * self.jitter_ratio * jitter_factor;
204        (capped as f64 + jitter) as u64
205    }
206}
207
208impl<P: Policy> Policy for RetryWithBackoffPolicy<P> {
209    fn authorize(
210        &self,
211        run_id: &RunId,
212        action: &Action,
213        ctx: &PolicyCtx,
214    ) -> Result<(), KernelError> {
215        self.inner.authorize(run_id, action, ctx)
216    }
217
218    fn retry_strategy_attempt(
219        &self,
220        err: &ActionError,
221        action: &Action,
222        attempt: u32,
223    ) -> RetryDecision {
224        if matches!(err.kind, ActionErrorKind::Permanent) {
225            return RetryDecision::Fail;
226        }
227        if attempt < self.max_retries {
228            RetryDecision::RetryAfterMs(self.delay_ms(err, attempt))
229        } else {
230            let _ = action;
231            RetryDecision::Fail
232        }
233    }
234
235    fn budget(&self) -> BudgetRules {
236        self.inner.budget()
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::kernel::action::ActionError;
244
245    #[test]
246    fn permanent_error_returns_fail() {
247        let err = ActionError::permanent("bad request");
248        let policy = AllowListPolicy::tools_only(std::iter::once("t1".to_string()));
249        let action = Action::CallTool {
250            tool: "t1".into(),
251            input: serde_json::json!(null),
252        };
253        let d = policy.retry_strategy_attempt(&err, &action, 0);
254        assert!(matches!(d, RetryDecision::Fail));
255    }
256
257    #[test]
258    fn retry_with_backoff_transient_retries_then_fails() {
259        let inner = AllowListPolicy::tools_only(std::iter::once("t1".to_string()));
260        let policy = RetryWithBackoffPolicy::new(inner, 2, 10);
261        let err = ActionError::transient("timeout");
262        let action = Action::CallTool {
263            tool: "t1".into(),
264            input: serde_json::json!(null),
265        };
266        assert!(matches!(
267            policy.retry_strategy_attempt(&err, &action, 0),
268            RetryDecision::RetryAfterMs(10)
269        ));
270        assert!(matches!(
271            policy.retry_strategy_attempt(&err, &action, 1),
272            RetryDecision::RetryAfterMs(_)
273        ));
274        assert!(matches!(
275            policy.retry_strategy_attempt(&err, &action, 2),
276            RetryDecision::Fail
277        ));
278    }
279
280    #[test]
281    fn retry_with_backoff_rate_limited_uses_retry_after_ms() {
282        let inner = AllowListPolicy::tools_only(std::iter::empty());
283        let policy = RetryWithBackoffPolicy::new(inner, 3, 100);
284        let err = ActionError::rate_limited("429", 2500);
285        let action = Action::CallLLM {
286            provider: "p1".into(),
287            input: serde_json::json!(null),
288        };
289        let d = policy.retry_strategy_attempt(&err, &action, 0);
290        assert!(matches!(d, RetryDecision::RetryAfterMs(2500)));
291    }
292
293    #[test]
294    fn retry_with_backoff_exponential_increases() {
295        let inner = AllowListPolicy::tools_only(std::iter::once("t1".to_string()));
296        let policy = RetryWithBackoffPolicy::with_exponential_backoff(inner, 5, 50, Some(500), 0.0);
297        let err = ActionError::transient("timeout");
298        let action = Action::CallTool {
299            tool: "t1".into(),
300            input: serde_json::json!(null),
301        };
302        let d0 = policy.retry_strategy_attempt(&err, &action, 0);
303        let d1 = policy.retry_strategy_attempt(&err, &action, 1);
304        let d2 = policy.retry_strategy_attempt(&err, &action, 2);
305        let ms0 = match &d0 {
306            RetryDecision::RetryAfterMs(m) => *m,
307            _ => 0,
308        };
309        let ms1 = match &d1 {
310            RetryDecision::RetryAfterMs(m) => *m,
311            _ => 0,
312        };
313        let ms2 = match &d2 {
314            RetryDecision::RetryAfterMs(m) => *m,
315            _ => 0,
316        };
317        assert!(ms0 == 50);
318        assert!(ms1 == 100);
319        assert!(ms2 == 200);
320    }
321}