use langchainrust::callbacks::{RunTree, RunType, CallbackManager, CallbackHandler, StdOutHandler};
use langchainrust::schema::Message;
use async_trait::async_trait;
use std::sync::{Arc, Mutex};
#[test]
fn test_run_type_as_str() {
assert_eq!(RunType::Llm.as_str(), "llm");
assert_eq!(RunType::Chain.as_str(), "chain");
assert_eq!(RunType::Tool.as_str(), "tool");
assert_eq!(RunType::Retriever.as_str(), "retriever");
assert_eq!(RunType::Embedding.as_str(), "embedding");
assert_eq!(RunType::Prompt.as_str(), "prompt");
assert_eq!(RunType::Parser.as_str(), "parser");
}
#[test]
fn test_run_type_emoji() {
assert_eq!(RunType::Llm.emoji(), "🤖");
assert_eq!(RunType::Chain.emoji(), "🔗");
assert_eq!(RunType::Tool.emoji(), "🔧");
assert_eq!(RunType::Retriever.emoji(), "📚");
assert_eq!(RunType::Embedding.emoji(), "📊");
assert_eq!(RunType::Prompt.emoji(), "📝");
assert_eq!(RunType::Parser.emoji(), "📄");
}
#[test]
fn test_run_type_display() {
assert_eq!(format!("{}", RunType::Llm), "llm");
assert_eq!(format!("{}", RunType::Chain), "chain");
}
#[test]
fn test_run_tree_new() {
let run = RunTree::new("Test Run", RunType::Chain, serde_json::json!({"input": "test"}));
assert_eq!(run.name, "Test Run");
assert_eq!(run.run_type, RunType::Chain);
assert!(run.outputs.is_none(), "创建时 outputs 应为 None");
assert!(run.error.is_none(), "创建时 error 应为 None");
assert!(run.parent_run_id.is_none(), "根运行的 parent_run_id 应为 None");
assert!(run.end_time.is_none(), "结束前 end_time 应为 None");
assert!(run.tags.is_empty(), "默认 tags 应为空");
assert!(run.metadata.is_empty(), "默认 metadata 应为空");
}
#[test]
fn test_run_tree_end() {
let mut run = RunTree::new("Test", RunType::Llm, serde_json::json!({}));
assert!(run.end_time.is_none());
run.end(serde_json::json!({"output": "result"}));
assert!(run.outputs.is_some(), "end() 后 outputs 应被设置");
assert!(run.end_time.is_some(), "end() 后 end_time 应被设置");
assert!(run.duration_ms().is_some(), "end() 后 duration_ms 应可用");
assert!(run.duration_ms().unwrap() >= 0, "耗时应为非负数");
}
#[test]
fn test_run_tree_end_with_error() {
let mut run = RunTree::new("Test", RunType::Tool, serde_json::json!({}));
run.end_with_error("Something went wrong");
assert!(run.error.is_some(), "error 应被设置");
assert_eq!(run.error.unwrap(), "Something went wrong");
assert!(run.end_time.is_some(), "错误时 end_time 也应被设置");
assert!(run.outputs.is_none(), "错误时 outputs 应保持 None");
}
#[test]
fn test_run_tree_with_tag() {
let run = RunTree::new("Test", RunType::Chain, serde_json::json!({}))
.with_tag("test-tag")
.with_tag("another-tag");
assert_eq!(run.tags.len(), 2);
assert!(run.tags.contains(&"test-tag".to_string()));
assert!(run.tags.contains(&"another-tag".to_string()));
}
#[test]
fn test_run_tree_with_metadata() {
let run = RunTree::new("Test", RunType::Chain, serde_json::json!({}))
.with_metadata("version", serde_json::json!("1.0"))
.with_metadata("count", serde_json::json!(42));
assert_eq!(run.metadata.len(), 2);
assert_eq!(run.metadata.get("version").unwrap(), "1.0");
assert_eq!(run.metadata.get("count").unwrap(), 42);
}
#[test]
fn test_run_tree_with_project() {
let run = RunTree::new("Test", RunType::Chain, serde_json::json!({}))
.with_project("my-langsmith-project");
assert_eq!(run.project_name, Some("my-langsmith-project".to_string()));
}
#[test]
fn test_run_tree_create_child() {
let parent = RunTree::new("Parent", RunType::Chain, serde_json::json!({"input": "test"}));
let child = parent.create_child("Child", RunType::Tool, serde_json::json!({"action": "run"}));
assert_eq!(child.name, "Child");
assert_eq!(child.run_type, RunType::Tool);
assert_eq!(child.parent_run_id, Some(parent.id), "子运行应引用父运行");
assert_eq!(child.trace_id, Some(parent.id), "trace_id 应为父运行的 id");
assert_eq!(child.project_name, parent.project_name, "子运行继承 project_name");
}
#[test]
fn test_run_tree_nested_children() {
let parent = RunTree::new("Parent", RunType::Chain, serde_json::json!({}));
let child1 = parent.create_child("Child1", RunType::Tool, serde_json::json!({}));
let grandchild = child1.create_child("Grandchild", RunType::Llm, serde_json::json!({}));
assert_eq!(grandchild.parent_run_id, Some(child1.id));
assert_eq!(grandchild.trace_id, Some(parent.id), "所有后代共享根 trace_id");
}
#[test]
fn test_run_tree_chain() {
let parent = RunTree::new("Chain", RunType::Chain, serde_json::json!({"input": "query"}));
let tool_run = parent.create_child("Calculator", RunType::Tool, serde_json::json!({"expr": "1+1"}));
let llm_run = parent.create_child("LLM", RunType::Llm, serde_json::json!({"prompt": "..."}));
assert_eq!(tool_run.parent_run_id, Some(parent.id));
assert_eq!(llm_run.parent_run_id, Some(parent.id));
assert_eq!(tool_run.trace_id, Some(parent.id));
assert_eq!(llm_run.trace_id, Some(parent.id));
}
#[test]
fn test_run_tree_duration() {
let mut run = RunTree::new("Test", RunType::Llm, serde_json::json!({}));
assert!(run.duration_ms().is_none(), "结束前 duration 应为 None");
run.end(serde_json::json!({"output": "done"}));
assert!(run.duration_ms().unwrap() >= 0, "结束后 duration 应为非负数");
}
#[test]
fn test_run_tree_uuid_v7() {
let run1 = RunTree::new("First", RunType::Chain, serde_json::json!({}));
let run2 = RunTree::new("Second", RunType::Chain, serde_json::json!({}));
assert_ne!(run1.id, run2.id, "每个运行应有唯一 ID");
}
#[test]
fn test_run_tree_serialization() {
let run = RunTree::new("Test", RunType::Chain, serde_json::json!({"input": "test"}))
.with_tag("test-tag")
.with_metadata("key", serde_json::json!("value"));
let json = serde_json::to_string(&run).unwrap();
assert!(json.contains("Test"));
assert!(json.contains("chain"));
let deserialized: RunTree = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.name, "Test");
assert_eq!(deserialized.run_type, RunType::Chain);
}
#[test]
fn test_run_tree_tags_and_metadata_combined() {
let run = RunTree::new("Test", RunType::Chain, serde_json::json!({}))
.with_tag("production")
.with_tag("v2")
.with_metadata("user_id", serde_json::json!("123"))
.with_metadata("session", serde_json::json!({"id": "abc", "start": 12345}));
assert_eq!(run.tags.len(), 2);
assert!(run.tags.contains(&"production".to_string()));
assert!(run.tags.contains(&"v2".to_string()));
assert_eq!(run.metadata.len(), 2);
assert_eq!(run.metadata.get("user_id").unwrap(), "123");
}
#[test]
fn test_callback_manager_new() {
let manager = CallbackManager::new();
assert!(manager.is_empty());
assert_eq!(manager.handlers().len(), 0);
}
#[test]
fn test_callback_manager_add_handler() {
let manager = CallbackManager::new()
.add_handler(Arc::new(StdOutHandler::new()));
assert!(!manager.is_empty());
assert_eq!(manager.handlers().len(), 1);
}
#[test]
fn test_callback_manager_multiple_handlers() {
let manager = CallbackManager::new()
.add_handler(Arc::new(StdOutHandler::new()))
.add_handler(Arc::new(StdOutHandler::new().with_verbose(false)));
assert_eq!(manager.handlers().len(), 2);
}
#[test]
fn test_callback_manager_clone() {
let manager = CallbackManager::new()
.add_handler(Arc::new(StdOutHandler::new()));
let cloned = manager.clone();
assert_eq!(cloned.handlers().len(), 1);
}
#[test]
fn test_callback_manager_debug() {
let manager = CallbackManager::new()
.add_handler(Arc::new(StdOutHandler::new()))
.add_handler(Arc::new(StdOutHandler::new()));
let debug_str = format!("{:?}", manager);
assert!(debug_str.contains("CallbackManager"));
assert!(debug_str.contains("handlers_count"));
}
struct MockCallbackHandler {
start_count: Arc<Mutex<usize>>,
end_count: Arc<Mutex<usize>>,
error_count: Arc<Mutex<usize>>,
}
impl MockCallbackHandler {
fn new() -> Self {
Self {
start_count: Arc::new(Mutex::new(0)),
end_count: Arc::new(Mutex::new(0)),
error_count: Arc::new(Mutex::new(0)),
}
}
}
#[async_trait]
impl CallbackHandler for MockCallbackHandler {
async fn on_run_start(&self, _run: &RunTree) {
let mut count = self.start_count.lock().unwrap();
*count += 1;
}
async fn on_run_end(&self, _run: &RunTree) {
let mut count = self.end_count.lock().unwrap();
*count += 1;
}
async fn on_run_error(&self, _run: &RunTree, _error: &str) {
let mut count = self.error_count.lock().unwrap();
*count += 1;
}
}
#[tokio::test]
async fn test_callback_handler_calls() {
let handler = Arc::new(MockCallbackHandler::new());
let start_count = Arc::clone(&handler.start_count);
let end_count = Arc::clone(&handler.end_count);
let manager = CallbackManager::new().add_handler(handler);
let run = RunTree::new("Test", RunType::Chain, serde_json::json!({}));
for h in manager.handlers() {
h.on_run_start(&run).await;
}
assert_eq!(*start_count.lock().unwrap(), 1);
for h in manager.handlers() {
h.on_run_end(&run).await;
}
assert_eq!(*end_count.lock().unwrap(), 1);
}
#[tokio::test]
async fn test_callback_handler_error() {
let handler = Arc::new(MockCallbackHandler::new());
let error_count = Arc::clone(&handler.error_count);
let manager = CallbackManager::new().add_handler(handler);
let run = RunTree::new("Test", RunType::Chain, serde_json::json!({}));
for h in manager.handlers() {
h.on_run_error(&run, "test error").await;
}
assert_eq!(*error_count.lock().unwrap(), 1);
}
#[tokio::test]
async fn test_llm_callbacks() {
let handler = Arc::new(MockCallbackHandler::new());
let start_count = Arc::clone(&handler.start_count);
let end_count = Arc::clone(&handler.end_count);
let manager = CallbackManager::new().add_handler(handler);
let run = RunTree::new("LLM", RunType::Llm, serde_json::json!({}));
let messages = vec![Message::human("test")];
for h in manager.handlers() {
h.on_llm_start(&run, &messages).await;
}
assert_eq!(*start_count.lock().unwrap(), 1);
for h in manager.handlers() {
h.on_llm_end(&run, "response").await;
}
assert_eq!(*end_count.lock().unwrap(), 1);
}
#[tokio::test]
async fn test_tool_callbacks() {
let handler = Arc::new(MockCallbackHandler::new());
let start_count = Arc::clone(&handler.start_count);
let end_count = Arc::clone(&handler.end_count);
let manager = CallbackManager::new().add_handler(handler);
let run = RunTree::new("Tool", RunType::Tool, serde_json::json!({}));
for h in manager.handlers() {
h.on_tool_start(&run, "Calculator", "1 + 1").await;
}
assert_eq!(*start_count.lock().unwrap(), 1);
for h in manager.handlers() {
h.on_tool_end(&run, "2").await;
}
assert_eq!(*end_count.lock().unwrap(), 1);
}
#[tokio::test]
async fn test_retriever_callbacks() {
let handler = Arc::new(MockCallbackHandler::new());
let start_count = Arc::clone(&handler.start_count);
let end_count = Arc::clone(&handler.end_count);
let manager = CallbackManager::new().add_handler(handler);
let run = RunTree::new("Retriever", RunType::Retriever, serde_json::json!({}));
for h in manager.handlers() {
h.on_retriever_start(&run, "query").await;
}
assert_eq!(*start_count.lock().unwrap(), 1);
for h in manager.handlers() {
h.on_retriever_end(&run, &[serde_json::json!({"doc": "result"})]).await;
}
assert_eq!(*end_count.lock().unwrap(), 1);
}