wesichain-langsmith 0.3.0

LangSmith observability for Wesichain
Documentation
#![allow(deprecated)]
use std::sync::Arc;
use std::time::Duration;

use futures::StreamExt;
use secrecy::SecretString;
use serde::{Deserialize, Serialize};
use serde_json::json;
use uuid::Uuid;
use wiremock::matchers::{method, path_regex};
use wiremock::{Mock, MockServer, ResponseTemplate};

use wesichain_core::{
    HasFinalOutput, HasUserInput, LlmRequest, LlmResponse, ReActStep, Runnable, ScratchpadState,
    Tool, ToolCall, ToolCallingLlm, ToolError, Value, WesichainError,
};
use wesichain_graph::{ExecutionOptions, GraphBuilder, GraphState, ReActAgentNode, StateSchema};
use wesichain_langsmith::{LangSmithConfig, LangSmithObserver};

#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)]
struct DemoState {
    input: String,
    scratchpad: Vec<ReActStep>,
    final_output: Option<String>,
    iterations: u32,
}

impl StateSchema for DemoState {
    type Update = Self;
    fn apply(_: &Self, update: Self) -> Self {
        update
    }
}

impl ScratchpadState for DemoState {
    fn scratchpad(&self) -> &Vec<ReActStep> {
        &self.scratchpad
    }

    fn scratchpad_mut(&mut self) -> &mut Vec<ReActStep> {
        &mut self.scratchpad
    }

    fn iteration_count(&self) -> u32 {
        self.iterations
    }

    fn increment_iteration(&mut self) {
        self.iterations += 1;
    }
}

impl HasUserInput for DemoState {
    fn user_input(&self) -> &str {
        &self.input
    }
}

impl HasFinalOutput for DemoState {
    fn final_output(&self) -> Option<&str> {
        self.final_output.as_deref()
    }

    fn set_final_output(&mut self, value: String) {
        self.final_output = Some(value);
    }
}

struct MockTool;

#[async_trait::async_trait]
impl Tool for MockTool {
    fn name(&self) -> &str {
        "calculator"
    }

    fn description(&self) -> &str {
        "math"
    }

    fn schema(&self) -> Value {
        json!({"type": "object"})
    }

    async fn invoke(&self, _args: Value) -> Result<Value, ToolError> {
        Ok(json!(4))
    }
}

struct MockLlm;

#[async_trait::async_trait]
impl Runnable<LlmRequest, LlmResponse> for MockLlm {
    async fn invoke(&self, _request: LlmRequest) -> Result<LlmResponse, WesichainError> {
        Ok(LlmResponse {
            content: "".to_string(),
            tool_calls: vec![ToolCall {
                id: Uuid::new_v4().to_string(),
                name: "calculator".to_string(),
                args: json!({"expression": "2+2"}),
            }],
            usage: None,
            model: String::new(),
        })
    }

    fn stream(
        &self,
        _input: LlmRequest,
    ) -> futures::stream::BoxStream<'_, Result<wesichain_core::StreamEvent, WesichainError>> {
        futures::stream::empty().boxed()
    }
}

impl ToolCallingLlm for MockLlm {}

#[tokio::test]
async fn langsmith_traces_react_agent() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path_regex("/runs"))
        .respond_with(ResponseTemplate::new(200))
        .mount(&server)
        .await;
    Mock::given(method("PATCH"))
        .and(path_regex("/runs/.*"))
        .respond_with(ResponseTemplate::new(200))
        .mount(&server)
        .await;

    let config = LangSmithConfig {
        api_key: SecretString::new("key".to_string()),
        api_url: server.uri(),
        project_name: "test".to_string(),
        flush_interval: Duration::from_secs(3600),
        max_batch_size: 50,
        queue_capacity: 1000,
        sampling_rate: 1.0,
        redact_regex: None,
    };

    let observer = Arc::new(LangSmithObserver::new(config));
    let options = ExecutionOptions {
        observer: Some(observer.clone()),
        ..Default::default()
    };

    let node = ReActAgentNode::builder()
        .llm(Arc::new(MockLlm))
        .tools(vec![Arc::new(MockTool)])
        .max_iterations(1)
        .build()
        .unwrap();

    let graph = GraphBuilder::new()
        .add_node("agent", node)
        .set_entry("agent")
        .build();

    let state = GraphState::new(DemoState {
        input: "2+2".to_string(),
        ..Default::default()
    });

    let _ = graph.invoke_with_options(state, options).await.unwrap();
    let stats = observer.flush(Duration::from_secs(5)).await.unwrap();

    assert!(stats.runs_flushed > 0);
    let requests = server.received_requests().await.unwrap();
    assert!(requests.iter().any(|req| req.method == "POST"));
    assert!(requests.iter().any(|req| req.method == "PATCH"));

    let has_parent = requests.iter().any(|req| {
        if req.method != "POST" {
            return false;
        }
        let payload: serde_json::Value = match serde_json::from_slice(&req.body) {
            Ok(value) => value,
            Err(_) => return false,
        };
        payload
            .get("parent_run_id")
            .map(|value| !value.is_null())
            .unwrap_or(false)
    });
    assert!(has_parent);
}