wesichain-agent 0.2.1

Rust-native LLM agents & chains with resumable ReAct workflows
Documentation
use std::collections::{BTreeMap, HashSet};
use std::sync::Arc;

use schemars::{schema::RootSchema, JsonSchema};
use serde::de::DeserializeOwned;
use serde::Deserialize;
use serde_json::Value;

use crate::error::ToolDispatchError;
pub use tokio_util::sync::CancellationToken;

pub type ToolError = wesichain_core::ToolError;

#[derive(Clone, Debug)]
pub struct ToolContext {
    pub correlation_id: String,
    pub step_id: u32,
    pub cancellation: CancellationToken,
}

#[allow(async_fn_in_trait)]
pub trait TypedTool {
    type Args: DeserializeOwned + JsonSchema;
    type Output: serde::Serialize + JsonSchema;

    const NAME: &'static str;

    async fn run(&self, args: Self::Args, ctx: ToolContext) -> Result<Self::Output, ToolError>;
}

#[derive(Clone, Debug)]
pub struct ToolSchema {
    pub args_schema: RootSchema,
    pub output_schema: RootSchema,
}

#[derive(Clone, Debug, Deserialize)]
pub struct ToolCallEnvelope {
    pub name: String,
    pub args: Value,
    pub call_id: String,
}

#[derive(Clone)]
pub struct ToolSet {
    entries: Vec<ToolMetadata>,
    schema_catalog: BTreeMap<String, ToolSchema>,
    dispatchers: BTreeMap<String, Arc<dyn ErasedToolRunner>>,
}

impl std::fmt::Debug for ToolSet {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ToolSet")
            .field("entries", &self.entries)
            .field("schema_catalog_len", &self.schema_catalog.len())
            .field("dispatchers_len", &self.dispatchers.len())
            .finish()
    }
}

impl ToolSet {
    #[allow(
        clippy::new_ret_no_self,
        reason = "ToolSet::new intentionally starts a builder-first registration API"
    )]
    pub fn new() -> ToolSetBuilder {
        ToolSetBuilder {
            entries: Vec::new(),
            dispatchers: Vec::new(),
        }
    }

    pub fn names(&self) -> Vec<&str> {
        self.entries
            .iter()
            .map(|entry| entry.name.as_str())
            .collect()
    }

    pub fn schema_catalog(&self) -> &BTreeMap<String, ToolSchema> {
        &self.schema_catalog
    }

    pub async fn dispatch(
        &self,
        envelope: ToolCallEnvelope,
        ctx: ToolContext,
    ) -> Result<Value, ToolDispatchError> {
        let Some(dispatcher) = self.dispatchers.get(&envelope.name) else {
            return Err(ToolDispatchError::UnknownTool {
                name: envelope.name,
                call_id: envelope.call_id,
            });
        };

        dispatcher
            .dispatch(&envelope.name, envelope.args, envelope.call_id, ctx)
            .await
    }
}

#[derive(Clone, Default)]
pub struct ToolSetBuilder {
    entries: Vec<ToolMetadata>,
    dispatchers: Vec<ToolDispatchMetadata>,
}

impl ToolSetBuilder {
    pub fn register<T>(mut self) -> Self
    where
        T: TypedTool,
    {
        self.entries.push(ToolMetadata {
            name: T::NAME.to_string(),
            schema: ToolSchema {
                args_schema: schemars::schema_for!(T::Args),
                output_schema: schemars::schema_for!(T::Output),
            },
        });
        self
    }

    pub fn register_with<T>(mut self, tool: T) -> Self
    where
        T: TypedTool + Send + Sync + 'static,
    {
        self.entries.push(ToolMetadata {
            name: T::NAME.to_string(),
            schema: ToolSchema {
                args_schema: schemars::schema_for!(T::Args),
                output_schema: schemars::schema_for!(T::Output),
            },
        });
        self.dispatchers.push(ToolDispatchMetadata {
            name: T::NAME.to_string(),
            runner: Arc::new(TypedToolRunner { tool }),
        });
        self
    }

    pub fn build(self) -> Result<ToolSet, ToolSetBuildError> {
        let mut seen = HashSet::new();
        let mut catalog = BTreeMap::new();
        let mut dispatchers = BTreeMap::new();

        for entry in &self.entries {
            if entry.name.trim().is_empty() {
                return Err(ToolSetBuildError::InvalidName {
                    name: entry.name.clone(),
                });
            }

            if !seen.insert(entry.name.clone()) {
                return Err(ToolSetBuildError::DuplicateName {
                    name: entry.name.clone(),
                });
            }

            catalog.insert(entry.name.clone(), entry.schema.clone());
        }

        for dispatch in self.dispatchers {
            dispatchers.insert(dispatch.name, dispatch.runner);
        }

        Ok(ToolSet {
            entries: self.entries,
            schema_catalog: catalog,
            dispatchers,
        })
    }
}

#[derive(Clone, Debug)]
struct ToolMetadata {
    name: String,
    schema: ToolSchema,
}

#[derive(Clone)]
struct ToolDispatchMetadata {
    name: String,
    runner: Arc<dyn ErasedToolRunner>,
}

#[async_trait::async_trait(?Send)]
trait ErasedToolRunner: Send + Sync {
    async fn dispatch(
        &self,
        name: &str,
        args: Value,
        call_id: String,
        ctx: ToolContext,
    ) -> Result<Value, ToolDispatchError>;
}

#[derive(Clone)]
struct TypedToolRunner<T> {
    tool: T,
}

#[async_trait::async_trait(?Send)]
impl<T> ErasedToolRunner for TypedToolRunner<T>
where
    T: TypedTool + Send + Sync,
{
    async fn dispatch(
        &self,
        name: &str,
        args: Value,
        call_id: String,
        ctx: ToolContext,
    ) -> Result<Value, ToolDispatchError> {
        let typed_args = serde_json::from_value::<T::Args>(args).map_err(|source| {
            ToolDispatchError::InvalidArgs {
                name: name.to_string(),
                call_id: call_id.clone(),
                source,
            }
        })?;

        let output = self.tool.run(typed_args, ctx).await.map_err(|source| {
            ToolDispatchError::Execution {
                name: name.to_string(),
                call_id: call_id.clone(),
                source,
            }
        })?;

        serde_json::to_value(output).map_err(|source| ToolDispatchError::Serialization {
            name: name.to_string(),
            call_id,
            source,
        })
    }
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ToolSetBuildError {
    InvalidName { name: String },
    DuplicateName { name: String },
}

impl std::fmt::Display for ToolSetBuildError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ToolSetBuildError::InvalidName { name } => {
                write!(f, "tool name must not be empty or whitespace: {name:?}")
            }
            ToolSetBuildError::DuplicateName { name } => {
                write!(f, "duplicate tool name: {name}")
            }
        }
    }
}

impl std::error::Error for ToolSetBuildError {}