tirea-agentos 0.4.0

Agent runtime with streaming LLM integration, sub-agent orchestration, and context window management
Documentation
use super::sorted_registry_ids;
use super::traits::{ToolRegistry, ToolRegistryError};
use crate::contracts::runtime::tool_call::Tool;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};

#[derive(Clone, Default)]
pub struct InMemoryToolRegistry {
    tools: HashMap<String, Arc<dyn Tool>>,
}

impl std::fmt::Debug for InMemoryToolRegistry {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("InMemoryToolRegistry")
            .field("len", &self.tools.len())
            .finish()
    }
}

impl InMemoryToolRegistry {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn len(&self) -> usize {
        self.tools.len()
    }

    pub fn is_empty(&self) -> bool {
        self.tools.is_empty()
    }

    pub fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
        self.tools.get(id).cloned()
    }

    pub fn ids(&self) -> impl Iterator<Item = &String> {
        self.tools.keys()
    }

    pub fn register(&mut self, tool: Arc<dyn Tool>) -> Result<(), ToolRegistryError> {
        let id = tool.descriptor().id;
        if self.tools.contains_key(&id) {
            return Err(ToolRegistryError::ToolIdConflict(id));
        }
        self.tools.insert(id, tool);
        Ok(())
    }

    pub fn register_named(
        &mut self,
        id: impl Into<String>,
        tool: Arc<dyn Tool>,
    ) -> Result<(), ToolRegistryError> {
        let key = id.into();
        let descriptor_id = tool.descriptor().id;
        if key != descriptor_id {
            return Err(ToolRegistryError::ToolIdMismatch { key, descriptor_id });
        }
        if self.tools.contains_key(&key) {
            return Err(ToolRegistryError::ToolIdConflict(key));
        }
        self.tools.insert(key, tool);
        Ok(())
    }

    pub fn extend_named(
        &mut self,
        tools: HashMap<String, Arc<dyn Tool>>,
    ) -> Result<(), ToolRegistryError> {
        for (key, tool) in tools {
            self.register_named(key, tool)?;
        }
        Ok(())
    }

    pub fn extend_registry(&mut self, other: &dyn ToolRegistry) -> Result<(), ToolRegistryError> {
        self.extend_named(other.snapshot())
    }

    pub fn merge_many(
        regs: impl IntoIterator<Item = InMemoryToolRegistry>,
    ) -> Result<InMemoryToolRegistry, ToolRegistryError> {
        let mut out = InMemoryToolRegistry::new();
        for r in regs {
            out.extend_named(r.into_map())?;
        }
        Ok(out)
    }

    pub fn into_map(self) -> HashMap<String, Arc<dyn Tool>> {
        self.tools
    }

    pub fn to_map(&self) -> HashMap<String, Arc<dyn Tool>> {
        self.tools.clone()
    }
}

impl ToolRegistry for InMemoryToolRegistry {
    fn len(&self) -> usize {
        self.len()
    }

    fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
        self.get(id)
    }

    fn ids(&self) -> Vec<String> {
        sorted_registry_ids(&self.tools)
    }

    fn snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
        self.tools.clone()
    }
}

#[derive(Clone, Default)]
pub struct CompositeToolRegistry {
    registries: Vec<Arc<dyn ToolRegistry>>,
    cached_snapshot: Arc<RwLock<HashMap<String, Arc<dyn Tool>>>>,
}

impl std::fmt::Debug for CompositeToolRegistry {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let snapshot = match self.cached_snapshot.read() {
            Ok(guard) => guard,
            Err(poisoned) => poisoned.into_inner(),
        };
        f.debug_struct("CompositeToolRegistry")
            .field("registries", &self.registries.len())
            .field("len", &snapshot.len())
            .finish()
    }
}

impl CompositeToolRegistry {
    pub fn try_new(
        regs: impl IntoIterator<Item = Arc<dyn ToolRegistry>>,
    ) -> Result<Self, ToolRegistryError> {
        let registries: Vec<Arc<dyn ToolRegistry>> = regs.into_iter().collect();
        let merged = Self::merge_snapshots(&registries)?;
        Ok(Self {
            registries,
            cached_snapshot: Arc::new(RwLock::new(merged)),
        })
    }

    fn merge_snapshots(
        registries: &[Arc<dyn ToolRegistry>],
    ) -> Result<HashMap<String, Arc<dyn Tool>>, ToolRegistryError> {
        let mut merged = InMemoryToolRegistry::new();
        for reg in registries {
            merged.extend_registry(reg.as_ref())?;
        }
        Ok(merged.into_map())
    }

    fn refresh_snapshot(&self) -> Result<HashMap<String, Arc<dyn Tool>>, ToolRegistryError> {
        Self::merge_snapshots(&self.registries)
    }

