wesichain-agent 0.2.0

Rust-native LLM agents & chains with resumable ReAct workflows
Documentation
#![allow(deprecated)]

use async_trait::async_trait;
use futures::stream::StreamExt;
use std::sync::{Arc, Mutex};
use wesichain_agent::{Tool, ToolCallingAgent, ToolRegistry};
use wesichain_core::{Runnable, StreamEvent, ToolError, Value, WesichainError};
use wesichain_llm::{LlmRequest, LlmResponse, Message, Role, ToolCall};

struct MockLlm;

#[async_trait]
impl Runnable<LlmRequest, LlmResponse> for MockLlm {
    async fn invoke(&self, input: LlmRequest) -> Result<LlmResponse, WesichainError> {
        if input.messages.len() == 1 {
            return Ok(LlmResponse {
                content: String::new(),
                tool_calls: vec![ToolCall {
                    id: "1".to_string(),
                    name: "echo".to_string(),
                    args: Value::from("hi"),
                }],
            });
        }
        Ok(LlmResponse {
            content: "done".to_string(),
            tool_calls: vec![],
        })
    }

    fn stream(
        &self,
        _input: LlmRequest,
    ) -> futures::stream::BoxStream<'_, Result<StreamEvent, WesichainError>> {
        futures::stream::empty().boxed()
    }
}

struct EchoTool;

#[async_trait]
impl Tool for EchoTool {
    fn name(&self) -> &str {
        "echo"
    }

    fn description(&self) -> &str {
        "echo"
    }

    fn schema(&self) -> Value {
        Value::from("schema")
    }

    async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
        Ok(input)
    }
}

#[tokio::test]
async fn agent_calls_tool_then_finishes() {
    let mut tools = ToolRegistry::new();
    tools.register(Box::new(EchoTool));

    let agent = ToolCallingAgent::new(MockLlm, tools, "mock".to_string()).max_steps(3);
    let output = agent.invoke("hi".to_string()).await.unwrap();
    assert_eq!(output, "done");
}

#[tokio::test]
async fn agent_stops_after_max_steps() {
    struct LoopLlm;

    #[async_trait]
    impl Runnable<LlmRequest, LlmResponse> for LoopLlm {
        async fn invoke(&self, _input: LlmRequest) -> Result<LlmResponse, WesichainError> {
            Ok(LlmResponse {
                content: String::new(),
                tool_calls: vec![ToolCall {
                    id: "1".to_string(),
                    name: "echo".to_string(),
                    args: Value::from("hi"),
                }],
            })
        }

        fn stream(
            &self,
            _input: LlmRequest,
        ) -> futures::stream::BoxStream<'_, Result<StreamEvent, WesichainError>> {
            futures::stream::empty().boxed()
        }
    }

    let mut tools = ToolRegistry::new();
    tools.register(Box::new(EchoTool));

    let agent = ToolCallingAgent::new(LoopLlm, tools, "mock".to_string()).max_steps(2);
    let err = agent.invoke("hi".to_string()).await.unwrap_err();
    assert!(matches!(err, WesichainError::Custom(_)));
}

#[tokio::test]
async fn agent_includes_assistant_message_before_tool_results() {
    struct RecordingLlm {
        calls: Arc<Mutex<Vec<Vec<Message>>>>,
    }

    #[async_trait]
    impl Runnable<LlmRequest, LlmResponse> for RecordingLlm {
        async fn invoke(&self, input: LlmRequest) -> Result<LlmResponse, WesichainError> {
            self.calls.lock().unwrap().push(input.messages.clone());
            if input.messages.len() == 1 {
                return Ok(LlmResponse {
                    content: "need tool".to_string(),
                    tool_calls: vec![ToolCall {
                        id: "1".to_string(),
                        name: "echo".to_string(),
                        args: Value::from("hi"),
                    }],
                });
            }
            Ok(LlmResponse {
                content: "done".to_string(),
                tool_calls: vec![],
            })
        }

        fn stream(
            &self,
            _input: LlmRequest,
        ) -> futures::stream::BoxStream<'_, Result<StreamEvent, WesichainError>> {
            futures::stream::empty().boxed()
        }
    }

    let calls = Arc::new(Mutex::new(Vec::new()));
    let llm = RecordingLlm {
        calls: Arc::clone(&calls),
    };
    let mut tools = ToolRegistry::new();
    tools.register(Box::new(EchoTool));

    let agent = ToolCallingAgent::new(llm, tools, "mock".to_string()).max_steps(3);
    let output = agent.invoke("hi".to_string()).await.unwrap();
    assert_eq!(output, "done");

    let recorded = calls.lock().unwrap();
    assert_eq!(recorded.len(), 2);
    let second_call = &recorded[1];
    assert!(second_call
        .iter()
        .any(|message| { message.role == Role::Assistant && message.content == "need tool" }));
    assert_eq!(second_call.len(), 3);
}

#[tokio::test]
async fn agent_returns_tool_call_failed_for_missing_tool() {
    struct MissingToolLlm;

    #[async_trait]
    impl Runnable<LlmRequest, LlmResponse> for MissingToolLlm {
        async fn invoke(&self, _input: LlmRequest) -> Result<LlmResponse, WesichainError> {
            Ok(LlmResponse {
                content: String::new(),
                tool_calls: vec![ToolCall {
                    id: "1".to_string(),
                    name: "missing".to_string(),
                    args: Value::from("hi"),
                }],
            })
        }

        fn stream(
            &self,
            _input: LlmRequest,
        ) -> futures::stream::BoxStream<'_, Result<StreamEvent, WesichainError>> {
            futures::stream::empty().boxed()
        }
    }

    let tools = ToolRegistry::new();
    let agent = ToolCallingAgent::new(MissingToolLlm, tools, "mock".to_string()).max_steps(1);
    let err = agent.invoke("hi".to_string()).await.unwrap_err();
    assert!(matches!(err, WesichainError::ToolCallFailed { .. }));
}

#[tokio::test]
async fn agent_stream_returns_not_implemented_error() {
    let mut tools = ToolRegistry::new();
    tools.register(Box::new(EchoTool));

    let agent = ToolCallingAgent::new(MockLlm, tools, "mock".to_string());
    let mut stream = agent.stream("hi".to_string());
    let result = stream.next().await.unwrap();
    assert!(
        matches!(result, Err(WesichainError::Custom(message)) if message == "stream not implemented")
    );
}