use crate::error::{SwarmError, SwarmResult};
use crate::types::{Agent, Instructions, Message, RuntimeLimits, SwarmConfig};
use serde_json::Value;
use std::time::Instant;
use url::Url;
pub fn validate_api_request(
agent: &Agent,
messages: &[Message],
model: &Option<String>,
max_turns: usize,
) -> SwarmResult<()> {
if max_turns == 0 {
return Err(SwarmError::ValidationError(
"max_turns must be greater than 0".to_string(),
));
}
if let Some(model_name) = model {
if model_name.trim().is_empty() {
return Err(SwarmError::ValidationError(
"Model name cannot be empty".to_string(),
));
}
}
if agent.name().trim().is_empty() {
return Err(SwarmError::ValidationError(
"Agent name cannot be empty".to_string(),
));
}
match agent.instructions() {
Instructions::Text(text) => {
if text.trim().is_empty() {
return Err(SwarmError::ValidationError(
"Agent instructions cannot be empty".to_string(),
));
}
}
Instructions::Function(_) => {} }
if messages.is_empty() {
return Err(SwarmError::ValidationError(
"Message history cannot be empty".to_string(),
));
}
for message in messages {
message.validate()?;
}
Ok(())
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BudgetExhausted {
TokenBudget { used: u32, limit: u32 },
TokensPerRequest { used: u32, limit: u32 },
WallTime { elapsed_secs: u64, limit: u64 },
ToolCallQuota { used: u32, limit: u32 },
MaxDepth { depth: u32, limit: u32 },
}
impl std::fmt::Display for BudgetExhausted {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TokenBudget { used, limit } => write!(
f,
"token budget exhausted: {} / {} tokens used",
used, limit
),
Self::TokensPerRequest { used, limit } => write!(
f,
"per-request token limit exceeded: {} / {} tokens used",
used, limit
),
Self::WallTime {
elapsed_secs,
limit,
} => write!(
f,
"wall-time limit exceeded: {}s elapsed, limit {}s",
elapsed_secs, limit
),
Self::ToolCallQuota { used, limit } => write!(
f,
"tool call quota exhausted: {} / {} calls used",
used, limit
),
Self::MaxDepth { depth, limit } => write!(
f,
"max recursion depth exceeded: depth {} / limit {}",
depth, limit
),
}
}
}
impl From<BudgetExhausted> for SwarmError {
fn from(e: BudgetExhausted) -> Self {
SwarmError::Other(e.to_string())
}
}
pub struct BudgetEnforcer {
limits: RuntimeLimits,
start: Instant,
pub iterations: u32,
pub total_tokens: u32,
pub tool_calls: u32,
pub depth: u32,
}
impl BudgetEnforcer {
pub fn new(limits: RuntimeLimits) -> Self {
Self {
limits,
start: Instant::now(),
iterations: 0,
total_tokens: 0,
tool_calls: 0,
depth: 0,
}
}
pub fn check(&self) -> Result<(), BudgetExhausted> {
if let Some(budget) = self.limits.token_budget {
if self.total_tokens >= budget {
return Err(BudgetExhausted::TokenBudget {
used: self.total_tokens,
limit: budget,
});
}
}
if let Some(max_secs) = self.limits.max_wall_time_secs {
let elapsed = self.start.elapsed().as_secs();
if elapsed >= max_secs {
return Err(BudgetExhausted::WallTime {
elapsed_secs: elapsed,
limit: max_secs,
});
}
}
if let Some(quota) = self.limits.max_tool_calls {
if self.tool_calls >= quota {
return Err(BudgetExhausted::ToolCallQuota {
used: self.tool_calls,
limit: quota,
});
}
}
if let Some(max_depth) = self.limits.max_depth {
if self.depth > max_depth {
return Err(BudgetExhausted::MaxDepth {
depth: self.depth,
limit: max_depth,
});
}
}
Ok(())
}
pub fn add_tokens(&mut self, count: u32) {
self.total_tokens = self.total_tokens.saturating_add(count);
}
pub fn increment_iterations(&mut self) {
self.iterations = self.iterations.saturating_add(1);
}
pub fn increment_tool_calls(&mut self) {
self.tool_calls = self.tool_calls.saturating_add(1);
}
pub fn increment_depth(&mut self) {
self.depth = self.depth.saturating_add(1);
}
pub fn decrement_depth(&mut self) {
self.depth = self.depth.saturating_sub(1);
}
}
pub fn verify_tool_arguments(args: &Value, schema: &Value) -> SwarmResult<()> {
let required = match schema.get("required").and_then(|v| v.as_array()) {
Some(r) => r,
None => return Ok(()), };
for field in required {
let field_name = field.as_str().unwrap_or_default();
let present = args
.as_object()
.map(|m| m.contains_key(field_name))
.unwrap_or(false);
if !present {
return Err(SwarmError::ValidationError(format!(
"tool argument '{}' is required but missing",
field_name
)));
}
}
Ok(())
}
pub fn verify_structured_response(response: &Value, expected_fields: &[&str]) -> SwarmResult<()> {
let obj = response.as_object().ok_or_else(|| {
SwarmError::ValidationError("structured response must be a JSON object".to_string())
})?;
for field in expected_fields {
if !obj.contains_key(*field) {
return Err(SwarmError::ValidationError(format!(
"structured response missing required field '{}'",
field
)));
}
}
Ok(())
}
pub fn validate_api_url(url: &str, config: &SwarmConfig) -> SwarmResult<()> {
if url.trim().is_empty() {
return Err(SwarmError::ValidationError(
"API URL cannot be empty".to_string(),
));
}
let parsed_url = Url::parse(url)
.map_err(|e| SwarmError::ValidationError(format!("Invalid API URL format: {}", e)))?;
if parsed_url.host_str() == Some("localhost") {
return Ok(());
}
if !config
.valid_api_url_prefixes()
.iter()
.any(|prefix| url.starts_with(prefix.as_str()))
{
return Err(SwarmError::ValidationError(format!(
"API URL must start with one of: {}",
config
.valid_api_url_prefixes()
.iter()
.map(|prefix| prefix.as_str())
.collect::<Vec<_>>()
.join(", ")
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_max_depth_allows_exact_budget_then_exhausts() {
let mut budget = BudgetEnforcer::new(RuntimeLimits {
max_depth: Some(1),
..RuntimeLimits::default()
});
assert!(budget.check().is_ok());
budget.increment_depth();
assert!(budget.check().is_ok());
budget.increment_depth();
assert!(matches!(
budget.check(),
Err(BudgetExhausted::MaxDepth { depth: 2, limit: 1 })
));
}
}