1use crate::{
7 data::{ExampleData, Extraction},
8 exceptions::{LangExtractError, LangExtractResult},
9 extract, ExtractConfig,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PipelineStep {
17 pub id: String,
19
20 pub name: String,
22
23 pub description: String,
25
26 pub examples: Vec<ExampleData>,
28
29 pub prompt: String,
31
32 pub output_field: String,
34
35 pub filter: Option<PipelineFilter>,
37
38 pub depends_on: Vec<String>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct PipelineFilter {
45 pub class_filter: Option<String>,
47
48 pub text_pattern: Option<String>,
50
51 pub max_items: Option<usize>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct PipelineConfig {
58 pub name: String,
60
61 pub description: String,
63
64 pub version: String,
66
67 pub steps: Vec<PipelineStep>,
69
70 pub global_config: ExtractConfig,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct StepResult {
77 pub step_id: String,
79
80 pub step_name: String,
82
83 pub extractions: Vec<Extraction>,
85
86 pub processing_time_ms: u64,
88
89 pub input_count: usize,
91
92 pub success: bool,
94
95 pub error_message: Option<String>,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct PipelineResult {
102 pub config: PipelineConfig,
104
105 pub step_results: Vec<StepResult>,
107
108 pub nested_output: serde_json::Value,
110
111 pub total_time_ms: u64,
113
114 pub success: bool,
116
117 pub error_message: Option<String>,
119}
120
121pub struct PipelineExecutor {
123 config: PipelineConfig,
124}
125
126impl PipelineExecutor {
127 pub fn new(config: PipelineConfig) -> Self {
129 Self { config }
130 }
131
132 pub fn from_yaml_file(path: &std::path::Path) -> LangExtractResult<Self> {
134 let content = std::fs::read_to_string(path)
135 .map_err(|e| LangExtractError::configuration(format!("Failed to read pipeline file: {}", e)))?;
136
137 let config: PipelineConfig = serde_yaml::from_str(&content)
138 .map_err(|e| LangExtractError::configuration(format!("Failed to parse pipeline YAML: {}", e)))?;
139
140 Ok(Self::new(config))
141 }
142
143 pub async fn execute(&self, input_text: &str) -> LangExtractResult<PipelineResult> {
145 let start_time = std::time::Instant::now();
146
147 println!("🚀 Starting pipeline execution: {}", self.config.name);
148 println!("📝 Description: {}", self.config.description);
149
150 let mut step_results = Vec::new();
151 let mut context_data = HashMap::new();
152
153 let execution_order = self.resolve_execution_order()?;
155
156 for step_id in execution_order {
157 let step_result = self.execute_step(&step_id, input_text, &context_data).await?;
158 step_results.push(step_result.clone());
159
160 if step_result.success {
162 context_data.insert(step_id, step_result.extractions.clone());
163 } else {
164 return Err(LangExtractError::configuration(format!(
165 "Step '{}' failed: {}",
166 step_id,
167 step_result.error_message.unwrap_or("Unknown error".to_string())
168 )));
169 }
170 }
171
172 let nested_output = self.build_nested_output(&step_results)?;
174
175 let total_time = start_time.elapsed().as_millis() as u64;
176
177 println!("✅ Pipeline execution completed in {}ms", total_time);
178
179 Ok(PipelineResult {
180 config: self.config.clone(),
181 step_results,
182 nested_output,
183 total_time_ms: total_time,
184 success: true,
185 error_message: None,
186 })
187 }
188
189 fn resolve_execution_order(&self) -> LangExtractResult<Vec<String>> {
191 let mut order = Vec::new();
192 let mut visited = std::collections::HashSet::new();
193 let mut visiting = std::collections::HashSet::new();
194
195 for step in &self.config.steps {
196 self.resolve_step_dependencies(&step.id, &mut order, &mut visited, &mut visiting)?;
197 }
198
199 Ok(order)
200 }
201
202 fn resolve_step_dependencies(
204 &self,
205 step_id: &str,
206 order: &mut Vec<String>,
207 visited: &mut std::collections::HashSet<String>,
208 visiting: &mut std::collections::HashSet<String>,
209 ) -> LangExtractResult<()> {
210 if visited.contains(step_id) {
211 return Ok(());
212 }
213
214 if visiting.contains(step_id) {
215 return Err(LangExtractError::configuration(format!(
216 "Circular dependency detected involving step: {}", step_id
217 )));
218 }
219
220 visiting.insert(step_id.to_string());
221
222 if let Some(step) = self.config.steps.iter().find(|s| s.id == step_id) {
224 for dep in &step.depends_on {
225 self.resolve_step_dependencies(dep, order, visited, visiting)?;
226 }
227 }
228
229 visiting.remove(step_id);
230 visited.insert(step_id.to_string());
231 order.push(step_id.to_string());
232
233 Ok(())
234 }
235
236 async fn execute_step(
238 &self,
239 step_id: &str,
240 input_text: &str,
241 context_data: &HashMap<String, Vec<Extraction>>,
242 ) -> LangExtractResult<StepResult> {
243 let step = self.config.steps.iter().find(|s| s.id == step_id)
244 .ok_or_else(|| LangExtractError::configuration(format!("Step '{}' not found", step_id)))?;
245
246 let step_start = std::time::Instant::now();
247
248 println!("🔄 Executing step: {} ({})", step.name, step.id);
249
250 let step_input = self.prepare_step_input(step, input_text, context_data)?;
252 let input_count = step_input.len();
253
254 println!("📥 Processing {} input items", input_count);
255
256 let mut all_extractions = Vec::new();
257
258 for (i, input_item) in step_input.iter().enumerate() {
260 println!(" 📄 Processing item {}/{}", i + 1, input_count);
261
262 let step_config = self.config.global_config.clone();
264 let examples = if step.examples.is_empty() {
266 vec![] } else {
268 step.examples.clone()
269 };
270
271 match extract(
272 input_item,
273 Some(&step.prompt),
274 &examples,
275 step_config,
276 ).await {
277 Ok(result) => {
278 if let Some(extractions) = result.extractions {
279 all_extractions.extend(extractions);
280 }
281 }
282 Err(e) => {
283 println!(" ❌ Step '{}' failed on item {}/{}: {}", step.id, i + 1, input_count, e);
284 return Ok(StepResult {
285 step_id: step.id.clone(),
286 step_name: step.name.clone(),
287 extractions: Vec::new(),
288 processing_time_ms: step_start.elapsed().as_millis() as u64,
289 input_count,
290 success: false,
291 error_message: Some(e.to_string()),
292 });
293 }
294 }
295 }
296
297 let processing_time = step_start.elapsed().as_millis() as u64;
298
299 println!(" ✅ Step '{}' completed: {} extractions in {}ms",
300 step.name, all_extractions.len(), processing_time);
301
302 Ok(StepResult {
303 step_id: step.id.clone(),
304 step_name: step.name.clone(),
305 extractions: all_extractions,
306 processing_time_ms: processing_time,
307 input_count,
308 success: true,
309 error_message: None,
310 })
311 }
312
313 fn prepare_step_input(
315 &self,
316 step: &PipelineStep,
317 original_text: &str,
318 context_data: &HashMap<String, Vec<Extraction>>,
319 ) -> LangExtractResult<Vec<String>> {
320 if !step.depends_on.is_empty() {
322 let mut inputs = Vec::new();
323
324 for dep_id in &step.depends_on {
325 if let Some(extractions) = context_data.get(dep_id) {
326 let filtered_extractions = self.apply_filter(extractions, &step.filter);
328
329 for extraction in filtered_extractions {
330 inputs.push(extraction.extraction_text.clone());
331 }
332 }
333 }
334
335 Ok(inputs)
336 } else {
337 Ok(vec![original_text.to_string()])
339 }
340 }
341
342 fn apply_filter<'a>(
344 &self,
345 extractions: &'a [Extraction],
346 filter: &Option<PipelineFilter>,
347 ) -> Vec<&'a Extraction> {
348 if let Some(f) = filter {
349 extractions.iter()
350 .filter(|e| {
351 if let Some(class) = &f.class_filter {
353 if e.extraction_class != *class {
354 return false;
355 }
356 }
357
358 if let Some(pattern) = &f.text_pattern {
360 if let Ok(regex) = regex::Regex::new(pattern) {
361 if !regex.is_match(&e.extraction_text) {
362 return false;
363 }
364 }
365 }
366
367 true
368 })
369 .take(f.max_items.unwrap_or(usize::MAX))
370 .collect()
371 } else {
372 extractions.iter().collect()
373 }
374 }
375
376 fn build_nested_output(&self, step_results: &[StepResult]) -> LangExtractResult<serde_json::Value> {
378 let mut output = serde_json::Map::new();
379
380 for result in step_results {
382 if result.success {
383 let mut step_output = serde_json::Map::new();
384
385 let extractions_json: Vec<serde_json::Value> = result.extractions.iter()
387 .map(|e| {
388 let mut obj = serde_json::Map::new();
389 obj.insert("class".to_string(), serde_json::Value::String(e.extraction_class.clone()));
390 obj.insert("text".to_string(), serde_json::Value::String(e.extraction_text.clone()));
391 if let Some(interval) = &e.char_interval {
392 obj.insert("start".to_string(), serde_json::json!(interval.start_pos));
393 obj.insert("end".to_string(), serde_json::json!(interval.end_pos));
394 }
395 serde_json::Value::Object(obj)
396 })
397 .collect();
398
399 step_output.insert("extractions".to_string(), serde_json::Value::Array(extractions_json));
400 step_output.insert("count".to_string(), serde_json::json!(result.extractions.len()));
401 step_output.insert("processing_time_ms".to_string(), serde_json::json!(result.processing_time_ms));
402
403 output.insert(result.step_id.clone(), serde_json::Value::Object(step_output));
404 }
405 }
406
407 Ok(serde_json::Value::Object(output))
408 }
409}
410
411pub mod utils {
413 use super::*;
414
415 pub fn create_requirements_pipeline() -> PipelineConfig {
417 PipelineConfig {
418 name: "Requirements Extraction Pipeline".to_string(),
419 description: "Extract requirements and sub-divide into values, units, and specifications".to_string(),
420 version: "1.0.0".to_string(),
421 global_config: ExtractConfig {
422 model_id: "gemini-2.5-flash".to_string(),
423 api_key: None,
424 format_type: crate::data::FormatType::Json,
425 max_char_buffer: 8000,
426 temperature: 0.3,
427 fence_output: None,
428 use_schema_constraints: true,
429 batch_length: 4,
430 max_workers: 6,
431 additional_context: None,
432 resolver_params: std::collections::HashMap::new(),
433 language_model_params: std::collections::HashMap::new(),
434 debug: false,
435 model_url: None,
436 extraction_passes: 1,
437 enable_multipass: false,
438 multipass_min_extractions: 1,
439 multipass_quality_threshold: 0.3,
440 progress_handler: None,
441 },
442 steps: vec![
443 PipelineStep {
444 id: "extract_requirements".to_string(),
445 name: "Extract Requirements".to_string(),
446 description: "Extract all 'shall' statements and requirements from the document".to_string(),
447 examples: vec![
448 ExampleData::new(
449 "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
450 vec![
451 Extraction::new("requirement".to_string(),
452 "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string()),
453 ],
454 )
455 ],
456 prompt: "Extract all requirements, 'shall' statements, and specifications from the text. Include the complete statement.".to_string(),
457 output_field: "requirements".to_string(),
458 filter: None,
459 depends_on: vec![],
460 },
461 PipelineStep {
462 id: "extract_values".to_string(),
463 name: "Extract Values".to_string(),
464 description: "Extract numeric values, units, and specifications from requirements".to_string(),
465 examples: vec![
466 ExampleData::new(
467 "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
468 vec![
469 Extraction::new("value".to_string(), "100".to_string()),
470 Extraction::new("unit".to_string(), "transactions per second".to_string()),
471 Extraction::new("value".to_string(), "99.9".to_string()),
472 Extraction::new("unit".to_string(), "%".to_string()),
473 ],
474 )
475 ],
476 prompt: "From this requirement, extract all numeric values and their associated units or specifications.".to_string(),
477 output_field: "values".to_string(),
478 filter: Some(PipelineFilter {
479 class_filter: Some("requirement".to_string()),
480 text_pattern: None,
481 max_items: None,
482 }),
483 depends_on: vec!["extract_requirements".to_string()],
484 },
485 PipelineStep {
486 id: "extract_specifications".to_string(),
487 name: "Extract Specifications".to_string(),
488 description: "Extract detailed specifications and constraints from requirements".to_string(),
489 examples: vec![
490 ExampleData::new(
491 "The system shall process 100 transactions per second and maintain 99.9% uptime.".to_string(),
492 vec![
493 Extraction::new("specification".to_string(), "process 100 transactions per second".to_string()),
494 Extraction::new("constraint".to_string(), "maintain 99.9% uptime".to_string()),
495 ],
496 )
497 ],
498 prompt: "Extract detailed specifications, constraints, and performance requirements from this text.".to_string(),
499 output_field: "specifications".to_string(),
500 filter: Some(PipelineFilter {
501 class_filter: Some("requirement".to_string()),
502 text_pattern: None,
503 max_items: None,
504 }),
505 depends_on: vec!["extract_requirements".to_string()],
506 },
507 ],
508 }
509 }
510
511 pub fn save_pipeline_to_file(config: &PipelineConfig, path: &std::path::Path) -> LangExtractResult<()> {
513 let yaml_content = serde_yaml::to_string(config)
514 .map_err(|e| LangExtractError::configuration(format!("Failed to serialize pipeline: {}", e)))?;
515
516 std::fs::write(path, yaml_content)
517 .map_err(|e| LangExtractError::configuration(format!("Failed to write pipeline file: {}", e)))?;
518
519 Ok(())
520 }
521
522 pub fn load_pipeline_from_file(path: &std::path::Path) -> LangExtractResult<PipelineConfig> {
524 let content = std::fs::read_to_string(path)
525 .map_err(|e| LangExtractError::configuration(format!("Failed to read pipeline file: {}", e)))?;
526
527 serde_yaml::from_str(&content)
528 .map_err(|e| LangExtractError::configuration(format!("Failed to parse pipeline YAML: {}", e)))
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
537 fn test_pipeline_config_serialization() {
538 let config = utils::create_requirements_pipeline();
539 let yaml = serde_yaml::to_string(&config).unwrap();
540 let deserialized: PipelineConfig = serde_yaml::from_str(&yaml).unwrap();
541
542 assert_eq!(config.name, deserialized.name);
543 assert_eq!(config.steps.len(), deserialized.steps.len());
544 }
545
546 #[test]
547 fn test_dependency_resolution() {
548 let config = utils::create_requirements_pipeline();
549 let executor = PipelineExecutor::new(config);
550
551 let order = executor.resolve_execution_order().unwrap();
552
553 assert_eq!(order[0], "extract_requirements");
555 assert_eq!(order.len(), 3);
557 }
558
559 #[test]
560 fn test_filter_application() {
561 let executor = PipelineExecutor::new(utils::create_requirements_pipeline());
562
563 let extractions = vec![
564 Extraction::new("requirement".to_string(), "Test requirement".to_string()),
565 Extraction::new("other".to_string(), "Other text".to_string()),
566 ];
567
568 let filter = PipelineFilter {
569 class_filter: Some("requirement".to_string()),
570 text_pattern: None,
571 max_items: None,
572 };
573
574 let filtered = executor.apply_filter(&extractions, &Some(filter));
575 assert_eq!(filtered.len(), 1);
576 assert_eq!(filtered[0].extraction_class, "requirement");
577 }
578}