use std::sync::Arc;
use async_trait::async_trait;
use serde_json::{json, Value};
use tokio::sync::RwLock;
use cognis_core::documents::Document;
use cognis_core::error::Result;
use cognis_core::language_models::chat_model::BaseChatModel;
use cognis_core::messages::{HumanMessage, Message};
use cognis_core::retrievers::BaseRetriever;
use cognis_core::runnables::base::Runnable;
use cognis_core::runnables::config::RunnableConfig;
const DEFAULT_CONDENSE_PROMPT: &str = "Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.\n\nChat History:\n{chat_history}\nFollow Up Input: {question}\nStandalone question:";
const DEFAULT_QA_PROMPT: &str = "Use the following context to answer the question.\n\nContext:\n{context}\n\nQuestion: {question}\n\nAnswer:";
#[derive(Debug, Clone)]
pub struct ConversationalRetrievalResult {
pub answer: String,
pub source_documents: Vec<Document>,
pub condensed_question: String,
}
pub struct ConversationalRetrievalChain {
retriever: Arc<dyn BaseRetriever>,
llm: Arc<dyn BaseChatModel>,
condense_llm: Option<Arc<dyn BaseChatModel>>,
chat_history: Arc<RwLock<Vec<Message>>>,
k: usize,
condense_prompt: String,
qa_prompt: String,
}
impl ConversationalRetrievalChain {
pub fn new(retriever: Arc<dyn BaseRetriever>, llm: Arc<dyn BaseChatModel>) -> Self {
Self {
retriever,
llm,
condense_llm: None,
chat_history: Arc::new(RwLock::new(Vec::new())),
k: 4,
condense_prompt: DEFAULT_CONDENSE_PROMPT.to_string(),
qa_prompt: DEFAULT_QA_PROMPT.to_string(),
}
}
pub fn with_condense_llm(mut self, llm: Arc<dyn BaseChatModel>) -> Self {
self.condense_llm = Some(llm);
self
}
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn with_condense_prompt(mut self, prompt: impl Into<String>) -> Self {
self.condense_prompt = prompt.into();
self
}
pub fn with_qa_prompt(mut self, prompt: impl Into<String>) -> Self {
self.qa_prompt = prompt.into();
self
}
pub async fn condense_question(&self, question: &str) -> Result<String> {
let history = self.chat_history.read().await;
if history.is_empty() {
return Ok(question.to_string());
}
let history_text: String = history
.iter()
.map(|m| {
let role = match m.message_type() {
cognis_core::messages::MessageType::Human => "Human",
cognis_core::messages::MessageType::Ai => "AI",
_ => "Other",
};
format!("{}: {}", role, m.content().text())
})
.collect::<Vec<_>>()
.join("\n");
drop(history);
let prompt = self
.condense_prompt
.replace("{chat_history}", &history_text)
.replace("{question}", question);
let condense_model = self.condense_llm.as_ref().unwrap_or(&self.llm);
let messages = vec![Message::Human(HumanMessage::new(&prompt))];
let ai_msg = condense_model.invoke_messages(&messages, None).await?;
Ok(ai_msg.base.content.text())
}
pub async fn call(&self, question: &str) -> Result<String> {
let result = self.call_with_sources(question).await?;
Ok(result.answer)
}
pub async fn call_with_sources(&self, question: &str) -> Result<ConversationalRetrievalResult> {
let condensed = self.condense_question(question).await?;
let mut docs = self.retriever.get_relevant_documents(&condensed).await?;
docs.truncate(self.k);
let context = docs
.iter()
.map(|d| d.page_content.as_str())
.collect::<Vec<_>>()
.join("\n\n");
let qa_input = self
.qa_prompt
.replace("{context}", &context)
.replace("{question}", &condensed);
let messages = vec![Message::Human(HumanMessage::new(&qa_input))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
let answer = ai_msg.base.content.text();
{
let mut history = self.chat_history.write().await;
history.push(Message::human(question));
history.push(Message::ai(&answer));
}
Ok(ConversationalRetrievalResult {
answer,
source_documents: docs,
condensed_question: condensed,
})
}
pub async fn clear_history(&self) {
let mut history = self.chat_history.write().await;
history.clear();
}
}
#[async_trait]
impl Runnable for ConversationalRetrievalChain {
fn name(&self) -> &str {
"ConversationalRetrievalChain"
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
let question = input
.get("input")
.and_then(|v| v.as_str())
.or_else(|| input.as_str())
.ok_or_else(|| cognis_core::error::CognisError::TypeMismatch {
expected: "Object with 'input' string key or a plain string".into(),
got: format!("{}", input),
})?
.to_string();
let result = self.call_with_sources(&question).await?;
let source_docs: Vec<Value> = result
.source_documents
.iter()
.map(|d| serde_json::to_value(d).unwrap_or(Value::Null))
.collect();
Ok(json!({
"answer": result.answer,
"source_documents": source_docs,
"condensed_question": result.condensed_question,
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::documents::Document;
use cognis_core::error::Result;
use cognis_core::language_models::fake::FakeListChatModel;
use cognis_core::retrievers::BaseRetriever;
struct MockRetriever {
docs: Vec<Document>,
}
#[async_trait]
impl BaseRetriever for MockRetriever {
async fn get_relevant_documents(&self, _query: &str) -> Result<Vec<Document>> {
Ok(self.docs.clone())
}
}
fn make_retriever(docs: Vec<Document>) -> Arc<dyn BaseRetriever> {
Arc::new(MockRetriever { docs })
}
fn make_docs() -> Vec<Document> {
vec![
Document::new("Rust is a systems programming language."),
Document::new("Rust was first released in 2015."),
]
}
fn fake_model(responses: Vec<&str>) -> Arc<dyn BaseChatModel> {
Arc::new(FakeListChatModel::new(
responses.into_iter().map(String::from).collect(),
))
}
#[tokio::test]
async fn test_first_question_skips_condensation() {
let chain = ConversationalRetrievalChain::new(
make_retriever(make_docs()),
fake_model(vec!["answer about Rust"]),
);
let result = chain.call_with_sources("What is Rust?").await.unwrap();
assert_eq!(result.condensed_question, "What is Rust?");
assert_eq!(result.answer, "answer about Rust");
}
#[tokio::test]
async fn test_followup_condenses_with_history() {
let llm = fake_model(vec![
"Rust is great", "What year was Rust released?", "Rust was released in 2015", ]);
let chain = ConversationalRetrievalChain::new(make_retriever(make_docs()), llm);
let r1 = chain.call_with_sources("Tell me about Rust").await.unwrap();
assert_eq!(r1.condensed_question, "Tell me about Rust");
assert_eq!(r1.answer, "Rust is great");
let r2 = chain
.call_with_sources("When was it released?")
.await
.unwrap();
assert_eq!(r2.condensed_question, "What year was Rust released?");
assert_eq!(r2.answer, "Rust was released in 2015");
}
#[tokio::test]
async fn test_call_returns_answer_and_updates_history() {
let chain = ConversationalRetrievalChain::new(
make_retriever(make_docs()),
fake_model(vec!["the answer"]),
);
let answer = chain.call("my question").await.unwrap();
assert_eq!(answer, "the answer");
let history = chain.chat_history.read().await;
assert_eq!(history.len(), 2);
assert_eq!(history[0].content().text(), "my question");
assert_eq!(history[1].content().text(), "the answer");
}
#[tokio::test]
async fn test_call_with_sources_returns_docs_and_condensed() {
let docs = vec![
Document::new("doc1 content").with_id("d1"),
Document::new("doc2 content").with_id("d2"),
];
let chain = ConversationalRetrievalChain::new(
make_retriever(docs.clone()),
fake_model(vec!["sourced answer"]),
);
let result = chain.call_with_sources("question?").await.unwrap();
assert_eq!(result.answer, "sourced answer");
assert_eq!(result.source_documents.len(), 2);
assert_eq!(result.source_documents[0].page_content, "doc1 content");
assert_eq!(result.source_documents[1].page_content, "doc2 content");
assert_eq!(result.condensed_question, "question?");
}
#[tokio::test]
async fn test_clear_history_resets() {
let chain = ConversationalRetrievalChain::new(
make_retriever(make_docs()),
fake_model(vec!["a1", "condensed q", "a2"]),
);
chain.call("first question").await.unwrap();
{
let history = chain.chat_history.read().await;
assert_eq!(history.len(), 2);
}
chain.clear_history().await;
{
let history = chain.chat_history.read().await;
assert!(history.is_empty());
}
let r = chain.call_with_sources("new question").await.unwrap();
assert_eq!(r.condensed_question, "new question");
}
#[tokio::test]
async fn test_with_condense_llm() {
let main_llm = fake_model(vec!["main answer 1", "main answer 2"]);
let condense_llm = fake_model(vec!["standalone question from condense llm"]);
let chain = ConversationalRetrievalChain::new(make_retriever(make_docs()), main_llm)
.with_condense_llm(condense_llm);
chain.call("first q").await.unwrap();
let r = chain.call_with_sources("follow up").await.unwrap();
assert_eq!(
r.condensed_question,
"standalone question from condense llm"
);
assert_eq!(r.answer, "main answer 2");
}
#[tokio::test]
async fn test_k_limits_documents() {
let docs = vec![
Document::new("doc1"),
Document::new("doc2"),
Document::new("doc3"),
Document::new("doc4"),
Document::new("doc5"),
];
let chain =
ConversationalRetrievalChain::new(make_retriever(docs), fake_model(vec!["answer"]))
.with_k(2);
let result = chain.call_with_sources("q").await.unwrap();
assert_eq!(result.source_documents.len(), 2);
}
}