use crate::{
data::{ExampleData, Extraction, CharInterval},
exceptions::{LangExtractError, LangExtractResult},
extract, ExtractConfig,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use futures::future::join_all;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineStep {
pub id: String,
pub name: String,
pub description: String,
pub examples: Vec<ExampleData>,
pub prompt: String,
pub output_field: String,
pub filter: Option<PipelineFilter>,
pub depends_on: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineFilter {
pub class_filter: Option<String>,
pub text_pattern: Option<String>,
pub max_items: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineConfig {
pub name: String,
pub description: String,
pub version: String,
pub steps: Vec<PipelineStep>,
pub global_config: ExtractConfig,
#[serde(default)]
pub enable_parallel_execution: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepResult {
pub step_id: String,
pub step_name: String,
pub extractions: Vec<Extraction>,
pub processing_time_ms: u64,
pub input_count: usize,
pub success: bool,
pub error_message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineResult {
pub config: PipelineConfig,
pub step_results: Vec<StepResult>,
pub nested_output: serde_json::Value,
pub total_time_ms: u64,
pub success: bool,
pub error_message: Option<String>,
}
pub struct PipelineExecutor {
config: PipelineConfig,
}
#[derive(Debug, Clone)]
struct StepInputItem {
text: String,
parent_start: Option<usize>,
parent_end: Option<usize>,
parent_step_id: Option<String>,
parent_class: Option<String>,
parent_text: Option<String>,
}
impl PipelineExecutor {
pub fn new(config: PipelineConfig) -> Self {
Self { config }
}
pub fn from_yaml_file(path: &std::path::Path) -> LangExtractResult<Self> {
let content = std::fs::read_to_string(path)
.map_err(|e| LangExtractError::configuration(format!("Failed to read pipeline file: {}", e)))?;
let config: PipelineConfig = serde_yaml::from_str(&content)
.map_err(|e| LangExtractError::configuration(format!("Failed to parse pipeline YAML: {}", e)))?;
Ok(Self::new(config))
}
pub async fn execute(&self, input_text: &str) -> LangExtractResult<PipelineResult> {
let start_time = std::time::Instant::now();
log::info!("[pipeline] starting: {}", self.config.name);
log::info!("[pipeline] {}", self.config.description);
if self.config.enable_parallel_execution {
log::info!("[pipeline] mode: parallel (independent steps run concurrently)");
} else {
log::info!("[pipeline] mode: sequential");
}
if self.config.enable_parallel_execution {
self.execute_parallel(input_text, start_time).await
} else {
self.execute_sequential(input_text, start_time).await
}
}
async fn execute_sequential(&self, input_text: &str, start_time: std::time::Instant) -> LangExtractResult<PipelineResult> {
let mut step_results = Vec::new();
let mut context_data = HashMap::new();
let execution_order = self.resolve_execution_order()?;
for step_id in execution_order {
let step_result = self.execute_step(&step_id, input_text, &context_data).await?;
step_results.push(step_result.clone());
if step_result.success {
context_data.insert(step_id, step_result.extractions.clone());
} else {
return Err(LangExtractError::configuration(format!(
"Step '{}' failed: {}",
step_id,
step_result.error_message.unwrap_or("Unknown error".to_string())
)));
}
}
let nested_output = self.build_nested_output(&step_results)?;
let total_time = start_time.elapsed().as_millis() as u64;
log::info!("[pipeline] done in {}ms", total_time);
Ok(PipelineResult {
config: self.config.clone(),
step_results,
nested_output,
total_time_ms: total_time,
success: true,
error_message: None,
})
}
async fn execute_parallel(&self, input_text: &str, start_time: std::time::Instant) -> LangExtractResult<PipelineResult> {
let mut all_step_results = Vec::new();
let mut context_data = HashMap::new();
let execution_waves = self.resolve_execution_waves()?;
for (wave_index, wave_steps) in execution_waves.iter().enumerate() {
log::debug!("[pipeline] wave {}: {} steps", wave_index + 1, wave_steps.len());
if wave_steps.len() == 1 {
let step_id = &wave_steps[0];
let step_result = self.execute_step(step_id, input_text, &context_data).await?;
if step_result.success {
context_data.insert(step_id.clone(), step_result.extractions.clone());
all_step_results.push(step_result);
} else {
return Err(LangExtractError::configuration(format!(
"Step '{}' failed: {}",
step_id,
step_result.error_message.unwrap_or("Unknown error".to_string())
)));
}
} else {
log::debug!("[pipeline] running {} steps in parallel", wave_steps.len());
let parallel_futures: Vec<_> = wave_steps.iter()
.map(|step_id| self.execute_step(step_id, input_text, &context_data))
.collect();
let wave_results = join_all(parallel_futures).await;
for (i, result) in wave_results.into_iter().enumerate() {
let step_result = result?;
let step_id = &wave_steps[i];
if step_result.success {
context_data.insert(step_id.clone(), step_result.extractions.clone());
all_step_results.push(step_result);
} else {
return Err(LangExtractError::configuration(format!(
"Step '{}' failed: {}",
step_id,
step_result.error_message.unwrap_or("Unknown error".to_string())
)));
}
}
}
}
let nested_output = self.build_nested_output(&all_step_results)?;
let total_time = start_time.elapsed().as_millis() as u64;
log::info!("[pipeline] done in {}ms", total_time);
Ok(PipelineResult {
config: self.config.clone(),
step_results: all_step_results,
nested_output,
total_time_ms: total_time,
success: true,
error_message: None,
})
}
fn resolve_execution_order(&self) -> LangExtractResult<Vec<String>> {
let mut order = Vec::new();
let mut visited = std::collections::HashSet::new();
let mut visiting = std::collections::HashSet::new();
for step in &self.config.steps {
self.resolve_step_dependencies(&step.id, &mut order, &mut visited, &mut visiting)?;
}
Ok(order)
}
fn resolve_execution_waves(&self) -> LangExtractResult<Vec<Vec<String>>> {
let mut waves = Vec::new();
let mut completed_steps = std::collections::HashSet::new();
let mut remaining_steps: std::collections::HashSet<String> =
self.config.steps.iter().map(|s| s.id.clone()).collect();
while !remaining_steps.is_empty() {
let mut current_wave = Vec::new();
for step in &self.config.steps {
if remaining_steps.contains(&step.id) {
let dependencies_satisfied = step.depends_on.iter()
.all(|dep| completed_steps.contains(dep));
if dependencies_satisfied {
current_wave.push(step.id.clone());
}
}
}
if current_wave.is_empty() {
return Err(LangExtractError::configuration(
"Unable to resolve execution waves - possible circular dependency".to_string()
));
}
for step_id in ¤t_wave {
remaining_steps.remove(step_id);
completed_steps.insert(step_id.clone());
}
waves.push(current_wave);
}
Ok(waves)
}
fn resolve_step_dependencies(
&self,
step_id: &str,
order: &mut Vec<String>,
visited: &mut std::collections::HashSet<String>,
visiting: &mut std::collections::HashSet<String>,
) -> LangExtractResult<()> {
if visited.contains(step_id) {
return Ok(());
}
if visiting.contains(step_id) {
return Err(LangExtractError::configuration(format!(
"Circular dependency detected involving step: {}", step_id
)));
}
visiting.insert(step_id.to_string());
if let Some(step) = self.config.steps.iter().find(|s| s.id == step_id) {
for dep in &step.depends_on {
self.resolve_step_dependencies(dep, order, visited, visiting)?;
}
}
visiting.remove(step_id);
visited.insert(step_id.to_string());
order.push(step_id.to_string());
Ok(())
}
async fn execute_step(
&self,
step_id: &str,
input_text: &str,
context_data: &HashMap<String, Vec<Extraction>>,
) -> LangExtractResult<StepResult> {
let step = self.config.steps.iter().find(|s| s.id == step_id)
.ok_or_else(|| LangExtractError::configuration(format!("Step '{}' not found", step_id)))?;
let step_start = std::time::Instant::now();
log::info!("[pipeline] step: {} ({})", step.name, step.id);
let step_input = self.prepare_step_input(step, input_text, context_data)?;
let input_count = step_input.len();
log::debug!("[pipeline] processing {} input items", input_count);
let mut all_extractions = Vec::new();
for (i, input_item) in step_input.iter().enumerate() {
log::debug!("[pipeline] item {}/{}", i + 1, input_count);
let step_config = self.config.global_config.clone();
let examples = if step.examples.is_empty() {
vec![] } else {
step.examples.clone()
};
match extract(
&input_item.text,
Some(&step.prompt),
&examples,
step_config,
).await {
Ok(result) => {
if let Some(extractions) = result.extractions {
for mut ex in extractions {
if !step.depends_on.is_empty() {
if let Some(parent_start) = input_item.parent_start {
let mut abs_interval: Option<CharInterval> = None;
if let Some(ci) = &ex.char_interval {
if let (Some(ls), Some(le)) = (ci.start_pos, ci.end_pos) {
abs_interval = Some(CharInterval::new(Some(parent_start + ls), Some(parent_start + le)));
}
}
if abs_interval.is_none() {
if let Some(found) = input_item.text.find(&ex.extraction_text) {
let start = parent_start + found;
let end = start + ex.extraction_text.len();
abs_interval = Some(CharInterval::new(Some(start), Some(end)));
}
}
if let Some(ai) = abs_interval {
ex.char_interval = Some(ai);
}
if let Some(parent_step_id) = &input_item.parent_step_id {
let mut attrs = ex.attributes.take().unwrap_or_default();
attrs.insert(
"parent_step_id".to_string(),
serde_json::Value::String(parent_step_id.clone()),
);
if let Some(ps) = input_item.parent_start {
attrs.insert(
"parent_start".to_string(),
serde_json::Value::Number(serde_json::Number::from(ps as u64)),
);
}
if let Some(pe) = input_item.parent_end {
attrs.insert(
"parent_end".to_string(),
serde_json::Value::Number(serde_json::Number::from(pe as u64)),
);
}
if let Some(pc) = &input_item.parent_class {
attrs.insert(
"parent_class".to_string(),
serde_json::Value::String(pc.clone()),
);
}
if let Some(pt) = &input_item.parent_text {
attrs.insert(
"parent_text".to_string(),
serde_json::Value::String(pt.clone()),
);
}
ex.attributes = Some(attrs);
}
}
}
all_extractions.push(ex);
}
}
}
Err(e) => {
log::warn!("[pipeline] step '{}' failed on item {}/{}: {}", step.id, i + 1, input_count, e);
return Ok(StepResult {
step_id: step.id.clone(),
step_name: step.name.clone(),
extractions: Vec::new(),
processing_time_ms: step_start.elapsed().as_millis() as u64,
input_count,
success: false,
error_message: Some(e.to_string()),
});
}
}
}
let processing_time = step_start.elapsed().as_millis() as u64;
log::info!("[pipeline] step '{}' done: {} extractions in {}ms",
step.name, all_extractions.len(), processing_time);
Ok(StepResult {
step_id: step.id.clone(),
step_name: step.name.clone(),
extractions: all_extractions,
processing_time_ms: processing_time,
input_count,
success: true,
error_message: None,
})
}
fn prepare_step_input(
&self,
step: &PipelineStep,
original_text: &str,
context_data: &HashMap<String, Vec<Extraction>>,
) -> LangExtractResult<Vec<StepInputItem>> {
if !step.depends_on.is_empty() {
let mut inputs: Vec<StepInputItem> = Vec::new();
for dep_id in &step.depends_on {
if let Some(extractions) = context_data.get(dep_id) {
let filtered_extractions = self.apply_filter(extractions, &step.filter);
for extraction in filtered_extractions {
let parent_start = extraction.char_interval.as_ref().and_then(|ci| ci.start_pos);
let parent_end = extraction.char_interval.as_ref().and_then(|ci| ci.end_pos);
inputs.push(StepInputItem {
text: extraction.extraction_text.clone(),
parent_start,
parent_end,
parent_step_id: Some(dep_id.clone()),
parent_class: Some(extraction.extraction_class.clone()),
parent_text: Some(extraction.extraction_text.clone()),
});
}
}
}
Ok(inputs)
} else {
Ok(vec![StepInputItem {
text: original_text.to_string(),
parent_start: Some(0),
parent_end: Some(original_text.len()),
parent_step_id: None,
parent_class: None,
parent_text: None,
}])
}
}
fn apply_filter<'a>(
&self,
extractions: &'a [Extraction],
filter: &Option<PipelineFilter>,
) -> Vec<&'a Extraction> {
if let Some(f) = filter {
extractions.iter()
.filter(|e| {
if let Some(class) = &f.class_filter {
if e.extraction_class != *class {
return false;
}
}
if let Some(pattern) = &f.text_pattern {
if let Ok(regex) = regex::Regex::new(pattern) {
if !regex.is_match(&e.extraction_text) {
return false;
}
}
}
true
})
.take(f.max_items.unwrap_or(usize::MAX))
.collect()
} else {
extractions.iter().collect()
}
}
fn build_nested_output(&self, step_results: &[StepResult]) -> LangExtractResult<serde_json::Value> {
let mut output = serde_json::Map::new();
for result in step_results {
if result.success {
let mut step_output = serde_json::Map::new();
let extractions_json: Vec<serde_json::Value> = result.extractions.iter()
.map(|e| {
let mut obj = serde_json::Map::new();
obj.insert("class".to_string(), serde_json::Value::String(e.extraction_class.clone()));
obj.insert("text".to_string(), serde_json::Value::String(e.extraction_text.clone()));
if let Some(interval) = &e.char_interval {
obj.insert("start".to_string(), serde_json::json!(interval.start_pos));
obj.insert("end".to_string(), serde_json::json!(interval.end_pos));
}
serde_json::Value::Object(obj)
})
.collect();
step_output.insert("extractions".to_string(), serde_json::Value::Array(extractions_json));
step_output.insert("count".to_string(), serde_json::json!(result.extractions.len()));
step_output.insert("processing_time_ms".to_string(), serde_json::json!(result.processing_time_ms));
output.insert(result.step_id.clone(), serde_json::Value::Object(step_output));
}
}
Ok(serde_json::Value::Object(output))
}
}
pub mod utils {
use super::*;
pub fn create_requirements_pipeline() -> PipelineConfig {
PipelineConfig {
name: "Requirements Extraction Pipeline".to_string(),
description: "Extract requirements and sub-divide into values, units, and specifications".to_string(),
version: "1.0.0".to_string(),
enable_parallel_execution: false,
global_config: ExtractConfig {
model_id: "gemini-2.5-flash".to_string(),
api_key: None,
format_type: crate::data::FormatType::Json,
max_char_buffer: 8000,
temperature: 0.3,
fence_output: None,
use_schema_constraints: true,
batch_length: 4,
max_workers: 6,
additional_context: None,
resolver_params: std::collections::HashMap::new(),
language_model_params: std::collections::HashMap::new(),
debug: false,
model_url: None,
enable_multipass: false,
multipass_max_passes: 2,
multipass_min_extractions: 1,
multipass_quality_threshold: 0.3,
progress_handler: None,
},
steps: vec![
PipelineStep {
id: "extract_requirements".to_string(),
name: "Extract Requirements".to_string(),
description: "Extract all 'shall' statements and requirements from the document".to_string(),
examples: vec![
ExampleData::new(
"The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
vec![
Extraction::new("requirement".to_string(),
"The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string()),
],
)
],
prompt: "Extract all requirements, 'shall' statements, and specifications from the text. Include the complete statement.".to_string(),
output_field: "requirements".to_string(),
filter: None,
depends_on: vec![],
},
PipelineStep {
id: "extract_values".to_string(),
name: "Extract Values".to_string(),
description: "Extract numeric values, units, and specifications from requirements".to_string(),
examples: vec![
ExampleData::new(
"The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
vec![
Extraction::new("value".to_string(), "100".to_string()),
Extraction::new("unit".to_string(), "transactions per second".to_string()),
Extraction::new("value".to_string(), "99.9".to_string()),
Extraction::new("unit".to_string(), "%".to_string()),
],
)
],
prompt: "From this requirement, extract all numeric values and their associated units or specifications.".to_string(),
output_field: "values".to_string(),
filter: Some(PipelineFilter {
class_filter: Some("requirement".to_string()),
text_pattern: None,
max_items: None,
}),
depends_on: vec!["extract_requirements".to_string()],
},
PipelineStep {
id: "extract_specifications".to_string(),
name: "Extract Specifications".to_string(),
description: "Extract detailed specifications and constraints from requirements".to_string(),
examples: vec![
ExampleData::new(
"The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
vec![
Extraction::new("specification".to_string(), "process 100 transactions per second".to_string()),
Extraction::new("constraint".to_string(), "maintain 99.9% uptime".to_string()),
],
)
],
prompt: "Extract detailed specifications, constraints, and performance requirements from this text.".to_string(),
output_field: "specifications".to_string(),
filter: Some(PipelineFilter {
class_filter: Some("requirement".to_string()),
text_pattern: None,
max_items: None,
}),
depends_on: vec!["extract_requirements".to_string()],
},
],
}
}
pub fn save_pipeline_to_file(config: &PipelineConfig, path: &std::path::Path) -> LangExtractResult<()> {
let yaml_content = serde_yaml::to_string(config)
.map_err(|e| LangExtractError::configuration(format!("Failed to serialize pipeline: {}", e)))?;
std::fs::write(path, yaml_content)
.map_err(|e| LangExtractError::configuration(format!("Failed to write pipeline file: {}", e)))?;
Ok(())
}
pub fn load_pipeline_from_file(path: &std::path::Path) -> LangExtractResult<PipelineConfig> {
let content = std::fs::read_to_string(path)
.map_err(|e| LangExtractError::configuration(format!("Failed to read pipeline file: {}", e)))?;
serde_yaml::from_str(&content)
.map_err(|e| LangExtractError::configuration(format!("Failed to parse pipeline YAML: {}", e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_config_serialization() {
let config = utils::create_requirements_pipeline();
let yaml = serde_yaml::to_string(&config).unwrap();
let deserialized: PipelineConfig = serde_yaml::from_str(&yaml).unwrap();
assert_eq!(config.name, deserialized.name);
assert_eq!(config.steps.len(), deserialized.steps.len());
}
#[test]
fn test_dependency_resolution() {
let config = utils::create_requirements_pipeline();
let executor = PipelineExecutor::new(config);
let order = executor.resolve_execution_order().unwrap();
assert_eq!(order[0], "extract_requirements");
assert_eq!(order.len(), 3);
}
#[test]
fn test_filter_application() {
let executor = PipelineExecutor::new(utils::create_requirements_pipeline());
let extractions = vec![
Extraction::new("requirement".to_string(), "Test requirement".to_string()),
Extraction::new("other".to_string(), "Other text".to_string()),
];
let filter = PipelineFilter {
class_filter: Some("requirement".to_string()),
text_pattern: None,
max_items: None,
};
let filtered = executor.apply_filter(&extractions, &Some(filter));
assert_eq!(filtered.len(), 1);
assert_eq!(filtered[0].extraction_class, "requirement");
}
}