use crate::types::{BudgetId, TimestampMs};
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
pub const CAPS_MAX_BYTES: usize = 4096;
pub const CAPS_MAX_TOKENS: usize = 256;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct RetryPolicy {
#[serde(default = "default_max_retries")]
pub max_retries: u32,
#[serde(default)]
pub backoff: BackoffStrategy,
#[serde(default)]
pub retryable_categories: Vec<String>,
}
fn default_max_retries() -> u32 {
3
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: default_max_retries(),
backoff: BackoffStrategy::default(),
retryable_categories: Vec::new(),
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum BackoffStrategy {
Fixed { delay_ms: u64 },
Exponential {
initial_delay_ms: u64,
max_delay_ms: u64,
multiplier: f64,
#[serde(default)]
jitter: bool,
},
}
impl Default for BackoffStrategy {
fn default() -> Self {
Self::Exponential {
initial_delay_ms: 1000,
max_delay_ms: 60_000,
multiplier: 2.0,
jitter: false,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct TimeoutPolicy {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub attempt_timeout_ms: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub execution_deadline_ms: Option<u64>,
#[serde(default = "default_max_reclaim_count")]
pub max_reclaim_count: u32,
}
fn default_max_reclaim_count() -> u32 {
100
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct SuspensionPolicy {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub default_timeout_ms: Option<u64>,
#[serde(default = "default_timeout_behavior")]
pub timeout_behavior: String,
}
fn default_timeout_behavior() -> String {
"fail".to_owned()
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct FallbackPolicy {
pub tiers: Vec<FallbackTier>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct FallbackTier {
pub provider: String,
pub model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
}
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct RoutingRequirements {
#[serde(default)]
pub required_capabilities: BTreeSet<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub preferred_locality: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub isolation_level: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct StreamPolicy {
#[serde(default = "default_durability_mode")]
pub durability_mode: String,
#[serde(default = "default_retention_maxlen")]
pub retention_maxlen: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub retention_ttl_ms: Option<u64>,
}
fn default_durability_mode() -> String {
"buffered".to_owned()
}
fn default_retention_maxlen() -> u64 {
10_000
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ExecutionPolicy {
#[serde(default)]
pub priority: i32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub delay_until: Option<TimestampMs>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub retry_policy: Option<RetryPolicy>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout_policy: Option<TimeoutPolicy>,
#[serde(default = "default_max_reclaim_count")]
pub max_reclaim_count: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub suspension_policy: Option<SuspensionPolicy>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub fallback_policy: Option<FallbackPolicy>,
#[serde(default = "default_max_replay_count")]
pub max_replay_count: u32,
#[serde(default)]
pub budget_ids: Vec<BudgetId>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub routing_requirements: Option<RoutingRequirements>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dedup_window_ms: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stream_policy: Option<StreamPolicy>,
#[serde(default = "default_max_signals")]
pub max_signals_per_execution: u32,
}
impl Default for ExecutionPolicy {
fn default() -> Self {
Self {
priority: 0,
delay_until: None,
retry_policy: None,
timeout_policy: None,
max_reclaim_count: default_max_reclaim_count(),
suspension_policy: None,
fallback_policy: None,
max_replay_count: default_max_replay_count(),
budget_ids: Vec::new(),
routing_requirements: None,
dedup_window_ms: None,
stream_policy: None,
max_signals_per_execution: default_max_signals(),
}
}
}
fn default_max_replay_count() -> u32 {
10
}
fn default_max_signals() -> u32 {
10_000
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn execution_policy_defaults() {
let policy = ExecutionPolicy::default();
assert_eq!(policy.priority, 0);
assert_eq!(policy.max_reclaim_count, 100);
assert_eq!(policy.max_replay_count, 10);
assert_eq!(policy.max_signals_per_execution, 10_000);
assert!(policy.retry_policy.is_none());
assert!(policy.timeout_policy.is_none());
}
#[test]
fn retry_policy_serde() {
let policy = RetryPolicy {
max_retries: 3,
backoff: BackoffStrategy::Exponential {
initial_delay_ms: 100,
max_delay_ms: 30_000,
multiplier: 2.0,
jitter: true,
},
retryable_categories: vec!["timeout".into(), "provider_error".into()],
};
let json = serde_json::to_string(&policy).unwrap();
let parsed: RetryPolicy = serde_json::from_str(&json).unwrap();
assert_eq!(policy, parsed);
}
#[test]
fn timeout_policy_defaults() {
let json = r#"{"attempt_timeout_ms": 30000}"#;
let policy: TimeoutPolicy = serde_json::from_str(json).unwrap();
assert_eq!(policy.attempt_timeout_ms, Some(30_000));
assert_eq!(policy.max_reclaim_count, 100);
}
#[test]
fn retry_policy_defaults() {
let policy = RetryPolicy::default();
assert_eq!(policy.max_retries, 3);
assert_eq!(
policy.backoff,
BackoffStrategy::Exponential {
initial_delay_ms: 1000,
max_delay_ms: 60_000,
multiplier: 2.0,
jitter: false,
}
);
assert!(policy.retryable_categories.is_empty());
}
#[test]
fn retry_policy_lua_compatible_json() {
let policy = RetryPolicy::default();
let json = serde_json::to_value(&policy).unwrap();
assert_eq!(json["max_retries"], 3);
let backoff = &json["backoff"];
assert_eq!(backoff["type"], "exponential");
assert_eq!(backoff["initial_delay_ms"], 1000);
assert_eq!(backoff["max_delay_ms"], 60_000);
assert_eq!(backoff["multiplier"], 2.0);
let fixed = RetryPolicy {
max_retries: 1,
backoff: BackoffStrategy::Fixed { delay_ms: 5000 },
retryable_categories: vec![],
};
let json = serde_json::to_value(&fixed).unwrap();
assert_eq!(json["backoff"]["type"], "fixed");
assert_eq!(json["backoff"]["delay_ms"], 5000);
}
#[test]
fn retry_policy_deserialize_minimal() {
let json = r#"{"max_retries": 5}"#;
let policy: RetryPolicy = serde_json::from_str(json).unwrap();
assert_eq!(policy.max_retries, 5);
assert_eq!(policy.backoff, BackoffStrategy::default());
}
#[test]
fn default_execution_policy_has_no_nulls() {
let policy = ExecutionPolicy::default();
let json = serde_json::to_value(&policy).unwrap();
let obj = json.as_object().expect("top-level object");
for (key, value) in obj {
assert!(
!value.is_null(),
"default ExecutionPolicy must not emit null field `{key}` — \
Lua policy validation rejects null for optional table fields"
);
}
assert_eq!(obj.get("priority"), Some(&serde_json::json!(0)));
assert_eq!(obj.get("max_reclaim_count"), Some(&serde_json::json!(100)));
}
#[test]
fn partial_execution_policy_omits_unset_options() {
let policy = ExecutionPolicy {
retry_policy: Some(RetryPolicy::default()),
..Default::default()
};
let json = serde_json::to_value(&policy).unwrap();
let obj = json.as_object().expect("top-level object");
for field in [
"delay_until",
"timeout_policy",
"suspension_policy",
"fallback_policy",
"routing_requirements",
"dedup_window_ms",
"stream_policy",
] {
assert!(
!obj.contains_key(field),
"field `{field}` must be absent when unset, not `null`"
);
}
assert!(obj.contains_key("retry_policy"));
}
#[test]
fn populated_execution_policy_round_trip() {
let policy = ExecutionPolicy {
priority: 7,
delay_until: Some(TimestampMs(123_456)),
retry_policy: Some(RetryPolicy::default()),
timeout_policy: Some(TimeoutPolicy {
attempt_timeout_ms: Some(30_000),
execution_deadline_ms: Some(300_000),
max_reclaim_count: 5,
}),
suspension_policy: Some(SuspensionPolicy {
default_timeout_ms: Some(60_000),
timeout_behavior: "cancel".into(),
}),
fallback_policy: Some(FallbackPolicy {
tiers: vec![FallbackTier {
provider: "anthropic".into(),
model: "claude-opus".into(),
timeout_ms: Some(45_000),
}],
}),
routing_requirements: Some(RoutingRequirements {
required_capabilities: BTreeSet::from(["gpu".to_owned()]),
preferred_locality: Some("us-west-2".into()),
isolation_level: Some("strict".into()),
}),
dedup_window_ms: Some(86_400_000),
stream_policy: Some(StreamPolicy {
durability_mode: "durable".into(),
retention_maxlen: 5000,
retention_ttl_ms: Some(3_600_000),
}),
..Default::default()
};
let json = serde_json::to_string(&policy).unwrap();
assert!(
!json.contains(":null"),
"populated policy must not contain null fields: {json}"
);
let parsed: ExecutionPolicy = serde_json::from_str(&json).unwrap();
assert_eq!(policy, parsed);
}
#[test]
fn full_execution_policy_serde() {
let policy = ExecutionPolicy {
priority: 10,
retry_policy: Some(RetryPolicy {
max_retries: 5,
backoff: BackoffStrategy::Fixed { delay_ms: 1000 },
retryable_categories: vec![],
}),
..Default::default()
};
let json = serde_json::to_string(&policy).unwrap();
let parsed: ExecutionPolicy = serde_json::from_str(&json).unwrap();
assert_eq!(policy, parsed);
}
}