langchainrust 0.2.13

A LangChain-inspired framework for building LLM applications in Rust. Supports OpenAI, Agents, Tools, Memory, Chains, RAG, BM25, Hybrid Retrieval, LangGraph, HyDE, Reranking, MultiQuery, and native Function Calling.
//! ConversationChain 单元测试

#[path = "../common/mod.rs"]
mod common;

use common::TestConfig;
use langchainrust::{ConversationChain, ConversationBufferMemory, BaseChain};
use langchainrust::memory::BaseMemory;
use langchainrust::schema::{Message, MessageType};
use std::collections::HashMap;

#[test]
fn test_conversation_chain_new() {
    let llm = TestConfig::get().openai_chat();
    let memory = ConversationBufferMemory::new();
    let chain = ConversationChain::new(llm, memory);
    
    assert_eq!(chain.input_keys(), vec!["input"]);
    assert_eq!(chain.output_keys(), vec!["output"]);
    assert_eq!(chain.name(), "conversation_chain");
}

#[test]
fn test_conversation_chain_with_options() {
    let llm = TestConfig::get().openai_chat();
    let memory = ConversationBufferMemory::new();
    
    let chain = ConversationChain::new(llm, memory)
        .with_system_prompt("你是一个友好的助手")
        .with_input_key("question")
        .with_output_key("answer")
        .with_verbose(true);
    
    assert_eq!(chain.input_keys(), vec!["question"]);
    assert_eq!(chain.output_keys(), vec!["answer"]);
}

#[test]
fn test_conversation_chain_builder() {
    let llm = TestConfig::get().openai_chat();
    
    let chain = ConversationChain::builder(llm)
        .system_prompt("你是一个 Rust 专家")
        .input_key("query")
        .output_key("response")
        .verbose(true)
        .build();
    
    assert_eq!(chain.input_keys(), vec!["query"]);
    assert_eq!(chain.output_keys(), vec!["response"]);
}

#[tokio::test]
async fn test_clear_memory() {
    let llm = TestConfig::get().openai_chat();
    let memory = ConversationBufferMemory::new();
    let chain = ConversationChain::new(llm, memory);
    
    chain.clear_memory().await.unwrap();
    
    let mem = chain.memory().lock().await;
    assert_eq!(mem.chat_memory().len(), 0);
}

#[tokio::test]
async fn test_memory_integration() {
    let mut memory = ConversationBufferMemory::new();
    
    let inputs = HashMap::from([("input".to_string(), "你好".to_string())]);
    let outputs = HashMap::from([("output".to_string(), "你好!".to_string())]);
    
    memory.save_context(&inputs, &outputs).await.unwrap();
    
    let vars = memory.load_memory_variables(&HashMap::new()).await.unwrap();
    let history = vars.get("history").unwrap().as_str().unwrap();
    
    assert!(history.contains("Human: 你好"));
    assert!(history.contains("AI: 你好!"));
}

#[test]
fn test_prepare_messages_structure() {
    let llm = TestConfig::get().openai_chat();
    let memory = ConversationBufferMemory::new();
    let chain = ConversationChain::new(llm, memory)
        .with_system_prompt("测试系统提示");
    
    let history = vec![
        Message::human("第一轮问题"),
        Message::ai("第一轮回答"),
    ];
    
    let messages = chain.prepare_messages("第二轮问题", &history);
    
    assert_eq!(messages.len(), 4);
    
    assert!(matches!(messages[0].message_type, MessageType::System));
    assert!(matches!(messages[1].message_type, MessageType::Human));
    assert!(matches!(messages[2].message_type, MessageType::AI));
    assert!(matches!(messages[3].message_type, MessageType::Human));
    
    assert_eq!(messages[0].content, "测试系统提示");
    assert_eq!(messages[1].content, "第一轮问题");
    assert_eq!(messages[2].content, "第一轮回答");
    assert_eq!(messages[3].content, "第二轮问题");
}

#[test]
fn test_prepare_messages_without_system_prompt() {
    let llm = TestConfig::get().openai_chat();
    let memory = ConversationBufferMemory::new();
    let chain = ConversationChain::new(llm, memory);
    
    let history = vec![
        Message::human("历史问题"),
        Message::ai("历史回答"),
    ];
    
    let messages = chain.prepare_messages("当前问题", &history);
    
    assert_eq!(messages.len(), 3);
    assert!(matches!(messages[0].message_type, MessageType::Human));
    assert!(matches!(messages[1].message_type, MessageType::AI));
    assert!(matches!(messages[2].message_type, MessageType::Human));
}

#[test]
fn test_prepare_messages_empty_history() {
    let llm = TestConfig::get().openai_chat();
    let memory = ConversationBufferMemory::new();
    let chain = ConversationChain::new(llm, memory)
        .with_system_prompt("你是一个助手");
    
    let messages = chain.prepare_messages("第一个问题", &[]);
    
    assert_eq!(messages.len(), 2);
    assert!(matches!(messages[0].message_type, MessageType::System));
    assert!(matches!(messages[1].message_type, MessageType::Human));
}

