wesichain-graph 0.3.0

Rust-native LLM agents & chains with resumable ReAct workflows
Documentation
#![allow(deprecated)]
use std::sync::Arc;

use futures::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::json;
use wesichain_core::{
    HasFinalOutput, HasUserInput, LlmRequest, LlmResponse, ReActStep, Runnable, ScratchpadState,
    Tool, ToolCall, ToolCallingLlm, ToolError, Value, WesichainError,
};
use wesichain_graph::{ExecutionOptions, GraphBuilder, GraphState, ReActAgentNode, StateSchema};

#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)]
struct DemoState {
    input: String,
    scratchpad: Vec<ReActStep>,
    final_output: Option<String>,
    iterations: u32,
}

impl StateSchema for DemoState {
    type Update = Self;
    fn apply(_: &Self, update: Self) -> Self {
        update
    }
}

impl ScratchpadState for DemoState {
    fn scratchpad(&self) -> &Vec<ReActStep> {
        &self.scratchpad
    }

    fn scratchpad_mut(&mut self) -> &mut Vec<ReActStep> {
        &mut self.scratchpad
    }

    fn iteration_count(&self) -> u32 {
        self.iterations
    }

    fn increment_iteration(&mut self) {
        self.iterations += 1;
    }
}

impl HasUserInput for DemoState {
    fn user_input(&self) -> &str {
        &self.input
    }
}

impl HasFinalOutput for DemoState {
    fn final_output(&self) -> Option<&str> {
        self.final_output.as_deref()
    }

    fn set_final_output(&mut self, value: String) {
        self.final_output = Some(value);
    }
}

struct MockTool;

#[async_trait::async_trait]
impl Tool for MockTool {
    fn name(&self) -> &str {
        "calculator"
    }

    fn description(&self) -> &str {
        "math"
    }

    fn schema(&self) -> Value {
        json!({"type": "object"})
    }

    async fn invoke(&self, _args: Value) -> Result<Value, ToolError> {
        Ok(json!(4))
    }
}

struct MockLlm;

#[async_trait::async_trait]
impl Runnable<LlmRequest, LlmResponse> for MockLlm {
    async fn invoke(&self, _request: LlmRequest) -> Result<LlmResponse, WesichainError> {
        Ok(LlmResponse {
            content: "".to_string(),
            tool_calls: vec![ToolCall {
                id: "c1".to_string(),
                name: "calculator".to_string(),
                args: json!({"expression": "2+2"}),
            }],
            usage: None,
            model: String::new(),
        })
    }

    fn stream(
        &self,
        _input: LlmRequest,
    ) -> futures::stream::BoxStream<'_, Result<wesichain_core::StreamEvent, WesichainError>> {
        futures::stream::empty().boxed()
    }
}

impl ToolCallingLlm for MockLlm {}

#[tokio::test]
async fn react_agent_executes_tool_and_finishes() {
    let llm = Arc::new(MockLlm);
    let tool = Arc::new(MockTool);
    let node = ReActAgentNode::builder()
        .llm(llm)
        .tools(vec![tool])
        .build()
        .unwrap();

    let graph = GraphBuilder::new()
        .add_node("agent", node)
        .set_entry("agent")
        .build();
    let state = GraphState::new(DemoState {
        input: "2+2".to_string(),
        ..Default::default()
    });
    let out = graph
        .invoke_with_options(state, ExecutionOptions::default())
        .await
        .unwrap();
    assert_eq!(out.data.final_output.as_deref(), Some(""));
    assert!(out
        .data
        .scratchpad
        .iter()
        .any(|step| matches!(step, ReActStep::Observation(_))));
}