cerebro 1.1.3

A high-performance semantic memory engine for AI Agents, now featuring SwarmForge for built-in multi-agent orchestration.
Documentation
//! # Agent Definitions
//!
//! Core types for defining agents within a Cerebro swarm.
//! An agent is a specialized unit with its own system prompt, LLM provider,
//! available tools, and handoff rules.

use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;

/// Configuration for a specialized agent in the swarm.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
    /// Unique identifier for this agent.
    pub id: String,
    /// Human-readable name (e.g., "Security Reviewer").
    pub name: String,
    /// The system prompt that defines this agent's personality and expertise.
    pub system_prompt: String,
    /// Which LLM provider and model to use.
    pub model: LlmProvider,
    /// Tools this agent can invoke via function-calling.
    #[serde(default)]
    pub tools: Vec<ToolDefinition>,
    /// Agent IDs this agent is allowed to hand off to.
    #[serde(default)]
    pub handoff_targets: Vec<String>,
    /// Maximum steps before circuit-breaker terminates this agent.
    #[serde(default = "default_max_steps")]
    pub max_steps: usize,
}

fn default_max_steps() -> usize {
    10
}

/// Supported LLM providers for agent inference.
///
/// Covers all major providers natively, plus an `OpenAICompatible` catch-all
/// for any service that implements the OpenAI chat completions format
/// (Groq, Together, Mistral, DeepSeek, LM Studio, vLLM, etc.).
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "provider", rename_all = "snake_case")]
pub enum LlmProvider {
    /// Local Ollama instance.
    Ollama {
        model: String,
        #[serde(default = "default_ollama_url")]
        base_url: String,
    },
    /// OpenAI API (GPT-4o, GPT-4, o3, etc.).
    OpenAI {
        model: String,
        api_key: String,
    },
    /// Google Gemini API.
    Gemini {
        model: String,
        api_key: String,
    },
    /// Anthropic Claude API (Claude 4, Sonnet, Haiku, etc.).
    Anthropic {
        model: String,
        api_key: String,
        /// Max tokens for Claude responses.
        #[serde(default = "default_anthropic_max_tokens")]
        max_tokens: usize,
    },
    /// Any OpenAI-compatible API endpoint.
    /// Works with: Groq, Together, Mistral, DeepSeek, Fireworks,
    /// LM Studio, vLLM, text-generation-inference, Anyscale, etc.
    OpenAICompatible {
        model: String,
        api_key: String,
        /// The base URL (e.g., "https://api.groq.com/openai/v1").
        base_url: String,
        /// Optional custom provider label for display/tracing.
        #[serde(default)]
        provider_name: Option<String>,
    },
}

fn default_ollama_url() -> String {
    "http://localhost:11434".to_string()
}

fn default_anthropic_max_tokens() -> usize {
    4096
}

impl fmt::Display for LlmProvider {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            LlmProvider::Ollama { model, .. } => write!(f, "ollama/{}", model),
            LlmProvider::OpenAI { model, .. } => write!(f, "openai/{}", model),
            LlmProvider::Gemini { model, .. } => write!(f, "gemini/{}", model),
            LlmProvider::Anthropic { model, .. } => write!(f, "anthropic/{}", model),
            LlmProvider::OpenAICompatible { model, provider_name, .. } => {
                let name = provider_name.as_deref().unwrap_or("custom");
                write!(f, "{}/{}", name, model)
            }
        }
    }
}

/// A tool that an agent can invoke during execution.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
    /// Tool name (e.g., "web_search", "run_code").
    pub name: String,
    /// Human-readable description of what this tool does.
    pub description: String,
    /// JSON Schema for the tool's input parameters.
    pub parameters_schema: serde_json::Value,
}

/// A single message in an agent's conversation history (episodic memory).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
    /// The role of the message sender.
    pub role: Role,
    /// The text content of the message.
    pub content: String,
    /// ISO 8601 timestamp.
    pub timestamp: String,
    /// Optional metadata: tool_call_id, agent_id, handoff info, etc.
    #[serde(default)]
    pub metadata: HashMap<String, String>,
}

impl ChatMessage {
    /// Create a new message with the current timestamp.
    pub fn new(role: Role, content: impl Into<String>) -> Self {
        Self {
            role,
            content: content.into(),
            timestamp: chrono::Utc::now().to_rfc3339(),
            metadata: HashMap::new(),
        }
    }

    /// Builder: attach metadata.
    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
        self.metadata.insert(key.into(), value.into());
        self
    }
}

/// Message roles in a conversation.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Role {
    System,
    User,
    Assistant,
    Tool,
}

impl fmt::Display for Role {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Role::System => write!(f, "system"),
            Role::User => write!(f, "user"),
            Role::Assistant => write!(f, "assistant"),
            Role::Tool => write!(f, "tool"),
        }
    }
}

