Skip to main content

batuta/agent/
runtime_helpers.rs

1//! Helper functions for the agent runtime loop.
2//!
3//! Extracted from runtime.rs to keep module under 500-line threshold.
4//! Contains: context truncation, retry logic, event emission, MCP validation.
5
6use std::time::Duration;
7
8use tokio::sync::mpsc;
9use tracing::{instrument, warn};
10
11use super::driver::{CompletionRequest, CompletionResponse, LlmDriver, Message, StreamEvent};
12use super::result::AgentError;
13use crate::serve::context::ContextManager;
14
15/// Maximum retry attempts for retryable driver errors.
16const MAX_RETRIES: u32 = 3;
17/// Base delay for exponential backoff (milliseconds).
18const RETRY_BASE_MS: u64 = 1000;
19
20/// Truncate agent messages to fit within context window.
21pub(super) fn truncate_messages(
22    messages: &[Message],
23    context: &ContextManager,
24) -> Result<Vec<Message>, AgentError> {
25    let chat_msgs: Vec<_> = messages.iter().map(Message::to_chat_message).collect();
26
27    if context.fits(&chat_msgs) {
28        return Ok(messages.to_vec());
29    }
30
31    let truncated = context.truncate(&chat_msgs).map_err(
32        |crate::serve::context::ContextError::ExceedsLimit { tokens, limit }| {
33            AgentError::ContextOverflow { required: tokens, available: limit }
34        },
35    )?;
36
37    // Map truncated ChatMessages back to original Messages
38    // by matching content. SlidingWindow keeps most recent,
39    // so iterate from end of original list.
40    let mut result = Vec::with_capacity(truncated.len());
41    let mut msg_idx = messages.len();
42    for chat_msg in truncated.iter().rev() {
43        while msg_idx > 0 {
44            msg_idx -= 1;
45            if messages[msg_idx].to_chat_message().content == chat_msg.content {
46                result.push(messages[msg_idx].clone());
47                break;
48            }
49        }
50    }
51    result.reverse();
52    Ok(result)
53}
54
55/// Retry `driver.complete()` with exponential backoff for retryable errors.
56#[instrument(skip_all)]
57pub(super) async fn call_with_retry(
58    driver: &dyn LlmDriver,
59    request: &CompletionRequest,
60) -> Result<CompletionResponse, AgentError> {
61    let mut last_err = None;
62    for attempt in 0..=MAX_RETRIES {
63        match driver.complete(request.clone()).await {
64            Ok(response) => return Ok(response),
65            Err(AgentError::Driver(ref e)) if e.is_retryable() => {
66                warn!(
67                    attempt = attempt + 1,
68                    max = MAX_RETRIES,
69                    error = %e,
70                    "retryable driver error"
71                );
72                last_err = Some(AgentError::Driver(e.clone()));
73                if attempt < MAX_RETRIES {
74                    let delay = RETRY_BASE_MS * 2u64.pow(attempt);
75                    tokio::time::sleep(Duration::from_millis(delay)).await;
76                }
77            }
78            Err(e) => return Err(e),
79        }
80    }
81    Err(last_err.unwrap_or_else(|| AgentError::CircuitBreak("retry loop exhausted".into())))
82}
83
84pub(super) async fn emit(tx: Option<&mpsc::Sender<StreamEvent>>, event: StreamEvent) {
85    if let Some(tx) = tx {
86        let _ = tx.send(event).await;
87    }
88}
89
90/// Validate MCP transports against privacy tier (Poka-Yoke).
91/// Defense-in-depth: blocks SSE/WebSocket under Sovereign even if
92/// `manifest.validate()` was skipped.
93#[cfg(feature = "agents-mcp")]
94pub(super) fn validate_mcp_privacy(
95    manifest: &super::manifest::AgentManifest,
96) -> Result<(), AgentError> {
97    use crate::agent::manifest::McpTransport;
98    if manifest.privacy != crate::serve::backends::PrivacyTier::Sovereign {
99        return Ok(());
100    }
101    for server in &manifest.mcp_servers {
102        if matches!(server.transport, McpTransport::Sse | McpTransport::WebSocket) {
103            return Err(AgentError::CircuitBreak(format!(
104                "sovereign privacy blocks network MCP transport for '{}'",
105                server.name,
106            )));
107        }
108    }
109    Ok(())
110}