use super::index::{EmbeddedDocument, VectorIndex, VectorSearchResult};
use crate::chunking::{chunk, ChunkingStrategy};
use crate::error::ForgeError;
use crate::types::EmbeddingRequest;
use crate::AsyncForgeClient;
#[derive(Debug, Clone)]
pub struct RagConfig {
pub embedding_model: String,
pub chunking_strategy: ChunkingStrategy,
pub chunk_size: usize,
pub chunk_overlap: usize,
pub top_k: usize,
pub min_score: Option<f32>,
}
impl Default for RagConfig {
fn default() -> Self {
Self {
embedding_model: "text-embedding-ada-002".to_string(),
chunking_strategy: ChunkingStrategy::Sentence,
chunk_size: 512,
chunk_overlap: 50,
top_k: 5,
min_score: None,
}
}
}
#[derive(Debug, Clone)]
pub struct RetrievalResult {
pub content: String,
pub document_id: String,
pub chunk_index: usize,
pub score: f32,
pub metadata: std::collections::HashMap<String, serde_json::Value>,
}
pub struct RagPipelineBuilder {
config: RagConfig,
client: Option<AsyncForgeClient>,
}
impl RagPipelineBuilder {
pub fn new() -> Self {
Self {
config: RagConfig::default(),
client: None,
}
}
pub fn client(mut self, client: AsyncForgeClient) -> Self {
self.client = Some(client);
self
}
pub fn embedding_model(mut self, model: impl Into<String>) -> Self {
self.config.embedding_model = model.into();
self
}
pub fn chunking_strategy(mut self, strategy: ChunkingStrategy) -> Self {
self.config.chunking_strategy = strategy;
self
}
pub fn chunk_size(mut self, size: usize) -> Self {
self.config.chunk_size = size;
self
}
pub fn chunk_overlap(mut self, overlap: usize) -> Self {
self.config.chunk_overlap = overlap;
self
}
pub fn top_k(mut self, k: usize) -> Self {
self.config.top_k = k;
self
}
pub fn min_score(mut self, score: f32) -> Self {
self.config.min_score = Some(score);
self
}
pub fn build(self) -> Result<RagPipeline, ForgeError> {
let client = self
.client
.ok_or_else(|| ForgeError::config("RAG pipeline requires a LiteForge client for embeddings"))?;
Ok(RagPipeline {
config: self.config,
client,
index: VectorIndex::new(),
})
}
}
impl Default for RagPipelineBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct RagPipeline {
config: RagConfig,
client: AsyncForgeClient,
index: VectorIndex,
}
impl RagPipeline {
pub fn builder() -> RagPipelineBuilder {
RagPipelineBuilder::new()
}
pub async fn index_document(&mut self, id: &str, content: &str) -> Result<usize, ForgeError> {
self.index_document_with_metadata(id, content, std::collections::HashMap::new())
.await
}
pub async fn index_document_with_metadata(
&mut self,
id: &str,
content: &str,
metadata: std::collections::HashMap<String, serde_json::Value>,
) -> Result<usize, ForgeError> {
let chunks = chunk(
content,
self.config.chunk_size,
self.config.chunk_overlap,
self.config.chunking_strategy,
);
if chunks.is_empty() {
return Ok(0);
}
let chunk_texts: Vec<String> = chunks.iter().map(|c| c.text.clone()).collect();
let embeddings = self.embed_batch(&chunk_texts).await?;
for (i, (chunk_obj, embedding)) in chunks.iter().zip(embeddings).enumerate() {
let chunk_id = format!("{}#{}", id, i);
let mut doc = EmbeddedDocument::new(&chunk_id, &chunk_obj.text, embedding);
doc.metadata = metadata.clone();
doc.metadata
.insert("source_id".to_string(), serde_json::json!(id));
doc.metadata
.insert("chunk_index".to_string(), serde_json::json!(i));
self.index.add(doc);
}
Ok(chunks.len())
}
pub fn remove_document(&mut self, id: &str) -> usize {
let prefix = format!("{}#", id);
let ids_to_remove: Vec<String> = self
.index
.ids()
.iter()
.filter(|doc_id| doc_id.starts_with(&prefix))
.map(|s| s.to_string())
.collect();
let count = ids_to_remove.len();
for doc_id in ids_to_remove {
self.index.remove(&doc_id);
}
count
}
pub async fn retrieve(&self, query: &str) -> Result<Vec<RetrievalResult>, ForgeError> {
let query_embedding = self.embed(query).await?;
let results = if let Some(min_score) = self.config.min_score {
self.index
.search_with_threshold(&query_embedding, self.config.top_k, min_score)
} else {
self.index.search(&query_embedding, self.config.top_k)
};
Ok(results
.into_iter()
.map(|r| {
let document_id = r
.document
.metadata
.get("source_id")
.and_then(|v| v.as_str())
.unwrap_or(&r.document.id)
.to_string();
let chunk_index = r
.document
.metadata
.get("chunk_index")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
RetrievalResult {
content: r.document.content,
document_id,
chunk_index,
score: r.score,
metadata: r.document.metadata,
}
})
.collect())
}
async fn embed(&self, text: &str) -> Result<Vec<f32>, ForgeError> {
let request = EmbeddingRequest::new(&self.config.embedding_model, text);
let response = self.client.embeddings(request).await?;
response
.embedding()
.map(|e| e.to_vec())
.ok_or_else(|| ForgeError::internal("No embedding returned"))
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ForgeError> {
if texts.is_empty() {
return Ok(Vec::new());
}
let request = EmbeddingRequest::batch(&self.config.embedding_model, texts.to_vec());
let response = self.client.embeddings(request).await?;
Ok(response
.embeddings()
.into_iter()
.map(|e| e.to_vec())
.collect())
}
pub fn chunk_count(&self) -> usize {
self.index.len()
}
pub fn clear(&mut self) {
self.index.clear();
}
pub fn config(&self) -> &RagConfig {
&self.config
}
pub fn add_embedded(&mut self, documents: Vec<EmbeddedDocument>) {
self.index.add_batch(documents);
}
pub fn search_with_embedding(&self, embedding: &[f32]) -> Vec<VectorSearchResult> {
if let Some(min_score) = self.config.min_score {
self.index
.search_with_threshold(embedding, self.config.top_k, min_score)
} else {
self.index.search(embedding, self.config.top_k)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rag_config_default() {
let config = RagConfig::default();
assert_eq!(config.top_k, 5);
assert_eq!(config.chunk_size, 512);
}
#[test]
fn test_builder_requires_client() {
let result = RagPipelineBuilder::new().build();
assert!(result.is_err());
}
}