    fn read_cached_snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
        match self.cached_snapshot.read() {
            Ok(guard) => guard.clone(),
            Err(poisoned) => poisoned.into_inner().clone(),
        }
    }

    fn write_cached_snapshot(&self, snapshot: HashMap<String, Arc<dyn Tool>>) {
        match self.cached_snapshot.write() {
            Ok(mut guard) => *guard = snapshot,
            Err(poisoned) => *poisoned.into_inner() = snapshot,
        };
    }
}

impl ToolRegistry for CompositeToolRegistry {
    fn len(&self) -> usize {
        self.snapshot().len()
    }

    fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
        self.snapshot().get(id).cloned()
    }

    fn ids(&self) -> Vec<String> {
        let snapshot = self.snapshot();
        sorted_registry_ids(&snapshot)
    }

    fn snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
        match self.refresh_snapshot() {
            Ok(snapshot) => {
                self.write_cached_snapshot(snapshot.clone());
                snapshot
            }
            Err(_) => self.read_cached_snapshot(),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::contracts::runtime::tool_call::{ToolDescriptor, ToolError, ToolResult};
    use crate::contracts::ToolCallContext;
    use serde_json::json;

    struct StaticTool {
        descriptor: ToolDescriptor,
    }

    impl StaticTool {
        fn new(id: &str) -> Self {
            Self {
                descriptor: ToolDescriptor::new(id, id, "test tool"),
            }
        }
    }

    #[async_trait::async_trait]
    impl Tool for StaticTool {
        fn descriptor(&self) -> ToolDescriptor {
            self.descriptor.clone()
        }

        async fn execute(
            &self,
            _args: serde_json::Value,
            _ctx: &ToolCallContext<'_>,
        ) -> Result<ToolResult, ToolError> {
            Ok(ToolResult::success(
                self.descriptor.id.clone(),
                json!({"ok": true}),
            ))
        }
    }

    #[derive(Default)]
    struct MutableToolRegistry {
        tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
    }

    impl MutableToolRegistry {
        fn replace_ids(&self, ids: &[&str]) {
            let mut map = HashMap::new();
            for id in ids {
                map.insert(
                    (*id).to_string(),
                    Arc::new(StaticTool::new(id)) as Arc<dyn Tool>,
                );
            }
            match self.tools.write() {
                Ok(mut guard) => *guard = map,
                Err(poisoned) => *poisoned.into_inner() = map,
            }
        }
    }

    impl ToolRegistry for MutableToolRegistry {
        fn len(&self) -> usize {
            self.snapshot().len()
        }

        fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
            self.snapshot().get(id).cloned()
        }

        fn ids(&self) -> Vec<String> {
            let mut ids: Vec<String> = self.snapshot().keys().cloned().collect();
            ids.sort();
            ids
        }

        fn snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
            match self.tools.read() {
                Ok(guard) => guard.clone(),
                Err(poisoned) => poisoned.into_inner().clone(),
            }
        }
    }

    #[test]
    fn composite_tool_registry_reads_live_updates_from_source_registries() {
        let dynamic = Arc::new(MutableToolRegistry::default());
        dynamic.replace_ids(&["dynamic_a"]);

        let mut static_registry = InMemoryToolRegistry::new();
        static_registry
            .register_named("static_tool", Arc::new(StaticTool::new("static_tool")))
            .expect("register static tool");

        let composite = CompositeToolRegistry::try_new(vec![
            dynamic.clone() as Arc<dyn ToolRegistry>,
            Arc::new(static_registry) as Arc<dyn ToolRegistry>,
        ])
        .expect("compose registries");

        assert!(composite.ids().contains(&"dynamic_a".to_string()));
        assert!(composite.ids().contains(&"static_tool".to_string()));

        dynamic.replace_ids(&["dynamic_a", "dynamic_b"]);

        let ids = composite.ids();
        assert!(ids.contains(&"dynamic_a".to_string()));
        assert!(ids.contains(&"dynamic_b".to_string()));
        assert!(ids.contains(&"static_tool".to_string()));
    }

    #[test]
    fn composite_tool_registry_keeps_last_good_snapshot_on_runtime_conflict() {
        let reg_a = Arc::new(MutableToolRegistry::default());
        reg_a.replace_ids(&["tool_a"]);

        let reg_b = Arc::new(MutableToolRegistry::default());
        reg_b.replace_ids(&["tool_b"]);

        let composite = CompositeToolRegistry::try_new(vec![
            reg_a.clone() as Arc<dyn ToolRegistry>,
            reg_b.clone() as Arc<dyn ToolRegistry>,
        ])
        .expect("compose registries");

        let initial_ids = composite.ids();
        assert_eq!(
            initial_ids,
            vec!["tool_a".to_string(), "tool_b".to_string()]
        );

        // Introduce a conflict at runtime. Composite should fall back to last good snapshot.
        reg_b.replace_ids(&["tool_a"]);

        assert_eq!(composite.ids(), initial_ids);
        assert!(composite.get("tool_b").is_some());
    }
}