wesichain-langsmith 0.3.0

LangSmith observability for Wesichain
Documentation
use std::sync::Arc;
use std::time::Duration;

use chrono::{DateTime, Utc};
use dashmap::DashMap;
use regex::Regex;
use serde_json::{json, Value};
use uuid::Uuid;
use wesichain_core::{CallbackHandler, LlmInput, LlmResult, RunContext, RunType as CoreRunType};

use crate::{
    ensure_object, sanitize_value, truncate_value, FlushError, FlushStats, LangSmithConfig,
    LangSmithExporter, ProbabilitySampler, RunContextStore, RunEvent, RunType, Sampler,
};

const DEFAULT_MAX_BYTES: usize = 100_000;

pub struct LangSmithCallbackHandler {
    exporter: LangSmithExporter,
    sampler: Arc<dyn Sampler>,
    trace_sampling: DashMap<Uuid, bool>,
    redact_regex: Option<Regex>,
    max_bytes: usize,
    session_name: String,
}

impl LangSmithCallbackHandler {
    pub fn new(config: LangSmithConfig) -> Self {
        let sampler = Arc::new(ProbabilitySampler {
            rate: config.sampling_rate,
        });
        Self::with_sampler(config, sampler)
    }

    pub fn with_sampler(config: LangSmithConfig, sampler: Arc<dyn Sampler>) -> Self {
        let store = Arc::new(RunContextStore::default());
        let exporter = LangSmithExporter::new(config.clone(), store);
        Self {
            exporter,
            sampler,
            trace_sampling: DashMap::new(),
            redact_regex: config.redact_regex.clone(),
            max_bytes: DEFAULT_MAX_BYTES,
            session_name: config.project_name.clone(),
        }
    }

    pub async fn flush(&self, timeout: Duration) -> Result<FlushStats, FlushError> {
        self.exporter.flush(timeout).await
    }

    pub fn dropped_events(&self) -> usize {
        self.exporter.dropped_events()
    }

    fn should_sample(&self, trace_id: Uuid) -> bool {
        if let Some(entry) = self.trace_sampling.get(&trace_id) {
            return *entry;
        }
        let decision = self.sampler.should_sample(trace_id);
        self.trace_sampling.insert(trace_id, decision);
        decision
    }

    fn maybe_clear_trace(&self, ctx: &RunContext) {
        if ctx.parent_run_id.is_none() {
            self.trace_sampling.remove(&ctx.trace_id);
        }
    }

    fn sanitize_object(&self, value: Value) -> Value {
        let sanitized = sanitize_value(value, self.redact_regex.as_ref());
        let truncated = truncate_value(sanitized, self.max_bytes);
        ensure_object(truncated)
    }

    fn sanitize_error(&self, value: Value) -> String {
        let sanitized = sanitize_value(value, self.redact_regex.as_ref());
        let truncated = truncate_value(sanitized, self.max_bytes);
        match truncated {
            Value::String(text) => text,
            other => other.to_string(),
        }
    }

    fn map_run_type(run_type: &CoreRunType) -> RunType {
        match run_type {
            CoreRunType::Chain => RunType::Chain,
            CoreRunType::Tool => RunType::Tool,
            CoreRunType::Llm => RunType::Llm,
            CoreRunType::Agent => RunType::Agent,
            CoreRunType::Graph => RunType::Graph,
            CoreRunType::Retriever | CoreRunType::Runnable => RunType::Chain,
        }
    }

    fn prepare_llm_inputs(&self, input: &LlmInput) -> Value {
        let mut invocation_params = serde_json::Map::new();
        if let Some(temp) = input.temperature {
            invocation_params.insert("temperature".to_string(), json!(temp));
        }
        if let Some(max_tokens) = input.max_tokens {
            invocation_params.insert("max_tokens".to_string(), json!(max_tokens));
        }
        if !input.stop_sequences.is_empty() {
            invocation_params.insert("stop".to_string(), json!(input.stop_sequences));
        }

        json!({
            "model": input.model,
            "prompt": self.sanitize_object(Value::String(input.prompt.clone())),
            "invocation_params": invocation_params,
        })
    }

