use std::sync::Arc;
use oris_runtime::{
chain::{Chain, ConversationalRetrieverChainBuilder},
llm::{OpenAI, OpenAIModel},
memory::SimpleMemory,
prompt_args,
retrievers::{MergeStrategy, MergerRetriever, WikipediaRetriever},
schemas::Retriever as RetrieverTrait,
};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let wikipedia_retriever = Arc::new(WikipediaRetriever::new().with_max_docs(3));
let mut merger_retriever = MergerRetriever::new(vec![wikipedia_retriever.clone()]);
merger_retriever.config.strategy = MergeStrategy::ReciprocalRankFusion { k: 60.0 };
merger_retriever.config.top_k = 5;
let retriever: Box<dyn RetrieverTrait> = Box::new(merger_retriever);
let llm = OpenAI::default().with_model(OpenAIModel::Gpt35.to_string());
let chain = ConversationalRetrieverChainBuilder::new()
.llm(llm)
.retriever(retriever)
.memory(SimpleMemory::new().into())
.rephrase_question(true)
.build()?;
let result = chain
.invoke(prompt_args! {
"question" => "What is machine learning?",
})
.await?;
println!("Answer: {}", result);
Ok(())
}