use crate::{
chunk::{Chunk, Chunker, RecursiveChunker},
embed::{Embedder, MockEmbedder},
fusion::FusionStrategy,
index::{BM25Index, VectorStore},
rerank::{NoOpReranker, Reranker},
retrieve::{HybridRetriever, HybridRetrieverConfig, RetrievalResult},
Document, DocumentId, Error, Result,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
const DEFAULT_EMBEDDING_DIM: usize = 384;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Citation {
pub id: usize,
pub document_id: DocumentId,
pub chunk_id: crate::ChunkId,
pub title: Option<String>,
pub url: Option<String>,
pub page: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextChunk {
pub content: String,
pub citation_id: usize,
pub retrieval_score: f32,
pub rerank_score: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AssembledContext {
pub chunks: Vec<ContextChunk>,
pub total_tokens: usize,
pub citations: Vec<Citation>,
}
impl AssembledContext {
#[must_use]
pub fn new() -> Self {
Self { chunks: Vec::new(), total_tokens: 0, citations: Vec::new() }
}
pub fn add_chunk(&mut self, result: &RetrievalResult, citation_id: usize) {
let chunk = ContextChunk {
content: result.chunk.content.clone(),
citation_id,
retrieval_score: result.best_score(),
rerank_score: result.rerank_score,
};
self.total_tokens += result.chunk.content.len() / 4;
self.chunks.push(chunk);
}
pub fn add_citation(&mut self, result: &RetrievalResult) -> usize {
let id = self.citations.len() + 1;
let citation = Citation {
id,
document_id: result.chunk.document_id,
chunk_id: result.chunk.id,
title: result.chunk.metadata.title.clone(),
url: None, page: result.chunk.metadata.page,
};
self.citations.push(citation);
id
}
#[must_use]
pub fn format_with_citations(&self) -> String {
self.chunks
.iter()
.map(|c| format!("{} [{}]", c.content, c.citation_id))
.collect::<Vec<_>>()
.join("\n\n")
}
#[must_use]
pub fn format_plain(&self) -> String {
self.chunks.iter().map(|c| c.content.as_str()).collect::<Vec<_>>().join("\n\n")
}
#[must_use]
pub fn citation_list(&self) -> String {
self.citations
.iter()
.map(|c| {
let title = c.title.as_deref().unwrap_or("Untitled");
format!("[{}] {}", c.id, title)
})
.collect::<Vec<_>>()
.join("\n")
}
#[must_use]
pub fn len(&self) -> usize {
self.chunks.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.chunks.is_empty()
}
}
impl Default for AssembledContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub enum AssemblyStrategy {
#[default]
Sequential,
DocumentGrouped,
Interleaved,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextAssemblerConfig {
pub max_tokens: usize,
pub strategy: AssemblyStrategy,
pub include_citations: bool,
}
impl Default for ContextAssemblerConfig {
fn default() -> Self {
Self { max_tokens: 4096, strategy: AssemblyStrategy::Sequential, include_citations: true }
}
}
#[derive(Debug, Clone)]
pub struct ContextAssembler {
config: ContextAssemblerConfig,
}
impl ContextAssembler {
#[must_use]
pub fn new(config: ContextAssemblerConfig) -> Self {
Self { config }
}
#[must_use]
pub fn with_max_tokens(max_tokens: usize) -> Self {
Self::new(ContextAssemblerConfig { max_tokens, ..Default::default() })
}
#[must_use]
pub fn assemble(&self, results: &[RetrievalResult]) -> AssembledContext {
match self.config.strategy {
AssemblyStrategy::Sequential => self.assemble_sequential(results),
AssemblyStrategy::DocumentGrouped => self.assemble_grouped(results),
AssemblyStrategy::Interleaved => self.assemble_interleaved(results),
}
}
fn assemble_sequential(&self, results: &[RetrievalResult]) -> AssembledContext {
let mut context = AssembledContext::new();
let mut remaining_tokens = self.config.max_tokens;
for result in results {
let chunk_tokens = result.chunk.content.len() / 4;
if chunk_tokens > remaining_tokens {
break;
}
let citation_id =
if self.config.include_citations { context.add_citation(result) } else { 0 };
context.add_chunk(result, citation_id);
remaining_tokens = remaining_tokens.saturating_sub(chunk_tokens);
}
context
}
fn assemble_grouped(&self, results: &[RetrievalResult]) -> AssembledContext {
let mut by_doc: HashMap<DocumentId, Vec<&RetrievalResult>> = HashMap::new();
for result in results {
by_doc.entry(result.chunk.document_id).or_default().push(result);
}
let mut context = AssembledContext::new();
let mut remaining_tokens = self.config.max_tokens;
for (_, doc_results) in by_doc {
for result in doc_results {
let chunk_tokens = result.chunk.content.len() / 4;
if chunk_tokens > remaining_tokens {
break;
}
let citation_id =
if self.config.include_citations { context.add_citation(result) } else { 0 };
context.add_chunk(result, citation_id);
remaining_tokens = remaining_tokens.saturating_sub(chunk_tokens);
}
}
context
}
fn assemble_interleaved(&self, results: &[RetrievalResult]) -> AssembledContext {
self.assemble_sequential(results)
}
}
impl Default for ContextAssembler {
fn default() -> Self {
Self::new(ContextAssemblerConfig::default())
}
}
#[derive(Debug, Clone)]
pub struct RagPipelineConfig {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub embedding_dimension: usize,
pub retrieval: HybridRetrieverConfig,
pub context: ContextAssemblerConfig,
}
impl Default for RagPipelineConfig {
fn default() -> Self {
Self {
chunk_size: 512,
chunk_overlap: 50,
embedding_dimension: DEFAULT_EMBEDDING_DIM,
retrieval: HybridRetrieverConfig::default(),
context: ContextAssemblerConfig::default(),
}
}
}
pub struct RagPipeline<E: Embedder, R: Reranker> {
chunker: Box<dyn Chunker>,
embedder: E,
retriever: HybridRetriever<E>,
reranker: R,
assembler: ContextAssembler,
document_count: usize,
}
impl<E: Embedder + Clone, R: Reranker> RagPipeline<E, R> {
pub fn index_document(&mut self, document: &Document) -> Result<Vec<Chunk>> {
let mut chunks = self.chunker.chunk(document)?;
self.embedder.embed_chunks(&mut chunks)?;
for chunk in &chunks {
self.retriever.index(chunk.clone())?;
}
self.document_count += 1;
Ok(chunks)
}
pub fn index_documents(&mut self, documents: &[Document]) -> Result<usize> {
let mut total_chunks = 0;
for doc in documents {
let chunks = self.index_document(doc)?;
total_chunks += chunks.len();
}
Ok(total_chunks)
}
#[must_use]
pub fn document_count(&self) -> usize {
self.document_count
}
#[must_use]
pub fn chunk_count(&self) -> usize {
self.retriever.len()
}
pub fn query(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
let mut results = self.retriever.retrieve(query, k * 2)?;
results = self.reranker.rerank(query, &results, k)?;
Ok(results)
}
pub fn query_with_context(
&self,
query: &str,
k: usize,
) -> Result<(Vec<RetrievalResult>, AssembledContext)> {
let results = self.query(query, k)?;
let context = self.assembler.assemble(&results);
Ok((results, context))
}
#[must_use]
pub fn assembler(&self) -> &ContextAssembler {
&self.assembler
}
#[must_use]
pub fn assemble_context(&self, results: &[RetrievalResult]) -> AssembledContext {
self.assembler.assemble(results)
}
#[must_use]
pub fn chunker(&self) -> &dyn Chunker {
self.chunker.as_ref()
}
#[must_use]
pub fn embedder(&self) -> &E {
&self.embedder
}
}
pub struct RagPipelineBuilder<E: Embedder, R: Reranker> {
chunker: Option<Box<dyn Chunker>>,
embedder: Option<E>,
vector_store: Option<VectorStore>,
sparse_index: Option<BM25Index>,
reranker: Option<R>,
fusion: FusionStrategy,
assembler_config: ContextAssemblerConfig,
}
impl<E: Embedder + Clone, R: Reranker> RagPipelineBuilder<E, R> {
#[must_use]
pub fn new() -> Self {
Self {
chunker: None,
embedder: None,
vector_store: None,
sparse_index: None,
reranker: None,
fusion: FusionStrategy::default(),
assembler_config: ContextAssemblerConfig::default(),
}
}
#[must_use]
pub fn chunker(mut self, chunker: impl Chunker + 'static) -> Self {
self.chunker = Some(Box::new(chunker));
self
}
#[must_use]
pub fn embedder(mut self, embedder: E) -> Self {
self.embedder = Some(embedder);
self
}
#[must_use]
pub fn vector_store(mut self, store: VectorStore) -> Self {
self.vector_store = Some(store);
self
}
#[must_use]
pub fn sparse_index(mut self, index: BM25Index) -> Self {
self.sparse_index = Some(index);
self
}
#[must_use]
pub fn reranker(mut self, reranker: R) -> Self {
self.reranker = Some(reranker);
self
}
#[must_use]
pub fn fusion(mut self, fusion: FusionStrategy) -> Self {
self.fusion = fusion;
self
}
#[must_use]
pub fn max_context_tokens(mut self, max_tokens: usize) -> Self {
self.assembler_config.max_tokens = max_tokens;
self
}
pub fn build(self) -> Result<RagPipeline<E, R>> {
let embedder =
self.embedder.ok_or_else(|| Error::InvalidConfig("embedder required".to_string()))?;
let reranker =
self.reranker.ok_or_else(|| Error::InvalidConfig("reranker required".to_string()))?;
let chunker = self.chunker.unwrap_or_else(|| Box::new(RecursiveChunker::new(512, 50)));
let vector_store =
self.vector_store.unwrap_or_else(|| VectorStore::with_dimension(embedder.dimension()));
let sparse_index = self.sparse_index.unwrap_or_default();
let retrieval_config = HybridRetrieverConfig { fusion: self.fusion, ..Default::default() };
let retriever = HybridRetriever::new(vector_store, sparse_index, embedder.clone())
.with_config(retrieval_config);
let assembler = ContextAssembler::new(self.assembler_config);
Ok(RagPipeline { chunker, embedder, retriever, reranker, assembler, document_count: 0 })
}
}
impl<E: Embedder + Clone, R: Reranker> Default for RagPipelineBuilder<E, R> {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn pipeline_builder() -> RagPipelineBuilder<MockEmbedder, NoOpReranker> {
RagPipelineBuilder::new()
}
#[cfg(test)]
mod tests;