    fn prepare_llm_outputs(&self, result: &LlmResult) -> Value {
        let mut outputs = serde_json::Map::new();
        outputs.insert("generations".to_string(), json!(result.generations));
        outputs.insert("model".to_string(), json!(result.model));

        if let Some(usage) = &result.token_usage {
            outputs.insert(
                "token_usage".to_string(),
                json!({
                    "prompt_tokens": usage.prompt_tokens,
                    "completion_tokens": usage.completion_tokens,
                    "total_tokens": usage.total_tokens,
                }),
            );
        }

        self.sanitize_object(Value::Object(outputs))
    }
}

#[async_trait::async_trait]
impl CallbackHandler for LangSmithCallbackHandler {
    async fn on_start(&self, ctx: &RunContext, inputs: &Value) {
        if !self.should_sample(ctx.trace_id) {
            return;
        }

        let inputs = self.sanitize_object(inputs.clone());
        let metadata = serde_json::to_value(&ctx.metadata).unwrap_or(Value::Null);
        let metadata = self.sanitize_object(metadata);
        let event = RunEvent::Start {
            run_id: ctx.run_id,
            parent_run_id: ctx.parent_run_id,
            trace_id: ctx.trace_id,
            name: ctx.name.clone(),
            run_type: Self::map_run_type(&ctx.run_type),
            start_time: DateTime::<Utc>::from(ctx.start_time),
            inputs,
            tags: ctx.tags.clone(),
            metadata,
            session_name: self.session_name.clone(),
        };
        self.exporter.enqueue(event).await;
    }

    async fn on_end(&self, ctx: &RunContext, outputs: &Value, duration_ms: u128) {
        if !self.should_sample(ctx.trace_id) {
            self.maybe_clear_trace(ctx);
            return;
        }
        let outputs = self.sanitize_object(outputs.clone());
        let event = RunEvent::Update {
            run_id: ctx.run_id,
            end_time: Some(Utc::now()),
            outputs: Some(outputs),
            error: None,
            duration_ms: Some(duration_ms),
        };
        self.exporter.enqueue(event).await;
        self.maybe_clear_trace(ctx);
    }

    async fn on_error(&self, ctx: &RunContext, error: &Value, duration_ms: u128) {
        if !self.should_sample(ctx.trace_id) {
            self.maybe_clear_trace(ctx);
            return;
        }
        let error = self.sanitize_error(error.clone());
        let event = RunEvent::Update {
            run_id: ctx.run_id,
            end_time: Some(Utc::now()),
            outputs: None,
            error: Some(error),
            duration_ms: Some(duration_ms),
        };
        self.exporter.enqueue(event).await;
        self.maybe_clear_trace(ctx);
    }

    async fn on_llm_start(&self, ctx: &RunContext, input: &LlmInput) {
        if !self.should_sample(ctx.trace_id) {
            return;
        }

        let inputs = self.prepare_llm_inputs(input);
        let metadata = serde_json::to_value(&ctx.metadata).unwrap_or(Value::Null);
        let metadata = self.sanitize_object(metadata);

        let event = RunEvent::Start {
            run_id: ctx.run_id,
            parent_run_id: ctx.parent_run_id,
            trace_id: ctx.trace_id,
            name: ctx.name.clone(),
            run_type: RunType::Llm,
            start_time: DateTime::<Utc>::from(ctx.start_time),
            inputs,
            tags: ctx.tags.clone(),
            metadata,
            session_name: self.session_name.clone(),
        };
        self.exporter.enqueue(event).await;
    }

    async fn on_llm_end(&self, ctx: &RunContext, result: &LlmResult, duration_ms: u128) {
        if !self.should_sample(ctx.trace_id) {
            self.maybe_clear_trace(ctx);
            return;
        }

        let outputs = self.prepare_llm_outputs(result);

        let event = RunEvent::Update {
            run_id: ctx.run_id,
            end_time: Some(Utc::now()),
            outputs: Some(outputs),
            error: None,
            duration_ms: Some(duration_ms),
        };
        self.exporter.enqueue(event).await;
        self.maybe_clear_trace(ctx);
    }
}