use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use super::{AgentType, ClaudeFlowAgent, ClaudeFlowTask};
use crate::error::{Result, RuvLLMError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ClaudeModel {
Haiku,
Sonnet,
Opus,
}
impl ClaudeModel {
pub fn name(&self) -> &'static str {
match self {
Self::Haiku => "haiku",
Self::Sonnet => "sonnet",
Self::Opus => "opus",
}
}
pub fn model_id(&self) -> &'static str {
match self {
Self::Haiku => "claude-3-5-haiku-20241022",
Self::Sonnet => "claude-sonnet-4-20250514",
Self::Opus => "claude-opus-4-20250514",
}
}
pub fn input_cost_per_1k(&self) -> f64 {
match self {
Self::Haiku => 0.00025,
Self::Sonnet => 0.003,
Self::Opus => 0.015,
}
}
pub fn output_cost_per_1k(&self) -> f64 {
match self {
Self::Haiku => 0.00125,
Self::Sonnet => 0.015,
Self::Opus => 0.075,
}
}
pub fn typical_ttft_ms(&self) -> u64 {
match self {
Self::Haiku => 200,
Self::Sonnet => 500,
Self::Opus => 1500,
}
}
pub fn max_context_tokens(&self) -> usize {
match self {
Self::Haiku => 200_000,
Self::Sonnet => 200_000,
Self::Opus => 200_000,
}
}
}
impl Default for ClaudeModel {
fn default() -> Self {
Self::Sonnet
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
User,
Assistant,
System,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: Vec<ContentBlock>,
}
impl Message {
pub fn text(role: MessageRole, text: impl Into<String>) -> Self {
Self {
role,
content: vec![ContentBlock::Text { text: text.into() }],
}
}
pub fn user(text: impl Into<String>) -> Self {
Self::text(MessageRole::User, text)
}
pub fn assistant(text: impl Into<String>) -> Self {
Self::text(MessageRole::Assistant, text)
}
pub fn estimate_tokens(&self) -> usize {
self.content
.iter()
.map(|block| {
match block {
ContentBlock::Text { text } => text.len() / 4, ContentBlock::ToolUse { input, .. } => {
input.to_string().len() / 4 + 50 }
ContentBlock::ToolResult { content, .. } => content.len() / 4 + 20,
}
})
.sum()
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ClaudeRequest {
pub model: String,
pub messages: Vec<Message>,
pub max_tokens: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ClaudeResponse {
pub id: String,
pub model: String,
pub content: Vec<ContentBlock>,
pub stop_reason: Option<String>,
pub usage: UsageStats,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct UsageStats {
pub input_tokens: usize,
pub output_tokens: usize,
}
impl UsageStats {
pub fn calculate_cost(&self, model: ClaudeModel) -> f64 {
let input_cost = (self.input_tokens as f64 / 1000.0) * model.input_cost_per_1k();
let output_cost = (self.output_tokens as f64 / 1000.0) * model.output_cost_per_1k();
input_cost + output_cost
}
}
#[derive(Debug, Clone)]
pub struct StreamToken {
pub text: String,
pub index: usize,
pub latency_ms: u64,
pub quality_score: Option<f32>,
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
Start {
request_id: String,
model: ClaudeModel,
},
Token(StreamToken),
ContentBlockComplete { index: usize, content: ContentBlock },
Complete {
usage: UsageStats,
stop_reason: String,
total_latency_ms: u64,
},
Error { message: String, is_retryable: bool },
}
#[derive(Debug, Clone)]
pub struct QualityMonitor {
pub min_quality: f32,
pub check_interval: usize,
scores: Vec<f32>,
tokens_since_check: usize,
}
impl QualityMonitor {
pub fn new(min_quality: f32, check_interval: usize) -> Self {
Self {
min_quality,
check_interval,
scores: Vec::new(),
tokens_since_check: 0,
}
}
pub fn record(&mut self, score: f32) {
self.scores.push(score);
self.tokens_since_check += 1;
}
pub fn should_continue(&self) -> bool {
if self.scores.is_empty() {
return true;
}
let avg = self.scores.iter().sum::<f32>() / self.scores.len() as f32;
avg >= self.min_quality
}
pub fn should_check(&self) -> bool {
self.tokens_since_check >= self.check_interval
}
pub fn reset_check(&mut self) {
self.tokens_since_check = 0;
}
pub fn average_quality(&self) -> f32 {
if self.scores.is_empty() {
1.0
} else {
self.scores.iter().sum::<f32>() / self.scores.len() as f32
}
}
}
pub struct ResponseStreamer {
pub request_id: String,
pub model: ClaudeModel,
start_time: Instant,
token_count: usize,
quality_monitor: QualityMonitor,
sender: mpsc::Sender<StreamEvent>,
accumulated_text: String,
is_complete: bool,
}
impl ResponseStreamer {
pub fn new(request_id: String, model: ClaudeModel, sender: mpsc::Sender<StreamEvent>) -> Self {
Self {
request_id: request_id.clone(),
model,
start_time: Instant::now(),
token_count: 0,
quality_monitor: QualityMonitor::new(0.6, 20),
sender,
accumulated_text: String::new(),
is_complete: false,
}
}
pub async fn process_token(&mut self, text: String, quality_score: Option<f32>) -> Result<()> {
if self.is_complete {
return Err(RuvLLMError::InvalidOperation(
"Stream already complete".to_string(),
));
}
let token = StreamToken {
text: text.clone(),
index: self.token_count,
latency_ms: self.start_time.elapsed().as_millis() as u64,
quality_score,
};
if let Some(score) = quality_score {
self.quality_monitor.record(score);
}
self.accumulated_text.push_str(&text);
self.token_count += 1;
self.sender
.send(StreamEvent::Token(token))
.await
.map_err(|e| RuvLLMError::InvalidOperation(format!("Failed to send token: {}", e)))?;
Ok(())
}
pub async fn complete(&mut self, usage: UsageStats, stop_reason: String) -> Result<()> {
self.is_complete = true;
self.sender
.send(StreamEvent::Complete {
usage,
stop_reason,
total_latency_ms: self.start_time.elapsed().as_millis() as u64,
})
.await
.map_err(|e| {
RuvLLMError::InvalidOperation(format!("Failed to send complete: {}", e))
})?;
Ok(())
}
pub fn stats(&self) -> StreamStats {
let elapsed = self.start_time.elapsed();
StreamStats {
token_count: self.token_count,
elapsed_ms: elapsed.as_millis() as u64,
tokens_per_second: if elapsed.as_secs_f64() > 0.0 {
self.token_count as f64 / elapsed.as_secs_f64()
} else {
0.0
},
average_quality: self.quality_monitor.average_quality(),
is_complete: self.is_complete,
}
}
pub fn accumulated_text(&self) -> &str {
&self.accumulated_text
}
pub fn quality_acceptable(&self) -> bool {
self.quality_monitor.should_continue()
}
}
#[derive(Debug, Clone)]
pub struct StreamStats {
pub token_count: usize,
pub elapsed_ms: u64,
pub tokens_per_second: f64,
pub average_quality: f32,
pub is_complete: bool,
}
#[derive(Debug, Clone)]
pub struct ContextWindow {
messages: Vec<Message>,
system_prompt: Option<String>,
max_tokens: usize,
current_tokens: usize,
compression_threshold: f32,
}
impl ContextWindow {
pub fn new(max_tokens: usize) -> Self {
Self {
messages: Vec::new(),
system_prompt: None,
max_tokens,
current_tokens: 0,
compression_threshold: 0.8,
}
}
pub fn set_system(&mut self, prompt: impl Into<String>) {
let prompt = prompt.into();
self.current_tokens -= self.system_prompt.as_ref().map_or(0, |p| p.len() / 4);
self.current_tokens += prompt.len() / 4;
self.system_prompt = Some(prompt);
}
pub fn add_message(&mut self, message: Message) {
let tokens = message.estimate_tokens();
self.current_tokens += tokens;
self.messages.push(message);
if self.needs_compression() {
self.compress();
}
}
pub fn needs_compression(&self) -> bool {
self.current_tokens as f32 > self.max_tokens as f32 * self.compression_threshold
}
pub fn utilization(&self) -> f32 {
self.current_tokens as f32 / self.max_tokens as f32
}
pub fn compress(&mut self) {
if self.messages.len() <= 4 {
return;
}
let target_tokens = (self.max_tokens as f32 * 0.6) as usize;
let keep_first = 1;
let mut keep_last = 3;
while self.current_tokens > target_tokens && keep_last > 1 {
let to_remove = self.messages.len() - keep_first - keep_last;
if to_remove > 0 {
let removed: Vec<_> = self.messages.drain(keep_first..keep_first + 1).collect();
for msg in removed {
self.current_tokens -= msg.estimate_tokens();
}
} else {
keep_last -= 1;
}
}
}
pub fn expand_for_task(&mut self, task_complexity: f32, model: ClaudeModel) {
let base_max = model.max_context_tokens();
let expansion_factor = 0.5 + (task_complexity * 0.5); self.max_tokens = (base_max as f32 * expansion_factor) as usize;
}
pub fn get_messages(&self) -> &[Message] {
&self.messages
}
pub fn get_system(&self) -> Option<&str> {
self.system_prompt.as_deref()
}
pub fn token_count(&self) -> usize {
self.current_tokens
}
pub fn remaining_capacity(&self) -> usize {
self.max_tokens.saturating_sub(self.current_tokens)
}
pub fn clear(&mut self) {
self.messages.clear();
self.current_tokens = self.system_prompt.as_ref().map_or(0, |p| p.len() / 4);
}
}
pub struct ContextManager {
windows: HashMap<String, ContextWindow>,
default_max_tokens: usize,
}
impl ContextManager {
pub fn new(default_max_tokens: usize) -> Self {
Self {
windows: HashMap::new(),
default_max_tokens,
}
}
pub fn get_window(&mut self, agent_id: &str) -> &mut ContextWindow {
if !self.windows.contains_key(agent_id) {
self.windows.insert(
agent_id.to_string(),
ContextWindow::new(self.default_max_tokens),
);
}
self.windows.get_mut(agent_id).unwrap()
}
pub fn remove_window(&mut self, agent_id: &str) {
self.windows.remove(agent_id);
}
pub fn total_tokens(&self) -> usize {
self.windows.values().map(|w| w.token_count()).sum()
}
pub fn window_count(&self) -> usize {
self.windows.len()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AgentState {
Idle,
Running,
Blocked,
Completed,
Failed,
}
#[derive(Debug, Clone)]
pub struct AgentContext {
pub agent_id: String,
pub agent_type: AgentType,
pub model: ClaudeModel,
pub state: AgentState,
pub context_tokens: usize,
pub total_tokens_used: usize,
pub total_cost: f64,
pub started_at: Option<Instant>,
pub completed_at: Option<Instant>,
pub error: Option<String>,
}
impl AgentContext {
pub fn new(agent_id: String, agent_type: AgentType, model: ClaudeModel) -> Self {
Self {
agent_id,
agent_type,
model,
state: AgentState::Idle,
context_tokens: 0,
total_tokens_used: 0,
total_cost: 0.0,
started_at: None,
completed_at: None,
error: None,
}
}
pub fn start(&mut self) {
self.state = AgentState::Running;
self.started_at = Some(Instant::now());
}
pub fn block(&mut self) {
self.state = AgentState::Blocked;
}
pub fn complete(&mut self, usage: &UsageStats) {
self.state = AgentState::Completed;
self.completed_at = Some(Instant::now());
self.total_tokens_used += usage.input_tokens + usage.output_tokens;
self.total_cost += usage.calculate_cost(self.model);
}
pub fn fail(&mut self, error: String) {
self.state = AgentState::Failed;
self.completed_at = Some(Instant::now());
self.error = Some(error);
}
pub fn duration(&self) -> Option<Duration> {
match (self.started_at, self.completed_at) {
(Some(start), Some(end)) => Some(end.duration_since(start)),
(Some(start), None) => Some(start.elapsed()),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct WorkflowStep {
pub step_id: String,
pub agent_type: AgentType,
pub task: String,
pub dependencies: Vec<String>,
pub required_model: Option<ClaudeModel>,
pub max_retries: u32,
}
#[derive(Debug, Clone)]
pub struct WorkflowResult {
pub workflow_id: String,
pub step_results: HashMap<String, StepResult>,
pub total_duration: Duration,
pub total_tokens: usize,
pub total_cost: f64,
pub success: bool,
pub error: Option<String>,
}
#[derive(Debug, Clone)]
pub struct StepResult {
pub step_id: String,
pub agent_id: String,
pub model: ClaudeModel,
pub response: Option<String>,
pub duration: Duration,
pub tokens_used: usize,
pub cost: f64,
pub success: bool,
pub error: Option<String>,
}
pub struct AgentCoordinator {
agents: Arc<RwLock<HashMap<String, AgentContext>>>,
context_manager: Arc<RwLock<ContextManager>>,
default_model: ClaudeModel,
max_concurrent: usize,
workflows_executed: u64,
total_cost: f64,
}
impl AgentCoordinator {
pub fn new(default_model: ClaudeModel, max_concurrent: usize) -> Self {
Self {
agents: Arc::new(RwLock::new(HashMap::new())),
context_manager: Arc::new(RwLock::new(ContextManager::new(100_000))),
default_model,
max_concurrent,
workflows_executed: 0,
total_cost: 0.0,
}
}
pub fn spawn_agent(&self, agent_id: String, agent_type: AgentType) -> Result<()> {
let mut agents = self.agents.write();
if agents.len() >= self.max_concurrent {
return Err(RuvLLMError::OutOfMemory(format!(
"Maximum concurrent agents ({}) reached",
self.max_concurrent
)));
}
if agents.contains_key(&agent_id) {
return Err(RuvLLMError::InvalidOperation(format!(
"Agent {} already exists",
agent_id
)));
}
let context = AgentContext::new(agent_id.clone(), agent_type, self.default_model);
agents.insert(agent_id, context);
Ok(())
}
pub fn get_agent(&self, agent_id: &str) -> Option<AgentContext> {
self.agents.read().get(agent_id).cloned()
}
pub fn update_agent<F>(&self, agent_id: &str, f: F) -> Result<()>
where
F: FnOnce(&mut AgentContext),
{
let mut agents = self.agents.write();
let agent = agents
.get_mut(agent_id)
.ok_or_else(|| RuvLLMError::NotFound(format!("Agent {} not found", agent_id)))?;
f(agent);
Ok(())
}
pub fn terminate_agent(&self, agent_id: &str) -> Result<()> {
let mut agents = self.agents.write();
agents
.remove(agent_id)
.ok_or_else(|| RuvLLMError::NotFound(format!("Agent {} not found", agent_id)))?;
self.context_manager.write().remove_window(agent_id);
Ok(())
}
pub fn active_agent_count(&self) -> usize {
self.agents
.read()
.values()
.filter(|a| a.state == AgentState::Running)
.count()
}
pub fn total_agent_count(&self) -> usize {
self.agents.read().len()
}
pub async fn execute_workflow(
&mut self,
workflow_id: String,
steps: Vec<WorkflowStep>,
) -> Result<WorkflowResult> {
let start_time = Instant::now();
let mut step_results: HashMap<String, StepResult> = HashMap::new();
let mut completed_steps: std::collections::HashSet<String> =
std::collections::HashSet::new();
let mut pending_steps: Vec<&WorkflowStep> = steps.iter().collect();
while !pending_steps.is_empty() {
let ready_steps: Vec<_> = pending_steps
.iter()
.filter(|step| {
step.dependencies
.iter()
.all(|dep| completed_steps.contains(dep))
})
.cloned()
.collect();
if ready_steps.is_empty() && !pending_steps.is_empty() {
return Err(RuvLLMError::InvalidOperation(
"Workflow has circular dependencies".to_string(),
));
}
for step in ready_steps {
let agent_id = format!("{}-{}", workflow_id, step.step_id);
let model = step.required_model.unwrap_or(self.default_model);
self.spawn_agent(agent_id.clone(), step.agent_type)?;
self.update_agent(&agent_id, |a| a.start())?;
let step_start = Instant::now();
let result = StepResult {
step_id: step.step_id.clone(),
agent_id: agent_id.clone(),
model,
response: Some(format!("Completed: {}", step.task)),
duration: step_start.elapsed(),
tokens_used: 500, cost: 0.001, success: true,
error: None,
};
self.update_agent(&agent_id, |a| {
let usage = UsageStats {
input_tokens: 250,
output_tokens: 250,
};
a.complete(&usage);
})?;
step_results.insert(step.step_id.clone(), result);
completed_steps.insert(step.step_id.clone());
self.terminate_agent(&agent_id)?;
}
pending_steps.retain(|step| !completed_steps.contains(&step.step_id));
}
let total_tokens: usize = step_results.values().map(|r| r.tokens_used).sum();
let total_cost: f64 = step_results.values().map(|r| r.cost).sum();
self.workflows_executed += 1;
self.total_cost += total_cost;
Ok(WorkflowResult {
workflow_id,
step_results,
total_duration: start_time.elapsed(),
total_tokens,
total_cost,
success: true,
error: None,
})
}
pub fn stats(&self) -> CoordinatorStats {
let agents = self.agents.read();
let active_count = agents
.values()
.filter(|a| a.state == AgentState::Running)
.count();
let total_tokens: usize = agents.values().map(|a| a.total_tokens_used).sum();
CoordinatorStats {
total_agents: agents.len(),
active_agents: active_count,
blocked_agents: agents
.values()
.filter(|a| a.state == AgentState::Blocked)
.count(),
completed_agents: agents
.values()
.filter(|a| a.state == AgentState::Completed)
.count(),
failed_agents: agents
.values()
.filter(|a| a.state == AgentState::Failed)
.count(),
workflows_executed: self.workflows_executed,
total_tokens_used: total_tokens,
total_cost: self.total_cost,
}
}
}
#[derive(Debug, Clone)]
pub struct CoordinatorStats {
pub total_agents: usize,
pub active_agents: usize,
pub blocked_agents: usize,
pub completed_agents: usize,
pub failed_agents: usize,
pub workflows_executed: u64,
pub total_tokens_used: usize,
pub total_cost: f64,
}
pub struct CostEstimator {
usage_by_model: HashMap<ClaudeModel, UsageStats>,
}
impl CostEstimator {
pub fn new() -> Self {
Self {
usage_by_model: HashMap::new(),
}
}
pub fn estimate_request_cost(
&self,
model: ClaudeModel,
input_tokens: usize,
expected_output_tokens: usize,
) -> f64 {
let input_cost = (input_tokens as f64 / 1000.0) * model.input_cost_per_1k();
let output_cost = (expected_output_tokens as f64 / 1000.0) * model.output_cost_per_1k();
input_cost + output_cost
}
pub fn record_usage(&mut self, model: ClaudeModel, usage: &UsageStats) {
let entry = self.usage_by_model.entry(model).or_default();
entry.input_tokens += usage.input_tokens;
entry.output_tokens += usage.output_tokens;
}
pub fn total_cost(&self) -> f64 {
self.usage_by_model
.iter()
.map(|(model, usage)| usage.calculate_cost(*model))
.sum()
}
pub fn cost_breakdown(&self) -> HashMap<ClaudeModel, f64> {
self.usage_by_model
.iter()
.map(|(model, usage)| (*model, usage.calculate_cost(*model)))
.collect()
}
pub fn usage_by_model(&self) -> &HashMap<ClaudeModel, UsageStats> {
&self.usage_by_model
}
}
impl Default for CostEstimator {
fn default() -> Self {
Self::new()
}
}
pub struct LatencyTracker {
samples: HashMap<ClaudeModel, Vec<LatencySample>>,
max_samples: usize,
}
#[derive(Debug, Clone)]
pub struct LatencySample {
pub ttft_ms: u64,
pub total_ms: u64,
pub input_tokens: usize,
pub output_tokens: usize,
pub timestamp: Instant,
}
impl LatencyTracker {
pub fn new(max_samples: usize) -> Self {
Self {
samples: HashMap::new(),
max_samples,
}
}
pub fn record(&mut self, model: ClaudeModel, sample: LatencySample) {
let samples = self.samples.entry(model).or_default();
samples.push(sample);
if samples.len() > self.max_samples {
samples.remove(0);
}
}
pub fn average_ttft(&self, model: ClaudeModel) -> Option<f64> {
self.samples.get(&model).map(|samples| {
if samples.is_empty() {
return 0.0;
}
let sum: u64 = samples.iter().map(|s| s.ttft_ms).sum();
sum as f64 / samples.len() as f64
})
}
pub fn p95_ttft(&self, model: ClaudeModel) -> Option<u64> {
self.samples.get(&model).and_then(|samples| {
if samples.is_empty() {
return None;
}
let mut ttfts: Vec<u64> = samples.iter().map(|s| s.ttft_ms).collect();
ttfts.sort();
let idx = (ttfts.len() as f64 * 0.95) as usize;
ttfts.get(idx.min(ttfts.len() - 1)).copied()
})
}
pub fn average_tokens_per_second(&self, model: ClaudeModel) -> Option<f64> {
self.samples.get(&model).map(|samples| {
if samples.is_empty() {
return 0.0;
}
let total_tokens: usize = samples.iter().map(|s| s.output_tokens).sum();
let total_time_ms: u64 = samples.iter().map(|s| s.total_ms - s.ttft_ms).sum();
if total_time_ms == 0 {
return 0.0;
}
total_tokens as f64 / (total_time_ms as f64 / 1000.0)
})
}
pub fn get_stats(&self, model: ClaudeModel) -> Option<LatencyStats> {
self.samples.get(&model).map(|samples| LatencyStats {
sample_count: samples.len(),
avg_ttft_ms: self.average_ttft(model).unwrap_or(0.0),
p95_ttft_ms: self.p95_ttft(model).unwrap_or(0),
avg_tokens_per_second: self.average_tokens_per_second(model).unwrap_or(0.0),
})
}
}
#[derive(Debug, Clone)]
pub struct LatencyStats {
pub sample_count: usize,
pub avg_ttft_ms: f64,
pub p95_ttft_ms: u64,
pub avg_tokens_per_second: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_claude_model_costs() {
let usage = UsageStats {
input_tokens: 1000,
output_tokens: 500,
};
let haiku_cost = usage.calculate_cost(ClaudeModel::Haiku);
let sonnet_cost = usage.calculate_cost(ClaudeModel::Sonnet);
let opus_cost = usage.calculate_cost(ClaudeModel::Opus);
assert!(haiku_cost < sonnet_cost);
assert!(sonnet_cost < opus_cost);
}
#[test]
fn test_context_window_compression() {
let mut window = ContextWindow::new(1000);
for i in 0..20 {
window.add_message(Message::user(format!(
"Message {} with some content to add tokens",
i
)));
}
assert!(window.token_count() <= 1000);
}
#[test]
fn test_message_token_estimation() {
let msg = Message::user("Hello, this is a test message with some content.");
let tokens = msg.estimate_tokens();
assert!(tokens > 0);
assert!(tokens < 100); }
#[test]
fn test_quality_monitor() {
let mut monitor = QualityMonitor::new(0.6, 10);
for _ in 0..5 {
monitor.record(0.8);
}
assert!(monitor.should_continue());
let mut bad_monitor = QualityMonitor::new(0.6, 10);
for _ in 0..5 {
bad_monitor.record(0.3);
}
assert!(!bad_monitor.should_continue());
}
#[test]
fn test_agent_coordinator() {
let coordinator = AgentCoordinator::new(ClaudeModel::Sonnet, 10);
coordinator
.spawn_agent("agent-1".to_string(), AgentType::Coder)
.unwrap();
coordinator
.spawn_agent("agent-2".to_string(), AgentType::Researcher)
.unwrap();
assert_eq!(coordinator.total_agent_count(), 2);
coordinator.update_agent("agent-1", |a| a.start()).unwrap();
assert_eq!(coordinator.active_agent_count(), 1);
coordinator.terminate_agent("agent-1").unwrap();
assert_eq!(coordinator.total_agent_count(), 1);
}
#[test]
fn test_cost_estimator() {
let mut estimator = CostEstimator::new();
let usage = UsageStats {
input_tokens: 1000,
output_tokens: 500,
};
estimator.record_usage(ClaudeModel::Sonnet, &usage);
estimator.record_usage(ClaudeModel::Haiku, &usage);
let total = estimator.total_cost();
assert!(total > 0.0);
let breakdown = estimator.cost_breakdown();
assert!(breakdown.contains_key(&ClaudeModel::Sonnet));
assert!(breakdown.contains_key(&ClaudeModel::Haiku));
}
#[test]
fn test_latency_tracker() {
let mut tracker = LatencyTracker::new(100);
for i in 0..10 {
tracker.record(
ClaudeModel::Sonnet,
LatencySample {
ttft_ms: 400 + i * 10,
total_ms: 1000 + i * 100,
input_tokens: 500,
output_tokens: 200,
timestamp: Instant::now(),
},
);
}
let stats = tracker.get_stats(ClaudeModel::Sonnet).unwrap();
assert_eq!(stats.sample_count, 10);
assert!(stats.avg_ttft_ms > 400.0);
assert!(stats.avg_tokens_per_second > 0.0);
}
}