langextract_rust/
pipeline.rs

1//! Pipeline processing for multi-step information extraction.
2//!
3//! This module provides a pipeline system for processing documents through
4//! multiple extraction steps, creating nested hierarchical structures from text.
5
6use crate::{
7    data::{ExampleData, Extraction, CharInterval},
8    exceptions::{LangExtractError, LangExtractResult},
9    extract, ExtractConfig,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use futures::future::join_all;
14
15/// A single step in a processing pipeline
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct PipelineStep {
18    /// Unique identifier for this step
19    pub id: String,
20
21    /// Human-readable name for this step
22    pub name: String,
23
24    /// Description of what this step extracts
25    pub description: String,
26
27    /// Examples for this extraction step
28    pub examples: Vec<ExampleData>,
29
30    /// Extraction prompt/description
31    pub prompt: String,
32
33    /// Output field name for the results of this step
34    pub output_field: String,
35
36    /// Optional filter to only process certain extractions from previous steps
37    pub filter: Option<PipelineFilter>,
38
39    /// Dependencies - this step depends on output from these step IDs
40    pub depends_on: Vec<String>,
41}
42
43/// Filter configuration for processing specific extractions
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct PipelineFilter {
46    /// Filter by extraction class
47    pub class_filter: Option<String>,
48
49    /// Filter by regex pattern on extraction text
50    pub text_pattern: Option<String>,
51
52    /// Maximum number of items to process
53    pub max_items: Option<usize>,
54}
55
56/// Configuration for the entire pipeline
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct PipelineConfig {
59    /// Pipeline name
60    pub name: String,
61
62    /// Pipeline description
63    pub description: String,
64
65    /// Pipeline version
66    pub version: String,
67
68    /// All processing steps
69    pub steps: Vec<PipelineStep>,
70
71    /// Global configuration that applies to all steps
72    pub global_config: ExtractConfig,
73
74    /// Enable parallel execution of independent steps (default: false)
75    #[serde(default)]
76    pub enable_parallel_execution: bool,
77}
78
79/// Results from a single pipeline step
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct StepResult {
82    /// Step ID
83    pub step_id: String,
84
85    /// Step name
86    pub step_name: String,
87
88    /// Extractions produced by this step
89    pub extractions: Vec<Extraction>,
90
91    /// Processing time in milliseconds
92    pub processing_time_ms: u64,
93
94    /// Number of input items processed
95    pub input_count: usize,
96
97    /// Success status
98    pub success: bool,
99
100    /// Error message if failed
101    pub error_message: Option<String>,
102}
103
104/// Complete pipeline execution result
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct PipelineResult {
107    /// Pipeline configuration used
108    pub config: PipelineConfig,
109
110    /// Results from each step
111    pub step_results: Vec<StepResult>,
112
113    /// Final nested output structure
114    pub nested_output: serde_json::Value,
115
116    /// Total processing time
117    pub total_time_ms: u64,
118
119    /// Overall success status
120    pub success: bool,
121
122    /// Error message if pipeline failed
123    pub error_message: Option<String>,
124}
125
126/// Pipeline executor
127pub struct PipelineExecutor {
128    config: PipelineConfig,
129}
130
131/// Internal representation of a step input item including mapping context
132#[derive(Debug, Clone)]
133struct StepInputItem {
134    /// The text to process for this step (original document or parent extraction text)
135    text: String,
136    /// Absolute start offset of this text within the original document, if known
137    parent_start: Option<usize>,
138    /// Absolute end offset of this text within the original document, if known
139    parent_end: Option<usize>,
140    /// The step id of the parent that produced this text, if any
141    parent_step_id: Option<String>,
142    /// The parent extraction class (from step-1)
143    parent_class: Option<String>,
144    /// The parent extraction text (from step-1)
145    parent_text: Option<String>,
146}
147
148impl PipelineExecutor {
149    /// Create a new pipeline executor
150    pub fn new(config: PipelineConfig) -> Self {
151        Self { config }
152    }
153
154    /// Load pipeline configuration from YAML file
155    pub fn from_yaml_file(path: &std::path::Path) -> LangExtractResult<Self> {
156        let content = std::fs::read_to_string(path)
157            .map_err(|e| LangExtractError::configuration(format!("Failed to read pipeline file: {}", e)))?;
158
159        let config: PipelineConfig = serde_yaml::from_str(&content)
160            .map_err(|e| LangExtractError::configuration(format!("Failed to parse pipeline YAML: {}", e)))?;
161
162        Ok(Self::new(config))
163    }
164
165    /// Execute the entire pipeline
166    pub async fn execute(&self, input_text: &str) -> LangExtractResult<PipelineResult> {
167        let start_time = std::time::Instant::now();
168
169        println!("🚀 Starting pipeline execution: {}", self.config.name);
170        println!("📝 Description: {}", self.config.description);
171        
172        if self.config.enable_parallel_execution {
173            println!("⚡ Parallel execution enabled - independent steps will run concurrently");
174        } else {
175            println!("🔄 Sequential execution - steps will run one after another");
176        }
177
178        if self.config.enable_parallel_execution {
179            self.execute_parallel(input_text, start_time).await
180        } else {
181            self.execute_sequential(input_text, start_time).await
182        }
183    }
184
185    /// Execute pipeline sequentially (original behavior)
186    async fn execute_sequential(&self, input_text: &str, start_time: std::time::Instant) -> LangExtractResult<PipelineResult> {
187        let mut step_results = Vec::new();
188        let mut context_data = HashMap::new();
189
190        // Execute steps in dependency order
191        let execution_order = self.resolve_execution_order()?;
192
193        for step_id in execution_order {
194            let step_result = self.execute_step(&step_id, input_text, &context_data).await?;
195            step_results.push(step_result.clone());
196
197            // Store results for dependent steps
198            if step_result.success {
199                context_data.insert(step_id, step_result.extractions.clone());
200            } else {
201                return Err(LangExtractError::configuration(format!(
202                    "Step '{}' failed: {}",
203                    step_id,
204                    step_result.error_message.unwrap_or("Unknown error".to_string())
205                )));
206            }
207        }
208
209        // Build nested output structure
210        let nested_output = self.build_nested_output(&step_results)?;
211
212        let total_time = start_time.elapsed().as_millis() as u64;
213
214        println!("✅ Pipeline execution completed in {}ms", total_time);
215
216        Ok(PipelineResult {
217            config: self.config.clone(),
218            step_results,
219            nested_output,
220            total_time_ms: total_time,
221            success: true,
222            error_message: None,
223        })
224    }
225
226    /// Execute pipeline with parallel execution of independent steps
227    async fn execute_parallel(&self, input_text: &str, start_time: std::time::Instant) -> LangExtractResult<PipelineResult> {
228        let mut all_step_results = Vec::new();
229        let mut context_data = HashMap::new();
230        
231        // Group steps by dependency level
232        let execution_waves = self.resolve_execution_waves()?;
233        
234        for (wave_index, wave_steps) in execution_waves.iter().enumerate() {
235            println!("🌊 Executing wave {} with {} steps", wave_index + 1, wave_steps.len());
236            
237            if wave_steps.len() == 1 {
238                // Single step - execute normally
239                let step_id = &wave_steps[0];
240                let step_result = self.execute_step(step_id, input_text, &context_data).await?;
241                
242                if step_result.success {
243                    context_data.insert(step_id.clone(), step_result.extractions.clone());
244                    all_step_results.push(step_result);
245                } else {
246                    return Err(LangExtractError::configuration(format!(
247                        "Step '{}' failed: {}",
248                        step_id,
249                        step_result.error_message.unwrap_or("Unknown error".to_string())
250                    )));
251                }
252            } else {
253                // Multiple independent steps - execute in parallel
254                println!("⚡ Running {} steps in parallel", wave_steps.len());
255                
256                let parallel_futures: Vec<_> = wave_steps.iter()
257                    .map(|step_id| self.execute_step(step_id, input_text, &context_data))
258                    .collect();
259                
260                let wave_results = join_all(parallel_futures).await;
261                
262                // Process results
263                for (i, result) in wave_results.into_iter().enumerate() {
264                    let step_result = result?;
265                    let step_id = &wave_steps[i];
266                    
267                    if step_result.success {
268                        context_data.insert(step_id.clone(), step_result.extractions.clone());
269                        all_step_results.push(step_result);
270                    } else {
271                        return Err(LangExtractError::configuration(format!(
272                            "Step '{}' failed: {}",
273                            step_id,
274                            step_result.error_message.unwrap_or("Unknown error".to_string())
275                        )));
276                    }
277                }
278            }
279        }
280
281        // Build nested output structure
282        let nested_output = self.build_nested_output(&all_step_results)?;
283
284        let total_time = start_time.elapsed().as_millis() as u64;
285
286        println!("✅ Pipeline execution completed in {}ms", total_time);
287
288        Ok(PipelineResult {
289            config: self.config.clone(),
290            step_results: all_step_results,
291            nested_output,
292            total_time_ms: total_time,
293            success: true,
294            error_message: None,
295        })
296    }
297
298    /// Resolve the execution order based on dependencies
299    fn resolve_execution_order(&self) -> LangExtractResult<Vec<String>> {
300        let mut order = Vec::new();
301        let mut visited = std::collections::HashSet::new();
302        let mut visiting = std::collections::HashSet::new();
303
304        for step in &self.config.steps {
305            self.resolve_step_dependencies(&step.id, &mut order, &mut visited, &mut visiting)?;
306        }
307
308        Ok(order)
309    }
310
311    /// Resolve execution waves for parallel processing
312    /// Groups steps by dependency level - steps in the same wave can run in parallel
313    fn resolve_execution_waves(&self) -> LangExtractResult<Vec<Vec<String>>> {
314        let mut waves = Vec::new();
315        let mut completed_steps = std::collections::HashSet::new();
316        let mut remaining_steps: std::collections::HashSet<String> = 
317            self.config.steps.iter().map(|s| s.id.clone()).collect();
318
319        while !remaining_steps.is_empty() {
320            let mut current_wave = Vec::new();
321            
322            // Find all steps whose dependencies are satisfied
323            for step in &self.config.steps {
324                if remaining_steps.contains(&step.id) {
325                    let dependencies_satisfied = step.depends_on.iter()
326                        .all(|dep| completed_steps.contains(dep));
327                    
328                    if dependencies_satisfied {
329                        current_wave.push(step.id.clone());
330                    }
331                }
332            }
333            
334            if current_wave.is_empty() {
335                // This shouldn't happen if there are no circular dependencies
336                return Err(LangExtractError::configuration(
337                    "Unable to resolve execution waves - possible circular dependency".to_string()
338                ));
339            }
340            
341            // Remove steps from remaining and add to completed
342            for step_id in &current_wave {
343                remaining_steps.remove(step_id);
344                completed_steps.insert(step_id.clone());
345            }
346            
347            waves.push(current_wave);
348        }
349
350        Ok(waves)
351    }
352
353    /// Recursive function to resolve dependencies
354    fn resolve_step_dependencies(
355        &self,
356        step_id: &str,
357        order: &mut Vec<String>,
358        visited: &mut std::collections::HashSet<String>,
359        visiting: &mut std::collections::HashSet<String>,
360    ) -> LangExtractResult<()> {
361        if visited.contains(step_id) {
362            return Ok(());
363        }
364
365        if visiting.contains(step_id) {
366            return Err(LangExtractError::configuration(format!(
367                "Circular dependency detected involving step: {}", step_id
368            )));
369        }
370
371        visiting.insert(step_id.to_string());
372
373        // Find the step and process its dependencies
374        if let Some(step) = self.config.steps.iter().find(|s| s.id == step_id) {
375            for dep in &step.depends_on {
376                self.resolve_step_dependencies(dep, order, visited, visiting)?;
377            }
378        }
379
380        visiting.remove(step_id);
381        visited.insert(step_id.to_string());
382        order.push(step_id.to_string());
383
384        Ok(())
385    }
386
387    /// Execute a single pipeline step
388    async fn execute_step(
389        &self,
390        step_id: &str,
391        input_text: &str,
392        context_data: &HashMap<String, Vec<Extraction>>,
393    ) -> LangExtractResult<StepResult> {
394        let step = self.config.steps.iter().find(|s| s.id == step_id)
395            .ok_or_else(|| LangExtractError::configuration(format!("Step '{}' not found", step_id)))?;
396
397        let step_start = std::time::Instant::now();
398
399        println!("🔄 Executing step: {} ({})", step.name, step.id);
400
401        // Determine input text for this step with mapping context
402        let step_input = self.prepare_step_input(step, input_text, context_data)?;
403        let input_count = step_input.len();
404
405        println!("📥 Processing {} input items", input_count);
406
407        let mut all_extractions = Vec::new();
408
409        // Process each input item
410        for (i, input_item) in step_input.iter().enumerate() {
411            println!("  📄 Processing item {}/{}", i + 1, input_count);
412
413            // Create extraction config for this step
414            let step_config = self.config.global_config.clone();
415            // Use step-specific examples if provided, otherwise use global
416            let examples = if step.examples.is_empty() {
417                vec![] // Will need to be provided externally
418            } else {
419                step.examples.clone()
420            };
421
422            match extract(
423                &input_item.text,
424                Some(&step.prompt),
425                &examples,
426                step_config,
427            ).await {
428                Ok(result) => {
429                    if let Some(extractions) = result.extractions {
430                        for mut ex in extractions {
431                            // For dependent steps, transform local intervals to absolute using parent start
432                            if !step.depends_on.is_empty() {
433                                if let Some(parent_start) = input_item.parent_start {
434                                    let mut abs_interval: Option<CharInterval> = None;
435
436                                    // If model returned local positions relative to subtext, map them
437                                    if let Some(ci) = &ex.char_interval {
438                                        if let (Some(ls), Some(le)) = (ci.start_pos, ci.end_pos) {
439                                            abs_interval = Some(CharInterval::new(Some(parent_start + ls), Some(parent_start + le)));
440                                        }
441                                    }
442
443                                    // Fallback: exact substring match within subtext
444                                    if abs_interval.is_none() {
445                                        if let Some(found) = input_item.text.find(&ex.extraction_text) {
446                                            let start = parent_start + found;
447                                            let end = start + ex.extraction_text.len();
448                                            abs_interval = Some(CharInterval::new(Some(start), Some(end)));
449                                        }
450                                    }
451
452                                    if let Some(ai) = abs_interval {
453                                        ex.char_interval = Some(ai);
454                                    }
455
456                                    // Annotate with parent metadata for downstream linkage
457                                    if let Some(parent_step_id) = &input_item.parent_step_id {
458                                        let mut attrs = ex.attributes.take().unwrap_or_default();
459                                        attrs.insert(
460                                            "parent_step_id".to_string(),
461                                            serde_json::Value::String(parent_step_id.clone()),
462                                        );
463                                        if let Some(ps) = input_item.parent_start {
464                                            attrs.insert(
465                                                "parent_start".to_string(),
466                                                serde_json::Value::Number(serde_json::Number::from(ps as u64)),
467                                            );
468                                        }
469                                        if let Some(pe) = input_item.parent_end {
470                                            attrs.insert(
471                                                "parent_end".to_string(),
472                                                serde_json::Value::Number(serde_json::Number::from(pe as u64)),
473                                            );
474                                        }
475                                        if let Some(pc) = &input_item.parent_class {
476                                            attrs.insert(
477                                                "parent_class".to_string(),
478                                                serde_json::Value::String(pc.clone()),
479                                            );
480                                        }
481                                        if let Some(pt) = &input_item.parent_text {
482                                            attrs.insert(
483                                                "parent_text".to_string(),
484                                                serde_json::Value::String(pt.clone()),
485                                            );
486                                        }
487                                        ex.attributes = Some(attrs);
488                                    }
489                                }
490                            }
491                            all_extractions.push(ex);
492                        }
493                    }
494                }
495                Err(e) => {
496                    println!("  ❌ Step '{}' failed on item {}/{}: {}", step.id, i + 1, input_count, e);
497                    return Ok(StepResult {
498                        step_id: step.id.clone(),
499                        step_name: step.name.clone(),
500                        extractions: Vec::new(),
501                        processing_time_ms: step_start.elapsed().as_millis() as u64,
502                        input_count,
503                        success: false,
504                        error_message: Some(e.to_string()),
505                    });
506                }
507            }
508        }
509
510        let processing_time = step_start.elapsed().as_millis() as u64;
511
512        println!("  ✅ Step '{}' completed: {} extractions in {}ms",
513                step.name, all_extractions.len(), processing_time);
514
515        Ok(StepResult {
516            step_id: step.id.clone(),
517            step_name: step.name.clone(),
518            extractions: all_extractions,
519            processing_time_ms: processing_time,
520            input_count,
521            success: true,
522            error_message: None,
523        })
524    }
525
526    /// Prepare input text for a step based on its configuration
527    fn prepare_step_input(
528        &self,
529        step: &PipelineStep,
530        original_text: &str,
531        context_data: &HashMap<String, Vec<Extraction>>,
532    ) -> LangExtractResult<Vec<StepInputItem>> {
533        // If step has dependencies, use extractions from dependent steps
534        if !step.depends_on.is_empty() {
535            let mut inputs: Vec<StepInputItem> = Vec::new();
536
537            for dep_id in &step.depends_on {
538                if let Some(extractions) = context_data.get(dep_id) {
539                    // Apply filter if specified
540                    let filtered_extractions = self.apply_filter(extractions, &step.filter);
541
542                    for extraction in filtered_extractions {
543                        let parent_start = extraction.char_interval.as_ref().and_then(|ci| ci.start_pos);
544                        let parent_end = extraction.char_interval.as_ref().and_then(|ci| ci.end_pos);
545                        inputs.push(StepInputItem {
546                            text: extraction.extraction_text.clone(),
547                            parent_start,
548                            parent_end,
549                            parent_step_id: Some(dep_id.clone()),
550                            parent_class: Some(extraction.extraction_class.clone()),
551                            parent_text: Some(extraction.extraction_text.clone()),
552                        });
553                    }
554                }
555            }
556
557            Ok(inputs)
558        } else {
559            // First step - use original text
560            Ok(vec![StepInputItem {
561                text: original_text.to_string(),
562                parent_start: Some(0),
563                parent_end: Some(original_text.len()),
564                parent_step_id: None,
565                parent_class: None,
566                parent_text: None,
567            }])
568        }
569    }
570
571    /// Apply filter to extractions
572    fn apply_filter<'a>(
573        &self,
574        extractions: &'a [Extraction],
575        filter: &Option<PipelineFilter>,
576    ) -> Vec<&'a Extraction> {
577        if let Some(f) = filter {
578            extractions.iter()
579                .filter(|e| {
580                    // Check class filter
581                    if let Some(class) = &f.class_filter {
582                        if e.extraction_class != *class {
583                            return false;
584                        }
585                    }
586
587                    // Check text pattern filter
588                    if let Some(pattern) = &f.text_pattern {
589                        if let Ok(regex) = regex::Regex::new(pattern) {
590                            if !regex.is_match(&e.extraction_text) {
591                                return false;
592                            }
593                        }
594                    }
595
596                    true
597                })
598                .take(f.max_items.unwrap_or(usize::MAX))
599                .collect()
600        } else {
601            extractions.iter().collect()
602        }
603    }
604
605    /// Build the final nested output structure
606    fn build_nested_output(&self, step_results: &[StepResult]) -> LangExtractResult<serde_json::Value> {
607        let mut output = serde_json::Map::new();
608
609        // Group results by step
610        for result in step_results {
611            if result.success {
612                let mut step_output = serde_json::Map::new();
613
614                // Convert extractions to JSON
615                let extractions_json: Vec<serde_json::Value> = result.extractions.iter()
616                    .map(|e| {
617                        let mut obj = serde_json::Map::new();
618                        obj.insert("class".to_string(), serde_json::Value::String(e.extraction_class.clone()));
619                        obj.insert("text".to_string(), serde_json::Value::String(e.extraction_text.clone()));
620                        if let Some(interval) = &e.char_interval {
621                            obj.insert("start".to_string(), serde_json::json!(interval.start_pos));
622                            obj.insert("end".to_string(), serde_json::json!(interval.end_pos));
623                        }
624                        serde_json::Value::Object(obj)
625                    })
626                    .collect();
627
628                step_output.insert("extractions".to_string(), serde_json::Value::Array(extractions_json));
629                step_output.insert("count".to_string(), serde_json::json!(result.extractions.len()));
630                step_output.insert("processing_time_ms".to_string(), serde_json::json!(result.processing_time_ms));
631
632                output.insert(result.step_id.clone(), serde_json::Value::Object(step_output));
633            }
634        }
635
636        Ok(serde_json::Value::Object(output))
637    }
638}
639
640/// Utility functions for pipeline management
641pub mod utils {
642    use super::*;
643
644    /// Create a sample pipeline configuration for requirements extraction
645    pub fn create_requirements_pipeline() -> PipelineConfig {
646        PipelineConfig {
647            name: "Requirements Extraction Pipeline".to_string(),
648            description: "Extract requirements and sub-divide into values, units, and specifications".to_string(),
649            version: "1.0.0".to_string(),
650            enable_parallel_execution: false,
651            global_config: ExtractConfig {
652                model_id: "gemini-2.5-flash".to_string(),
653                api_key: None,
654                format_type: crate::data::FormatType::Json,
655                max_char_buffer: 8000,
656                temperature: 0.3,
657                fence_output: None,
658                use_schema_constraints: true,
659                batch_length: 4,
660                max_workers: 6,
661                additional_context: None,
662                resolver_params: std::collections::HashMap::new(),
663                language_model_params: std::collections::HashMap::new(),
664                debug: false,
665                model_url: None,
666                extraction_passes: 1,
667                enable_multipass: false,
668                multipass_min_extractions: 1,
669                multipass_quality_threshold: 0.3,
670                progress_handler: None,
671            },
672            steps: vec![
673                PipelineStep {
674                    id: "extract_requirements".to_string(),
675                    name: "Extract Requirements".to_string(),
676                    description: "Extract all 'shall' statements and requirements from the document".to_string(),
677                    examples: vec![
678                        ExampleData::new(
679                            "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
680                            vec![
681                                Extraction::new("requirement".to_string(),
682                                    "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string()),
683                            ],
684                        )
685                    ],
686                    prompt: "Extract all requirements, 'shall' statements, and specifications from the text. Include the complete statement.".to_string(),
687                    output_field: "requirements".to_string(),
688                    filter: None,
689                    depends_on: vec![],
690                },
691                PipelineStep {
692                    id: "extract_values".to_string(),
693                    name: "Extract Values".to_string(),
694                    description: "Extract numeric values, units, and specifications from requirements".to_string(),
695                    examples: vec![
696                        ExampleData::new(
697                            "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
698                            vec![
699                                Extraction::new("value".to_string(), "100".to_string()),
700                                Extraction::new("unit".to_string(), "transactions per second".to_string()),
701                                Extraction::new("value".to_string(), "99.9".to_string()),
702                                Extraction::new("unit".to_string(), "%".to_string()),
703                            ],
704                        )
705                    ],
706                    prompt: "From this requirement, extract all numeric values and their associated units or specifications.".to_string(),
707                    output_field: "values".to_string(),
708                    filter: Some(PipelineFilter {
709                        class_filter: Some("requirement".to_string()),
710                        text_pattern: None,
711                        max_items: None,
712                    }),
713                    depends_on: vec!["extract_requirements".to_string()],
714                },
715                PipelineStep {
716                    id: "extract_specifications".to_string(),
717                    name: "Extract Specifications".to_string(),
718                    description: "Extract detailed specifications and constraints from requirements".to_string(),
719                    examples: vec![
720                        ExampleData::new(
721                            "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
722                            vec![
723                                Extraction::new("specification".to_string(), "process 100 transactions per second".to_string()),
724                                Extraction::new("constraint".to_string(), "maintain 99.9% uptime".to_string()),
725                            ],
726                        )
727                    ],
728                    prompt: "Extract detailed specifications, constraints, and performance requirements from this text.".to_string(),
729                    output_field: "specifications".to_string(),
730                    filter: Some(PipelineFilter {
731                        class_filter: Some("requirement".to_string()),
732                        text_pattern: None,
733                        max_items: None,
734                    }),
735                    depends_on: vec!["extract_requirements".to_string()],
736                },
737            ],
738        }
739    }
740
741    /// Save pipeline configuration to YAML file
742    pub fn save_pipeline_to_file(config: &PipelineConfig, path: &std::path::Path) -> LangExtractResult<()> {
743        let yaml_content = serde_yaml::to_string(config)
744            .map_err(|e| LangExtractError::configuration(format!("Failed to serialize pipeline: {}", e)))?;
745
746        std::fs::write(path, yaml_content)
747            .map_err(|e| LangExtractError::configuration(format!("Failed to write pipeline file: {}", e)))?;
748
749        Ok(())
750    }
751
752    /// Load pipeline configuration from YAML file
753    pub fn load_pipeline_from_file(path: &std::path::Path) -> LangExtractResult<PipelineConfig> {
754        let content = std::fs::read_to_string(path)
755            .map_err(|e| LangExtractError::configuration(format!("Failed to read pipeline file: {}", e)))?;
756
757        serde_yaml::from_str(&content)
758            .map_err(|e| LangExtractError::configuration(format!("Failed to parse pipeline YAML: {}", e)))
759    }
760}
761
762#[cfg(test)]
763mod tests {
764    use super::*;
765
766    #[test]
767    fn test_pipeline_config_serialization() {
768        let config = utils::create_requirements_pipeline();
769        let yaml = serde_yaml::to_string(&config).unwrap();
770        let deserialized: PipelineConfig = serde_yaml::from_str(&yaml).unwrap();
771
772        assert_eq!(config.name, deserialized.name);
773        assert_eq!(config.steps.len(), deserialized.steps.len());
774    }
775
776    #[test]
777    fn test_dependency_resolution() {
778        let config = utils::create_requirements_pipeline();
779        let executor = PipelineExecutor::new(config);
780
781        let order = executor.resolve_execution_order().unwrap();
782
783        // Should start with step that has no dependencies
784        assert_eq!(order[0], "extract_requirements");
785        // Should include all steps
786        assert_eq!(order.len(), 3);
787    }
788
789    #[test]
790    fn test_filter_application() {
791        let executor = PipelineExecutor::new(utils::create_requirements_pipeline());
792
793        let extractions = vec![
794            Extraction::new("requirement".to_string(), "Test requirement".to_string()),
795            Extraction::new("other".to_string(), "Other text".to_string()),
796        ];
797
798        let filter = PipelineFilter {
799            class_filter: Some("requirement".to_string()),
800            text_pattern: None,
801            max_items: None,
802        };
803
804        let filtered = executor.apply_filter(&extractions, &Some(filter));
805        assert_eq!(filtered.len(), 1);
806        assert_eq!(filtered[0].extraction_class, "requirement");
807    }
808}