langextract_rust/
pipeline.rs

1//! Pipeline processing for multi-step information extraction.
2//!
3//! This module provides a pipeline system for processing documents through
4//! multiple extraction steps, creating nested hierarchical structures from text.
5
6use crate::{
7    data::{ExampleData, Extraction},
8    exceptions::{LangExtractError, LangExtractResult},
9    extract, ExtractConfig,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// A single step in a processing pipeline
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PipelineStep {
17    /// Unique identifier for this step
18    pub id: String,
19
20    /// Human-readable name for this step
21    pub name: String,
22
23    /// Description of what this step extracts
24    pub description: String,
25
26    /// Examples for this extraction step
27    pub examples: Vec<ExampleData>,
28
29    /// Extraction prompt/description
30    pub prompt: String,
31
32    /// Output field name for the results of this step
33    pub output_field: String,
34
35    /// Optional filter to only process certain extractions from previous steps
36    pub filter: Option<PipelineFilter>,
37
38    /// Dependencies - this step depends on output from these step IDs
39    pub depends_on: Vec<String>,
40}
41
42/// Filter configuration for processing specific extractions
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct PipelineFilter {
45    /// Filter by extraction class
46    pub class_filter: Option<String>,
47
48    /// Filter by regex pattern on extraction text
49    pub text_pattern: Option<String>,
50
51    /// Maximum number of items to process
52    pub max_items: Option<usize>,
53}
54
55/// Configuration for the entire pipeline
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct PipelineConfig {
58    /// Pipeline name
59    pub name: String,
60
61    /// Pipeline description
62    pub description: String,
63
64    /// Pipeline version
65    pub version: String,
66
67    /// All processing steps
68    pub steps: Vec<PipelineStep>,
69
70    /// Global configuration that applies to all steps
71    pub global_config: ExtractConfig,
72}
73
74/// Results from a single pipeline step
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct StepResult {
77    /// Step ID
78    pub step_id: String,
79
80    /// Step name
81    pub step_name: String,
82
83    /// Extractions produced by this step
84    pub extractions: Vec<Extraction>,
85
86    /// Processing time in milliseconds
87    pub processing_time_ms: u64,
88
89    /// Number of input items processed
90    pub input_count: usize,
91
92    /// Success status
93    pub success: bool,
94
95    /// Error message if failed
96    pub error_message: Option<String>,
97}
98
99/// Complete pipeline execution result
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct PipelineResult {
102    /// Pipeline configuration used
103    pub config: PipelineConfig,
104
105    /// Results from each step
106    pub step_results: Vec<StepResult>,
107
108    /// Final nested output structure
109    pub nested_output: serde_json::Value,
110
111    /// Total processing time
112    pub total_time_ms: u64,
113
114    /// Overall success status
115    pub success: bool,
116
117    /// Error message if pipeline failed
118    pub error_message: Option<String>,
119}
120
121/// Pipeline executor
122pub struct PipelineExecutor {
123    config: PipelineConfig,
124}
125
126impl PipelineExecutor {
127    /// Create a new pipeline executor
128    pub fn new(config: PipelineConfig) -> Self {
129        Self { config }
130    }
131
132    /// Load pipeline configuration from YAML file
133    pub fn from_yaml_file(path: &std::path::Path) -> LangExtractResult<Self> {
134        let content = std::fs::read_to_string(path)
135            .map_err(|e| LangExtractError::configuration(format!("Failed to read pipeline file: {}", e)))?;
136
137        let config: PipelineConfig = serde_yaml::from_str(&content)
138            .map_err(|e| LangExtractError::configuration(format!("Failed to parse pipeline YAML: {}", e)))?;
139
140        Ok(Self::new(config))
141    }
142
143    /// Execute the entire pipeline
144    pub async fn execute(&self, input_text: &str) -> LangExtractResult<PipelineResult> {
145        let start_time = std::time::Instant::now();
146
147        println!("🚀 Starting pipeline execution: {}", self.config.name);
148        println!("📝 Description: {}", self.config.description);
149
150        let mut step_results = Vec::new();
151        let mut context_data = HashMap::new();
152
153        // Execute steps in dependency order
154        let execution_order = self.resolve_execution_order()?;
155
156        for step_id in execution_order {
157            let step_result = self.execute_step(&step_id, input_text, &context_data).await?;
158            step_results.push(step_result.clone());
159
160            // Store results for dependent steps
161            if step_result.success {
162                context_data.insert(step_id, step_result.extractions.clone());
163            } else {
164                return Err(LangExtractError::configuration(format!(
165                    "Step '{}' failed: {}",
166                    step_id,
167                    step_result.error_message.unwrap_or("Unknown error".to_string())
168                )));
169            }
170        }
171
172        // Build nested output structure
173        let nested_output = self.build_nested_output(&step_results)?;
174
175        let total_time = start_time.elapsed().as_millis() as u64;
176
177        println!("✅ Pipeline execution completed in {}ms", total_time);
178
179        Ok(PipelineResult {
180            config: self.config.clone(),
181            step_results,
182            nested_output,
183            total_time_ms: total_time,
184            success: true,
185            error_message: None,
186        })
187    }
188
189    /// Resolve the execution order based on dependencies
190    fn resolve_execution_order(&self) -> LangExtractResult<Vec<String>> {
191        let mut order = Vec::new();
192        let mut visited = std::collections::HashSet::new();
193        let mut visiting = std::collections::HashSet::new();
194
195        for step in &self.config.steps {
196            self.resolve_step_dependencies(&step.id, &mut order, &mut visited, &mut visiting)?;
197        }
198
199        Ok(order)
200    }
201
202    /// Recursive function to resolve dependencies
203    fn resolve_step_dependencies(
204        &self,
205        step_id: &str,
206        order: &mut Vec<String>,
207        visited: &mut std::collections::HashSet<String>,
208        visiting: &mut std::collections::HashSet<String>,
209    ) -> LangExtractResult<()> {
210        if visited.contains(step_id) {
211            return Ok(());
212        }
213
214        if visiting.contains(step_id) {
215            return Err(LangExtractError::configuration(format!(
216                "Circular dependency detected involving step: {}", step_id
217            )));
218        }
219
220        visiting.insert(step_id.to_string());
221
222        // Find the step and process its dependencies
223        if let Some(step) = self.config.steps.iter().find(|s| s.id == step_id) {
224            for dep in &step.depends_on {
225                self.resolve_step_dependencies(dep, order, visited, visiting)?;
226            }
227        }
228
229        visiting.remove(step_id);
230        visited.insert(step_id.to_string());
231        order.push(step_id.to_string());
232
233        Ok(())
234    }
235
236    /// Execute a single pipeline step
237    async fn execute_step(
238        &self,
239        step_id: &str,
240        input_text: &str,
241        context_data: &HashMap<String, Vec<Extraction>>,
242    ) -> LangExtractResult<StepResult> {
243        let step = self.config.steps.iter().find(|s| s.id == step_id)
244            .ok_or_else(|| LangExtractError::configuration(format!("Step '{}' not found", step_id)))?;
245
246        let step_start = std::time::Instant::now();
247
248        println!("🔄 Executing step: {} ({})", step.name, step.id);
249
250        // Determine input text for this step
251        let step_input = self.prepare_step_input(step, input_text, context_data)?;
252        let input_count = step_input.len();
253
254        println!("📥 Processing {} input items", input_count);
255
256        let mut all_extractions = Vec::new();
257
258        // Process each input item
259        for (i, input_item) in step_input.iter().enumerate() {
260            println!("  📄 Processing item {}/{}", i + 1, input_count);
261
262            // Create extraction config for this step
263            let step_config = self.config.global_config.clone();
264            // Use step-specific examples if provided, otherwise use global
265            let examples = if step.examples.is_empty() {
266                vec![] // Will need to be provided externally
267            } else {
268                step.examples.clone()
269            };
270
271            match extract(
272                input_item,
273                Some(&step.prompt),
274                &examples,
275                step_config,
276            ).await {
277                Ok(result) => {
278                    if let Some(extractions) = result.extractions {
279                        all_extractions.extend(extractions);
280                    }
281                }
282                Err(e) => {
283                    println!("  ❌ Step '{}' failed on item {}/{}: {}", step.id, i + 1, input_count, e);
284                    return Ok(StepResult {
285                        step_id: step.id.clone(),
286                        step_name: step.name.clone(),
287                        extractions: Vec::new(),
288                        processing_time_ms: step_start.elapsed().as_millis() as u64,
289                        input_count,
290                        success: false,
291                        error_message: Some(e.to_string()),
292                    });
293                }
294            }
295        }
296
297        let processing_time = step_start.elapsed().as_millis() as u64;
298
299        println!("  ✅ Step '{}' completed: {} extractions in {}ms",
300                step.name, all_extractions.len(), processing_time);
301
302        Ok(StepResult {
303            step_id: step.id.clone(),
304            step_name: step.name.clone(),
305            extractions: all_extractions,
306            processing_time_ms: processing_time,
307            input_count,
308            success: true,
309            error_message: None,
310        })
311    }
312
313    /// Prepare input text for a step based on its configuration
314    fn prepare_step_input(
315        &self,
316        step: &PipelineStep,
317        original_text: &str,
318        context_data: &HashMap<String, Vec<Extraction>>,
319    ) -> LangExtractResult<Vec<String>> {
320        // If step has dependencies, use extractions from dependent steps
321        if !step.depends_on.is_empty() {
322            let mut inputs = Vec::new();
323
324            for dep_id in &step.depends_on {
325                if let Some(extractions) = context_data.get(dep_id) {
326                    // Apply filter if specified
327                    let filtered_extractions = self.apply_filter(extractions, &step.filter);
328
329                    for extraction in filtered_extractions {
330                        inputs.push(extraction.extraction_text.clone());
331                    }
332                }
333            }
334
335            Ok(inputs)
336        } else {
337            // First step - use original text
338            Ok(vec![original_text.to_string()])
339        }
340    }
341
342    /// Apply filter to extractions
343    fn apply_filter<'a>(
344        &self,
345        extractions: &'a [Extraction],
346        filter: &Option<PipelineFilter>,
347    ) -> Vec<&'a Extraction> {
348        if let Some(f) = filter {
349            extractions.iter()
350                .filter(|e| {
351                    // Check class filter
352                    if let Some(class) = &f.class_filter {
353                        if e.extraction_class != *class {
354                            return false;
355                        }
356                    }
357
358                    // Check text pattern filter
359                    if let Some(pattern) = &f.text_pattern {
360                        if let Ok(regex) = regex::Regex::new(pattern) {
361                            if !regex.is_match(&e.extraction_text) {
362                                return false;
363                            }
364                        }
365                    }
366
367                    true
368                })
369                .take(f.max_items.unwrap_or(usize::MAX))
370                .collect()
371        } else {
372            extractions.iter().collect()
373        }
374    }
375
376    /// Build the final nested output structure
377    fn build_nested_output(&self, step_results: &[StepResult]) -> LangExtractResult<serde_json::Value> {
378        let mut output = serde_json::Map::new();
379
380        // Group results by step
381        for result in step_results {
382            if result.success {
383                let mut step_output = serde_json::Map::new();
384
385                // Convert extractions to JSON
386                let extractions_json: Vec<serde_json::Value> = result.extractions.iter()
387                    .map(|e| {
388                        let mut obj = serde_json::Map::new();
389                        obj.insert("class".to_string(), serde_json::Value::String(e.extraction_class.clone()));
390                        obj.insert("text".to_string(), serde_json::Value::String(e.extraction_text.clone()));
391                        if let Some(interval) = &e.char_interval {
392                            obj.insert("start".to_string(), serde_json::json!(interval.start_pos));
393                            obj.insert("end".to_string(), serde_json::json!(interval.end_pos));
394                        }
395                        serde_json::Value::Object(obj)
396                    })
397                    .collect();
398
399                step_output.insert("extractions".to_string(), serde_json::Value::Array(extractions_json));
400                step_output.insert("count".to_string(), serde_json::json!(result.extractions.len()));
401                step_output.insert("processing_time_ms".to_string(), serde_json::json!(result.processing_time_ms));
402
403                output.insert(result.step_id.clone(), serde_json::Value::Object(step_output));
404            }
405        }
406
407        Ok(serde_json::Value::Object(output))
408    }
409}
410
411/// Utility functions for pipeline management
412pub mod utils {
413    use super::*;
414
415    /// Create a sample pipeline configuration for requirements extraction
416    pub fn create_requirements_pipeline() -> PipelineConfig {
417        PipelineConfig {
418            name: "Requirements Extraction Pipeline".to_string(),
419            description: "Extract requirements and sub-divide into values, units, and specifications".to_string(),
420            version: "1.0.0".to_string(),
421            global_config: ExtractConfig {
422                model_id: "gemini-2.5-flash".to_string(),
423                api_key: None,
424                format_type: crate::data::FormatType::Json,
425                max_char_buffer: 8000,
426                temperature: 0.3,
427                fence_output: None,
428                use_schema_constraints: true,
429                batch_length: 4,
430                max_workers: 6,
431                additional_context: None,
432                resolver_params: std::collections::HashMap::new(),
433                language_model_params: std::collections::HashMap::new(),
434                debug: false,
435                model_url: None,
436                extraction_passes: 1,
437                enable_multipass: false,
438                multipass_min_extractions: 1,
439                multipass_quality_threshold: 0.3,
440                progress_handler: None,
441            },
442            steps: vec![
443                PipelineStep {
444                    id: "extract_requirements".to_string(),
445                    name: "Extract Requirements".to_string(),
446                    description: "Extract all 'shall' statements and requirements from the document".to_string(),
447                    examples: vec![
448                        ExampleData::new(
449                            "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
450                            vec![
451                                Extraction::new("requirement".to_string(),
452                                    "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string()),
453                            ],
454                        )
455                    ],
456                    prompt: "Extract all requirements, 'shall' statements, and specifications from the text. Include the complete statement.".to_string(),
457                    output_field: "requirements".to_string(),
458                    filter: None,
459                    depends_on: vec![],
460                },
461                PipelineStep {
462                    id: "extract_values".to_string(),
463                    name: "Extract Values".to_string(),
464                    description: "Extract numeric values, units, and specifications from requirements".to_string(),
465                    examples: vec![
466                        ExampleData::new(
467                            "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
468                            vec![
469                                Extraction::new("value".to_string(), "100".to_string()),
470                                Extraction::new("unit".to_string(), "transactions per second".to_string()),
471                                Extraction::new("value".to_string(), "99.9".to_string()),
472                                Extraction::new("unit".to_string(), "%".to_string()),
473                            ],
474                        )
475                    ],
476                    prompt: "From this requirement, extract all numeric values and their associated units or specifications.".to_string(),
477                    output_field: "values".to_string(),
478                    filter: Some(PipelineFilter {
479                        class_filter: Some("requirement".to_string()),
480                        text_pattern: None,
481                        max_items: None,
482                    }),
483                    depends_on: vec!["extract_requirements".to_string()],
484                },
485                PipelineStep {
486                    id: "extract_specifications".to_string(),
487                    name: "Extract Specifications".to_string(),
488                    description: "Extract detailed specifications and constraints from requirements".to_string(),
489                    examples: vec![
490                        ExampleData::new(
491                            "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
492                            vec![
493                                Extraction::new("specification".to_string(), "process 100 transactions per second".to_string()),
494                                Extraction::new("constraint".to_string(), "maintain 99.9% uptime".to_string()),
495                            ],
496                        )
497                    ],
498                    prompt: "Extract detailed specifications, constraints, and performance requirements from this text.".to_string(),
499                    output_field: "specifications".to_string(),
500                    filter: Some(PipelineFilter {
501                        class_filter: Some("requirement".to_string()),
502                        text_pattern: None,
503                        max_items: None,
504                    }),
505                    depends_on: vec!["extract_requirements".to_string()],
506                },
507            ],
508        }
509    }
510
511    /// Save pipeline configuration to YAML file
512    pub fn save_pipeline_to_file(config: &PipelineConfig, path: &std::path::Path) -> LangExtractResult<()> {
513        let yaml_content = serde_yaml::to_string(config)
514            .map_err(|e| LangExtractError::configuration(format!("Failed to serialize pipeline: {}", e)))?;
515
516        std::fs::write(path, yaml_content)
517            .map_err(|e| LangExtractError::configuration(format!("Failed to write pipeline file: {}", e)))?;
518
519        Ok(())
520    }
521
522    /// Load pipeline configuration from YAML file
523    pub fn load_pipeline_from_file(path: &std::path::Path) -> LangExtractResult<PipelineConfig> {
524        let content = std::fs::read_to_string(path)
525            .map_err(|e| LangExtractError::configuration(format!("Failed to read pipeline file: {}", e)))?;
526
527        serde_yaml::from_str(&content)
528            .map_err(|e| LangExtractError::configuration(format!("Failed to parse pipeline YAML: {}", e)))
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535
536    #[test]
537    fn test_pipeline_config_serialization() {
538        let config = utils::create_requirements_pipeline();
539        let yaml = serde_yaml::to_string(&config).unwrap();
540        let deserialized: PipelineConfig = serde_yaml::from_str(&yaml).unwrap();
541
542        assert_eq!(config.name, deserialized.name);
543        assert_eq!(config.steps.len(), deserialized.steps.len());
544    }
545
546    #[test]
547    fn test_dependency_resolution() {
548        let config = utils::create_requirements_pipeline();
549        let executor = PipelineExecutor::new(config);
550
551        let order = executor.resolve_execution_order().unwrap();
552
553        // Should start with step that has no dependencies
554        assert_eq!(order[0], "extract_requirements");
555        // Should include all steps
556        assert_eq!(order.len(), 3);
557    }
558
559    #[test]
560    fn test_filter_application() {
561        let executor = PipelineExecutor::new(utils::create_requirements_pipeline());
562
563        let extractions = vec![
564            Extraction::new("requirement".to_string(), "Test requirement".to_string()),
565            Extraction::new("other".to_string(), "Other text".to_string()),
566        ];
567
568        let filter = PipelineFilter {
569            class_filter: Some("requirement".to_string()),
570            text_pattern: None,
571            max_items: None,
572        };
573
574        let filtered = executor.apply_filter(&extractions, &Some(filter));
575        assert_eq!(filtered.len(), 1);
576        assert_eq!(filtered[0].extraction_class, "requirement");
577    }
578}