batuta/agent/
runtime_helpers.rs1use std::time::Duration;
7
8use tokio::sync::mpsc;
9use tracing::{instrument, warn};
10
11use super::driver::{CompletionRequest, CompletionResponse, LlmDriver, Message, StreamEvent};
12use super::result::AgentError;
13use crate::serve::context::ContextManager;
14
15const MAX_RETRIES: u32 = 3;
17const RETRY_BASE_MS: u64 = 1000;
19
20pub(super) fn truncate_messages(
22 messages: &[Message],
23 context: &ContextManager,
24) -> Result<Vec<Message>, AgentError> {
25 let chat_msgs: Vec<_> = messages.iter().map(Message::to_chat_message).collect();
26
27 if context.fits(&chat_msgs) {
28 return Ok(messages.to_vec());
29 }
30
31 let truncated = context.truncate(&chat_msgs).map_err(
32 |crate::serve::context::ContextError::ExceedsLimit { tokens, limit }| {
33 AgentError::ContextOverflow { required: tokens, available: limit }
34 },
35 )?;
36
37 let mut result = Vec::with_capacity(truncated.len());
41 let mut msg_idx = messages.len();
42 for chat_msg in truncated.iter().rev() {
43 while msg_idx > 0 {
44 msg_idx -= 1;
45 if messages[msg_idx].to_chat_message().content == chat_msg.content {
46 result.push(messages[msg_idx].clone());
47 break;
48 }
49 }
50 }
51 result.reverse();
52 Ok(result)
53}
54
55#[instrument(skip_all)]
57pub(super) async fn call_with_retry(
58 driver: &dyn LlmDriver,
59 request: &CompletionRequest,
60) -> Result<CompletionResponse, AgentError> {
61 let mut last_err = None;
62 for attempt in 0..=MAX_RETRIES {
63 match driver.complete(request.clone()).await {
64 Ok(response) => return Ok(response),
65 Err(AgentError::Driver(ref e)) if e.is_retryable() => {
66 warn!(
67 attempt = attempt + 1,
68 max = MAX_RETRIES,
69 error = %e,
70 "retryable driver error"
71 );
72 last_err = Some(AgentError::Driver(e.clone()));
73 if attempt < MAX_RETRIES {
74 let delay = RETRY_BASE_MS * 2u64.pow(attempt);
75 tokio::time::sleep(Duration::from_millis(delay)).await;
76 }
77 }
78 Err(e) => return Err(e),
79 }
80 }
81 Err(last_err.unwrap_or_else(|| AgentError::CircuitBreak("retry loop exhausted".into())))
82}
83
84pub(super) async fn emit(tx: Option<&mpsc::Sender<StreamEvent>>, event: StreamEvent) {
85 if let Some(tx) = tx {
86 let _ = tx.send(event).await;
87 }
88}
89
90#[cfg(feature = "agents-mcp")]
94pub(super) fn validate_mcp_privacy(
95 manifest: &super::manifest::AgentManifest,
96) -> Result<(), AgentError> {
97 use crate::agent::manifest::McpTransport;
98 if manifest.privacy != crate::serve::backends::PrivacyTier::Sovereign {
99 return Ok(());
100 }
101 for server in &manifest.mcp_servers {
102 if matches!(server.transport, McpTransport::Sse | McpTransport::WebSocket) {
103 return Err(AgentError::CircuitBreak(format!(
104 "sovereign privacy blocks network MCP transport for '{}'",
105 server.name,
106 )));
107 }
108 }
109 Ok(())
110}