wesichain-graph 0.3.0

Rust-native LLM agents & chains with resumable ReAct workflows
Documentation
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
use wesichain_core::Tool;
use wesichain_core::{ToolError, Value};
use wesichain_graph::{GraphState, HasToolCalls, StateSchema, ToolNode};
use wesichain_llm::{Message, Role, ToolCall};

#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)]
struct AgentState {
    tool_calls: Vec<ToolCall>,
    tool_results: Vec<Message>,
}

impl StateSchema for AgentState {
    type Update = Self;
    fn apply(_: &Self, update: Self) -> Self {
        update
    }
}

impl HasToolCalls for AgentState {
    fn tool_calls(&self) -> &Vec<ToolCall> {
        &self.tool_calls
    }

    fn push_tool_result(&mut self, message: Message) {
        self.tool_results.push(message);
    }
}

#[derive(Default)]
struct MockTool {
    calls: Arc<Mutex<Vec<Value>>>,
}

#[async_trait::async_trait]
impl Tool for MockTool {
    fn name(&self) -> &str {
        "echo"
    }

    fn description(&self) -> &str {
        "echo"
    }

    fn schema(&self) -> Value {
        serde_json::json!({"type": "object"})
    }

    async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
        self.calls.lock().unwrap().push(input.clone());
        Ok(input)
    }
}

#[tokio::test]
async fn tool_node_executes_calls_and_appends_results() {
    let calls = vec![ToolCall {
        id: "1".into(),
        name: "echo".into(),
        args: serde_json::json!({"text": "hi"}),
    }];
    let state = GraphState::new(AgentState {
        tool_calls: calls,
        tool_results: Vec::new(),
    });
    let tool = Arc::new(MockTool::default());
    let calls_log = tool.calls.clone();
    let node = ToolNode::new(vec![tool]);
    let update = node.invoke(state).await.unwrap();
    assert_eq!(calls_log.lock().unwrap().len(), 1);
    assert_eq!(update.data.tool_results.len(), 1);
    assert_eq!(update.data.tool_results[0].role, Role::Tool);
    assert_eq!(
        update.data.tool_results[0].tool_call_id,
        Some("1".to_string())
    );
}