/// Runtime state of an agent during an active swarm execution.
/// This is NOT persisted — it lives only for the duration of a swarm run.
pub struct AgentRuntime {
    /// The agent's static configuration.
    pub config: AgentConfig,
    /// Accumulated conversation messages (episodic memory for this run).
    pub messages: Vec<ChatMessage>,
    /// Current step counter.
    pub step: usize,
}

impl AgentRuntime {
    /// Create a new runtime from a config, initializing with the system prompt.
    pub fn new(config: AgentConfig) -> Self {
        let system_msg = ChatMessage::new(Role::System, &config.system_prompt);
        Self {
            messages: vec![system_msg],
            config,
            step: 0,
        }
    }

    /// Push a message to the conversation and increment the step counter.
    pub fn push_message(&mut self, msg: ChatMessage) {
        self.messages.push(msg);
    }

    /// Advance the step counter. Returns `true` if the agent has exceeded max_steps.
    pub fn advance_step(&mut self) -> bool {
        self.step += 1;
        self.step >= self.config.max_steps
    }

    /// Check if the agent has hit its circuit breaker.
    pub fn is_exhausted(&self) -> bool {
        self.step >= self.config.max_steps
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn sample_config() -> AgentConfig {
        AgentConfig {
            id: "security-agent".into(),
            name: "Security Reviewer".into(),
            system_prompt: "You are a security expert. Analyze code for vulnerabilities.".into(),
            model: LlmProvider::Ollama {
                model: "llama3".into(),
                base_url: default_ollama_url(),
            },
            tools: vec![],
            handoff_targets: vec!["perf-agent".into()],
            max_steps: 5,
        }
    }

    #[test]
    fn test_agent_config_serialization() {
        let config = sample_config();
        let json = serde_json::to_string_pretty(&config).unwrap();
        let deser: AgentConfig = serde_json::from_str(&json).unwrap();
        assert_eq!(deser.id, "security-agent");
        assert_eq!(deser.name, "Security Reviewer");
        assert_eq!(deser.max_steps, 5);
    }

    #[test]
    fn test_llm_provider_display() {
        let ollama = LlmProvider::Ollama {
            model: "llama3".into(),
            base_url: default_ollama_url(),
        };
        assert_eq!(format!("{}", ollama), "ollama/llama3");

        let anthropic = LlmProvider::Anthropic {
            model: "claude-sonnet-4-20250514".into(),
            api_key: "sk-test".into(),
            max_tokens: 4096,
        };
        assert_eq!(format!("{}", anthropic), "anthropic/claude-sonnet-4-20250514");

        let groq = LlmProvider::OpenAICompatible {
            model: "llama-3.3-70b-versatile".into(),
            api_key: "gsk-test".into(),
            base_url: "https://api.groq.com/openai/v1".into(),
            provider_name: Some("groq".into()),
        };
        assert_eq!(format!("{}", groq), "groq/llama-3.3-70b-versatile");

        let custom = LlmProvider::OpenAICompatible {
            model: "my-model".into(),
            api_key: "key".into(),
            base_url: "http://localhost:8080/v1".into(),
            provider_name: None,
        };
        assert_eq!(format!("{}", custom), "custom/my-model");
    }

    #[test]
    fn test_chat_message_builder() {
        let msg = ChatMessage::new(Role::User, "Hello")
            .with_metadata("agent_id", "sec-01")
            .with_metadata("run_id", "run-abc");
        assert_eq!(msg.role, Role::User);
        assert_eq!(msg.content, "Hello");
        assert_eq!(msg.metadata.get("agent_id").unwrap(), "sec-01");
        assert!(!msg.timestamp.is_empty());
    }

    #[test]
    fn test_agent_runtime_lifecycle() {
        let config = sample_config();
        let mut rt = AgentRuntime::new(config);
        assert_eq!(rt.step, 0);
        assert_eq!(rt.messages.len(), 1); // system prompt
        assert_eq!(rt.messages[0].role, Role::System);

        rt.push_message(ChatMessage::new(Role::User, "Review this code"));
        assert_eq!(rt.messages.len(), 2);

        // Advance 4 steps (max is 5, so step 4 is not exhausted)
        for _ in 0..4 {
            assert!(!rt.advance_step());
        }
        assert!(!rt.is_exhausted());

        // Step 5 hits the circuit breaker
        assert!(rt.advance_step());
        assert!(rt.is_exhausted());
    }

    #[test]
    fn test_role_display() {
        assert_eq!(format!("{}", Role::System), "system");
        assert_eq!(format!("{}", Role::Assistant), "assistant");
        assert_eq!(format!("{}", Role::Tool), "tool");
    }
}