Skip to main content

adk_rag/
pipeline.rs

1//! RAG pipeline orchestrator.
2//!
3//! The [`RagPipeline`] coordinates the full ingest-and-query workflow by
4//! composing an [`EmbeddingProvider`], a [`VectorStore`], a [`Chunker`],
5//! and an optional [`Reranker`].
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use adk_rag::{RagPipeline, RagConfig, InMemoryVectorStore, FixedSizeChunker};
11//!
12//! let pipeline = RagPipeline::builder()
13//!     .config(RagConfig::default())
14//!     .embedding_provider(Arc::new(my_embedder))
15//!     .vector_store(Arc::new(InMemoryVectorStore::new()))
16//!     .chunker(Arc::new(FixedSizeChunker::new(512, 100)))
17//!     .build()?;
18//!
19//! pipeline.create_collection("docs").await?;
20//! pipeline.ingest("docs", &document).await?;
21//! let results = pipeline.query("docs", "search query").await?;
22//! ```
23
24use std::sync::Arc;
25
26use tracing::{error, info};
27
28use crate::chunking::Chunker;
29use crate::config::RagConfig;
30use crate::document::{Chunk, Document, SearchResult};
31use crate::embedding::EmbeddingProvider;
32use crate::error::{RagError, Result};
33use crate::reranker::Reranker;
34use crate::vectorstore::VectorStore;
35
36/// The RAG pipeline orchestrator.
37///
38/// Coordinates document ingestion (chunk → embed → store) and query
39/// execution (embed → search → rerank → filter). Construct one via
40/// [`RagPipeline::builder()`].
41pub struct RagPipeline {
42    config: RagConfig,
43    embedding_provider: Arc<dyn EmbeddingProvider>,
44    vector_store: Arc<dyn VectorStore>,
45    chunker: Arc<dyn Chunker>,
46    reranker: Option<Arc<dyn Reranker>>,
47}
48
49impl RagPipeline {
50    /// Create a new [`RagPipelineBuilder`].
51    pub fn builder() -> RagPipelineBuilder {
52        RagPipelineBuilder::default()
53    }
54
55    /// Return a reference to the pipeline configuration.
56    pub fn config(&self) -> &RagConfig {
57        &self.config
58    }
59
60    /// Return a reference to the embedding provider.
61    pub fn embedding_provider(&self) -> &Arc<dyn EmbeddingProvider> {
62        &self.embedding_provider
63    }
64
65    /// Return a reference to the vector store.
66    pub fn vector_store(&self) -> &Arc<dyn VectorStore> {
67        &self.vector_store
68    }
69
70    /// Create a named collection in the vector store.
71    ///
72    /// The collection is created with the dimensionality reported by the
73    /// configured [`EmbeddingProvider`].
74    ///
75    /// # Errors
76    ///
77    /// Returns [`RagError::PipelineError`] if the vector store operation fails.
78    pub async fn create_collection(&self, name: &str) -> Result<()> {
79        let dimensions = self.embedding_provider.dimensions();
80        self.vector_store.create_collection(name, dimensions).await.map_err(|e| {
81            error!(collection = name, error = %e, "failed to create collection");
82            RagError::PipelineError(format!("failed to create collection '{name}': {e}"))
83        })
84    }
85
86    /// Delete a named collection from the vector store.
87    ///
88    /// # Errors
89    ///
90    /// Returns [`RagError::PipelineError`] if the vector store operation fails.
91    pub async fn delete_collection(&self, name: &str) -> Result<()> {
92        self.vector_store.delete_collection(name).await.map_err(|e| {
93            error!(collection = name, error = %e, "failed to delete collection");
94            RagError::PipelineError(format!("failed to delete collection '{name}': {e}"))
95        })
96    }
97
98    /// Ingest a single document: chunk → embed → store.
99    ///
100    /// Returns the chunks that were stored (with embeddings attached).
101    ///
102    /// # Errors
103    ///
104    /// Returns [`RagError::PipelineError`] if embedding or storage fails,
105    /// including the document ID in the error message.
106    pub async fn ingest(&self, collection: &str, document: &Document) -> Result<Vec<Chunk>> {
107        // 1. Chunk the document
108        let mut chunks = self.chunker.chunk(document);
109        if chunks.is_empty() {
110            info!(document.id = %document.id, chunk_count = 0, "ingested document (empty)");
111            return Ok(chunks);
112        }
113
114        // 2. Collect chunk texts for batch embedding
115        let texts: Vec<&str> = chunks.iter().map(|c| c.text.as_str()).collect();
116
117        // 3. Generate embeddings
118        let embeddings = self.embedding_provider.embed_batch(&texts).await.map_err(|e| {
119            error!(document.id = %document.id, error = %e, "embedding failed during ingestion");
120            RagError::PipelineError(format!("embedding failed for document '{}': {e}", document.id))
121        })?;
122
123        // 4. Attach embeddings to chunks
124        for (chunk, embedding) in chunks.iter_mut().zip(embeddings) {
125            chunk.embedding = embedding;
126        }
127
128        // 5. Upsert into vector store
129        self.vector_store.upsert(collection, &chunks).await.map_err(|e| {
130            error!(document.id = %document.id, error = %e, "upsert failed during ingestion");
131            RagError::PipelineError(format!("upsert failed for document '{}': {e}", document.id))
132        })?;
133
134        let chunk_count = chunks.len();
135        info!(document.id = %document.id, chunk_count, "ingested document");
136
137        Ok(chunks)
138    }
139
140    /// Ingest multiple documents through the chunk → embed → store workflow.
141    ///
142    /// Returns all chunks that were stored across all documents.
143    ///
144    /// # Errors
145    ///
146    /// Returns [`RagError::PipelineError`] on the first document that fails,
147    /// including the document ID in the error message.
148    pub async fn ingest_batch(
149        &self,
150        collection: &str,
151        documents: &[Document],
152    ) -> Result<Vec<Chunk>> {
153        let mut all_chunks = Vec::new();
154        for document in documents {
155            let chunks = self.ingest(collection, document).await?;
156            all_chunks.extend(chunks);
157        }
158        Ok(all_chunks)
159    }
160
161    /// Query the pipeline: embed → search → rerank → filter by threshold.
162    ///
163    /// Returns search results ordered by descending relevance score. Results
164    /// below the configured `similarity_threshold` are filtered out.
165    ///
166    /// # Errors
167    ///
168    /// Returns [`RagError::PipelineError`] if embedding or search fails.
169    pub async fn query(&self, collection: &str, query: &str) -> Result<Vec<SearchResult>> {
170        // 1. Embed the query
171        let query_embedding = self.embedding_provider.embed(query).await.map_err(|e| {
172            error!(error = %e, "embedding failed during query");
173            RagError::PipelineError(format!("query embedding failed: {e}"))
174        })?;
175
176        // 2. Search the vector store
177        let results = self
178            .vector_store
179            .search(collection, &query_embedding, self.config.top_k)
180            .await
181            .map_err(|e| {
182                error!(collection, error = %e, "vector store search failed");
183                RagError::PipelineError(format!("search failed in collection '{collection}': {e}"))
184            })?;
185
186        // 3. Rerank if a reranker is configured
187        let results = if let Some(reranker) = &self.reranker {
188            reranker.rerank(query, results).await.map_err(|e| {
189                error!(error = %e, "reranking failed");
190                RagError::PipelineError(format!("reranking failed: {e}"))
191            })?
192        } else {
193            results
194        };
195
196        // 4. Filter by similarity threshold
197        let threshold = self.config.similarity_threshold;
198        let filtered: Vec<SearchResult> =
199            results.into_iter().filter(|r| r.score >= threshold).collect();
200
201        info!(result_count = filtered.len(), "query completed");
202
203        Ok(filtered)
204    }
205}
206
207/// Builder for constructing a [`RagPipeline`].
208///
209/// All fields except `reranker` are required. Call [`build()`](RagPipelineBuilder::build)
210/// to validate and produce the pipeline.
211///
212/// # Example
213///
214/// ```rust,ignore
215/// let pipeline = RagPipeline::builder()
216///     .config(RagConfig::default())
217///     .embedding_provider(Arc::new(embedder))
218///     .vector_store(Arc::new(store))
219///     .chunker(Arc::new(chunker))
220///     .reranker(Arc::new(reranker))  // optional
221///     .build()?;
222/// ```
223#[derive(Default)]
224pub struct RagPipelineBuilder {
225    config: Option<RagConfig>,
226    embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
227    vector_store: Option<Arc<dyn VectorStore>>,
228    chunker: Option<Arc<dyn Chunker>>,
229    reranker: Option<Arc<dyn Reranker>>,
230}
231
232impl RagPipelineBuilder {
233    /// Set the pipeline configuration.
234    pub fn config(mut self, config: RagConfig) -> Self {
235        self.config = Some(config);
236        self
237    }
238
239    /// Set the embedding provider.
240    pub fn embedding_provider(mut self, provider: Arc<dyn EmbeddingProvider>) -> Self {
241        self.embedding_provider = Some(provider);
242        self
243    }
244
245    /// Set the vector store backend.
246    pub fn vector_store(mut self, store: Arc<dyn VectorStore>) -> Self {
247        self.vector_store = Some(store);
248        self
249    }
250
251    /// Set the document chunker.
252    pub fn chunker(mut self, chunker: Arc<dyn Chunker>) -> Self {
253        self.chunker = Some(chunker);
254        self
255    }
256
257    /// Set an optional reranker for post-search result reordering.
258    pub fn reranker(mut self, reranker: Arc<dyn Reranker>) -> Self {
259        self.reranker = Some(reranker);
260        self
261    }
262
263    /// Build the [`RagPipeline`], validating that all required fields are set.
264    ///
265    /// # Errors
266    ///
267    /// Returns [`RagError::ConfigError`] if any required field is missing.
268    pub fn build(self) -> Result<RagPipeline> {
269        let config =
270            self.config.ok_or_else(|| RagError::ConfigError("config is required".to_string()))?;
271        let embedding_provider = self
272            .embedding_provider
273            .ok_or_else(|| RagError::ConfigError("embedding_provider is required".to_string()))?;
274        let vector_store = self
275            .vector_store
276            .ok_or_else(|| RagError::ConfigError("vector_store is required".to_string()))?;
277        let chunker =
278            self.chunker.ok_or_else(|| RagError::ConfigError("chunker is required".to_string()))?;
279
280        Ok(RagPipeline {
281            config,
282            embedding_provider,
283            vector_store,
284            chunker,
285            reranker: self.reranker,
286        })
287    }
288}