Skip to main content

ai_lib_core/client/
policy.rs

1use crate::{Error, Result};
2use std::time::Duration;
3
4use crate::client::signals::SignalsSnapshot;
5use crate::error_code::StandardErrorCode;
6
7/// Internal decision for how to proceed after a failed attempt.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum Decision {
10    Retry { delay: Duration },
11    Fallback,
12    Fail,
13}
14
15/// Internal policy engine that unifies retry / fallback behavior.
16///
17/// Important constraints:
18/// - Keep this internal (no public API commitments yet).
19/// - Prefer deterministic, explainable behavior over clever heuristics.
20pub struct PolicyEngine {
21    manifest: crate::protocol::ProtocolManifest,
22    pub max_retries: u32,
23    pub min_delay_ms: u32,
24    pub max_delay_ms: u32,
25}
26
27impl PolicyEngine {
28    pub fn new(manifest: &crate::protocol::ProtocolManifest) -> Self {
29        let retry = manifest.retry_policy.as_ref();
30        let max_retries = retry.and_then(|p| p.max_retries).unwrap_or(0);
31        let min_delay_ms = retry.and_then(|p| p.min_delay_ms).unwrap_or(0);
32        let max_delay_ms = retry.and_then(|p| p.max_delay_ms).unwrap_or(min_delay_ms);
33        Self {
34            manifest: manifest.clone(),
35            max_retries,
36            min_delay_ms,
37            max_delay_ms,
38        }
39    }
40
41    /// Validates if the manifest supports all capabilities required by the request.
42    ///
43    /// This is a pre-flight guard that validates user intent against protocol capabilities
44    /// before making any network requests, saving latency and cost.
45    pub fn validate_capabilities(&self, request: &crate::protocol::UnifiedRequest) -> Result<()> {
46        let manifest = &self.manifest;
47
48        // Check for Tooling support
49        if request
50            .tools
51            .as_ref()
52            .is_some_and(|tools| !tools.is_empty())
53            && !manifest.supports_capability("tools")
54        {
55            return Err(Error::validation_with_context(
56                "Model does not support tool calling",
57                crate::ErrorContext::new()
58                    .with_field_path("request.tools")
59                    .with_source("capability_validator"),
60            ));
61        }
62
63        // Check for Streaming support
64        if request.stream && !manifest.supports_capability("streaming") {
65            return Err(Error::validation_with_context(
66                "Model does not support streaming",
67                crate::ErrorContext::new()
68                    .with_field_path("request.stream")
69                    .with_source("capability_validator"),
70            ));
71        }
72
73        // Check for Multimodal support (Vision/Audio)
74        let has_multimodal = request
75            .messages
76            .iter()
77            .any(|m: &crate::types::message::Message| m.contains_image() || m.contains_audio());
78        if has_multimodal {
79            let supports_multimodal = manifest.supports_capability("multimodal")
80                || manifest.supports_capability("vision")
81                || manifest.supports_capability("audio");
82
83            if !supports_multimodal {
84                return Err(Error::validation_with_context(
85                    "Model does not support multimodal content (images/audio)",
86                    crate::ErrorContext::new()
87                        .with_field_path("request.messages")
88                        .with_source("capability_validator"),
89                ));
90            }
91        }
92
93        if request.response_format.is_some() && !manifest.supports_capability("structured_output") {
94            return Err(Error::validation_with_context(
95                "Model does not support structured output (JSON mode / response_format)",
96                crate::ErrorContext::new()
97                    .with_field_path("request.response_format")
98                    .with_source("capability_validator")
99                    .with_standard_code(StandardErrorCode::InvalidRequest),
100            ));
101        }
102
103        if let Some(tools) = request.tools.as_ref() {
104            let needs_mcp = tools.iter().any(|t| {
105                t.tool_type.eq_ignore_ascii_case("mcp") || t.function.name.starts_with("mcp__")
106            });
107            if needs_mcp && !manifest.supports_capability("mcp_client") {
108                return Err(Error::validation_with_context(
109                    "Model does not declare mcp_client; MCP tool bridge is not allowed",
110                    crate::ErrorContext::new()
111                        .with_field_path("request.tools")
112                        .with_source("capability_validator")
113                        .with_standard_code(StandardErrorCode::RequestTooLarge),
114                ));
115            }
116        }
117
118        // Parameter range validation (pre-flight guard to avoid invalid requests)
119        // Note: Currently, parameter constraints are not defined in the protocol manifest.
120        // This is a placeholder for future enhancement when capabilities include constraints.
121        // For now, we rely on provider APIs to reject invalid parameters.
122
123        Ok(())
124    }
125
126    fn backoff_delay(&self, attempt: u32, retry_after_ms: Option<u32>) -> Duration {
127        let base = if self.min_delay_ms == 0 {
128            0
129        } else {
130            // exponential backoff: min_delay * 2^attempt
131            let factor = 1u32.checked_shl(attempt).unwrap_or(u32::MAX);
132            self.min_delay_ms.saturating_mul(factor)
133        };
134        let chosen = retry_after_ms.unwrap_or(base).min(self.max_delay_ms);
135        Duration::from_millis(chosen as u64)
136    }
137
138    /// Optional pre-decision based on current runtime signals (facts), before attempting a call.
139    ///
140    /// Keep this conservative: only skip work that is *known* to fail right now.
141    pub fn pre_decide(&self, signals: &SignalsSnapshot, has_fallback: bool) -> Option<Decision> {
142        if !has_fallback {
143            return None;
144        }
145
146        // If this candidate is currently saturated (no inflight permits),
147        // prefer trying a fallback candidate rather than waiting here.
148        if let Some(inflight) = signals.inflight.as_ref() {
149            if inflight.available == 0 {
150                return Some(Decision::Fallback);
151            }
152        }
153
154        None
155    }
156
157    /// Decide what to do next after an attempt failed.
158    ///
159    /// - `attempt` is 0-based (first failure => attempt=0).
160    /// - `has_fallback` indicates there is another candidate to try.
161    pub fn decide(&self, err: &Error, attempt: u32, has_fallback: bool) -> Result<Decision> {
162        let (mut retryable, mut fallbackable, retry_after_ms) = match err {
163            Error::Remote {
164                retryable,
165                fallbackable,
166                retry_after_ms,
167                ..
168            } => (*retryable, *fallbackable, *retry_after_ms),
169            Error::Transport(_) => (true, true, None),
170            Error::Runtime { message: msg, .. } => {
171                // Preflight and guard errors are policy-relevant.
172                // Keep these rules simple and explainable:
173                // - circuit breaker open => try fallback if available
174                // - attempt timeout => retry and/or fallback
175                let m = msg.to_lowercase();
176                if m.contains("circuit breaker open") {
177                    (false, true, None)
178                } else if m.contains("timeout") {
179                    (true, true, None)
180                } else {
181                    (false, false, None)
182                }
183            }
184            _ => (false, false, None),
185        };
186
187        // Prefer ErrorContext 2.0 flags when present
188        if let Some(ctx) = err.context() {
189            if let Some(r) = ctx.retryable {
190                retryable = r;
191            }
192            if let Some(f) = ctx.fallbackable {
193                fallbackable = f;
194            }
195        }
196
197        if retryable && attempt < self.max_retries {
198            return Ok(Decision::Retry {
199                delay: self.backoff_delay(attempt, retry_after_ms),
200            });
201        }
202
203        if fallbackable && has_fallback {
204            return Ok(Decision::Fallback);
205        }
206
207        Ok(Decision::Fail)
208    }
209}