#[tokio::test]
async fn test_conversation_chain_invoke_structure() {
    let llm = TestConfig::get().openai_chat();
    let memory = ConversationBufferMemory::new();
    let chain = ConversationChain::new(llm, memory);
    
    let mem_before = chain.memory().lock().await;
    let count_before = mem_before.chat_memory().len();
    drop(mem_before);
    
    assert_eq!(count_before, 0);
}

#[tokio::test]
async fn test_memory_with_multiple_contexts() {
    let mut memory = ConversationBufferMemory::new();
    
    let inputs1 = HashMap::from([("input".to_string(), "第一轮".to_string())]);
    let outputs1 = HashMap::from([("output".to_string(), "回答1".to_string())]);
    memory.save_context(&inputs1, &outputs1).await.unwrap();
    
    let inputs2 = HashMap::from([("input".to_string(), "第二轮".to_string())]);
    let outputs2 = HashMap::from([("output".to_string(), "回答2".to_string())]);
    memory.save_context(&inputs2, &outputs2).await.unwrap();
    
    assert_eq!(memory.chat_memory().len(), 4);
    
    let vars = memory.load_memory_variables(&HashMap::new()).await.unwrap();
    let history = vars.get("history").unwrap().as_str().unwrap();
    
    assert!(history.contains("第一轮"));
    assert!(history.contains("回答1"));
    assert!(history.contains("第二轮"));
    assert!(history.contains("回答2"));
}

#[test]
fn test_conversation_chain_default_memory() {
    let llm = TestConfig::get().openai_chat();
    let chain = ConversationChain::builder(llm).build();
    
    assert_eq!(chain.input_keys(), vec!["input"]);
    assert_eq!(chain.output_keys(), vec!["output"]);
}

// ============================================================================
// 真实 LLM 调用测试
// ============================================================================

#[tokio::test]
async fn test_llm_single_call() {
    let llm = TestConfig::get().openai_chat();
    let memory = ConversationBufferMemory::new();
    
    let chain = ConversationChain::new(llm, memory)
        .with_system_prompt("你是一个友好的助手,请简短回答")
        .with_verbose(true);
    
    println!("\n=== 单轮对话测试 ===");
    let result = chain.predict("你好,介绍一下自己").await.unwrap();
    
    println!("用户: 你好,介绍一下自己");
    println!("AI: {}", result);
    
    assert!(!result.is_empty());
    
    let mem = chain.memory().lock().await;
    assert_eq!(mem.chat_memory().len(), 2);
}

#[tokio::test]
async fn test_llm_multi_turn_memory() {
    let llm = TestConfig::get().openai_chat();
    let memory = ConversationBufferMemory::new();
    
    let chain = ConversationChain::new(llm, memory)
        .with_system_prompt("你是一个友好的助手。请记住用户告诉你的信息。")
        .with_verbose(true);
    
    println!("\n=== 多轮对话记忆测试 ===");
    
    println!("\n--- 第一轮 ---");
    let result1 = chain.predict("你好,我叫张三,我喜欢编程").await.unwrap();
    println!("用户: 你好,我叫张三,我喜欢编程");
    println!("AI: {}", result1);
    assert!(!result1.is_empty());
    
    println!("\n--- 第二轮 ---");
    let result2 = chain.predict("我叫什么名字?我喜欢什么?").await.unwrap();
    println!("用户: 我叫什么名字?我喜欢什么?");
    println!("AI: {}", result2);
    assert!(!result2.is_empty());
    
    println!("\n--- 验证记忆 ---");
    let mem = chain.memory().lock().await;
    let history = mem.chat_memory().to_string();
    println!("历史记录:\n{}", history);
    
    assert!(history.contains("张三"), "记忆应包含名字");
    assert!(history.contains("编程"), "记忆应包含爱好");
    assert_eq!(mem.chat_memory().len(), 4);
}

#[tokio::test]
async fn test_llm_clear_memory() {
    let llm = TestConfig::get().openai_chat();
    let memory = ConversationBufferMemory::new();
    
    let chain = ConversationChain::new(llm, memory)
        .with_verbose(true);
    
    println!("\n=== 清空记忆测试 ===");
    
    println!("\n--- 第一轮 ---");
    let result1 = chain.predict("我叫李四").await.unwrap();
    println!("用户: 我叫李四");
    println!("AI: {}", result1);
    
    println!("\n--- 清空记忆 ---");
    chain.clear_memory().await.unwrap();
    
    let mem = chain.memory().lock().await;
    assert_eq!(mem.chat_memory().len(), 0);
    println!("记忆已清空");
    drop(mem);
    
    println!("\n--- 第二轮(清空后) ---");
    let result2 = chain.predict("我叫什么名字?").await.unwrap();
    println!("用户: 我叫什么名字?");
    println!("AI: {}", result2);
    println!("注意: AI 不应知道 '李四'");
}