use crate::{
Document, DocumentChunk, DocumentChunker, Embedding, EmbeddingService, RetrievalService,
RragError, RragResult, SearchResult, StorageService,
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct PipelineContext {
pub execution_id: String,
pub data: PipelineData,
pub metadata: HashMap<String, serde_json::Value>,
pub execution_history: Vec<StepExecution>,
pub config: PipelineConfig,
}
#[derive(Debug, Clone)]
pub enum PipelineData {
Text(String),
Document(Document),
Documents(Vec<Document>),
Chunks(Vec<DocumentChunk>),
Embeddings(Vec<Embedding>),
SearchResults(Vec<SearchResult>),
Json(serde_json::Value),
}
#[derive(Debug, Clone)]
pub struct PipelineConfig {
pub max_execution_time: u64,
pub continue_on_error: bool,
pub enable_parallelism: bool,
pub max_parallel_steps: usize,
pub enable_caching: bool,
pub custom_config: HashMap<String, serde_json::Value>,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
max_execution_time: 300, continue_on_error: false,
enable_parallelism: true,
max_parallel_steps: 4,
enable_caching: false,
custom_config: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepExecution {
pub step_id: String,
pub start_time: chrono::DateTime<chrono::Utc>,
pub duration_ms: u64,
pub success: bool,
pub error_message: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl PipelineContext {
pub fn new(data: PipelineData) -> Self {
Self {
execution_id: uuid::Uuid::new_v4().to_string(),
data,
metadata: HashMap::new(),
execution_history: Vec::new(),
config: PipelineConfig::default(),
}
}
pub fn with_config(data: PipelineData, config: PipelineConfig) -> Self {
Self {
execution_id: uuid::Uuid::new_v4().to_string(),
data,
metadata: HashMap::new(),
execution_history: Vec::new(),
config,
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn record_step(&mut self, step_execution: StepExecution) {
self.execution_history.push(step_execution);
}
pub fn total_execution_time(&self) -> u64 {
self.execution_history
.iter()
.map(|step| step.duration_ms)
.sum()
}
pub fn has_failures(&self) -> bool {
self.execution_history.iter().any(|step| !step.success)
}
}
#[async_trait]
pub trait PipelineStep: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn input_types(&self) -> Vec<&'static str>;
fn output_type(&self) -> &'static str;
async fn execute(&self, context: PipelineContext) -> RragResult<PipelineContext>;
fn validate_input(&self, _data: &PipelineData) -> RragResult<()> {
Ok(())
}
fn is_parallelizable(&self) -> bool {
true
}
fn dependencies(&self) -> Vec<&str> {
Vec::new()
}
}
pub struct TextPreprocessingStep {
operations: Vec<TextOperation>,
}
#[derive(Debug, Clone)]
pub enum TextOperation {
ToLowercase,
NormalizeWhitespace,
RemoveSpecialChars,
RegexReplace {
pattern: String,
replacement: String,
},
}
impl TextPreprocessingStep {
pub fn new(operations: Vec<TextOperation>) -> Self {
Self { operations }
}
fn process_text(&self, text: &str) -> String {
let mut result = text.to_string();
for operation in &self.operations {
result = match operation {
TextOperation::ToLowercase => result.to_lowercase(),
TextOperation::NormalizeWhitespace => {
result.split_whitespace().collect::<Vec<_>>().join(" ")
}
TextOperation::RemoveSpecialChars => result
.chars()
.filter(|c| c.is_alphanumeric() || c.is_whitespace())
.collect(),
TextOperation::RegexReplace {
pattern,
replacement,
} => {
result.replace(pattern, replacement)
}
};
}
result
}
}
#[async_trait]
impl PipelineStep for TextPreprocessingStep {
fn name(&self) -> &str {
"text_preprocessing"
}
fn description(&self) -> &str {
"Preprocesses text data with various normalization operations"
}
fn input_types(&self) -> Vec<&'static str> {
vec!["Text", "Document", "Documents"]
}
fn output_type(&self) -> &'static str {
"Text|Document|Documents"
}
async fn execute(&self, mut context: PipelineContext) -> RragResult<PipelineContext> {
let start_time = Instant::now();
let step_start = chrono::Utc::now();
let processed_data = match &context.data {
PipelineData::Text(text) => PipelineData::Text(self.process_text(text)),
PipelineData::Document(doc) => {
let processed_content = self.process_text(doc.content_str());
let mut new_doc = Document::new(processed_content);
new_doc.id = doc.id.clone();
new_doc.metadata = doc.metadata.clone();
new_doc.content_hash = doc.content_hash.clone();
new_doc.created_at = doc.created_at;
PipelineData::Document(new_doc)
}
PipelineData::Documents(docs) => {
let processed_docs: Vec<Document> = docs
.iter()
.map(|doc| {
let processed_content = self.process_text(doc.content_str());
let mut new_doc = Document::new(processed_content);
new_doc.id = doc.id.clone();
new_doc.metadata = doc.metadata.clone();
new_doc.content_hash = doc.content_hash.clone();
new_doc.created_at = doc.created_at;
new_doc
})
.collect();
PipelineData::Documents(processed_docs)
}
_ => {
let error = "Input must be Text, Document, or Documents";
context.record_step(StepExecution {
step_id: self.name().to_string(),
start_time: step_start,
duration_ms: start_time.elapsed().as_millis() as u64,
success: false,
error_message: Some(error.to_string()),
metadata: HashMap::new(),
});
return Err(RragError::document_processing(error));
}
};
context.data = processed_data;
context.record_step(StepExecution {
step_id: self.name().to_string(),
start_time: step_start,
duration_ms: start_time.elapsed().as_millis() as u64,
success: true,
error_message: None,
metadata: HashMap::new(),
});
Ok(context)
}
}
pub struct DocumentChunkingStep {
chunker: DocumentChunker,
}
impl DocumentChunkingStep {
pub fn new(chunker: DocumentChunker) -> Self {
Self { chunker }
}
}
#[async_trait]
impl PipelineStep for DocumentChunkingStep {
fn name(&self) -> &str {
"document_chunking"
}
fn description(&self) -> &str {
"Splits documents into smaller chunks for processing"
}
fn input_types(&self) -> Vec<&'static str> {
vec!["Document", "Documents"]
}
fn output_type(&self) -> &'static str {
"Chunks"
}
async fn execute(&self, mut context: PipelineContext) -> RragResult<PipelineContext> {
let start_time = Instant::now();
let step_start = chrono::Utc::now();
let chunks = match &context.data {
PipelineData::Document(doc) => self.chunker.chunk_document(doc)?,
PipelineData::Documents(docs) => {
let mut all_chunks = Vec::new();
for doc in docs {
all_chunks.extend(self.chunker.chunk_document(doc)?);
}
all_chunks
}
_ => {
let error = "Input must be Document or Documents";
context.record_step(StepExecution {
step_id: self.name().to_string(),
start_time: step_start,
duration_ms: start_time.elapsed().as_millis() as u64,
success: false,
error_message: Some(error.to_string()),
metadata: HashMap::new(),
});
return Err(RragError::document_processing(error));
}
};
context.data = PipelineData::Chunks(chunks);
context.record_step(StepExecution {
step_id: self.name().to_string(),
start_time: step_start,
duration_ms: start_time.elapsed().as_millis() as u64,
success: true,
error_message: None,
metadata: HashMap::new(),
});
Ok(context)
}
}
pub struct EmbeddingStep {
embedding_service: Arc<EmbeddingService>,
}
impl EmbeddingStep {
pub fn new(embedding_service: Arc<EmbeddingService>) -> Self {
Self { embedding_service }
}
}
#[async_trait]
impl PipelineStep for EmbeddingStep {
fn name(&self) -> &str {
"embedding_generation"
}
fn description(&self) -> &str {
"Generates embeddings for documents or chunks"
}
fn input_types(&self) -> Vec<&'static str> {
vec!["Document", "Documents", "Chunks"]
}
fn output_type(&self) -> &'static str {
"Embeddings"
}
async fn execute(&self, mut context: PipelineContext) -> RragResult<PipelineContext> {
let start_time = Instant::now();
let step_start = chrono::Utc::now();
let embeddings = match &context.data {
PipelineData::Document(doc) => {
vec![self.embedding_service.embed_document(doc).await?]
}
PipelineData::Documents(docs) => self.embedding_service.embed_documents(docs).await?,
PipelineData::Chunks(chunks) => self.embedding_service.embed_chunks(chunks).await?,
_ => {
let error = "Input must be Document, Documents, or Chunks";
context.record_step(StepExecution {
step_id: self.name().to_string(),
start_time: step_start,
duration_ms: start_time.elapsed().as_millis() as u64,
success: false,
error_message: Some(error.to_string()),
metadata: HashMap::new(),
});
return Err(RragError::embedding("pipeline", error));
}
};
context.data = PipelineData::Embeddings(embeddings);
context.record_step(StepExecution {
step_id: self.name().to_string(),
start_time: step_start,
duration_ms: start_time.elapsed().as_millis() as u64,
success: true,
error_message: None,
metadata: HashMap::new(),
});
Ok(context)
}
}
pub struct RetrievalStep {
retrieval_service: Arc<RetrievalService>,
search_config: SearchStepConfig,
}
#[derive(Debug, Clone)]
pub struct SearchStepConfig {
pub limit: usize,
pub min_score: f32,
pub query_text: Option<String>,
}
impl Default for SearchStepConfig {
fn default() -> Self {
Self {
limit: 10,
min_score: 0.0,
query_text: None,
}
}
}
impl RetrievalStep {
pub fn new(retrieval_service: Arc<RetrievalService>) -> Self {
Self {
retrieval_service,
search_config: SearchStepConfig::default(),
}
}
pub fn with_config(retrieval_service: Arc<RetrievalService>, config: SearchStepConfig) -> Self {
Self {
retrieval_service,
search_config: config,
}
}
}
#[async_trait]
impl PipelineStep for RetrievalStep {
fn name(&self) -> &str {
"similarity_retrieval"
}
fn description(&self) -> &str {
"Performs similarity search using embeddings"
}
fn input_types(&self) -> Vec<&'static str> {
vec!["Embeddings"]
}
fn output_type(&self) -> &'static str {
"SearchResults"
}
async fn execute(&self, mut context: PipelineContext) -> RragResult<PipelineContext> {
let start_time = Instant::now();
let step_start = chrono::Utc::now();
let search_results = match &context.data {
PipelineData::Embeddings(embeddings) => {
if embeddings.is_empty() {
Vec::new()
} else {
let query_embedding = embeddings[0].clone();
self.retrieval_service
.search_embedding(query_embedding, Some(self.search_config.limit))
.await?
}
}
_ => {
let error = "Input must be Embeddings";
context.record_step(StepExecution {
step_id: self.name().to_string(),
start_time: step_start,
duration_ms: start_time.elapsed().as_millis() as u64,
success: false,
error_message: Some(error.to_string()),
metadata: HashMap::new(),
});
return Err(RragError::retrieval(error));
}
};
context.data = PipelineData::SearchResults(search_results);
context.record_step(StepExecution {
step_id: self.name().to_string(),
start_time: step_start,
duration_ms: start_time.elapsed().as_millis() as u64,
success: true,
error_message: None,
metadata: HashMap::new(),
});
Ok(context)
}
}
pub struct Pipeline {
steps: Vec<Arc<dyn PipelineStep>>,
config: PipelineConfig,
metadata: HashMap<String, serde_json::Value>,
}
impl Pipeline {
pub fn new() -> Self {
Self {
steps: Vec::new(),
config: PipelineConfig::default(),
metadata: HashMap::new(),
}
}
pub fn with_config(config: PipelineConfig) -> Self {
Self {
steps: Vec::new(),
config,
metadata: HashMap::new(),
}
}
pub fn add_step(mut self, step: Arc<dyn PipelineStep>) -> Self {
self.steps.push(step);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub async fn execute(&self, initial_data: PipelineData) -> RragResult<PipelineContext> {
let mut context = PipelineContext::with_config(initial_data, self.config.clone());
context.metadata.extend(self.metadata.clone());
let start_time = Instant::now();
for step in &self.steps {
if start_time.elapsed().as_secs() > self.config.max_execution_time {
return Err(RragError::timeout(
"pipeline_execution",
self.config.max_execution_time * 1000,
));
}
if let Err(e) = step.validate_input(&context.data) {
if !self.config.continue_on_error {
return Err(e);
}
context.record_step(StepExecution {
step_id: step.name().to_string(),
start_time: chrono::Utc::now(),
duration_ms: 0,
success: false,
error_message: Some(e.to_string()),
metadata: HashMap::new(),
});
continue;
}
let context_clone = PipelineContext {
execution_id: context.execution_id.clone(),
data: context.data.clone(),
metadata: context.metadata.clone(),
execution_history: context.execution_history.clone(),
config: context.config.clone(),
};
match step.execute(context_clone).await {
Ok(new_context) => {
context = new_context;
}
Err(e) => {
if !self.config.continue_on_error {
return Err(e);
}
context.record_step(StepExecution {
step_id: step.name().to_string(),
start_time: chrono::Utc::now(),
duration_ms: 0,
success: false,
error_message: Some(e.to_string()),
metadata: HashMap::new(),
});
}
}
}
Ok(context)
}
pub fn get_step_info(&self) -> Vec<PipelineStepInfo> {
self.steps
.iter()
.map(|step| PipelineStepInfo {
name: step.name().to_string(),
description: step.description().to_string(),
input_types: step.input_types().iter().map(|s| s.to_string()).collect(),
output_type: step.output_type().to_string(),
is_parallelizable: step.is_parallelizable(),
dependencies: step.dependencies().iter().map(|s| s.to_string()).collect(),
})
.collect()
}
}
impl Default for Pipeline {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineStepInfo {
pub name: String,
pub description: String,
pub input_types: Vec<String>,
pub output_type: String,
pub is_parallelizable: bool,
pub dependencies: Vec<String>,
}
pub struct RagPipelineBuilder {
embedding_service: Option<Arc<EmbeddingService>>,
retrieval_service: Option<Arc<RetrievalService>>,
storage_service: Option<Arc<StorageService>>,
config: PipelineConfig,
}
impl RagPipelineBuilder {
pub fn new() -> Self {
Self {
embedding_service: None,
retrieval_service: None,
storage_service: None,
config: PipelineConfig::default(),
}
}
pub fn with_embedding_service(mut self, service: Arc<EmbeddingService>) -> Self {
self.embedding_service = Some(service);
self
}
pub fn with_retrieval_service(mut self, service: Arc<RetrievalService>) -> Self {
self.retrieval_service = Some(service);
self
}
pub fn with_storage_service(mut self, service: Arc<StorageService>) -> Self {
self.storage_service = Some(service);
self
}
pub fn with_config(mut self, config: PipelineConfig) -> Self {
self.config = config;
self
}
pub fn build_ingestion_pipeline(&self) -> RragResult<Pipeline> {
let embedding_service = self
.embedding_service
.as_ref()
.ok_or_else(|| RragError::config("embedding_service", "required", "missing"))?;
let pipeline = Pipeline::with_config(self.config.clone())
.add_step(Arc::new(TextPreprocessingStep::new(vec![
TextOperation::NormalizeWhitespace,
TextOperation::ToLowercase,
])))
.add_step(Arc::new(DocumentChunkingStep::new(DocumentChunker::new())))
.add_step(Arc::new(EmbeddingStep::new(embedding_service.clone())));
Ok(pipeline)
}
pub fn build_query_pipeline(&self) -> RragResult<Pipeline> {
let embedding_service = self
.embedding_service
.as_ref()
.ok_or_else(|| RragError::config("embedding_service", "required", "missing"))?;
let retrieval_service = self
.retrieval_service
.as_ref()
.ok_or_else(|| RragError::config("retrieval_service", "required", "missing"))?;
let pipeline = Pipeline::with_config(self.config.clone())
.add_step(Arc::new(TextPreprocessingStep::new(vec![
TextOperation::NormalizeWhitespace,
])))
.add_step(Arc::new(EmbeddingStep::new(embedding_service.clone())))
.add_step(Arc::new(RetrievalStep::new(retrieval_service.clone())));
Ok(pipeline)
}
}
impl Default for RagPipelineBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Document, EmbeddingService, InMemoryRetriever, LocalEmbeddingProvider};
#[tokio::test]
async fn test_text_preprocessing_step() {
let step = TextPreprocessingStep::new(vec![
TextOperation::ToLowercase,
TextOperation::NormalizeWhitespace,
]);
let context = PipelineContext::new(PipelineData::Text(" HELLO WORLD ".to_string()));
let result = step.execute(context).await.unwrap();
if let PipelineData::Text(processed) = result.data {
assert_eq!(processed, "hello world");
} else {
panic!("Expected Text output");
}
assert!(result.execution_history[0].success);
}
#[tokio::test]
async fn test_document_chunking_step() {
let step = DocumentChunkingStep::new(DocumentChunker::new());
let doc = Document::new(
"This is a test document with some content that should be chunked appropriately.",
);
let context = PipelineContext::new(PipelineData::Document(doc));
let result = step.execute(context).await.unwrap();
if let PipelineData::Chunks(chunks) = result.data {
assert!(!chunks.is_empty());
} else {
panic!("Expected Chunks output");
}
}
#[tokio::test]
async fn test_embedding_step() {
let provider = Arc::new(LocalEmbeddingProvider::new("test-model", 128));
let embedding_service = Arc::new(EmbeddingService::new(provider));
let step = EmbeddingStep::new(embedding_service);
let doc = Document::new("Test document for embedding");
let context = PipelineContext::new(PipelineData::Document(doc));
let result = step.execute(context).await.unwrap();
if let PipelineData::Embeddings(embeddings) = result.data {
assert_eq!(embeddings.len(), 1);
assert_eq!(embeddings[0].dimensions, 128);
} else {
panic!("Expected Embeddings output");
}
}
#[tokio::test]
async fn test_pipeline_execution() {
let provider = Arc::new(LocalEmbeddingProvider::new("test-model", 128));
let embedding_service = Arc::new(EmbeddingService::new(provider));
let pipeline = Pipeline::new()
.add_step(Arc::new(TextPreprocessingStep::new(vec![
TextOperation::ToLowercase,
])))
.add_step(Arc::new(EmbeddingStep::new(embedding_service)));
let doc = Document::new("TEST DOCUMENT");
let result = pipeline.execute(PipelineData::Document(doc)).await.unwrap();
assert_eq!(result.execution_history.len(), 2);
assert!(result.execution_history.iter().all(|step| step.success));
if let PipelineData::Embeddings(embeddings) = result.data {
assert_eq!(embeddings.len(), 1);
} else {
panic!("Expected Embeddings output");
}
}
#[tokio::test]
async fn test_rag_pipeline_builder() {
let provider = Arc::new(LocalEmbeddingProvider::new("test-model", 128));
let embedding_service = Arc::new(EmbeddingService::new(provider));
let builder = RagPipelineBuilder::new().with_embedding_service(embedding_service);
let pipeline = builder.build_ingestion_pipeline().unwrap();
let step_info = pipeline.get_step_info();
assert_eq!(step_info.len(), 3); assert_eq!(step_info[0].name, "text_preprocessing");
assert_eq!(step_info[1].name, "document_chunking");
assert_eq!(step_info[2].name, "embedding_generation");
}
#[test]
fn test_pipeline_context() {
let mut context = PipelineContext::new(PipelineData::Text("test".to_string()))
.with_metadata(
"test_key",
serde_json::Value::String("test_value".to_string()),
);
assert_eq!(
context.metadata.get("test_key").unwrap().as_str().unwrap(),
"test_value"
);
let step_execution = StepExecution {
step_id: "test_step".to_string(),
start_time: chrono::Utc::now(),
duration_ms: 100,
success: true,
error_message: None,
metadata: HashMap::new(),
};
context.record_step(step_execution);
assert_eq!(context.execution_history.len(), 1);
assert_eq!(context.total_execution_time(), 100);
assert!(!context.has_failures());
}
}