Skip to main content

ai_agents_process/
processor.rs

1//! Process processor for executing input/output transformations
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Instant;
6
7use ai_agents_core::{AgentError, ChatMessage, LLMProvider, Result};
8use ai_agents_llm::LLMRegistry;
9
10use super::config::*;
11
12#[derive(Debug, Clone, Default)]
13pub struct ProcessData {
14    pub content: String,
15    pub original: String,
16    pub context: HashMap<String, serde_json::Value>,
17    pub metadata: ProcessMetadata,
18}
19
20#[derive(Debug, Clone, Default)]
21pub struct ProcessMetadata {
22    pub stages_executed: Vec<String>,
23    pub timing: HashMap<String, u64>,
24    pub warnings: Vec<String>,
25    pub rejected: bool,
26    pub rejection_reason: Option<String>,
27}
28
29impl ProcessData {
30    pub fn new(content: impl Into<String>) -> Self {
31        let content = content.into();
32        Self {
33            original: content.clone(),
34            content,
35            context: HashMap::new(),
36            metadata: ProcessMetadata::default(),
37        }
38    }
39
40    pub fn with_context(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
41        self.context.insert(key.into(), value);
42        self
43    }
44}
45
46#[derive(Debug)]
47pub struct ProcessProcessor {
48    config: ProcessConfig,
49    llm_registry: Option<Arc<LLMRegistry>>,
50}
51
52impl Default for ProcessProcessor {
53    fn default() -> Self {
54        Self::new(ProcessConfig::default())
55    }
56}
57
58impl ProcessProcessor {
59    pub fn new(config: ProcessConfig) -> Self {
60        Self {
61            config,
62            llm_registry: None,
63        }
64    }
65
66    pub fn with_llm_registry(mut self, registry: Arc<LLMRegistry>) -> Self {
67        self.llm_registry = Some(registry);
68        self
69    }
70
71    pub async fn process_input(&self, input: &str) -> Result<ProcessData> {
72        let mut data = ProcessData::new(input);
73
74        for stage in &self.config.input {
75            data = self.execute_stage(stage, data).await?;
76            if data.metadata.rejected {
77                break;
78            }
79        }
80
81        Ok(data)
82    }
83
84    pub async fn process_output(
85        &self,
86        output: &str,
87        input_context: &HashMap<String, serde_json::Value>,
88    ) -> Result<ProcessData> {
89        let mut data = ProcessData::new(output);
90        data.context = input_context.clone();
91
92        for stage in &self.config.output {
93            data = self.execute_stage(stage, data).await?;
94            if data.metadata.rejected {
95                break;
96            }
97        }
98
99        Ok(data)
100    }
101
102    fn execute_stage<'a>(
103        &'a self,
104        stage: &'a ProcessStage,
105        data: ProcessData,
106    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ProcessData>> + Send + 'a>> {
107        Box::pin(async move {
108            let start = Instant::now();
109            let stage_name = stage
110                .id()
111                .map(String::from)
112                .unwrap_or_else(|| self.get_stage_type_name(stage));
113
114            // Check condition before executing stage
115            if let Some(condition) = stage.condition() {
116                if !self.evaluate_condition_expr(condition, &data) {
117                    if self.config.settings.debug.log_stages {
118                        tracing::debug!(
119                            "[Process] Stage skipped (condition not met): {}",
120                            stage_name
121                        );
122                    }
123                    return Ok(data);
124                }
125            }
126
127            if self.config.settings.debug.log_stages {
128                tracing::debug!("[Process] Executing stage: {}", stage_name);
129            }
130
131            let data_clone = data.clone();
132            let result = match stage {
133                ProcessStage::Normalize(s) => self.execute_normalize(&s.config, data).await,
134                ProcessStage::Detect(s) => self.execute_detect(&s.config, data).await,
135                ProcessStage::Extract(s) => self.execute_extract(&s.config, data).await,
136                ProcessStage::Sanitize(s) => self.execute_sanitize(&s.config, data).await,
137                ProcessStage::Transform(s) => self.execute_transform(&s.config, data).await,
138                ProcessStage::Validate(s) => self.execute_validate(&s.config, data).await,
139                ProcessStage::Format(s) => self.execute_format(&s.config, data).await,
140                ProcessStage::Enrich(s) => self.execute_enrich(&s.config, data).await,
141                ProcessStage::Conditional(s) => self.execute_conditional(&s.config, data).await,
142            };
143
144            match result {
145                Ok(mut d) => {
146                    d.metadata.stages_executed.push(stage_name.clone());
147                    if self.config.settings.debug.include_timing {
148                        d.metadata
149                            .timing
150                            .insert(stage_name, start.elapsed().as_millis() as u64);
151                    }
152                    Ok(d)
153                }
154                Err(e) => {
155                    let mut fallback_data = data_clone;
156                    match self.config.settings.on_stage_error.default {
157                        StageErrorAction::Stop => Err(e),
158                        StageErrorAction::Continue => {
159                            fallback_data
160                                .metadata
161                                .warnings
162                                .push(format!("Stage {} failed: {}", stage_name, e));
163                            Ok(fallback_data)
164                        }
165                        StageErrorAction::Retry => {
166                            if let Some(retry_config) = &self.config.settings.on_stage_error.retry {
167                                for _ in 0..retry_config.max_retries {
168                                    tokio::time::sleep(std::time::Duration::from_millis(
169                                        retry_config.backoff_ms,
170                                    ))
171                                    .await;
172                                }
173                            }
174                            fallback_data
175                                .metadata
176                                .warnings
177                                .push(format!("Stage {} failed after retries: {}", stage_name, e));
178                            Ok(fallback_data)
179                        }
180                    }
181                }
182            }
183        })
184    }
185
186    fn get_stage_type_name(&self, stage: &ProcessStage) -> String {
187        match stage {
188            ProcessStage::Normalize(_) => "normalize".to_string(),
189            ProcessStage::Detect(_) => "detect".to_string(),
190            ProcessStage::Extract(_) => "extract".to_string(),
191            ProcessStage::Sanitize(_) => "sanitize".to_string(),
192            ProcessStage::Transform(_) => "transform".to_string(),
193            ProcessStage::Validate(_) => "validate".to_string(),
194            ProcessStage::Format(_) => "format".to_string(),
195            ProcessStage::Enrich(_) => "enrich".to_string(),
196            ProcessStage::Conditional(_) => "conditional".to_string(),
197        }
198    }
199
200    async fn execute_normalize(
201        &self,
202        config: &NormalizeConfig,
203        mut data: ProcessData,
204    ) -> Result<ProcessData> {
205        let mut content = data.content.clone();
206
207        if config.trim {
208            content = content.trim().to_string();
209        }
210
211        if config.collapse_whitespace {
212            content = content.split_whitespace().collect::<Vec<_>>().join(" ");
213        }
214
215        if config.lowercase {
216            content = content.to_lowercase();
217        }
218
219        // Unicode normalization would require unicode-normalization crate
220        // For now, we skip it as it's optional
221
222        data.content = content;
223        Ok(data)
224    }
225
226    async fn execute_detect(
227        &self,
228        config: &DetectConfig,
229        mut data: ProcessData,
230    ) -> Result<ProcessData> {
231        let llm = self.get_llm(config.llm.as_deref())?;
232
233        let detection_types: Vec<&str> = config
234            .detect
235            .iter()
236            .map(|d| match d {
237                DetectionType::Language => "language (ISO 639-1 code)",
238                DetectionType::Sentiment => "sentiment (positive, negative, neutral)",
239                DetectionType::Intent => "intent",
240                DetectionType::Topic => "topic",
241                DetectionType::Formality => "formality (formal, informal)",
242                DetectionType::Urgency => "urgency (low, medium, high, critical)",
243            })
244            .collect();
245
246        let intents_desc = if !config.intents.is_empty() {
247            let intents: Vec<String> = config
248                .intents
249                .iter()
250                .map(|i| format!("- {}: {}", i.id, i.description))
251                .collect();
252            format!("\n\nAvailable intents:\n{}", intents.join("\n"))
253        } else {
254            String::new()
255        };
256
257        let prompt = format!(
258            "Analyze the following text and detect: {}\n{}\n\n\
259             Respond with JSON only: {{\"language\": \"...\", \"sentiment\": \"...\", \"intent\": \"...\", ...}}\n\n\
260             Text: {}",
261            detection_types.join(", "),
262            intents_desc,
263            data.content
264        );
265
266        let messages = vec![ChatMessage::user(&prompt)];
267        let response = llm
268            .complete(&messages, None)
269            .await
270            .map_err(|e| AgentError::LLM(e.to_string()))?;
271
272        if let Ok(result) =
273            serde_json::from_str::<serde_json::Value>(&extract_json(&response.content))
274        {
275            for (key, context_path) in &config.store_in_context {
276                if let Some(value) = result.get(key) {
277                    data.context.insert(context_path.clone(), value.clone());
278                }
279            }
280            data.context.insert("detection".to_string(), result);
281        }
282
283        Ok(data)
284    }
285
286    async fn execute_extract(
287        &self,
288        config: &ExtractConfig,
289        mut data: ProcessData,
290    ) -> Result<ProcessData> {
291        let llm = self.get_llm(config.llm.as_deref())?;
292
293        let schema_desc: Vec<String> = config
294            .schema
295            .iter()
296            .map(|(name, schema)| {
297                let type_str = format!("{:?}", schema.field_type).to_lowercase();
298                let desc = schema.description.as_deref().unwrap_or("");
299                let values = if !schema.values.is_empty() {
300                    format!(" (values: {})", schema.values.join(", "))
301                } else {
302                    String::new()
303                };
304                let required = if schema.required { " [required]" } else { "" };
305                format!("- {}: {} - {}{}{}", name, type_str, desc, values, required)
306            })
307            .collect();
308
309        let prompt = format!(
310            "Extract the following fields from the text:\n{}\n\n\
311             Respond with JSON only. Use null for fields not found.\n\n\
312             Text: {}",
313            schema_desc.join("\n"),
314            data.content
315        );
316
317        let messages = vec![ChatMessage::user(&prompt)];
318        let response = llm
319            .complete(&messages, None)
320            .await
321            .map_err(|e| AgentError::LLM(e.to_string()))?;
322
323        if let Ok(result) =
324            serde_json::from_str::<serde_json::Value>(&extract_json(&response.content))
325        {
326            if let Some(context_path) = &config.store_in_context {
327                data.context.insert(context_path.clone(), result.clone());
328            }
329            data.context.insert("extracted".to_string(), result);
330        }
331
332        Ok(data)
333    }
334
335    async fn execute_sanitize(
336        &self,
337        config: &SanitizeConfig,
338        mut data: ProcessData,
339    ) -> Result<ProcessData> {
340        let llm = self.get_llm(config.llm.as_deref())?;
341
342        let mut instructions = Vec::new();
343
344        if let Some(pii_config) = &config.pii {
345            if !pii_config.types.is_empty() {
346                let pii_types: Vec<String> = pii_config
347                    .types
348                    .iter()
349                    .map(|t| format!("{:?}", t).to_lowercase())
350                    .collect();
351                let action = match pii_config.action {
352                    PIIAction::Mask => format!("replace with '{}'", pii_config.mask_char.repeat(4)),
353                    PIIAction::Remove => "remove completely".to_string(),
354                    PIIAction::Flag => "wrap with [PII: type]".to_string(),
355                };
356                instructions.push(format!("PII types to {}: {}", action, pii_types.join(", ")));
357            }
358        }
359
360        if let Some(harmful_config) = &config.harmful {
361            if !harmful_config.detect.is_empty() {
362                let types: Vec<String> = harmful_config
363                    .detect
364                    .iter()
365                    .map(|t| format!("{:?}", t).to_lowercase())
366                    .collect();
367                instructions.push(format!("Detect harmful content: {}", types.join(", ")));
368            }
369        }
370
371        if !config.remove.is_empty() {
372            instructions.push(format!(
373                "Remove any mentions of: {}",
374                config.remove.join(", ")
375            ));
376        }
377
378        if instructions.is_empty() {
379            return Ok(data);
380        }
381
382        let prompt = format!(
383            "Sanitize the following text according to these rules:\n{}\n\n\
384             Return only the sanitized text, nothing else.\n\n\
385             Text: {}",
386            instructions.join("\n"),
387            data.content
388        );
389
390        let messages = vec![ChatMessage::user(&prompt)];
391        let response = llm
392            .complete(&messages, None)
393            .await
394            .map_err(|e| AgentError::LLM(e.to_string()))?;
395
396        data.content = response.content.trim().to_string();
397        Ok(data)
398    }
399
400    async fn execute_transform(
401        &self,
402        config: &TransformConfig,
403        mut data: ProcessData,
404    ) -> Result<ProcessData> {
405        let prompt = match &config.prompt {
406            Some(p) => p.clone(),
407            None => return Ok(data),
408        };
409
410        let llm = self.get_llm(config.llm.as_deref())?;
411
412        let full_prompt = format!("{}\n\nOriginal text:\n{}", prompt, data.content);
413
414        let messages = vec![ChatMessage::user(&full_prompt)];
415        let response = llm
416            .complete(&messages, None)
417            .await
418            .map_err(|e| AgentError::LLM(e.to_string()))?;
419
420        data.content = response.content.trim().to_string();
421        Ok(data)
422    }
423
424    async fn execute_validate(
425        &self,
426        config: &ValidateConfig,
427        mut data: ProcessData,
428    ) -> Result<ProcessData> {
429        // Rule-based validation
430        for rule in &config.rules {
431            match rule {
432                ValidationRule::MinLength {
433                    min_length,
434                    on_fail,
435                } => {
436                    if data.content.len() < *min_length {
437                        match on_fail.action {
438                            ValidationActionType::Reject => {
439                                data.metadata.rejected = true;
440                                data.metadata.rejection_reason = Some(format!(
441                                    "Content too short: {} < {} characters",
442                                    data.content.len(),
443                                    min_length
444                                ));
445                                return Ok(data);
446                            }
447                            ValidationActionType::Warn => {
448                                data.metadata.warnings.push(format!(
449                                    "Content shorter than {} characters",
450                                    min_length
451                                ));
452                            }
453                            ValidationActionType::Truncate => {} // N/A for min_length
454                        }
455                    }
456                }
457                ValidationRule::MaxLength {
458                    max_length,
459                    on_fail,
460                } => {
461                    if data.content.len() > *max_length {
462                        match on_fail.action {
463                            ValidationActionType::Truncate => {
464                                data.content = data.content.chars().take(*max_length).collect();
465                            }
466                            ValidationActionType::Reject => {
467                                data.metadata.rejected = true;
468                                data.metadata.rejection_reason = Some(format!(
469                                    "Content too long: {} > {} characters",
470                                    data.content.len(),
471                                    max_length
472                                ));
473                                return Ok(data);
474                            }
475                            ValidationActionType::Warn => {
476                                data.metadata
477                                    .warnings
478                                    .push(format!("Content longer than {} characters", max_length));
479                            }
480                        }
481                    }
482                }
483                ValidationRule::Pattern { pattern, on_fail } => {
484                    if let Ok(re) = regex::Regex::new(pattern) {
485                        if !re.is_match(&data.content) {
486                            match on_fail.action {
487                                ValidationActionType::Reject => {
488                                    data.metadata.rejected = true;
489                                    data.metadata.rejection_reason =
490                                        Some("Content does not match required pattern".to_string());
491                                    return Ok(data);
492                                }
493                                ValidationActionType::Warn => {
494                                    data.metadata.warnings.push(
495                                        "Content does not match expected pattern".to_string(),
496                                    );
497                                }
498                                ValidationActionType::Truncate => {} // N/A for pattern
499                            }
500                        }
501                    }
502                }
503            }
504        }
505
506        // LLM-based validation
507        if !config.criteria.is_empty() {
508            let llm = self.get_llm(config.llm.as_deref())?;
509
510            let criteria_list = config
511                .criteria
512                .iter()
513                .enumerate()
514                .map(|(i, c)| format!("{}. {}", i + 1, c))
515                .collect::<Vec<_>>()
516                .join("\n");
517
518            let prompt = format!(
519                "Evaluate if the following content meets these criteria:\n{}\n\n\
520                 Respond with JSON: {{\"passes\": true/false, \"score\": 0.0-1.0, \"issues\": [\"...\"]}}\n\n\
521                 Content: {}",
522                criteria_list, data.content
523            );
524
525            let messages = vec![ChatMessage::user(&prompt)];
526            let response = llm
527                .complete(&messages, None)
528                .await
529                .map_err(|e| AgentError::LLM(e.to_string()))?;
530
531            if let Ok(result) =
532                serde_json::from_str::<serde_json::Value>(&extract_json(&response.content))
533            {
534                let score = result.get("score").and_then(|s| s.as_f64()).unwrap_or(1.0) as f32;
535                let passes = result
536                    .get("passes")
537                    .and_then(|p| p.as_bool())
538                    .unwrap_or(true);
539
540                if !passes || score < config.threshold {
541                    match config.on_fail.action {
542                        ValidationFailType::Reject => {
543                            data.metadata.rejected = true;
544                            let issues = result
545                                .get("issues")
546                                .and_then(|i| i.as_array())
547                                .map(|arr| {
548                                    arr.iter()
549                                        .filter_map(|v| v.as_str())
550                                        .collect::<Vec<_>>()
551                                        .join(", ")
552                                })
553                                .unwrap_or_else(|| "Validation failed".to_string());
554                            data.metadata.rejection_reason = Some(issues);
555                            return Ok(data);
556                        }
557                        ValidationFailType::Regenerate => {
558                            data.metadata
559                                .warnings
560                                .push("Content may need regeneration".to_string());
561                        }
562                        ValidationFailType::Warn => {
563                            if let Some(issues) = result.get("issues").and_then(|i| i.as_array()) {
564                                for issue in issues {
565                                    if let Some(s) = issue.as_str() {
566                                        data.metadata.warnings.push(s.to_string());
567                                    }
568                                }
569                            }
570                        }
571                    }
572                }
573            }
574        }
575
576        Ok(data)
577    }
578
579    async fn execute_format(
580        &self,
581        config: &FormatConfig,
582        mut data: ProcessData,
583    ) -> Result<ProcessData> {
584        let template = if let Some(channel) = &config.channel {
585            config
586                .channels
587                .get(channel)
588                .and_then(|c| c.template.as_ref())
589                .or(config.template.as_ref())
590        } else {
591            config.template.as_ref()
592        };
593
594        if let Some(tmpl) = template {
595            // Simple template substitution
596            let mut result = tmpl.clone();
597            result = result.replace("{{ response }}", &data.content);
598            result = result.replace("{{response}}", &data.content);
599
600            // Replace context variables
601            for (key, value) in &data.context {
602                let placeholder = format!("{{{{ context.{} }}}}", key);
603                let placeholder_no_space = format!("{{{{context.{}}}}}", key);
604                let value_str = match value {
605                    serde_json::Value::String(s) => s.clone(),
606                    _ => value.to_string(),
607                };
608                result = result.replace(&placeholder, &value_str);
609                result = result.replace(&placeholder_no_space, &value_str);
610            }
611
612            data.content = result;
613        }
614
615        // Apply channel-specific max_length
616        if let Some(channel) = &config.channel {
617            if let Some(channel_config) = config.channels.get(channel) {
618                if let Some(max_len) = channel_config.max_length {
619                    if data.content.len() > max_len {
620                        data.content = data.content.chars().take(max_len).collect();
621                    }
622                }
623            }
624        }
625
626        Ok(data)
627    }
628
629    async fn execute_enrich(
630        &self,
631        config: &EnrichConfig,
632        mut data: ProcessData,
633    ) -> Result<ProcessData> {
634        let result = match &config.source {
635            EnrichSource::None => return Ok(data),
636            EnrichSource::Api {
637                url,
638                method: _,
639                headers: _,
640                body: _,
641                extract: _,
642            } => {
643                // API enrichment would require HTTP client
644                // For now, add a warning
645                data.metadata
646                    .warnings
647                    .push(format!("API enrichment not yet implemented: {}", url));
648                return Ok(data);
649            }
650            EnrichSource::File { path, format } => {
651                // File enrichment
652                match std::fs::read_to_string(path) {
653                    Ok(content) => match format.as_deref() {
654                        Some("json") => serde_json::from_str(&content).ok(),
655                        Some("yaml") => serde_yaml::from_str(&content).ok(),
656                        _ => Some(serde_json::Value::String(content)),
657                    },
658                    Err(e) => match config.on_error {
659                        EnrichErrorAction::Stop => return Err(AgentError::IoError(e)),
660                        EnrichErrorAction::Continue | EnrichErrorAction::Warn => {
661                            data.metadata
662                                .warnings
663                                .push(format!("File read failed: {}", e));
664                            return Ok(data);
665                        }
666                    },
667                }
668            }
669            EnrichSource::Tool { tool, args: _ } => {
670                // Tool execution would need tool registry access
671                data.metadata
672                    .warnings
673                    .push(format!("Tool enrichment not yet implemented: {}", tool));
674                return Ok(data);
675            }
676        };
677
678        if let Some(value) = result {
679            if let Some(context_path) = &config.store_in_context {
680                data.context.insert(context_path.clone(), value);
681            }
682        }
683
684        Ok(data)
685    }
686
687    async fn execute_conditional(
688        &self,
689        config: &ConditionalConfig,
690        data: ProcessData,
691    ) -> Result<ProcessData> {
692        let condition_met = self.evaluate_condition(&config.condition, &data);
693
694        let stages = if condition_met {
695            &config.then_stages
696        } else {
697            &config.else_stages
698        };
699
700        let mut result = data;
701        for stage in stages {
702            result = self.execute_stage(stage, result).await?;
703            if result.metadata.rejected {
704                break;
705            }
706        }
707
708        Ok(result)
709    }
710
711    fn evaluate_condition(&self, condition: &Option<ConditionExpr>, data: &ProcessData) -> bool {
712        match condition {
713            None => true,
714            Some(expr) => self.evaluate_condition_expr(expr, data),
715        }
716    }
717
718    fn evaluate_condition_expr(&self, condition: &ConditionExpr, data: &ProcessData) -> bool {
719        match condition {
720            ConditionExpr::All { all } => all.iter().all(|c| self.evaluate_condition_expr(c, data)),
721            ConditionExpr::Any { any } => any.iter().any(|c| self.evaluate_condition_expr(c, data)),
722            ConditionExpr::Simple(map) => self.evaluate_simple_condition(map, data),
723        }
724    }
725
726    fn evaluate_simple_condition(
727        &self,
728        map: &std::collections::HashMap<String, serde_json::Value>,
729        data: &ProcessData,
730    ) -> bool {
731        for (path, expected) in map {
732            let actual = self.get_nested_value(&data.context, path);
733
734            // Handle { exists: true/false }
735            if let Some(obj) = expected.as_object() {
736                if let Some(exists_val) = obj.get("exists") {
737                    let should_exist = exists_val.as_bool().unwrap_or(true);
738                    let does_exist =
739                        actual.is_some() && !matches!(actual, Some(serde_json::Value::Null));
740                    if does_exist != should_exist {
741                        return false;
742                    }
743                    continue;
744                }
745            }
746
747            // Direct value comparison
748            match (actual, expected) {
749                (Some(a), e) if a == e => continue,
750                (None, serde_json::Value::Null) => continue,
751                _ => return false,
752            }
753        }
754        true
755    }
756
757    fn get_nested_value<'a>(
758        &self,
759        context: &'a std::collections::HashMap<String, serde_json::Value>,
760        path: &str,
761    ) -> Option<&'a serde_json::Value> {
762        let parts: Vec<&str> = path.split('.').collect();
763        if parts.is_empty() {
764            return None;
765        }
766
767        let mut current: Option<&serde_json::Value> = context.get(parts[0]);
768
769        for part in &parts[1..] {
770            current = current.and_then(|v| {
771                if let serde_json::Value::Object(obj) = v {
772                    obj.get(*part)
773                } else {
774                    None
775                }
776            });
777        }
778
779        current
780    }
781
782    fn get_llm(&self, alias: Option<&str>) -> Result<Arc<dyn LLMProvider>> {
783        let registry = self
784            .llm_registry
785            .as_ref()
786            .ok_or_else(|| AgentError::Config("LLM registry not configured for process".into()))?;
787
788        match alias {
789            Some(name) => registry
790                .get(name)
791                .map_err(|e| AgentError::LLM(e.to_string())),
792            None => registry
793                .router()
794                .or_else(|_| registry.default())
795                .map_err(|e| AgentError::LLM(e.to_string())),
796        }
797    }
798}
799
800fn extract_json(response: &str) -> String {
801    let trimmed = response.trim();
802
803    if trimmed.starts_with("```json") {
804        if let Some(end) = trimmed[7..].find("```") {
805            return trimmed[7..7 + end].trim().to_string();
806        }
807    }
808
809    if trimmed.starts_with("```") {
810        if let Some(end) = trimmed[3..].find("```") {
811            return trimmed[3..3 + end].trim().to_string();
812        }
813    }
814
815    if let Some(start) = trimmed.find('{') {
816        if let Some(end) = trimmed.rfind('}') {
817            return trimmed[start..=end].to_string();
818        }
819    }
820
821    trimmed.to_string()
822}
823
824#[cfg(test)]
825mod tests {
826    use super::*;
827
828    #[test]
829    fn test_process_data_new() {
830        let data = ProcessData::new("test content");
831        assert_eq!(data.content, "test content");
832        assert_eq!(data.original, "test content");
833        assert!(data.context.is_empty());
834    }
835
836    #[test]
837    fn test_process_data_with_context() {
838        let data = ProcessData::new("test").with_context("key", serde_json::json!("value"));
839        assert!(data.context.contains_key("key"));
840    }
841
842    #[tokio::test]
843    async fn test_normalize_trim() {
844        let processor = ProcessProcessor::default();
845        let config = NormalizeConfig {
846            trim: true,
847            ..Default::default()
848        };
849        let data = ProcessData::new("  hello world  ");
850        let result = processor.execute_normalize(&config, data).await.unwrap();
851        assert_eq!(result.content, "hello world");
852    }
853
854    #[tokio::test]
855    async fn test_normalize_collapse_whitespace() {
856        let processor = ProcessProcessor::default();
857        let config = NormalizeConfig {
858            trim: true,
859            collapse_whitespace: true,
860            ..Default::default()
861        };
862        let data = ProcessData::new("hello    world\n\ntest");
863        let result = processor.execute_normalize(&config, data).await.unwrap();
864        assert_eq!(result.content, "hello world test");
865    }
866
867    #[tokio::test]
868    async fn test_normalize_lowercase() {
869        let processor = ProcessProcessor::default();
870        let config = NormalizeConfig {
871            lowercase: true,
872            ..Default::default()
873        };
874        let data = ProcessData::new("Hello World");
875        let result = processor.execute_normalize(&config, data).await.unwrap();
876        assert_eq!(result.content, "hello world");
877    }
878
879    #[tokio::test]
880    async fn test_validate_min_length_reject() {
881        let processor = ProcessProcessor::default();
882        let config = ValidateConfig {
883            rules: vec![ValidationRule::MinLength {
884                min_length: 10,
885                on_fail: ValidationAction {
886                    action: ValidationActionType::Reject,
887                    message: None,
888                },
889            }],
890            ..Default::default()
891        };
892        let data = ProcessData::new("short");
893        let result = processor.execute_validate(&config, data).await.unwrap();
894        assert!(result.metadata.rejected);
895    }
896
897    #[tokio::test]
898    async fn test_validate_max_length_truncate() {
899        let processor = ProcessProcessor::default();
900        let config = ValidateConfig {
901            rules: vec![ValidationRule::MaxLength {
902                max_length: 5,
903                on_fail: ValidationAction {
904                    action: ValidationActionType::Truncate,
905                    message: None,
906                },
907            }],
908            ..Default::default()
909        };
910        let data = ProcessData::new("hello world");
911        let result = processor.execute_validate(&config, data).await.unwrap();
912        assert_eq!(result.content, "hello");
913        assert!(!result.metadata.rejected);
914    }
915
916    #[tokio::test]
917    async fn test_format_simple_template() {
918        let processor = ProcessProcessor::default();
919        let config = FormatConfig {
920            template: Some("Response: {{ response }}".to_string()),
921            ..Default::default()
922        };
923        let data = ProcessData::new("Hello!");
924        let result = processor.execute_format(&config, data).await.unwrap();
925        assert_eq!(result.content, "Response: Hello!");
926    }
927
928    #[test]
929    fn test_extract_json() {
930        assert_eq!(extract_json(r#"{"key": 1}"#), r#"{"key": 1}"#);
931        assert_eq!(extract_json("```json\n{\"key\": 1}\n```"), r#"{"key": 1}"#);
932        assert_eq!(extract_json("Some text {\"key\": 1} more"), r#"{"key": 1}"#);
933    }
934
935    #[test]
936    fn test_evaluate_condition_empty() {
937        let processor = ProcessProcessor::default();
938        let data = ProcessData::new("test");
939        assert!(processor.evaluate_condition(&None, &data));
940    }
941
942    #[test]
943    fn test_evaluate_condition_simple_exists_true() {
944        let processor = ProcessProcessor::default();
945        let mut data = ProcessData::new("test");
946        data.context.insert(
947            "session".to_string(),
948            serde_json::json!({ "user_name": "Alice" }),
949        );
950
951        let mut map = std::collections::HashMap::new();
952        map.insert(
953            "session.user_name".to_string(),
954            serde_json::json!({ "exists": true }),
955        );
956        let condition = ConditionExpr::Simple(map);
957
958        assert!(processor.evaluate_condition_expr(&condition, &data));
959    }
960
961    #[test]
962    fn test_evaluate_condition_simple_exists_false() {
963        let processor = ProcessProcessor::default();
964        let data = ProcessData::new("test");
965
966        let mut map = std::collections::HashMap::new();
967        map.insert(
968            "session.user_name".to_string(),
969            serde_json::json!({ "exists": false }),
970        );
971        let condition = ConditionExpr::Simple(map);
972
973        assert!(processor.evaluate_condition_expr(&condition, &data));
974    }
975
976    #[test]
977    fn test_evaluate_condition_all() {
978        let processor = ProcessProcessor::default();
979        let mut data = ProcessData::new("test");
980        data.context.insert(
981            "session".to_string(),
982            serde_json::json!({ "user_name": "Alice", "language": "en" }),
983        );
984
985        let mut map1 = std::collections::HashMap::new();
986        map1.insert(
987            "session.user_name".to_string(),
988            serde_json::json!({ "exists": true }),
989        );
990        let mut map2 = std::collections::HashMap::new();
991        map2.insert(
992            "session.language".to_string(),
993            serde_json::json!({ "exists": true }),
994        );
995
996        let condition = ConditionExpr::All {
997            all: vec![ConditionExpr::Simple(map1), ConditionExpr::Simple(map2)],
998        };
999
1000        assert!(processor.evaluate_condition_expr(&condition, &data));
1001    }
1002
1003    #[test]
1004    fn test_evaluate_condition_any() {
1005        let processor = ProcessProcessor::default();
1006        let mut data = ProcessData::new("test");
1007        data.context.insert(
1008            "session".to_string(),
1009            serde_json::json!({ "tier": "premium" }),
1010        );
1011
1012        let mut map1 = std::collections::HashMap::new();
1013        map1.insert("session.tier".to_string(), serde_json::json!("premium"));
1014        let mut map2 = std::collections::HashMap::new();
1015        map2.insert("session.tier".to_string(), serde_json::json!("enterprise"));
1016
1017        let condition = ConditionExpr::Any {
1018            any: vec![ConditionExpr::Simple(map1), ConditionExpr::Simple(map2)],
1019        };
1020
1021        assert!(processor.evaluate_condition_expr(&condition, &data));
1022    }
1023
1024    #[test]
1025    fn test_evaluate_condition_value_match() {
1026        let processor = ProcessProcessor::default();
1027        let mut data = ProcessData::new("test");
1028        data.context.insert(
1029            "input".to_string(),
1030            serde_json::json!({ "sentiment": "negative" }),
1031        );
1032
1033        let mut map = std::collections::HashMap::new();
1034        map.insert("input.sentiment".to_string(), serde_json::json!("negative"));
1035        let condition = ConditionExpr::Simple(map);
1036
1037        assert!(processor.evaluate_condition_expr(&condition, &data));
1038    }
1039
1040    #[test]
1041    fn test_get_nested_value() {
1042        let processor = ProcessProcessor::default();
1043        let mut context = std::collections::HashMap::new();
1044        context.insert(
1045            "session".to_string(),
1046            serde_json::json!({ "user": { "name": "Alice" } }),
1047        );
1048
1049        let result = processor.get_nested_value(&context, "session.user.name");
1050        assert_eq!(result, Some(&serde_json::json!("Alice")));
1051
1052        let result = processor.get_nested_value(&context, "session.nonexistent");
1053        assert!(result.is_none());
1054    }
1055
1056    //
1057    // LLM-based stage tests (detect, extract, sanitize, transform, validate)
1058    //
1059    fn create_mock_registry(response: &str) -> Arc<ai_agents_llm::LLMRegistry> {
1060        use ai_agents_llm::mock::MockLLMProvider;
1061        let mut mock = MockLLMProvider::new("test");
1062        mock.set_response(response);
1063        let mut registry = ai_agents_llm::LLMRegistry::new();
1064        registry.register("default", std::sync::Arc::new(mock));
1065        registry.set_default("default");
1066        std::sync::Arc::new(registry)
1067    }
1068
1069    fn create_mock_registry_multi(responses: Vec<&str>) -> Arc<ai_agents_llm::LLMRegistry> {
1070        use ai_agents_llm::mock::MockLLMProvider;
1071        let mut mock = MockLLMProvider::new("test");
1072        mock.set_responses(responses.into_iter().map(String::from).collect(), true);
1073        let mut registry = ai_agents_llm::LLMRegistry::new();
1074        registry.register("default", std::sync::Arc::new(mock));
1075        registry.set_default("default");
1076        std::sync::Arc::new(registry)
1077    }
1078
1079    #[tokio::test]
1080    async fn test_detect_stage_language_sentiment() {
1081        let registry = create_mock_registry(
1082            r#"{"language": "ko", "sentiment": "positive", "intent": "greeting"}"#,
1083        );
1084        let config = ProcessConfig {
1085            input: vec![ProcessStage::Detect(DetectStage {
1086                id: Some("detect_test".to_string()),
1087                condition: None,
1088                config: DetectConfig {
1089                    llm: None,
1090                    detect: vec![DetectionType::Language, DetectionType::Sentiment],
1091                    intents: vec![IntentDefinition {
1092                        id: "greeting".to_string(),
1093                        description: "User says hello".to_string(),
1094                    }],
1095                    store_in_context: {
1096                        let mut m = std::collections::HashMap::new();
1097                        m.insert("language".to_string(), "input.language".to_string());
1098                        m.insert("sentiment".to_string(), "input.sentiment".to_string());
1099                        m
1100                    },
1101                },
1102            })],
1103            ..Default::default()
1104        };
1105        let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1106        let result = processor.process_input("안녕하세요!").await.unwrap();
1107
1108        assert_eq!(
1109            result.context.get("input.language"),
1110            Some(&serde_json::json!("ko"))
1111        );
1112        assert_eq!(
1113            result.context.get("input.sentiment"),
1114            Some(&serde_json::json!("positive"))
1115        );
1116        assert!(
1117            result
1118                .metadata
1119                .stages_executed
1120                .contains(&"detect_test".to_string())
1121        );
1122    }
1123
1124    #[tokio::test]
1125    async fn test_extract_stage_entities() {
1126        let registry = create_mock_registry(r#"{"order_number": "ORD-12345", "urgency": "high"}"#);
1127        let config = ProcessConfig {
1128            input: vec![ProcessStage::Extract(ExtractStage {
1129                id: Some("extract_test".to_string()),
1130                condition: None,
1131                config: ExtractConfig {
1132                    llm: None,
1133                    schema: {
1134                        let mut m = std::collections::HashMap::new();
1135                        m.insert(
1136                            "order_number".to_string(),
1137                            FieldSchema {
1138                                field_type: FieldType::String,
1139                                description: Some("Order number".to_string()),
1140                                required: true,
1141                                values: vec![],
1142                            },
1143                        );
1144                        m.insert(
1145                            "urgency".to_string(),
1146                            FieldSchema {
1147                                field_type: FieldType::Enum,
1148                                description: Some("Urgency level".to_string()),
1149                                required: false,
1150                                values: vec![
1151                                    "low".to_string(),
1152                                    "medium".to_string(),
1153                                    "high".to_string(),
1154                                ],
1155                            },
1156                        );
1157                        m
1158                    },
1159                    store_in_context: Some("extracted".to_string()),
1160                },
1161            })],
1162            ..Default::default()
1163        };
1164        let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1165        let result = processor
1166            .process_input("My order ORD-12345 is urgent!")
1167            .await
1168            .unwrap();
1169
1170        let extracted = result.context.get("extracted").unwrap();
1171        assert_eq!(extracted["order_number"], "ORD-12345");
1172        assert_eq!(extracted["urgency"], "high");
1173    }
1174
1175    #[tokio::test]
1176    async fn test_sanitize_stage_pii_masking() {
1177        let registry = create_mock_registry("Call me at ****-****-**** or email at ****@****.com");
1178        let config = ProcessConfig {
1179            input: vec![ProcessStage::Sanitize(SanitizeStage {
1180                id: Some("sanitize_test".to_string()),
1181                condition: None,
1182                config: SanitizeConfig {
1183                    llm: None,
1184                    pii: Some(PIISanitizeConfig {
1185                        action: PIIAction::Mask,
1186                        types: vec![PIIType::Phone, PIIType::Email],
1187                        mask_char: "*".to_string(),
1188                    }),
1189                    harmful: None,
1190                    remove: vec![],
1191                },
1192            })],
1193            ..Default::default()
1194        };
1195        let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1196        let result = processor
1197            .process_input("Call me at 010-1234-5678 or email at user@example.com")
1198            .await
1199            .unwrap();
1200
1201        // LLM returns sanitized text
1202        assert!(result.content.contains("****"));
1203        assert!(!result.content.contains("010-1234-5678"));
1204        assert!(!result.content.contains("user@example.com"));
1205    }
1206
1207    #[tokio::test]
1208    async fn test_transform_stage_tone_adjustment() {
1209        let registry = create_mock_registry(
1210            "I understand your frustration. Let me help you resolve this issue right away.",
1211        );
1212        let config = ProcessConfig {
1213            output: vec![ProcessStage::Transform(TransformStage {
1214                id: Some("tone_test".to_string()),
1215                condition: None,
1216                config: TransformConfig {
1217                    llm: None,
1218                    prompt: Some("Rewrite to be more empathetic.".to_string()),
1219                    max_output_tokens: None,
1220                },
1221            })],
1222            ..Default::default()
1223        };
1224        let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1225
1226        let input_context = std::collections::HashMap::new();
1227        let result = processor
1228            .process_output("Your request is being processed.", &input_context)
1229            .await
1230            .unwrap();
1231
1232        assert!(result.content.contains("understand"));
1233    }
1234
1235    #[tokio::test]
1236    async fn test_validate_stage_llm_criteria() {
1237        let registry = create_mock_registry(
1238            r#"{"passes": false, "score": 0.3, "issues": ["Response is too vague"]}"#,
1239        );
1240        let config = ProcessConfig {
1241            output: vec![ProcessStage::Validate(ValidateStage {
1242                id: Some("quality_test".to_string()),
1243                condition: None,
1244                config: ValidateConfig {
1245                    rules: vec![],
1246                    llm: None,
1247                    criteria: vec!["Response is specific and actionable".to_string()],
1248                    threshold: 0.7,
1249                    on_fail: ValidationFailAction {
1250                        action: ValidationFailType::Warn,
1251                        ..Default::default()
1252                    },
1253                },
1254            })],
1255            ..Default::default()
1256        };
1257        let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1258
1259        let input_context = std::collections::HashMap::new();
1260        let result = processor
1261            .process_output("It depends.", &input_context)
1262            .await
1263            .unwrap();
1264
1265        // Should have a warning because score (0.3) < threshold (0.7)
1266        assert!(
1267            result.metadata.warnings.iter().any(|w| w.contains("vague")),
1268            "Expected warning about vague response, got: {:?}",
1269            result.metadata.warnings
1270        );
1271    }
1272
1273    #[tokio::test]
1274    async fn test_validate_stage_llm_criteria_reject() {
1275        let registry = create_mock_registry(
1276            r#"{"passes": false, "score": 0.2, "issues": ["Contains harmful content"]}"#,
1277        );
1278        let config = ProcessConfig {
1279            output: vec![ProcessStage::Validate(ValidateStage {
1280                id: Some("reject_test".to_string()),
1281                condition: None,
1282                config: ValidateConfig {
1283                    rules: vec![],
1284                    llm: None,
1285                    criteria: vec!["Response is safe".to_string()],
1286                    threshold: 0.7,
1287                    on_fail: ValidationFailAction {
1288                        action: ValidationFailType::Reject,
1289                        ..Default::default()
1290                    },
1291                },
1292            })],
1293            ..Default::default()
1294        };
1295        let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1296
1297        let input_context = std::collections::HashMap::new();
1298        let result = processor
1299            .process_output("Dangerous content here.", &input_context)
1300            .await
1301            .unwrap();
1302
1303        assert!(result.metadata.rejected);
1304        assert!(
1305            result
1306                .metadata
1307                .rejection_reason
1308                .as_ref()
1309                .unwrap()
1310                .contains("harmful")
1311        );
1312    }
1313
1314    #[tokio::test]
1315    async fn test_full_input_pipeline_chain() {
1316        // normalize → detect → extract pipeline
1317        let registry = create_mock_registry_multi(vec![
1318            // detect response
1319            r#"{"language": "en", "sentiment": "neutral"}"#,
1320            // extract response
1321            r#"{"user_name": "Alice", "topic": "billing"}"#,
1322        ]);
1323        let config = ProcessConfig {
1324            input: vec![
1325                ProcessStage::Normalize(NormalizeStage {
1326                    id: Some("norm".to_string()),
1327                    condition: None,
1328                    config: NormalizeConfig {
1329                        trim: true,
1330                        collapse_whitespace: true,
1331                        ..Default::default()
1332                    },
1333                }),
1334                ProcessStage::Detect(DetectStage {
1335                    id: Some("detect".to_string()),
1336                    condition: None,
1337                    config: DetectConfig {
1338                        llm: None,
1339                        detect: vec![DetectionType::Language, DetectionType::Sentiment],
1340                        intents: vec![],
1341                        store_in_context: {
1342                            let mut m = std::collections::HashMap::new();
1343                            m.insert("language".to_string(), "input.language".to_string());
1344                            m
1345                        },
1346                    },
1347                }),
1348                ProcessStage::Extract(ExtractStage {
1349                    id: Some("extract".to_string()),
1350                    condition: None,
1351                    config: ExtractConfig {
1352                        llm: None,
1353                        schema: {
1354                            let mut m = std::collections::HashMap::new();
1355                            m.insert(
1356                                "user_name".to_string(),
1357                                FieldSchema {
1358                                    field_type: FieldType::String,
1359                                    description: Some("User name".to_string()),
1360                                    ..Default::default()
1361                                },
1362                            );
1363                            m
1364                        },
1365                        store_in_context: Some("entities".to_string()),
1366                    },
1367                }),
1368            ],
1369            ..Default::default()
1370        };
1371        let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1372        let result = processor
1373            .process_input("  Hi, I'm   Alice and I have a billing question  ")
1374            .await
1375            .unwrap();
1376
1377        // Verify normalize ran
1378        assert_eq!(
1379            result.content,
1380            "Hi, I'm Alice and I have a billing question"
1381        );
1382
1383        // Verify detect stored context
1384        assert_eq!(
1385            result.context.get("input.language"),
1386            Some(&serde_json::json!("en"))
1387        );
1388
1389        // Verify extract stored context
1390        let entities = result.context.get("entities").unwrap();
1391        assert_eq!(entities["user_name"], "Alice");
1392
1393        // Verify all stages executed in order
1394        assert_eq!(
1395            result.metadata.stages_executed,
1396            vec!["norm", "detect", "extract"]
1397        );
1398    }
1399
1400    #[tokio::test]
1401    async fn test_conditional_stage_skips_on_false() {
1402        let registry = create_mock_registry(r#"{"language": "en"}"#);
1403        let config = ProcessConfig {
1404            input: vec![ProcessStage::Detect(DetectStage {
1405                id: Some("should_skip".to_string()),
1406                condition: Some(ConditionExpr::Simple({
1407                    let mut map = std::collections::HashMap::new();
1408                    map.insert("needs_detection".to_string(), serde_json::json!(true));
1409                    map
1410                })),
1411                config: DetectConfig {
1412                    llm: None,
1413                    detect: vec![DetectionType::Language],
1414                    ..Default::default()
1415                },
1416            })],
1417            ..Default::default()
1418        };
1419        let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1420        let result = processor.process_input("Hello").await.unwrap();
1421
1422        // Stage should be skipped because "needs_detection" is not in context
1423        assert!(
1424            !result
1425                .metadata
1426                .stages_executed
1427                .contains(&"should_skip".to_string()),
1428            "Stage should have been skipped"
1429        );
1430    }
1431
1432    #[tokio::test]
1433    async fn test_stage_skipped_when_condition_false() {
1434        let config = ProcessConfig {
1435            input: vec![ProcessStage::Extract(ExtractStage {
1436                id: Some("skip_me".to_string()),
1437                condition: Some(ConditionExpr::Simple({
1438                    let mut map = std::collections::HashMap::new();
1439                    map.insert(
1440                        "session.user".to_string(),
1441                        serde_json::json!({ "exists": false }),
1442                    );
1443                    map
1444                })),
1445                config: ExtractConfig::default(),
1446            })],
1447            settings: ProcessSettings {
1448                debug: ProcessDebugConfig {
1449                    log_stages: true,
1450                    ..Default::default()
1451                },
1452                ..Default::default()
1453            },
1454            ..Default::default()
1455        };
1456        let processor = ProcessProcessor::new(config);
1457
1458        let mut data = ProcessData::new("test");
1459        data.context.insert(
1460            "session".to_string(),
1461            serde_json::json!({ "user": "Alice" }),
1462        );
1463
1464        let result = processor.process_input("test").await.unwrap();
1465        assert!(
1466            !result
1467                .metadata
1468                .stages_executed
1469                .contains(&"skip_me".to_string())
1470        );
1471    }
1472}