agent-base 0.1.0

A lightweight Agent Runtime Kernel for building AI agents in Rust
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use async_trait::async_trait;
use serde_json::{json, Value};
use tokio::sync::broadcast;

use crate::llm::LlmClient;
use crate::types::{AgentResult, AgentError, AgentEvent, SessionId};
use crate::engine::SessionStore;

pub mod mcp;
pub mod policy;
pub mod subagent;

pub use mcp::{McpClient, McpToolInfo, McpToolRegistry};
pub use subagent::{SubAgentSessionPolicy, SubAgentTool};

pub use policy::ToolPolicy;

#[derive(Clone, Debug, Default)]
pub struct ToolOutput {
    pub summary: String,
    pub raw: Option<Value>,
    pub control_flow: ToolControlFlow,
    pub truncated: bool,
}

#[derive(Clone, Debug, Default)]
pub enum ToolControlFlow {
    #[default]
    Break,
    Continue,
}

#[derive(Clone)]
pub struct ToolContext {
    pub session_id: SessionId,
    pub event_bus: broadcast::Sender<AgentEvent>,
    pub llm_client: Option<Arc<dyn LlmClient>>,
    pub session_store: Option<Arc<dyn SessionStore>>,
}

#[async_trait]
pub trait Tool: Send + Sync {
    fn name(&self) -> &'static str;
    fn definition(&self) -> Value;
    async fn call(&self, args: &Value, ctx: &ToolContext) -> AgentResult<ToolOutput>;
}

#[async_trait]
pub trait TypedTool: Send + Sync {
    type Args: serde::de::DeserializeOwned;
    type Output: serde::Serialize;

    fn name(&self) -> &'static str;
    fn description(&self) -> &'static str;
    fn parameters_schema(&self) -> Value;
    async fn call_typed(&self, args: Self::Args, ctx: &ToolContext) -> AgentResult<Self::Output>;

    fn control_flow() -> ToolControlFlow
    where
        Self: Sized,
    {
        ToolControlFlow::Break
    }

    fn format_output(&self, output: Self::Output) -> String {
        serde_json::to_string(&output).unwrap_or_default()
    }
}

#[async_trait]
impl<T: TypedTool + Send + Sync + 'static> Tool for T {
    fn name(&self) -> &'static str {
        TypedTool::name(self)
    }

    fn definition(&self) -> Value {
        json!({
            "type": "function",
            "function": {
                "name": self.name(),
                "description": self.description(),
                "parameters": self.parameters_schema(),
            }
        })
    }

    async fn call(&self, args: &Value, ctx: &ToolContext) -> AgentResult<ToolOutput> {
        let typed_args: T::Args = serde_json::from_value(args.clone())
            .map_err(|_| AgentError::ToolArgsInvalid {
                name: self.name().to_string(),
                raw: args.to_string(),
            })?;
        let output = self.call_typed(typed_args, ctx).await?;
        let output_json = serde_json::to_value(&output).ok();
        let summary = self.format_output(output);
        Ok(ToolOutput {
            summary,
            raw: output_json,
            control_flow: T::control_flow(),
            truncated: false,
        })
    }
}

pub(crate) type ToolRef = Arc<dyn Tool>;

#[derive(Clone, Default)]
pub struct ToolRegistry {
    tools: HashMap<String, ToolRef>,
}

impl ToolRegistry {
    pub fn register(&mut self, tool: impl Tool + 'static) {
        self.tools.insert(tool.name().to_string(), Arc::new(tool));
    }

    pub fn register_arc(&mut self, tool: Arc<dyn Tool>) {
        self.tools.insert(tool.name().to_string(), tool);
    }

    pub fn get(&self, name: &str) -> Option<ToolRef> {
        self.tools.get(name).cloned()
    }

    pub fn definitions(&self) -> Vec<Value> {
        self.tools.values().map(|tool| tool.definition()).collect()
    }

    pub fn len(&self) -> usize {
        self.tools.len()
    }

    pub fn is_empty(&self) -> bool {
        self.tools.is_empty()
    }
}