use std::collections::HashSet;
use crate::kernel::action::{Action, ActionError, ActionErrorKind};
use crate::kernel::identity::RunId;
use crate::kernel::KernelError;
#[derive(Clone, Debug, Default)]
pub struct PolicyCtx {
pub user_id: Option<String>,
pub metadata: std::collections::HashMap<String, String>,
}
#[derive(Clone, Debug)]
pub enum RetryDecision {
Retry,
RetryAfterMs(u64),
Fail,
}
#[derive(Clone, Debug, Default)]
pub struct BudgetRules {
pub max_tool_calls: Option<u64>,
pub max_llm_tokens: Option<u64>,
}
pub trait Policy: Send + Sync {
fn authorize(
&self,
run_id: &RunId,
action: &Action,
ctx: &PolicyCtx,
) -> Result<(), KernelError>;
fn retry_strategy(&self, err: &dyn std::fmt::Display, _action: &Action) -> RetryDecision {
let _ = err;
RetryDecision::Fail
}
fn retry_strategy_attempt(
&self,
err: &ActionError,
action: &Action,
attempt: u32,
) -> RetryDecision {
let _ = (action, attempt);
match &err.kind {
ActionErrorKind::Permanent => RetryDecision::Fail,
ActionErrorKind::Transient | ActionErrorKind::RateLimited => {
if let ActionErrorKind::RateLimited = &err.kind {
if let Some(ms) = err.retry_after_ms {
return RetryDecision::RetryAfterMs(ms);
}
}
RetryDecision::Fail
}
}
}
fn budget(&self) -> BudgetRules {
BudgetRules::default()
}
}
pub struct AllowListPolicy {
pub allowed_tools: HashSet<String>,
pub allowed_providers: HashSet<String>,
}
impl AllowListPolicy {
pub fn new(allowed_tools: HashSet<String>, allowed_providers: HashSet<String>) -> Self {
Self {
allowed_tools,
allowed_providers,
}
}
pub fn tools_only(tools: impl IntoIterator<Item = String>) -> Self {
Self {
allowed_tools: tools.into_iter().collect(),
allowed_providers: HashSet::new(),
}
}
pub fn providers_only(providers: impl IntoIterator<Item = String>) -> Self {
Self {
allowed_tools: HashSet::new(),
allowed_providers: providers.into_iter().collect(),
}
}
}
impl Policy for AllowListPolicy {
fn authorize(
&self,
_run_id: &RunId,
action: &Action,
_ctx: &PolicyCtx,
) -> Result<(), KernelError> {
match action {
Action::CallTool { tool, .. } => {
if self.allowed_tools.contains(tool) {
Ok(())
} else {
Err(KernelError::Policy(format!("tool not allowed: {}", tool)))
}
}
Action::CallLLM { provider, .. } => {
if self.allowed_providers.contains(provider) {
Ok(())
} else {
Err(KernelError::Policy(format!(
"provider not allowed: {}",
provider
)))
}
}
Action::Sleep { .. } | Action::WaitSignal { .. } => Ok(()),
}
}
}
pub struct RetryWithBackoffPolicy<P> {
pub inner: P,
pub max_retries: u32,
pub backoff_base_ms: u64,
pub backoff_cap_ms: Option<u64>,
pub jitter_ratio: f64,
}
impl<P: Policy> RetryWithBackoffPolicy<P> {
pub fn new(inner: P, max_retries: u32, backoff_ms: u64) -> Self {
Self {
inner,
max_retries,
backoff_base_ms: backoff_ms,
backoff_cap_ms: None,
jitter_ratio: 0.0,
}
}
pub fn with_exponential_backoff(
inner: P,
max_retries: u32,
backoff_base_ms: u64,
backoff_cap_ms: Option<u64>,
jitter_ratio: f64,
) -> Self {
Self {
inner,
max_retries,
backoff_base_ms,
backoff_cap_ms,
jitter_ratio,
}
}
fn delay_ms(&self, err: &ActionError, attempt: u32) -> u64 {
if matches!(err.kind, ActionErrorKind::RateLimited) && err.retry_after_ms.is_some() {
return err.retry_after_ms.unwrap();
}
let exp = self
.backoff_base_ms
.saturating_mul(2_u64.saturating_pow(attempt));
let capped = match self.backoff_cap_ms {
Some(cap) => std::cmp::min(exp, cap),
None => exp,
};
if self.jitter_ratio <= 0.0 {
return capped;
}
let jitter_factor = ((attempt.wrapping_mul(31)) % 100) as f64 / 100.0;
let jitter = capped as f64 * self.jitter_ratio * jitter_factor;
(capped as f64 + jitter) as u64
}
}
impl<P: Policy> Policy for RetryWithBackoffPolicy<P> {
fn authorize(
&self,
run_id: &RunId,
action: &Action,
ctx: &PolicyCtx,
) -> Result<(), KernelError> {
self.inner.authorize(run_id, action, ctx)
}
fn retry_strategy_attempt(
&self,
err: &ActionError,
action: &Action,
attempt: u32,
) -> RetryDecision {
if matches!(err.kind, ActionErrorKind::Permanent) {
return RetryDecision::Fail;
}
if attempt < self.max_retries {
RetryDecision::RetryAfterMs(self.delay_ms(err, attempt))
} else {
let _ = action;
RetryDecision::Fail
}
}
fn budget(&self) -> BudgetRules {
self.inner.budget()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::action::ActionError;
#[test]
fn permanent_error_returns_fail() {
let err = ActionError::permanent("bad request");
let policy = AllowListPolicy::tools_only(std::iter::once("t1".to_string()));
let action = Action::CallTool {
tool: "t1".into(),
input: serde_json::json!(null),
};
let d = policy.retry_strategy_attempt(&err, &action, 0);
assert!(matches!(d, RetryDecision::Fail));
}
#[test]
fn retry_with_backoff_transient_retries_then_fails() {
let inner = AllowListPolicy::tools_only(std::iter::once("t1".to_string()));
let policy = RetryWithBackoffPolicy::new(inner, 2, 10);
let err = ActionError::transient("timeout");
let action = Action::CallTool {
tool: "t1".into(),
input: serde_json::json!(null),
};
assert!(matches!(
policy.retry_strategy_attempt(&err, &action, 0),
RetryDecision::RetryAfterMs(10)
));
assert!(matches!(
policy.retry_strategy_attempt(&err, &action, 1),
RetryDecision::RetryAfterMs(_)
));
assert!(matches!(
policy.retry_strategy_attempt(&err, &action, 2),
RetryDecision::Fail
));
}
#[test]
fn retry_with_backoff_rate_limited_uses_retry_after_ms() {
let inner = AllowListPolicy::tools_only(std::iter::empty());
let policy = RetryWithBackoffPolicy::new(inner, 3, 100);
let err = ActionError::rate_limited("429", 2500);
let action = Action::CallLLM {
provider: "p1".into(),
input: serde_json::json!(null),
};
let d = policy.retry_strategy_attempt(&err, &action, 0);
assert!(matches!(d, RetryDecision::RetryAfterMs(2500)));
}
#[test]
fn retry_with_backoff_exponential_increases() {
let inner = AllowListPolicy::tools_only(std::iter::once("t1".to_string()));
let policy = RetryWithBackoffPolicy::with_exponential_backoff(inner, 5, 50, Some(500), 0.0);
let err = ActionError::transient("timeout");
let action = Action::CallTool {
tool: "t1".into(),
input: serde_json::json!(null),
};
let d0 = policy.retry_strategy_attempt(&err, &action, 0);
let d1 = policy.retry_strategy_attempt(&err, &action, 1);
let d2 = policy.retry_strategy_attempt(&err, &action, 2);
let ms0 = match &d0 {
RetryDecision::RetryAfterMs(m) => *m,
_ => 0,
};
let ms1 = match &d1 {
RetryDecision::RetryAfterMs(m) => *m,
_ => 0,
};
let ms2 = match &d2 {
RetryDecision::RetryAfterMs(m) => *m,
_ => 0,
};
assert!(ms0 == 50);
assert!(ms1 == 100);
assert!(ms2 == 200);
}
}