langchainrust 0.2.7

A LangChain-inspired framework for building LLM applications in Rust. Supports OpenAI, Agents, Tools, Memory, Chains, RAG, BM25, Hybrid Retrieval, LangGraph, and native Function Calling.
use async_trait::async_trait;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::PathBuf;
use std::sync::Mutex;

use crate::callbacks::{CallbackHandler, RunTree};
use crate::schema::Message;

pub enum LogFormat {
    Plain,
    Json,
    JsonLines,
}

pub struct FileCallbackHandler {
    file: Mutex<File>,
    path: PathBuf,
    format: LogFormat,
}

impl FileCallbackHandler {
    pub fn new(path: impl Into<PathBuf>) -> Result<Self, std::io::Error> {
        let path = path.into();
        let file = OpenOptions::new()
            .create(true)
            .append(true)
            .open(&path)?;
        Ok(Self {
            file: Mutex::new(file),
            path,
            format: LogFormat::JsonLines,
        })
    }

    pub fn with_format(mut self, format: LogFormat) -> Self {
        self.format = format;
        self
    }

    fn write_log(&self, event: &str, run: &RunTree, data: serde_json::Value) {
        if let Ok(mut file) = self.file.lock() {
            match self.format {
                LogFormat::JsonLines => {
                    let entry = serde_json::json!({
                        "timestamp": chrono::Utc::now().to_rfc3339(),
                        "event": event,
                        "run_id": run.id,
                        "run_name": run.name,
                        "run_type": run.run_type.as_str(),
                        "data": data
                    });
                    let _ = writeln!(file, "{}", entry);
                }
                LogFormat::Json => {
                    let entry = serde_json::json!({
                        "timestamp": chrono::Utc::now().to_rfc3339(),
                        "event": event,
                        "run_id": run.id,
                        "run_name": run.name,
                        "run_type": run.run_type.as_str(),
                        "data": data
                    });
                    let _ = writeln!(file, "{}", serde_json::to_string_pretty(&entry).unwrap_or_default());
                }
                LogFormat::Plain => {
                    let _ = writeln!(
                        file,
                        "[{}] {} - {} ({})",
                        chrono::Utc::now().to_rfc3339(),
                        event,
                        run.name,
                        run.run_type.as_str()
                    );
                }
            }
        }
    }
}

#[async_trait]
impl CallbackHandler for FileCallbackHandler {
    async fn on_run_start(&self, run: &RunTree) {
        self.write_log("run_start", run, serde_json::json!({"inputs": run.inputs}));
    }

    async fn on_run_end(&self, run: &RunTree) {
        self.write_log("run_end", run, serde_json::json!({"outputs": run.outputs, "duration_ms": run.duration_ms()}));
    }

    async fn on_run_error(&self, run: &RunTree, error: &str) {
        self.write_log("run_error", run, serde_json::json!({"error": error}));
    }

    async fn on_llm_start(&self, run: &RunTree, messages: &[Message]) {
        self.write_log("llm_start", run, serde_json::json!({
            "messages_count": messages.len(),
            "messages": messages.iter().map(|m| m.content.clone()).collect::<Vec<_>>()
        }));
    }

    async fn on_llm_end(&self, run: &RunTree, response: &str) {
        self.write_log("llm_end", run, serde_json::json!({"response_length": response.len()}));
    }

    async fn on_llm_new_token(&self, run: &RunTree, token: &str) {
        self.write_log("llm_new_token", run, serde_json::json!({"token": token}));
    }

    async fn on_llm_error(&self, run: &RunTree, error: &str) {
        self.write_log("llm_error", run, serde_json::json!({"error": error}));
    }

    async fn on_chain_start(&self, run: &RunTree, inputs: &serde_json::Value) {
        self.write_log("chain_start", run, serde_json::json!({"inputs": inputs}));
    }

    async fn on_chain_end(&self, run: &RunTree, outputs: &serde_json::Value) {
        self.write_log("chain_end", run, serde_json::json!({"outputs": outputs}));
    }

    async fn on_chain_error(&self, run: &RunTree, error: &str) {
        self.write_log("chain_error", run, serde_json::json!({"error": error}));
    }

    async fn on_tool_start(&self, run: &RunTree, tool_name: &str, input: &str) {
        self.write_log("tool_start", run, serde_json::json!({"tool_name": tool_name, "input": input}));
    }

    async fn on_tool_end(&self, run: &RunTree, output: &str) {
        self.write_log("tool_end", run, serde_json::json!({"output": output}));
    }

    async fn on_tool_error(&self, run: &RunTree, error: &str) {
        self.write_log("tool_error", run, serde_json::json!({"error": error}));
    }

    async fn on_retriever_start(&self, run: &RunTree, query: &str) {
        self.write_log("retriever_start", run, serde_json::json!({"query": query}));
    }

    async fn on_retriever_end(&self, run: &RunTree, documents: &[serde_json::Value]) {
        self.write_log("retriever_end", run, serde_json::json!({"documents_count": documents.len()}));
    }

    async fn on_retriever_error(&self, run: &RunTree, error: &str) {
        self.write_log("retriever_error", run, serde_json::json!({"error": error}));
    }
}

impl std::fmt::Debug for FileCallbackHandler {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("FileCallbackHandler")
            .field("path", &self.path)
            .field("format", &match self.format {
                LogFormat::Plain => "Plain",
                LogFormat::Json => "Json",
                LogFormat::JsonLines => "JsonLines",
            })
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use tempfile::NamedTempFile;

    #[test]
    fn test_file_handler_creation() {
        let temp_file = NamedTempFile::new().unwrap();
        let handler = FileCallbackHandler::new(temp_file.path()).unwrap();
        assert!(handler.path.exists());
    }

    #[tokio::test]
    async fn test_write_log_json_lines() {
        let temp_file = NamedTempFile::new().unwrap();
        let handler = FileCallbackHandler::new(temp_file.path()).unwrap();
        
        let run = RunTree::new("test_run", crate::callbacks::RunType::Llm, serde_json::json!({"input": "test"}));
        handler.write_log("test_event", &run, serde_json::json!({"data": "value"}));
        
        let content = std::fs::read_to_string(temp_file.path()).unwrap();
        assert!(content.contains("test_event"));
        assert!(content.contains("\"run_name\":\"test_run\""));
    }

    #[tokio::test]
    async fn test_write_log_plain_format() {
        let temp_file = NamedTempFile::new().unwrap();
        let handler = FileCallbackHandler::new(temp_file.path()).unwrap().with_format(LogFormat::Plain);
        
        let run = RunTree::new("test_run", crate::callbacks::RunType::Llm, serde_json::json!({"input": "test"}));
        handler.write_log("test_event", &run, serde_json::json!({}));
        
        let content = std::fs::read_to_string(temp_file.path()).unwrap();
        assert!(content.contains("test_event"));
        assert!(content.contains("test_run"));
    }
}