use std::sync::Arc;
use comp_cat_rs::effect::io::Io;
use crate::embedding::{Embedding, EmbeddingModel, EmbeddingRequest};
use crate::error::Error;
use crate::model::CompletionModel;
use crate::vector_store::{SearchResult, VectorStoreIndex};
#[allow(clippy::needless_pass_by_value)] pub fn rag<M, E, V>(
query: String,
model: Arc<M>,
embedder: Arc<E>,
store: Arc<V>,
preamble: Option<String>,
top_k: usize,
) -> Io<Error, String>
where
M: CompletionModel + Send + Sync + 'static,
E: EmbeddingModel + Send + Sync + 'static,
V: VectorStoreIndex + Send + Sync + 'static,
{
let embed_io = embedder.embed(EmbeddingRequest::single(query.clone()));
embed_io.flat_map(move |embeddings| {
let query_embedding: Embedding = embeddings.into_iter()
.next()
.unwrap_or_else(|| Embedding::new(Vec::new()));
store.search(&query_embedding, top_k).flat_map(move |results| {
let context = format_context(&results);
let augmented_prompt = format!(
"Context:\n{context}\n\nQuestion: {query}\n\nAnswer based on the context above."
);
let messages = preamble.iter()
.map(|p| crate::model::Message::system(p.clone()))
.chain(std::iter::once(crate::model::Message::user(augmented_prompt)))
.collect();
let request = crate::model::CompletionRequest::new(messages);
model.complete(request).map(|r| r.content().to_owned())
})
})
}
fn format_context(results: &[SearchResult]) -> String {
results.iter()
.enumerate()
.map(|(i, r)| format!("[{}] (score: {:.3}) {}", i + 1, r.score(), r.document().content()))
.collect::<Vec<_>>()
.join("\n")
}