1use crate::{
7 data::{ExampleData, Extraction, CharInterval},
8 exceptions::{LangExtractError, LangExtractResult},
9 extract, ExtractConfig,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use futures::future::join_all;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct PipelineStep {
18 pub id: String,
20
21 pub name: String,
23
24 pub description: String,
26
27 pub examples: Vec<ExampleData>,
29
30 pub prompt: String,
32
33 pub output_field: String,
35
36 pub filter: Option<PipelineFilter>,
38
39 pub depends_on: Vec<String>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct PipelineFilter {
46 pub class_filter: Option<String>,
48
49 pub text_pattern: Option<String>,
51
52 pub max_items: Option<usize>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct PipelineConfig {
59 pub name: String,
61
62 pub description: String,
64
65 pub version: String,
67
68 pub steps: Vec<PipelineStep>,
70
71 pub global_config: ExtractConfig,
73
74 #[serde(default)]
76 pub enable_parallel_execution: bool,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct StepResult {
82 pub step_id: String,
84
85 pub step_name: String,
87
88 pub extractions: Vec<Extraction>,
90
91 pub processing_time_ms: u64,
93
94 pub input_count: usize,
96
97 pub success: bool,
99
100 pub error_message: Option<String>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct PipelineResult {
107 pub config: PipelineConfig,
109
110 pub step_results: Vec<StepResult>,
112
113 pub nested_output: serde_json::Value,
115
116 pub total_time_ms: u64,
118
119 pub success: bool,
121
122 pub error_message: Option<String>,
124}
125
126pub struct PipelineExecutor {
128 config: PipelineConfig,
129}
130
131#[derive(Debug, Clone)]
133struct StepInputItem {
134 text: String,
136 parent_start: Option<usize>,
138 parent_end: Option<usize>,
140 parent_step_id: Option<String>,
142 parent_class: Option<String>,
144 parent_text: Option<String>,
146}
147
148impl PipelineExecutor {
149 pub fn new(config: PipelineConfig) -> Self {
151 Self { config }
152 }
153
154 pub fn from_yaml_file(path: &std::path::Path) -> LangExtractResult<Self> {
156 let content = std::fs::read_to_string(path)
157 .map_err(|e| LangExtractError::configuration(format!("Failed to read pipeline file: {}", e)))?;
158
159 let config: PipelineConfig = serde_yaml::from_str(&content)
160 .map_err(|e| LangExtractError::configuration(format!("Failed to parse pipeline YAML: {}", e)))?;
161
162 Ok(Self::new(config))
163 }
164
165 pub async fn execute(&self, input_text: &str) -> LangExtractResult<PipelineResult> {
167 let start_time = std::time::Instant::now();
168
169 println!("🚀 Starting pipeline execution: {}", self.config.name);
170 println!("📝 Description: {}", self.config.description);
171
172 if self.config.enable_parallel_execution {
173 println!("⚡ Parallel execution enabled - independent steps will run concurrently");
174 } else {
175 println!("🔄 Sequential execution - steps will run one after another");
176 }
177
178 if self.config.enable_parallel_execution {
179 self.execute_parallel(input_text, start_time).await
180 } else {
181 self.execute_sequential(input_text, start_time).await
182 }
183 }
184
185 async fn execute_sequential(&self, input_text: &str, start_time: std::time::Instant) -> LangExtractResult<PipelineResult> {
187 let mut step_results = Vec::new();
188 let mut context_data = HashMap::new();
189
190 let execution_order = self.resolve_execution_order()?;
192
193 for step_id in execution_order {
194 let step_result = self.execute_step(&step_id, input_text, &context_data).await?;
195 step_results.push(step_result.clone());
196
197 if step_result.success {
199 context_data.insert(step_id, step_result.extractions.clone());
200 } else {
201 return Err(LangExtractError::configuration(format!(
202 "Step '{}' failed: {}",
203 step_id,
204 step_result.error_message.unwrap_or("Unknown error".to_string())
205 )));
206 }
207 }
208
209 let nested_output = self.build_nested_output(&step_results)?;
211
212 let total_time = start_time.elapsed().as_millis() as u64;
213
214 println!("✅ Pipeline execution completed in {}ms", total_time);
215
216 Ok(PipelineResult {
217 config: self.config.clone(),
218 step_results,
219 nested_output,
220 total_time_ms: total_time,
221 success: true,
222 error_message: None,
223 })
224 }
225
226 async fn execute_parallel(&self, input_text: &str, start_time: std::time::Instant) -> LangExtractResult<PipelineResult> {
228 let mut all_step_results = Vec::new();
229 let mut context_data = HashMap::new();
230
231 let execution_waves = self.resolve_execution_waves()?;
233
234 for (wave_index, wave_steps) in execution_waves.iter().enumerate() {
235 println!("🌊 Executing wave {} with {} steps", wave_index + 1, wave_steps.len());
236
237 if wave_steps.len() == 1 {
238 let step_id = &wave_steps[0];
240 let step_result = self.execute_step(step_id, input_text, &context_data).await?;
241
242 if step_result.success {
243 context_data.insert(step_id.clone(), step_result.extractions.clone());
244 all_step_results.push(step_result);
245 } else {
246 return Err(LangExtractError::configuration(format!(
247 "Step '{}' failed: {}",
248 step_id,
249 step_result.error_message.unwrap_or("Unknown error".to_string())
250 )));
251 }
252 } else {
253 println!("⚡ Running {} steps in parallel", wave_steps.len());
255
256 let parallel_futures: Vec<_> = wave_steps.iter()
257 .map(|step_id| self.execute_step(step_id, input_text, &context_data))
258 .collect();
259
260 let wave_results = join_all(parallel_futures).await;
261
262 for (i, result) in wave_results.into_iter().enumerate() {
264 let step_result = result?;
265 let step_id = &wave_steps[i];
266
267 if step_result.success {
268 context_data.insert(step_id.clone(), step_result.extractions.clone());
269 all_step_results.push(step_result);
270 } else {
271 return Err(LangExtractError::configuration(format!(
272 "Step '{}' failed: {}",
273 step_id,
274 step_result.error_message.unwrap_or("Unknown error".to_string())
275 )));
276 }
277 }
278 }
279 }
280
281 let nested_output = self.build_nested_output(&all_step_results)?;
283
284 let total_time = start_time.elapsed().as_millis() as u64;
285
286 println!("✅ Pipeline execution completed in {}ms", total_time);
287
288 Ok(PipelineResult {
289 config: self.config.clone(),
290 step_results: all_step_results,
291 nested_output,
292 total_time_ms: total_time,
293 success: true,
294 error_message: None,
295 })
296 }
297
298 fn resolve_execution_order(&self) -> LangExtractResult<Vec<String>> {
300 let mut order = Vec::new();
301 let mut visited = std::collections::HashSet::new();
302 let mut visiting = std::collections::HashSet::new();
303
304 for step in &self.config.steps {
305 self.resolve_step_dependencies(&step.id, &mut order, &mut visited, &mut visiting)?;
306 }
307
308 Ok(order)
309 }
310
311 fn resolve_execution_waves(&self) -> LangExtractResult<Vec<Vec<String>>> {
314 let mut waves = Vec::new();
315 let mut completed_steps = std::collections::HashSet::new();
316 let mut remaining_steps: std::collections::HashSet<String> =
317 self.config.steps.iter().map(|s| s.id.clone()).collect();
318
319 while !remaining_steps.is_empty() {
320 let mut current_wave = Vec::new();
321
322 for step in &self.config.steps {
324 if remaining_steps.contains(&step.id) {
325 let dependencies_satisfied = step.depends_on.iter()
326 .all(|dep| completed_steps.contains(dep));
327
328 if dependencies_satisfied {
329 current_wave.push(step.id.clone());
330 }
331 }
332 }
333
334 if current_wave.is_empty() {
335 return Err(LangExtractError::configuration(
337 "Unable to resolve execution waves - possible circular dependency".to_string()
338 ));
339 }
340
341 for step_id in ¤t_wave {
343 remaining_steps.remove(step_id);
344 completed_steps.insert(step_id.clone());
345 }
346
347 waves.push(current_wave);
348 }
349
350 Ok(waves)
351 }
352
353 fn resolve_step_dependencies(
355 &self,
356 step_id: &str,
357 order: &mut Vec<String>,
358 visited: &mut std::collections::HashSet<String>,
359 visiting: &mut std::collections::HashSet<String>,
360 ) -> LangExtractResult<()> {
361 if visited.contains(step_id) {
362 return Ok(());
363 }
364
365 if visiting.contains(step_id) {
366 return Err(LangExtractError::configuration(format!(
367 "Circular dependency detected involving step: {}", step_id
368 )));
369 }
370
371 visiting.insert(step_id.to_string());
372
373 if let Some(step) = self.config.steps.iter().find(|s| s.id == step_id) {
375 for dep in &step.depends_on {
376 self.resolve_step_dependencies(dep, order, visited, visiting)?;
377 }
378 }
379
380 visiting.remove(step_id);
381 visited.insert(step_id.to_string());
382 order.push(step_id.to_string());
383
384 Ok(())
385 }
386
387 async fn execute_step(
389 &self,
390 step_id: &str,
391 input_text: &str,
392 context_data: &HashMap<String, Vec<Extraction>>,
393 ) -> LangExtractResult<StepResult> {
394 let step = self.config.steps.iter().find(|s| s.id == step_id)
395 .ok_or_else(|| LangExtractError::configuration(format!("Step '{}' not found", step_id)))?;
396
397 let step_start = std::time::Instant::now();
398
399 println!("🔄 Executing step: {} ({})", step.name, step.id);
400
401 let step_input = self.prepare_step_input(step, input_text, context_data)?;
403 let input_count = step_input.len();
404
405 println!("📥 Processing {} input items", input_count);
406
407 let mut all_extractions = Vec::new();
408
409 for (i, input_item) in step_input.iter().enumerate() {
411 println!(" 📄 Processing item {}/{}", i + 1, input_count);
412
413 let step_config = self.config.global_config.clone();
415 let examples = if step.examples.is_empty() {
417 vec![] } else {
419 step.examples.clone()
420 };
421
422 match extract(
423 &input_item.text,
424 Some(&step.prompt),
425 &examples,
426 step_config,
427 ).await {
428 Ok(result) => {
429 if let Some(extractions) = result.extractions {
430 for mut ex in extractions {
431 if !step.depends_on.is_empty() {
433 if let Some(parent_start) = input_item.parent_start {
434 let mut abs_interval: Option<CharInterval> = None;
435
436 if let Some(ci) = &ex.char_interval {
438 if let (Some(ls), Some(le)) = (ci.start_pos, ci.end_pos) {
439 abs_interval = Some(CharInterval::new(Some(parent_start + ls), Some(parent_start + le)));
440 }
441 }
442
443 if abs_interval.is_none() {
445 if let Some(found) = input_item.text.find(&ex.extraction_text) {
446 let start = parent_start + found;
447 let end = start + ex.extraction_text.len();
448 abs_interval = Some(CharInterval::new(Some(start), Some(end)));
449 }
450 }
451
452 if let Some(ai) = abs_interval {
453 ex.char_interval = Some(ai);
454 }
455
456 if let Some(parent_step_id) = &input_item.parent_step_id {
458 let mut attrs = ex.attributes.take().unwrap_or_default();
459 attrs.insert(
460 "parent_step_id".to_string(),
461 serde_json::Value::String(parent_step_id.clone()),
462 );
463 if let Some(ps) = input_item.parent_start {
464 attrs.insert(
465 "parent_start".to_string(),
466 serde_json::Value::Number(serde_json::Number::from(ps as u64)),
467 );
468 }
469 if let Some(pe) = input_item.parent_end {
470 attrs.insert(
471 "parent_end".to_string(),
472 serde_json::Value::Number(serde_json::Number::from(pe as u64)),
473 );
474 }
475 if let Some(pc) = &input_item.parent_class {
476 attrs.insert(
477 "parent_class".to_string(),
478 serde_json::Value::String(pc.clone()),
479 );
480 }
481 if let Some(pt) = &input_item.parent_text {
482 attrs.insert(
483 "parent_text".to_string(),
484 serde_json::Value::String(pt.clone()),
485 );
486 }
487 ex.attributes = Some(attrs);
488 }
489 }
490 }
491 all_extractions.push(ex);
492 }
493 }
494 }
495 Err(e) => {
496 println!(" ❌ Step '{}' failed on item {}/{}: {}", step.id, i + 1, input_count, e);
497 return Ok(StepResult {
498 step_id: step.id.clone(),
499 step_name: step.name.clone(),
500 extractions: Vec::new(),
501 processing_time_ms: step_start.elapsed().as_millis() as u64,
502 input_count,
503 success: false,
504 error_message: Some(e.to_string()),
505 });
506 }
507 }
508 }
509
510 let processing_time = step_start.elapsed().as_millis() as u64;
511
512 println!(" ✅ Step '{}' completed: {} extractions in {}ms",
513 step.name, all_extractions.len(), processing_time);
514
515 Ok(StepResult {
516 step_id: step.id.clone(),
517 step_name: step.name.clone(),
518 extractions: all_extractions,
519 processing_time_ms: processing_time,
520 input_count,
521 success: true,
522 error_message: None,
523 })
524 }
525
526 fn prepare_step_input(
528 &self,
529 step: &PipelineStep,
530 original_text: &str,
531 context_data: &HashMap<String, Vec<Extraction>>,
532 ) -> LangExtractResult<Vec<StepInputItem>> {
533 if !step.depends_on.is_empty() {
535 let mut inputs: Vec<StepInputItem> = Vec::new();
536
537 for dep_id in &step.depends_on {
538 if let Some(extractions) = context_data.get(dep_id) {
539 let filtered_extractions = self.apply_filter(extractions, &step.filter);
541
542 for extraction in filtered_extractions {
543 let parent_start = extraction.char_interval.as_ref().and_then(|ci| ci.start_pos);
544 let parent_end = extraction.char_interval.as_ref().and_then(|ci| ci.end_pos);
545 inputs.push(StepInputItem {
546 text: extraction.extraction_text.clone(),
547 parent_start,
548 parent_end,
549 parent_step_id: Some(dep_id.clone()),
550 parent_class: Some(extraction.extraction_class.clone()),
551 parent_text: Some(extraction.extraction_text.clone()),
552 });
553 }
554 }
555 }
556
557 Ok(inputs)
558 } else {
559 Ok(vec![StepInputItem {
561 text: original_text.to_string(),
562 parent_start: Some(0),
563 parent_end: Some(original_text.len()),
564 parent_step_id: None,
565 parent_class: None,
566 parent_text: None,
567 }])
568 }
569 }
570
571 fn apply_filter<'a>(
573 &self,
574 extractions: &'a [Extraction],
575 filter: &Option<PipelineFilter>,
576 ) -> Vec<&'a Extraction> {
577 if let Some(f) = filter {
578 extractions.iter()
579 .filter(|e| {
580 if let Some(class) = &f.class_filter {
582 if e.extraction_class != *class {
583 return false;
584 }
585 }
586
587 if let Some(pattern) = &f.text_pattern {
589 if let Ok(regex) = regex::Regex::new(pattern) {
590 if !regex.is_match(&e.extraction_text) {
591 return false;
592 }
593 }
594 }
595
596 true
597 })
598 .take(f.max_items.unwrap_or(usize::MAX))
599 .collect()
600 } else {
601 extractions.iter().collect()
602 }
603 }
604
605 fn build_nested_output(&self, step_results: &[StepResult]) -> LangExtractResult<serde_json::Value> {
607 let mut output = serde_json::Map::new();
608
609 for result in step_results {
611 if result.success {
612 let mut step_output = serde_json::Map::new();
613
614 let extractions_json: Vec<serde_json::Value> = result.extractions.iter()
616 .map(|e| {
617 let mut obj = serde_json::Map::new();
618 obj.insert("class".to_string(), serde_json::Value::String(e.extraction_class.clone()));
619 obj.insert("text".to_string(), serde_json::Value::String(e.extraction_text.clone()));
620 if let Some(interval) = &e.char_interval {
621 obj.insert("start".to_string(), serde_json::json!(interval.start_pos));
622 obj.insert("end".to_string(), serde_json::json!(interval.end_pos));
623 }
624 serde_json::Value::Object(obj)
625 })
626 .collect();
627
628 step_output.insert("extractions".to_string(), serde_json::Value::Array(extractions_json));
629 step_output.insert("count".to_string(), serde_json::json!(result.extractions.len()));
630 step_output.insert("processing_time_ms".to_string(), serde_json::json!(result.processing_time_ms));
631
632 output.insert(result.step_id.clone(), serde_json::Value::Object(step_output));
633 }
634 }
635
636 Ok(serde_json::Value::Object(output))
637 }
638}
639
640pub mod utils {
642 use super::*;
643
644 pub fn create_requirements_pipeline() -> PipelineConfig {
646 PipelineConfig {
647 name: "Requirements Extraction Pipeline".to_string(),
648 description: "Extract requirements and sub-divide into values, units, and specifications".to_string(),
649 version: "1.0.0".to_string(),
650 enable_parallel_execution: false,
651 global_config: ExtractConfig {
652 model_id: "gemini-2.5-flash".to_string(),
653 api_key: None,
654 format_type: crate::data::FormatType::Json,
655 max_char_buffer: 8000,
656 temperature: 0.3,
657 fence_output: None,
658 use_schema_constraints: true,
659 batch_length: 4,
660 max_workers: 6,
661 additional_context: None,
662 resolver_params: std::collections::HashMap::new(),
663 language_model_params: std::collections::HashMap::new(),
664 debug: false,
665 model_url: None,
666 extraction_passes: 1,
667 enable_multipass: false,
668 multipass_min_extractions: 1,
669 multipass_quality_threshold: 0.3,
670 progress_handler: None,
671 },
672 steps: vec![
673 PipelineStep {
674 id: "extract_requirements".to_string(),
675 name: "Extract Requirements".to_string(),
676 description: "Extract all 'shall' statements and requirements from the document".to_string(),
677 examples: vec![
678 ExampleData::new(
679 "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
680 vec![
681 Extraction::new("requirement".to_string(),
682 "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string()),
683 ],
684 )
685 ],
686 prompt: "Extract all requirements, 'shall' statements, and specifications from the text. Include the complete statement.".to_string(),
687 output_field: "requirements".to_string(),
688 filter: None,
689 depends_on: vec![],
690 },
691 PipelineStep {
692 id: "extract_values".to_string(),
693 name: "Extract Values".to_string(),
694 description: "Extract numeric values, units, and specifications from requirements".to_string(),
695 examples: vec![
696 ExampleData::new(
697 "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
698 vec![
699 Extraction::new("value".to_string(), "100".to_string()),
700 Extraction::new("unit".to_string(), "transactions per second".to_string()),
701 Extraction::new("value".to_string(), "99.9".to_string()),
702 Extraction::new("unit".to_string(), "%".to_string()),
703 ],
704 )
705 ],
706 prompt: "From this requirement, extract all numeric values and their associated units or specifications.".to_string(),
707 output_field: "values".to_string(),
708 filter: Some(PipelineFilter {
709 class_filter: Some("requirement".to_string()),
710 text_pattern: None,
711 max_items: None,
712 }),
713 depends_on: vec!["extract_requirements".to_string()],
714 },
715 PipelineStep {
716 id: "extract_specifications".to_string(),
717 name: "Extract Specifications".to_string(),
718 description: "Extract detailed specifications and constraints from requirements".to_string(),
719 examples: vec![
720 ExampleData::new(
721 "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
722 vec![
723 Extraction::new("specification".to_string(), "process 100 transactions per second".to_string()),
724 Extraction::new("constraint".to_string(), "maintain 99.9% uptime".to_string()),
725 ],
726 )
727 ],
728 prompt: "Extract detailed specifications, constraints, and performance requirements from this text.".to_string(),
729 output_field: "specifications".to_string(),
730 filter: Some(PipelineFilter {
731 class_filter: Some("requirement".to_string()),
732 text_pattern: None,
733 max_items: None,
734 }),
735 depends_on: vec!["extract_requirements".to_string()],
736 },
737 ],
738 }
739 }
740
741 pub fn save_pipeline_to_file(config: &PipelineConfig, path: &std::path::Path) -> LangExtractResult<()> {
743 let yaml_content = serde_yaml::to_string(config)
744 .map_err(|e| LangExtractError::configuration(format!("Failed to serialize pipeline: {}", e)))?;
745
746 std::fs::write(path, yaml_content)
747 .map_err(|e| LangExtractError::configuration(format!("Failed to write pipeline file: {}", e)))?;
748
749 Ok(())
750 }
751
752 pub fn load_pipeline_from_file(path: &std::path::Path) -> LangExtractResult<PipelineConfig> {
754 let content = std::fs::read_to_string(path)
755 .map_err(|e| LangExtractError::configuration(format!("Failed to read pipeline file: {}", e)))?;
756
757 serde_yaml::from_str(&content)
758 .map_err(|e| LangExtractError::configuration(format!("Failed to parse pipeline YAML: {}", e)))
759 }
760}
761
762#[cfg(test)]
763mod tests {
764 use super::*;
765
766 #[test]
767 fn test_pipeline_config_serialization() {
768 let config = utils::create_requirements_pipeline();
769 let yaml = serde_yaml::to_string(&config).unwrap();
770 let deserialized: PipelineConfig = serde_yaml::from_str(&yaml).unwrap();
771
772 assert_eq!(config.name, deserialized.name);
773 assert_eq!(config.steps.len(), deserialized.steps.len());
774 }
775
776 #[test]
777 fn test_dependency_resolution() {
778 let config = utils::create_requirements_pipeline();
779 let executor = PipelineExecutor::new(config);
780
781 let order = executor.resolve_execution_order().unwrap();
782
783 assert_eq!(order[0], "extract_requirements");
785 assert_eq!(order.len(), 3);
787 }
788
789 #[test]
790 fn test_filter_application() {
791 let executor = PipelineExecutor::new(utils::create_requirements_pipeline());
792
793 let extractions = vec![
794 Extraction::new("requirement".to_string(), "Test requirement".to_string()),
795 Extraction::new("other".to_string(), "Other text".to_string()),
796 ];
797
798 let filter = PipelineFilter {
799 class_filter: Some("requirement".to_string()),
800 text_pattern: None,
801 max_items: None,
802 };
803
804 let filtered = executor.apply_filter(&extractions, &Some(filter));
805 assert_eq!(filtered.len(), 1);
806 assert_eq!(filtered[0].extraction_class, "requirement");
807 }
808}