use tracing::debug;
use crate::chunker::ChunkConfig;
use crate::embedding::Embedder;
use crate::error::RagError;
use crate::retriever::{Retriever, RetrieverConfig};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct RagConfig {
pub chunk_config: ChunkConfig,
pub retriever_config: RetrieverConfig,
pub max_context_chars: usize,
pub context_separator: String,
pub prompt_template: String,
}
impl Default for RagConfig {
fn default() -> Self {
Self {
chunk_config: ChunkConfig::default(),
retriever_config: RetrieverConfig::default(),
max_context_chars: 4096,
context_separator: "\n---\n".to_string(),
prompt_template: "{context}\n\nQuestion: {query}\n\nAnswer:".to_string(),
}
}
}
impl RagConfig {
#[must_use]
pub fn with_chunk_config(mut self, chunk_config: ChunkConfig) -> Self {
self.chunk_config = chunk_config;
self
}
#[must_use]
pub fn with_retriever_config(mut self, retriever_config: RetrieverConfig) -> Self {
self.retriever_config = retriever_config;
self
}
#[must_use]
pub fn with_max_context_chars(mut self, max_context_chars: usize) -> Self {
self.max_context_chars = max_context_chars;
self
}
#[must_use]
pub fn with_context_separator(mut self, context_separator: impl Into<String>) -> Self {
self.context_separator = context_separator.into();
self
}
#[must_use]
pub fn with_prompt_template(mut self, prompt_template: impl Into<String>) -> Self {
self.prompt_template = prompt_template.into();
self
}
}
#[derive(Debug, Clone)]
pub struct PipelineStats {
pub documents_indexed: usize,
pub chunks_indexed: usize,
pub embedding_dim: usize,
pub store_memory_bytes: usize,
}
pub struct RagPipeline<E: Embedder> {
retriever: Retriever<E>,
config: RagConfig,
}
impl<E: Embedder> RagPipeline<E> {
pub fn new(embedder: E, config: RagConfig) -> Self {
let retriever = Retriever::new(embedder, config.retriever_config.clone());
Self { retriever, config }
}
pub fn index_document(&mut self, text: &str) -> Result<usize, RagError> {
self.retriever.add_document(text, &self.config.chunk_config)
}
pub fn index_documents(&mut self, texts: &[&str]) -> Result<Vec<usize>, RagError> {
self.retriever
.add_documents(texts, &self.config.chunk_config)
}
pub fn retrieve_context(&self, query: &str) -> Result<String, RagError> {
if query.trim().is_empty() {
return Err(RagError::EmptyQuery);
}
let results = self.retriever.retrieve(query)?;
let mut parts: Vec<&str> = Vec::with_capacity(results.len());
let sep = &self.config.context_separator;
let mut total_chars = 0usize;
for result in &results {
let text_len = result.chunk.text.len();
let sep_len = if parts.is_empty() { 0 } else { sep.len() };
if total_chars + sep_len + text_len > self.config.max_context_chars && !parts.is_empty()
{
break;
}
total_chars += sep_len + text_len;
parts.push(&result.chunk.text);
}
debug!(
chunks_used = parts.len(),
context_chars = total_chars,
"context assembled"
);
Ok(parts.join(sep))
}
pub fn build_prompt(&self, query: &str) -> Result<String, RagError> {
if query.trim().is_empty() {
return Err(RagError::EmptyQuery);
}
let context = match self.retrieve_context(query) {
Ok(ctx) => ctx,
Err(RagError::NoDocumentsIndexed) => String::new(),
Err(e) => return Err(e),
};
let prompt = self
.config
.prompt_template
.replace("{context}", &context)
.replace("{query}", query);
Ok(prompt)
}
pub fn stats(&self) -> PipelineStats {
PipelineStats {
documents_indexed: self.retriever.document_count(),
chunks_indexed: self.retriever.chunk_count(),
embedding_dim: self.retriever.embedder().embedding_dim(),
store_memory_bytes: self.retriever.store().memory_usage_bytes(),
}
}
pub fn retriever(&self) -> &Retriever<E> {
&self.retriever
}
}