use crate::{Error, Result};
use std::time::Duration;
use crate::client::signals::SignalsSnapshot;
use crate::error_code::StandardErrorCode;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Decision {
Retry { delay: Duration },
Fallback,
Fail,
}
pub struct PolicyEngine {
manifest: crate::protocol::ProtocolManifest,
pub max_retries: u32,
pub min_delay_ms: u32,
pub max_delay_ms: u32,
}
impl PolicyEngine {
pub fn new(manifest: &crate::protocol::ProtocolManifest) -> Self {
let retry = manifest.retry_policy.as_ref();
let max_retries = retry.and_then(|p| p.max_retries).unwrap_or(0);
let min_delay_ms = retry.and_then(|p| p.min_delay_ms).unwrap_or(0);
let max_delay_ms = retry.and_then(|p| p.max_delay_ms).unwrap_or(min_delay_ms);
Self {
manifest: manifest.clone(),
max_retries,
min_delay_ms,
max_delay_ms,
}
}
pub fn validate_capabilities(&self, request: &crate::protocol::UnifiedRequest) -> Result<()> {
let manifest = &self.manifest;
if request
.tools
.as_ref()
.is_some_and(|tools| !tools.is_empty())
&& !manifest.supports_capability("tools")
{
return Err(Error::validation_with_context(
"Model does not support tool calling",
crate::ErrorContext::new()
.with_field_path("request.tools")
.with_source("capability_validator"),
));
}
if request.stream && !manifest.supports_capability("streaming") {
return Err(Error::validation_with_context(
"Model does not support streaming",
crate::ErrorContext::new()
.with_field_path("request.stream")
.with_source("capability_validator"),
));
}
let has_multimodal = request
.messages
.iter()
.any(|m: &crate::types::message::Message| m.contains_image() || m.contains_audio());
if has_multimodal {
let supports_multimodal = manifest.supports_capability("multimodal")
|| manifest.supports_capability("vision")
|| manifest.supports_capability("audio");
if !supports_multimodal {
return Err(Error::validation_with_context(
"Model does not support multimodal content (images/audio)",
crate::ErrorContext::new()
.with_field_path("request.messages")
.with_source("capability_validator"),
));
}
}
if request.response_format.is_some() && !manifest.supports_capability("structured_output") {
return Err(Error::validation_with_context(
"Model does not support structured output (JSON mode / response_format)",
crate::ErrorContext::new()
.with_field_path("request.response_format")
.with_source("capability_validator")
.with_standard_code(StandardErrorCode::InvalidRequest),
));
}
if let Some(tools) = request.tools.as_ref() {
let needs_mcp = tools.iter().any(|t| {
t.tool_type.eq_ignore_ascii_case("mcp") || t.function.name.starts_with("mcp__")
});
if needs_mcp && !manifest.supports_capability("mcp_client") {
return Err(Error::validation_with_context(
"Model does not declare mcp_client; MCP tool bridge is not allowed",
crate::ErrorContext::new()
.with_field_path("request.tools")
.with_source("capability_validator")
.with_standard_code(StandardErrorCode::RequestTooLarge),
));
}
}
Ok(())
}
fn backoff_delay(&self, attempt: u32, retry_after_ms: Option<u32>) -> Duration {
let base = if self.min_delay_ms == 0 {
0
} else {
let factor = 1u32.checked_shl(attempt).unwrap_or(u32::MAX);
self.min_delay_ms.saturating_mul(factor)
};
let chosen = retry_after_ms.unwrap_or(base).min(self.max_delay_ms);
Duration::from_millis(chosen as u64)
}
pub fn pre_decide(&self, signals: &SignalsSnapshot, has_fallback: bool) -> Option<Decision> {
if !has_fallback {
return None;
}
if let Some(inflight) = signals.inflight.as_ref() {
if inflight.available == 0 {
return Some(Decision::Fallback);
}
}
None
}
pub fn decide(&self, err: &Error, attempt: u32, has_fallback: bool) -> Result<Decision> {
let (mut retryable, mut fallbackable, retry_after_ms) = match err {
Error::Remote {
retryable,
fallbackable,
retry_after_ms,
..
} => (*retryable, *fallbackable, *retry_after_ms),
Error::Transport(_) => (true, true, None),
Error::Runtime { message: msg, .. } => {
let m = msg.to_lowercase();
if m.contains("circuit breaker open") {
(false, true, None)
} else if m.contains("timeout") {
(true, true, None)
} else {
(false, false, None)
}
}
_ => (false, false, None),
};
if let Some(ctx) = err.context() {
if let Some(r) = ctx.retryable {
retryable = r;
}
if let Some(f) = ctx.fallbackable {
fallbackable = f;
}
}
if retryable && attempt < self.max_retries {
return Ok(Decision::Retry {
delay: self.backoff_delay(attempt, retry_after_ms),
});
}
if fallbackable && has_fallback {
return Ok(Decision::Fallback);
}
Ok(Decision::Fail)
}
}