use std::sync::Arc;
use cognis_core::documents::Document;
use cognis_core::error::{CognisError, Result};
use cognis_core::language_models::chat_model::BaseChatModel;
use cognis_core::messages::{HumanMessage, Message};
const DEFAULT_INITIAL_PROMPT: &str = "Answer the question based on the text.\n\n\
Text: {text}\n\
Question: {question}\n\
Answer:";
const DEFAULT_REFINE_PROMPT: &str = "Refine the answer using additional context.\n\n\
Existing answer: {existing_answer}\n\
New context: {text}\n\
Question: {question}\n\
Refined answer:";
pub struct RefineChain {
llm: Arc<dyn BaseChatModel>,
initial_prompt: String,
refine_prompt: String,
}
impl RefineChain {
pub fn new(llm: Arc<dyn BaseChatModel>) -> Self {
Self {
llm,
initial_prompt: DEFAULT_INITIAL_PROMPT.to_string(),
refine_prompt: DEFAULT_REFINE_PROMPT.to_string(),
}
}
pub fn with_initial_prompt(mut self, prompt: impl Into<String>) -> Self {
self.initial_prompt = prompt.into();
self
}
pub fn with_refine_prompt(mut self, prompt: impl Into<String>) -> Self {
self.refine_prompt = prompt.into();
self
}
fn format_initial_prompt(&self, text: &str, question: &str) -> String {
self.initial_prompt
.replace("{text}", text)
.replace("{question}", question)
}
fn format_refine_prompt(&self, text: &str, question: &str, existing_answer: &str) -> String {
self.refine_prompt
.replace("{text}", text)
.replace("{question}", question)
.replace("{existing_answer}", existing_answer)
}
pub async fn call(&self, question: &str, documents: &[Document]) -> Result<String> {
if documents.is_empty() {
return Err(CognisError::Other(
"RefineChain requires at least one document".into(),
));
}
let first_prompt = self.format_initial_prompt(&documents[0].page_content, question);
let messages = vec![Message::Human(HumanMessage::new(&first_prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
let mut current_answer = ai_msg.base.content.text();
for doc in &documents[1..] {
let refine_prompt =
self.format_refine_prompt(&doc.page_content, question, ¤t_answer);
let messages = vec![Message::Human(HumanMessage::new(&refine_prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
current_answer = ai_msg.base.content.text();
}
Ok(current_answer)
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::language_models::fake::{FakeListChatModel, ParrotFakeChatModel};
fn fake_model(responses: Vec<&str>) -> Arc<dyn BaseChatModel> {
Arc::new(FakeListChatModel::new(
responses.into_iter().map(String::from).collect(),
))
}
fn make_doc(content: &str) -> Document {
Document::new(content)
}
#[tokio::test]
async fn test_single_doc_uses_initial_prompt_only() {
let llm: Arc<dyn BaseChatModel> = Arc::new(ParrotFakeChatModel::new());
let chain = RefineChain::new(llm);
let docs = vec![make_doc("Rust is a language")];
let result = chain.call("What is Rust?", &docs).await.unwrap();
assert!(result.contains("Answer the question based on the text"));
assert!(result.contains("Rust is a language"));
assert!(result.contains("What is Rust?"));
}
#[tokio::test]
async fn test_multiple_docs_refine_iteratively() {
let llm = fake_model(vec!["initial-answer", "refined-once", "refined-twice"]);
let chain = RefineChain::new(llm);
let docs = vec![make_doc("Doc 1"), make_doc("Doc 2"), make_doc("Doc 3")];
let result = chain.call("question?", &docs).await.unwrap();
assert_eq!(result, "refined-twice");
}
#[tokio::test]
async fn test_end_to_end_call() {
let llm = fake_model(vec!["answer-v1", "answer-v2"]);
let chain = RefineChain::new(llm);
let docs = vec![make_doc("Context A"), make_doc("Context B")];
let result = chain.call("What is the answer?", &docs).await.unwrap();
assert_eq!(result, "answer-v2");
}
#[tokio::test]
async fn test_custom_prompts() {
let llm: Arc<dyn BaseChatModel> = Arc::new(ParrotFakeChatModel::new());
let chain = RefineChain::new(llm)
.with_initial_prompt("INIT: {text} Q: {question}")
.with_refine_prompt("REFINE: {existing_answer} + {text} Q: {question}");
let docs = vec![make_doc("alpha"), make_doc("beta")];
let result = chain.call("test?", &docs).await.unwrap();
assert!(result.contains("REFINE:"));
assert!(result.contains("beta"));
assert!(result.contains("test?"));
}
#[tokio::test]
async fn test_empty_documents_returns_error() {
let llm = fake_model(vec!["unused"]);
let chain = RefineChain::new(llm);
let result = chain.call("question?", &[]).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("at least one document"));
}
#[tokio::test]
async fn test_refine_prompt_includes_existing_answer() {
let llm: Arc<dyn BaseChatModel> = Arc::new(ParrotFakeChatModel::new());
let chain = RefineChain::new(llm);
let refine_formatted = chain.format_refine_prompt("new text", "my question", "old answer");
assert!(refine_formatted.contains("Existing answer: old answer"));
assert!(refine_formatted.contains("New context: new text"));
assert!(refine_formatted.contains("Question: my question"));
}
}