1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Instant;
6
7use ai_agents_core::{AgentError, ChatMessage, LLMProvider, Result};
8use ai_agents_llm::LLMRegistry;
9
10use super::config::*;
11
12#[derive(Debug, Clone, Default)]
13pub struct ProcessData {
14 pub content: String,
15 pub original: String,
16 pub context: HashMap<String, serde_json::Value>,
17 pub metadata: ProcessMetadata,
18}
19
20#[derive(Debug, Clone, Default)]
21pub struct ProcessMetadata {
22 pub stages_executed: Vec<String>,
23 pub timing: HashMap<String, u64>,
24 pub warnings: Vec<String>,
25 pub rejected: bool,
26 pub rejection_reason: Option<String>,
27}
28
29impl ProcessData {
30 pub fn new(content: impl Into<String>) -> Self {
31 let content = content.into();
32 Self {
33 original: content.clone(),
34 content,
35 context: HashMap::new(),
36 metadata: ProcessMetadata::default(),
37 }
38 }
39
40 pub fn with_context(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
41 self.context.insert(key.into(), value);
42 self
43 }
44}
45
46#[derive(Debug)]
47pub struct ProcessProcessor {
48 config: ProcessConfig,
49 llm_registry: Option<Arc<LLMRegistry>>,
50}
51
52impl Default for ProcessProcessor {
53 fn default() -> Self {
54 Self::new(ProcessConfig::default())
55 }
56}
57
58impl ProcessProcessor {
59 pub fn new(config: ProcessConfig) -> Self {
60 Self {
61 config,
62 llm_registry: None,
63 }
64 }
65
66 pub fn with_llm_registry(mut self, registry: Arc<LLMRegistry>) -> Self {
67 self.llm_registry = Some(registry);
68 self
69 }
70
71 pub async fn process_input(&self, input: &str) -> Result<ProcessData> {
72 let mut data = ProcessData::new(input);
73
74 for stage in &self.config.input {
75 data = self.execute_stage(stage, data).await?;
76 if data.metadata.rejected {
77 break;
78 }
79 }
80
81 Ok(data)
82 }
83
84 pub async fn process_output(
85 &self,
86 output: &str,
87 input_context: &HashMap<String, serde_json::Value>,
88 ) -> Result<ProcessData> {
89 let mut data = ProcessData::new(output);
90 data.context = input_context.clone();
91
92 for stage in &self.config.output {
93 data = self.execute_stage(stage, data).await?;
94 if data.metadata.rejected {
95 break;
96 }
97 }
98
99 Ok(data)
100 }
101
102 fn execute_stage<'a>(
103 &'a self,
104 stage: &'a ProcessStage,
105 data: ProcessData,
106 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ProcessData>> + Send + 'a>> {
107 Box::pin(async move {
108 let start = Instant::now();
109 let stage_name = stage
110 .id()
111 .map(String::from)
112 .unwrap_or_else(|| self.get_stage_type_name(stage));
113
114 if let Some(condition) = stage.condition() {
116 if !self.evaluate_condition_expr(condition, &data) {
117 if self.config.settings.debug.log_stages {
118 tracing::debug!(
119 "[Process] Stage skipped (condition not met): {}",
120 stage_name
121 );
122 }
123 return Ok(data);
124 }
125 }
126
127 if self.config.settings.debug.log_stages {
128 tracing::debug!("[Process] Executing stage: {}", stage_name);
129 }
130
131 let data_clone = data.clone();
132 let result = match stage {
133 ProcessStage::Normalize(s) => self.execute_normalize(&s.config, data).await,
134 ProcessStage::Detect(s) => self.execute_detect(&s.config, data).await,
135 ProcessStage::Extract(s) => self.execute_extract(&s.config, data).await,
136 ProcessStage::Sanitize(s) => self.execute_sanitize(&s.config, data).await,
137 ProcessStage::Transform(s) => self.execute_transform(&s.config, data).await,
138 ProcessStage::Validate(s) => self.execute_validate(&s.config, data).await,
139 ProcessStage::Format(s) => self.execute_format(&s.config, data).await,
140 ProcessStage::Enrich(s) => self.execute_enrich(&s.config, data).await,
141 ProcessStage::Conditional(s) => self.execute_conditional(&s.config, data).await,
142 };
143
144 match result {
145 Ok(mut d) => {
146 d.metadata.stages_executed.push(stage_name.clone());
147 if self.config.settings.debug.include_timing {
148 d.metadata
149 .timing
150 .insert(stage_name, start.elapsed().as_millis() as u64);
151 }
152 Ok(d)
153 }
154 Err(e) => {
155 let mut fallback_data = data_clone;
156 match self.config.settings.on_stage_error.default {
157 StageErrorAction::Stop => Err(e),
158 StageErrorAction::Continue => {
159 fallback_data
160 .metadata
161 .warnings
162 .push(format!("Stage {} failed: {}", stage_name, e));
163 Ok(fallback_data)
164 }
165 StageErrorAction::Retry => {
166 if let Some(retry_config) = &self.config.settings.on_stage_error.retry {
167 for _ in 0..retry_config.max_retries {
168 tokio::time::sleep(std::time::Duration::from_millis(
169 retry_config.backoff_ms,
170 ))
171 .await;
172 }
173 }
174 fallback_data
175 .metadata
176 .warnings
177 .push(format!("Stage {} failed after retries: {}", stage_name, e));
178 Ok(fallback_data)
179 }
180 }
181 }
182 }
183 })
184 }
185
186 fn get_stage_type_name(&self, stage: &ProcessStage) -> String {
187 match stage {
188 ProcessStage::Normalize(_) => "normalize".to_string(),
189 ProcessStage::Detect(_) => "detect".to_string(),
190 ProcessStage::Extract(_) => "extract".to_string(),
191 ProcessStage::Sanitize(_) => "sanitize".to_string(),
192 ProcessStage::Transform(_) => "transform".to_string(),
193 ProcessStage::Validate(_) => "validate".to_string(),
194 ProcessStage::Format(_) => "format".to_string(),
195 ProcessStage::Enrich(_) => "enrich".to_string(),
196 ProcessStage::Conditional(_) => "conditional".to_string(),
197 }
198 }
199
200 async fn execute_normalize(
201 &self,
202 config: &NormalizeConfig,
203 mut data: ProcessData,
204 ) -> Result<ProcessData> {
205 let mut content = data.content.clone();
206
207 if config.trim {
208 content = content.trim().to_string();
209 }
210
211 if config.collapse_whitespace {
212 content = content.split_whitespace().collect::<Vec<_>>().join(" ");
213 }
214
215 if config.lowercase {
216 content = content.to_lowercase();
217 }
218
219 data.content = content;
223 Ok(data)
224 }
225
226 async fn execute_detect(
227 &self,
228 config: &DetectConfig,
229 mut data: ProcessData,
230 ) -> Result<ProcessData> {
231 let llm = self.get_llm(config.llm.as_deref())?;
232
233 let detection_types: Vec<&str> = config
234 .detect
235 .iter()
236 .map(|d| match d {
237 DetectionType::Language => "language (ISO 639-1 code)",
238 DetectionType::Sentiment => "sentiment (positive, negative, neutral)",
239 DetectionType::Intent => "intent",
240 DetectionType::Topic => "topic",
241 DetectionType::Formality => "formality (formal, informal)",
242 DetectionType::Urgency => "urgency (low, medium, high, critical)",
243 })
244 .collect();
245
246 let intents_desc = if !config.intents.is_empty() {
247 let intents: Vec<String> = config
248 .intents
249 .iter()
250 .map(|i| format!("- {}: {}", i.id, i.description))
251 .collect();
252 format!("\n\nAvailable intents:\n{}", intents.join("\n"))
253 } else {
254 String::new()
255 };
256
257 let prompt = format!(
258 "Analyze the following text and detect: {}\n{}\n\n\
259 Respond with JSON only: {{\"language\": \"...\", \"sentiment\": \"...\", \"intent\": \"...\", ...}}\n\n\
260 Text: {}",
261 detection_types.join(", "),
262 intents_desc,
263 data.content
264 );
265
266 let messages = vec![ChatMessage::user(&prompt)];
267 let response = llm
268 .complete(&messages, None)
269 .await
270 .map_err(|e| AgentError::LLM(e.to_string()))?;
271
272 if let Ok(result) =
273 serde_json::from_str::<serde_json::Value>(&extract_json(&response.content))
274 {
275 for (key, context_path) in &config.store_in_context {
276 if let Some(value) = result.get(key) {
277 data.context.insert(context_path.clone(), value.clone());
278 }
279 }
280 data.context.insert("detection".to_string(), result);
281 }
282
283 Ok(data)
284 }
285
286 async fn execute_extract(
287 &self,
288 config: &ExtractConfig,
289 mut data: ProcessData,
290 ) -> Result<ProcessData> {
291 let llm = self.get_llm(config.llm.as_deref())?;
292
293 let schema_desc: Vec<String> = config
294 .schema
295 .iter()
296 .map(|(name, schema)| {
297 let type_str = format!("{:?}", schema.field_type).to_lowercase();
298 let desc = schema.description.as_deref().unwrap_or("");
299 let values = if !schema.values.is_empty() {
300 format!(" (values: {})", schema.values.join(", "))
301 } else {
302 String::new()
303 };
304 let required = if schema.required { " [required]" } else { "" };
305 format!("- {}: {} - {}{}{}", name, type_str, desc, values, required)
306 })
307 .collect();
308
309 let prompt = format!(
310 "Extract the following fields from the text:\n{}\n\n\
311 Respond with JSON only. Use null for fields not found.\n\n\
312 Text: {}",
313 schema_desc.join("\n"),
314 data.content
315 );
316
317 let messages = vec![ChatMessage::user(&prompt)];
318 let response = llm
319 .complete(&messages, None)
320 .await
321 .map_err(|e| AgentError::LLM(e.to_string()))?;
322
323 if let Ok(result) =
324 serde_json::from_str::<serde_json::Value>(&extract_json(&response.content))
325 {
326 if let Some(context_path) = &config.store_in_context {
327 data.context.insert(context_path.clone(), result.clone());
328 }
329 data.context.insert("extracted".to_string(), result);
330 }
331
332 Ok(data)
333 }
334
335 async fn execute_sanitize(
336 &self,
337 config: &SanitizeConfig,
338 mut data: ProcessData,
339 ) -> Result<ProcessData> {
340 let llm = self.get_llm(config.llm.as_deref())?;
341
342 let mut instructions = Vec::new();
343
344 if let Some(pii_config) = &config.pii {
345 if !pii_config.types.is_empty() {
346 let pii_types: Vec<String> = pii_config
347 .types
348 .iter()
349 .map(|t| format!("{:?}", t).to_lowercase())
350 .collect();
351 let action = match pii_config.action {
352 PIIAction::Mask => format!("replace with '{}'", pii_config.mask_char.repeat(4)),
353 PIIAction::Remove => "remove completely".to_string(),
354 PIIAction::Flag => "wrap with [PII: type]".to_string(),
355 };
356 instructions.push(format!("PII types to {}: {}", action, pii_types.join(", ")));
357 }
358 }
359
360 if let Some(harmful_config) = &config.harmful {
361 if !harmful_config.detect.is_empty() {
362 let types: Vec<String> = harmful_config
363 .detect
364 .iter()
365 .map(|t| format!("{:?}", t).to_lowercase())
366 .collect();
367 instructions.push(format!("Detect harmful content: {}", types.join(", ")));
368 }
369 }
370
371 if !config.remove.is_empty() {
372 instructions.push(format!(
373 "Remove any mentions of: {}",
374 config.remove.join(", ")
375 ));
376 }
377
378 if instructions.is_empty() {
379 return Ok(data);
380 }
381
382 let prompt = format!(
383 "Sanitize the following text according to these rules:\n{}\n\n\
384 Return only the sanitized text, nothing else.\n\n\
385 Text: {}",
386 instructions.join("\n"),
387 data.content
388 );
389
390 let messages = vec![ChatMessage::user(&prompt)];
391 let response = llm
392 .complete(&messages, None)
393 .await
394 .map_err(|e| AgentError::LLM(e.to_string()))?;
395
396 data.content = response.content.trim().to_string();
397 Ok(data)
398 }
399
400 async fn execute_transform(
401 &self,
402 config: &TransformConfig,
403 mut data: ProcessData,
404 ) -> Result<ProcessData> {
405 let prompt = match &config.prompt {
406 Some(p) => p.clone(),
407 None => return Ok(data),
408 };
409
410 let llm = self.get_llm(config.llm.as_deref())?;
411
412 let full_prompt = format!("{}\n\nOriginal text:\n{}", prompt, data.content);
413
414 let messages = vec![ChatMessage::user(&full_prompt)];
415 let response = llm
416 .complete(&messages, None)
417 .await
418 .map_err(|e| AgentError::LLM(e.to_string()))?;
419
420 data.content = response.content.trim().to_string();
421 Ok(data)
422 }
423
424 async fn execute_validate(
425 &self,
426 config: &ValidateConfig,
427 mut data: ProcessData,
428 ) -> Result<ProcessData> {
429 for rule in &config.rules {
431 match rule {
432 ValidationRule::MinLength {
433 min_length,
434 on_fail,
435 } => {
436 if data.content.len() < *min_length {
437 match on_fail.action {
438 ValidationActionType::Reject => {
439 data.metadata.rejected = true;
440 data.metadata.rejection_reason = Some(format!(
441 "Content too short: {} < {} characters",
442 data.content.len(),
443 min_length
444 ));
445 return Ok(data);
446 }
447 ValidationActionType::Warn => {
448 data.metadata.warnings.push(format!(
449 "Content shorter than {} characters",
450 min_length
451 ));
452 }
453 ValidationActionType::Truncate => {} }
455 }
456 }
457 ValidationRule::MaxLength {
458 max_length,
459 on_fail,
460 } => {
461 if data.content.len() > *max_length {
462 match on_fail.action {
463 ValidationActionType::Truncate => {
464 data.content = data.content.chars().take(*max_length).collect();
465 }
466 ValidationActionType::Reject => {
467 data.metadata.rejected = true;
468 data.metadata.rejection_reason = Some(format!(
469 "Content too long: {} > {} characters",
470 data.content.len(),
471 max_length
472 ));
473 return Ok(data);
474 }
475 ValidationActionType::Warn => {
476 data.metadata
477 .warnings
478 .push(format!("Content longer than {} characters", max_length));
479 }
480 }
481 }
482 }
483 ValidationRule::Pattern { pattern, on_fail } => {
484 if let Ok(re) = regex::Regex::new(pattern) {
485 if !re.is_match(&data.content) {
486 match on_fail.action {
487 ValidationActionType::Reject => {
488 data.metadata.rejected = true;
489 data.metadata.rejection_reason =
490 Some("Content does not match required pattern".to_string());
491 return Ok(data);
492 }
493 ValidationActionType::Warn => {
494 data.metadata.warnings.push(
495 "Content does not match expected pattern".to_string(),
496 );
497 }
498 ValidationActionType::Truncate => {} }
500 }
501 }
502 }
503 }
504 }
505
506 if !config.criteria.is_empty() {
508 let llm = self.get_llm(config.llm.as_deref())?;
509
510 let criteria_list = config
511 .criteria
512 .iter()
513 .enumerate()
514 .map(|(i, c)| format!("{}. {}", i + 1, c))
515 .collect::<Vec<_>>()
516 .join("\n");
517
518 let prompt = format!(
519 "Evaluate if the following content meets these criteria:\n{}\n\n\
520 Respond with JSON: {{\"passes\": true/false, \"score\": 0.0-1.0, \"issues\": [\"...\"]}}\n\n\
521 Content: {}",
522 criteria_list, data.content
523 );
524
525 let messages = vec![ChatMessage::user(&prompt)];
526 let response = llm
527 .complete(&messages, None)
528 .await
529 .map_err(|e| AgentError::LLM(e.to_string()))?;
530
531 if let Ok(result) =
532 serde_json::from_str::<serde_json::Value>(&extract_json(&response.content))
533 {
534 let score = result.get("score").and_then(|s| s.as_f64()).unwrap_or(1.0) as f32;
535 let passes = result
536 .get("passes")
537 .and_then(|p| p.as_bool())
538 .unwrap_or(true);
539
540 if !passes || score < config.threshold {
541 match config.on_fail.action {
542 ValidationFailType::Reject => {
543 data.metadata.rejected = true;
544 let issues = result
545 .get("issues")
546 .and_then(|i| i.as_array())
547 .map(|arr| {
548 arr.iter()
549 .filter_map(|v| v.as_str())
550 .collect::<Vec<_>>()
551 .join(", ")
552 })
553 .unwrap_or_else(|| "Validation failed".to_string());
554 data.metadata.rejection_reason = Some(issues);
555 return Ok(data);
556 }
557 ValidationFailType::Regenerate => {
558 data.metadata
559 .warnings
560 .push("Content may need regeneration".to_string());
561 }
562 ValidationFailType::Warn => {
563 if let Some(issues) = result.get("issues").and_then(|i| i.as_array()) {
564 for issue in issues {
565 if let Some(s) = issue.as_str() {
566 data.metadata.warnings.push(s.to_string());
567 }
568 }
569 }
570 }
571 }
572 }
573 }
574 }
575
576 Ok(data)
577 }
578
579 async fn execute_format(
580 &self,
581 config: &FormatConfig,
582 mut data: ProcessData,
583 ) -> Result<ProcessData> {
584 let template = if let Some(channel) = &config.channel {
585 config
586 .channels
587 .get(channel)
588 .and_then(|c| c.template.as_ref())
589 .or(config.template.as_ref())
590 } else {
591 config.template.as_ref()
592 };
593
594 if let Some(tmpl) = template {
595 let mut result = tmpl.clone();
597 result = result.replace("{{ response }}", &data.content);
598 result = result.replace("{{response}}", &data.content);
599
600 for (key, value) in &data.context {
602 let placeholder = format!("{{{{ context.{} }}}}", key);
603 let placeholder_no_space = format!("{{{{context.{}}}}}", key);
604 let value_str = match value {
605 serde_json::Value::String(s) => s.clone(),
606 _ => value.to_string(),
607 };
608 result = result.replace(&placeholder, &value_str);
609 result = result.replace(&placeholder_no_space, &value_str);
610 }
611
612 data.content = result;
613 }
614
615 if let Some(channel) = &config.channel {
617 if let Some(channel_config) = config.channels.get(channel) {
618 if let Some(max_len) = channel_config.max_length {
619 if data.content.len() > max_len {
620 data.content = data.content.chars().take(max_len).collect();
621 }
622 }
623 }
624 }
625
626 Ok(data)
627 }
628
629 async fn execute_enrich(
630 &self,
631 config: &EnrichConfig,
632 mut data: ProcessData,
633 ) -> Result<ProcessData> {
634 let result = match &config.source {
635 EnrichSource::None => return Ok(data),
636 EnrichSource::Api {
637 url,
638 method: _,
639 headers: _,
640 body: _,
641 extract: _,
642 } => {
643 data.metadata
646 .warnings
647 .push(format!("API enrichment not yet implemented: {}", url));
648 return Ok(data);
649 }
650 EnrichSource::File { path, format } => {
651 match std::fs::read_to_string(path) {
653 Ok(content) => match format.as_deref() {
654 Some("json") => serde_json::from_str(&content).ok(),
655 Some("yaml") => serde_yaml::from_str(&content).ok(),
656 _ => Some(serde_json::Value::String(content)),
657 },
658 Err(e) => match config.on_error {
659 EnrichErrorAction::Stop => return Err(AgentError::IoError(e)),
660 EnrichErrorAction::Continue | EnrichErrorAction::Warn => {
661 data.metadata
662 .warnings
663 .push(format!("File read failed: {}", e));
664 return Ok(data);
665 }
666 },
667 }
668 }
669 EnrichSource::Tool { tool, args: _ } => {
670 data.metadata
672 .warnings
673 .push(format!("Tool enrichment not yet implemented: {}", tool));
674 return Ok(data);
675 }
676 };
677
678 if let Some(value) = result {
679 if let Some(context_path) = &config.store_in_context {
680 data.context.insert(context_path.clone(), value);
681 }
682 }
683
684 Ok(data)
685 }
686
687 async fn execute_conditional(
688 &self,
689 config: &ConditionalConfig,
690 data: ProcessData,
691 ) -> Result<ProcessData> {
692 let condition_met = self.evaluate_condition(&config.condition, &data);
693
694 let stages = if condition_met {
695 &config.then_stages
696 } else {
697 &config.else_stages
698 };
699
700 let mut result = data;
701 for stage in stages {
702 result = self.execute_stage(stage, result).await?;
703 if result.metadata.rejected {
704 break;
705 }
706 }
707
708 Ok(result)
709 }
710
711 fn evaluate_condition(&self, condition: &Option<ConditionExpr>, data: &ProcessData) -> bool {
712 match condition {
713 None => true,
714 Some(expr) => self.evaluate_condition_expr(expr, data),
715 }
716 }
717
718 fn evaluate_condition_expr(&self, condition: &ConditionExpr, data: &ProcessData) -> bool {
719 match condition {
720 ConditionExpr::All { all } => all.iter().all(|c| self.evaluate_condition_expr(c, data)),
721 ConditionExpr::Any { any } => any.iter().any(|c| self.evaluate_condition_expr(c, data)),
722 ConditionExpr::Simple(map) => self.evaluate_simple_condition(map, data),
723 }
724 }
725
726 fn evaluate_simple_condition(
727 &self,
728 map: &std::collections::HashMap<String, serde_json::Value>,
729 data: &ProcessData,
730 ) -> bool {
731 for (path, expected) in map {
732 let actual = self.get_nested_value(&data.context, path);
733
734 if let Some(obj) = expected.as_object() {
736 if let Some(exists_val) = obj.get("exists") {
737 let should_exist = exists_val.as_bool().unwrap_or(true);
738 let does_exist =
739 actual.is_some() && !matches!(actual, Some(serde_json::Value::Null));
740 if does_exist != should_exist {
741 return false;
742 }
743 continue;
744 }
745 }
746
747 match (actual, expected) {
749 (Some(a), e) if a == e => continue,
750 (None, serde_json::Value::Null) => continue,
751 _ => return false,
752 }
753 }
754 true
755 }
756
757 fn get_nested_value<'a>(
758 &self,
759 context: &'a std::collections::HashMap<String, serde_json::Value>,
760 path: &str,
761 ) -> Option<&'a serde_json::Value> {
762 let parts: Vec<&str> = path.split('.').collect();
763 if parts.is_empty() {
764 return None;
765 }
766
767 let mut current: Option<&serde_json::Value> = context.get(parts[0]);
768
769 for part in &parts[1..] {
770 current = current.and_then(|v| {
771 if let serde_json::Value::Object(obj) = v {
772 obj.get(*part)
773 } else {
774 None
775 }
776 });
777 }
778
779 current
780 }
781
782 fn get_llm(&self, alias: Option<&str>) -> Result<Arc<dyn LLMProvider>> {
783 let registry = self
784 .llm_registry
785 .as_ref()
786 .ok_or_else(|| AgentError::Config("LLM registry not configured for process".into()))?;
787
788 match alias {
789 Some(name) => registry
790 .get(name)
791 .map_err(|e| AgentError::LLM(e.to_string())),
792 None => registry
793 .router()
794 .or_else(|_| registry.default())
795 .map_err(|e| AgentError::LLM(e.to_string())),
796 }
797 }
798}
799
800fn extract_json(response: &str) -> String {
801 let trimmed = response.trim();
802
803 if trimmed.starts_with("```json") {
804 if let Some(end) = trimmed[7..].find("```") {
805 return trimmed[7..7 + end].trim().to_string();
806 }
807 }
808
809 if trimmed.starts_with("```") {
810 if let Some(end) = trimmed[3..].find("```") {
811 return trimmed[3..3 + end].trim().to_string();
812 }
813 }
814
815 if let Some(start) = trimmed.find('{') {
816 if let Some(end) = trimmed.rfind('}') {
817 return trimmed[start..=end].to_string();
818 }
819 }
820
821 trimmed.to_string()
822}
823
824#[cfg(test)]
825mod tests {
826 use super::*;
827
828 #[test]
829 fn test_process_data_new() {
830 let data = ProcessData::new("test content");
831 assert_eq!(data.content, "test content");
832 assert_eq!(data.original, "test content");
833 assert!(data.context.is_empty());
834 }
835
836 #[test]
837 fn test_process_data_with_context() {
838 let data = ProcessData::new("test").with_context("key", serde_json::json!("value"));
839 assert!(data.context.contains_key("key"));
840 }
841
842 #[tokio::test]
843 async fn test_normalize_trim() {
844 let processor = ProcessProcessor::default();
845 let config = NormalizeConfig {
846 trim: true,
847 ..Default::default()
848 };
849 let data = ProcessData::new(" hello world ");
850 let result = processor.execute_normalize(&config, data).await.unwrap();
851 assert_eq!(result.content, "hello world");
852 }
853
854 #[tokio::test]
855 async fn test_normalize_collapse_whitespace() {
856 let processor = ProcessProcessor::default();
857 let config = NormalizeConfig {
858 trim: true,
859 collapse_whitespace: true,
860 ..Default::default()
861 };
862 let data = ProcessData::new("hello world\n\ntest");
863 let result = processor.execute_normalize(&config, data).await.unwrap();
864 assert_eq!(result.content, "hello world test");
865 }
866
867 #[tokio::test]
868 async fn test_normalize_lowercase() {
869 let processor = ProcessProcessor::default();
870 let config = NormalizeConfig {
871 lowercase: true,
872 ..Default::default()
873 };
874 let data = ProcessData::new("Hello World");
875 let result = processor.execute_normalize(&config, data).await.unwrap();
876 assert_eq!(result.content, "hello world");
877 }
878
879 #[tokio::test]
880 async fn test_validate_min_length_reject() {
881 let processor = ProcessProcessor::default();
882 let config = ValidateConfig {
883 rules: vec![ValidationRule::MinLength {
884 min_length: 10,
885 on_fail: ValidationAction {
886 action: ValidationActionType::Reject,
887 message: None,
888 },
889 }],
890 ..Default::default()
891 };
892 let data = ProcessData::new("short");
893 let result = processor.execute_validate(&config, data).await.unwrap();
894 assert!(result.metadata.rejected);
895 }
896
897 #[tokio::test]
898 async fn test_validate_max_length_truncate() {
899 let processor = ProcessProcessor::default();
900 let config = ValidateConfig {
901 rules: vec![ValidationRule::MaxLength {
902 max_length: 5,
903 on_fail: ValidationAction {
904 action: ValidationActionType::Truncate,
905 message: None,
906 },
907 }],
908 ..Default::default()
909 };
910 let data = ProcessData::new("hello world");
911 let result = processor.execute_validate(&config, data).await.unwrap();
912 assert_eq!(result.content, "hello");
913 assert!(!result.metadata.rejected);
914 }
915
916 #[tokio::test]
917 async fn test_format_simple_template() {
918 let processor = ProcessProcessor::default();
919 let config = FormatConfig {
920 template: Some("Response: {{ response }}".to_string()),
921 ..Default::default()
922 };
923 let data = ProcessData::new("Hello!");
924 let result = processor.execute_format(&config, data).await.unwrap();
925 assert_eq!(result.content, "Response: Hello!");
926 }
927
928 #[test]
929 fn test_extract_json() {
930 assert_eq!(extract_json(r#"{"key": 1}"#), r#"{"key": 1}"#);
931 assert_eq!(extract_json("```json\n{\"key\": 1}\n```"), r#"{"key": 1}"#);
932 assert_eq!(extract_json("Some text {\"key\": 1} more"), r#"{"key": 1}"#);
933 }
934
935 #[test]
936 fn test_evaluate_condition_empty() {
937 let processor = ProcessProcessor::default();
938 let data = ProcessData::new("test");
939 assert!(processor.evaluate_condition(&None, &data));
940 }
941
942 #[test]
943 fn test_evaluate_condition_simple_exists_true() {
944 let processor = ProcessProcessor::default();
945 let mut data = ProcessData::new("test");
946 data.context.insert(
947 "session".to_string(),
948 serde_json::json!({ "user_name": "Alice" }),
949 );
950
951 let mut map = std::collections::HashMap::new();
952 map.insert(
953 "session.user_name".to_string(),
954 serde_json::json!({ "exists": true }),
955 );
956 let condition = ConditionExpr::Simple(map);
957
958 assert!(processor.evaluate_condition_expr(&condition, &data));
959 }
960
961 #[test]
962 fn test_evaluate_condition_simple_exists_false() {
963 let processor = ProcessProcessor::default();
964 let data = ProcessData::new("test");
965
966 let mut map = std::collections::HashMap::new();
967 map.insert(
968 "session.user_name".to_string(),
969 serde_json::json!({ "exists": false }),
970 );
971 let condition = ConditionExpr::Simple(map);
972
973 assert!(processor.evaluate_condition_expr(&condition, &data));
974 }
975
976 #[test]
977 fn test_evaluate_condition_all() {
978 let processor = ProcessProcessor::default();
979 let mut data = ProcessData::new("test");
980 data.context.insert(
981 "session".to_string(),
982 serde_json::json!({ "user_name": "Alice", "language": "en" }),
983 );
984
985 let mut map1 = std::collections::HashMap::new();
986 map1.insert(
987 "session.user_name".to_string(),
988 serde_json::json!({ "exists": true }),
989 );
990 let mut map2 = std::collections::HashMap::new();
991 map2.insert(
992 "session.language".to_string(),
993 serde_json::json!({ "exists": true }),
994 );
995
996 let condition = ConditionExpr::All {
997 all: vec![ConditionExpr::Simple(map1), ConditionExpr::Simple(map2)],
998 };
999
1000 assert!(processor.evaluate_condition_expr(&condition, &data));
1001 }
1002
1003 #[test]
1004 fn test_evaluate_condition_any() {
1005 let processor = ProcessProcessor::default();
1006 let mut data = ProcessData::new("test");
1007 data.context.insert(
1008 "session".to_string(),
1009 serde_json::json!({ "tier": "premium" }),
1010 );
1011
1012 let mut map1 = std::collections::HashMap::new();
1013 map1.insert("session.tier".to_string(), serde_json::json!("premium"));
1014 let mut map2 = std::collections::HashMap::new();
1015 map2.insert("session.tier".to_string(), serde_json::json!("enterprise"));
1016
1017 let condition = ConditionExpr::Any {
1018 any: vec![ConditionExpr::Simple(map1), ConditionExpr::Simple(map2)],
1019 };
1020
1021 assert!(processor.evaluate_condition_expr(&condition, &data));
1022 }
1023
1024 #[test]
1025 fn test_evaluate_condition_value_match() {
1026 let processor = ProcessProcessor::default();
1027 let mut data = ProcessData::new("test");
1028 data.context.insert(
1029 "input".to_string(),
1030 serde_json::json!({ "sentiment": "negative" }),
1031 );
1032
1033 let mut map = std::collections::HashMap::new();
1034 map.insert("input.sentiment".to_string(), serde_json::json!("negative"));
1035 let condition = ConditionExpr::Simple(map);
1036
1037 assert!(processor.evaluate_condition_expr(&condition, &data));
1038 }
1039
1040 #[test]
1041 fn test_get_nested_value() {
1042 let processor = ProcessProcessor::default();
1043 let mut context = std::collections::HashMap::new();
1044 context.insert(
1045 "session".to_string(),
1046 serde_json::json!({ "user": { "name": "Alice" } }),
1047 );
1048
1049 let result = processor.get_nested_value(&context, "session.user.name");
1050 assert_eq!(result, Some(&serde_json::json!("Alice")));
1051
1052 let result = processor.get_nested_value(&context, "session.nonexistent");
1053 assert!(result.is_none());
1054 }
1055
1056 fn create_mock_registry(response: &str) -> Arc<ai_agents_llm::LLMRegistry> {
1060 use ai_agents_llm::mock::MockLLMProvider;
1061 let mut mock = MockLLMProvider::new("test");
1062 mock.set_response(response);
1063 let mut registry = ai_agents_llm::LLMRegistry::new();
1064 registry.register("default", std::sync::Arc::new(mock));
1065 registry.set_default("default");
1066 std::sync::Arc::new(registry)
1067 }
1068
1069 fn create_mock_registry_multi(responses: Vec<&str>) -> Arc<ai_agents_llm::LLMRegistry> {
1070 use ai_agents_llm::mock::MockLLMProvider;
1071 let mut mock = MockLLMProvider::new("test");
1072 mock.set_responses(responses.into_iter().map(String::from).collect(), true);
1073 let mut registry = ai_agents_llm::LLMRegistry::new();
1074 registry.register("default", std::sync::Arc::new(mock));
1075 registry.set_default("default");
1076 std::sync::Arc::new(registry)
1077 }
1078
1079 #[tokio::test]
1080 async fn test_detect_stage_language_sentiment() {
1081 let registry = create_mock_registry(
1082 r#"{"language": "ko", "sentiment": "positive", "intent": "greeting"}"#,
1083 );
1084 let config = ProcessConfig {
1085 input: vec![ProcessStage::Detect(DetectStage {
1086 id: Some("detect_test".to_string()),
1087 condition: None,
1088 config: DetectConfig {
1089 llm: None,
1090 detect: vec![DetectionType::Language, DetectionType::Sentiment],
1091 intents: vec![IntentDefinition {
1092 id: "greeting".to_string(),
1093 description: "User says hello".to_string(),
1094 }],
1095 store_in_context: {
1096 let mut m = std::collections::HashMap::new();
1097 m.insert("language".to_string(), "input.language".to_string());
1098 m.insert("sentiment".to_string(), "input.sentiment".to_string());
1099 m
1100 },
1101 },
1102 })],
1103 ..Default::default()
1104 };
1105 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1106 let result = processor.process_input("안녕하세요!").await.unwrap();
1107
1108 assert_eq!(
1109 result.context.get("input.language"),
1110 Some(&serde_json::json!("ko"))
1111 );
1112 assert_eq!(
1113 result.context.get("input.sentiment"),
1114 Some(&serde_json::json!("positive"))
1115 );
1116 assert!(
1117 result
1118 .metadata
1119 .stages_executed
1120 .contains(&"detect_test".to_string())
1121 );
1122 }
1123
1124 #[tokio::test]
1125 async fn test_extract_stage_entities() {
1126 let registry = create_mock_registry(r#"{"order_number": "ORD-12345", "urgency": "high"}"#);
1127 let config = ProcessConfig {
1128 input: vec![ProcessStage::Extract(ExtractStage {
1129 id: Some("extract_test".to_string()),
1130 condition: None,
1131 config: ExtractConfig {
1132 llm: None,
1133 schema: {
1134 let mut m = std::collections::HashMap::new();
1135 m.insert(
1136 "order_number".to_string(),
1137 FieldSchema {
1138 field_type: FieldType::String,
1139 description: Some("Order number".to_string()),
1140 required: true,
1141 values: vec![],
1142 },
1143 );
1144 m.insert(
1145 "urgency".to_string(),
1146 FieldSchema {
1147 field_type: FieldType::Enum,
1148 description: Some("Urgency level".to_string()),
1149 required: false,
1150 values: vec![
1151 "low".to_string(),
1152 "medium".to_string(),
1153 "high".to_string(),
1154 ],
1155 },
1156 );
1157 m
1158 },
1159 store_in_context: Some("extracted".to_string()),
1160 },
1161 })],
1162 ..Default::default()
1163 };
1164 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1165 let result = processor
1166 .process_input("My order ORD-12345 is urgent!")
1167 .await
1168 .unwrap();
1169
1170 let extracted = result.context.get("extracted").unwrap();
1171 assert_eq!(extracted["order_number"], "ORD-12345");
1172 assert_eq!(extracted["urgency"], "high");
1173 }
1174
1175 #[tokio::test]
1176 async fn test_sanitize_stage_pii_masking() {
1177 let registry = create_mock_registry("Call me at ****-****-**** or email at ****@****.com");
1178 let config = ProcessConfig {
1179 input: vec![ProcessStage::Sanitize(SanitizeStage {
1180 id: Some("sanitize_test".to_string()),
1181 condition: None,
1182 config: SanitizeConfig {
1183 llm: None,
1184 pii: Some(PIISanitizeConfig {
1185 action: PIIAction::Mask,
1186 types: vec![PIIType::Phone, PIIType::Email],
1187 mask_char: "*".to_string(),
1188 }),
1189 harmful: None,
1190 remove: vec![],
1191 },
1192 })],
1193 ..Default::default()
1194 };
1195 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1196 let result = processor
1197 .process_input("Call me at 010-1234-5678 or email at user@example.com")
1198 .await
1199 .unwrap();
1200
1201 assert!(result.content.contains("****"));
1203 assert!(!result.content.contains("010-1234-5678"));
1204 assert!(!result.content.contains("user@example.com"));
1205 }
1206
1207 #[tokio::test]
1208 async fn test_transform_stage_tone_adjustment() {
1209 let registry = create_mock_registry(
1210 "I understand your frustration. Let me help you resolve this issue right away.",
1211 );
1212 let config = ProcessConfig {
1213 output: vec![ProcessStage::Transform(TransformStage {
1214 id: Some("tone_test".to_string()),
1215 condition: None,
1216 config: TransformConfig {
1217 llm: None,
1218 prompt: Some("Rewrite to be more empathetic.".to_string()),
1219 max_output_tokens: None,
1220 },
1221 })],
1222 ..Default::default()
1223 };
1224 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1225
1226 let input_context = std::collections::HashMap::new();
1227 let result = processor
1228 .process_output("Your request is being processed.", &input_context)
1229 .await
1230 .unwrap();
1231
1232 assert!(result.content.contains("understand"));
1233 }
1234
1235 #[tokio::test]
1236 async fn test_validate_stage_llm_criteria() {
1237 let registry = create_mock_registry(
1238 r#"{"passes": false, "score": 0.3, "issues": ["Response is too vague"]}"#,
1239 );
1240 let config = ProcessConfig {
1241 output: vec![ProcessStage::Validate(ValidateStage {
1242 id: Some("quality_test".to_string()),
1243 condition: None,
1244 config: ValidateConfig {
1245 rules: vec![],
1246 llm: None,
1247 criteria: vec!["Response is specific and actionable".to_string()],
1248 threshold: 0.7,
1249 on_fail: ValidationFailAction {
1250 action: ValidationFailType::Warn,
1251 ..Default::default()
1252 },
1253 },
1254 })],
1255 ..Default::default()
1256 };
1257 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1258
1259 let input_context = std::collections::HashMap::new();
1260 let result = processor
1261 .process_output("It depends.", &input_context)
1262 .await
1263 .unwrap();
1264
1265 assert!(
1267 result.metadata.warnings.iter().any(|w| w.contains("vague")),
1268 "Expected warning about vague response, got: {:?}",
1269 result.metadata.warnings
1270 );
1271 }
1272
1273 #[tokio::test]
1274 async fn test_validate_stage_llm_criteria_reject() {
1275 let registry = create_mock_registry(
1276 r#"{"passes": false, "score": 0.2, "issues": ["Contains harmful content"]}"#,
1277 );
1278 let config = ProcessConfig {
1279 output: vec![ProcessStage::Validate(ValidateStage {
1280 id: Some("reject_test".to_string()),
1281 condition: None,
1282 config: ValidateConfig {
1283 rules: vec![],
1284 llm: None,
1285 criteria: vec!["Response is safe".to_string()],
1286 threshold: 0.7,
1287 on_fail: ValidationFailAction {
1288 action: ValidationFailType::Reject,
1289 ..Default::default()
1290 },
1291 },
1292 })],
1293 ..Default::default()
1294 };
1295 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1296
1297 let input_context = std::collections::HashMap::new();
1298 let result = processor
1299 .process_output("Dangerous content here.", &input_context)
1300 .await
1301 .unwrap();
1302
1303 assert!(result.metadata.rejected);
1304 assert!(
1305 result
1306 .metadata
1307 .rejection_reason
1308 .as_ref()
1309 .unwrap()
1310 .contains("harmful")
1311 );
1312 }
1313
1314 #[tokio::test]
1315 async fn test_full_input_pipeline_chain() {
1316 let registry = create_mock_registry_multi(vec![
1318 r#"{"language": "en", "sentiment": "neutral"}"#,
1320 r#"{"user_name": "Alice", "topic": "billing"}"#,
1322 ]);
1323 let config = ProcessConfig {
1324 input: vec![
1325 ProcessStage::Normalize(NormalizeStage {
1326 id: Some("norm".to_string()),
1327 condition: None,
1328 config: NormalizeConfig {
1329 trim: true,
1330 collapse_whitespace: true,
1331 ..Default::default()
1332 },
1333 }),
1334 ProcessStage::Detect(DetectStage {
1335 id: Some("detect".to_string()),
1336 condition: None,
1337 config: DetectConfig {
1338 llm: None,
1339 detect: vec![DetectionType::Language, DetectionType::Sentiment],
1340 intents: vec![],
1341 store_in_context: {
1342 let mut m = std::collections::HashMap::new();
1343 m.insert("language".to_string(), "input.language".to_string());
1344 m
1345 },
1346 },
1347 }),
1348 ProcessStage::Extract(ExtractStage {
1349 id: Some("extract".to_string()),
1350 condition: None,
1351 config: ExtractConfig {
1352 llm: None,
1353 schema: {
1354 let mut m = std::collections::HashMap::new();
1355 m.insert(
1356 "user_name".to_string(),
1357 FieldSchema {
1358 field_type: FieldType::String,
1359 description: Some("User name".to_string()),
1360 ..Default::default()
1361 },
1362 );
1363 m
1364 },
1365 store_in_context: Some("entities".to_string()),
1366 },
1367 }),
1368 ],
1369 ..Default::default()
1370 };
1371 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1372 let result = processor
1373 .process_input(" Hi, I'm Alice and I have a billing question ")
1374 .await
1375 .unwrap();
1376
1377 assert_eq!(
1379 result.content,
1380 "Hi, I'm Alice and I have a billing question"
1381 );
1382
1383 assert_eq!(
1385 result.context.get("input.language"),
1386 Some(&serde_json::json!("en"))
1387 );
1388
1389 let entities = result.context.get("entities").unwrap();
1391 assert_eq!(entities["user_name"], "Alice");
1392
1393 assert_eq!(
1395 result.metadata.stages_executed,
1396 vec!["norm", "detect", "extract"]
1397 );
1398 }
1399
1400 #[tokio::test]
1401 async fn test_conditional_stage_skips_on_false() {
1402 let registry = create_mock_registry(r#"{"language": "en"}"#);
1403 let config = ProcessConfig {
1404 input: vec![ProcessStage::Detect(DetectStage {
1405 id: Some("should_skip".to_string()),
1406 condition: Some(ConditionExpr::Simple({
1407 let mut map = std::collections::HashMap::new();
1408 map.insert("needs_detection".to_string(), serde_json::json!(true));
1409 map
1410 })),
1411 config: DetectConfig {
1412 llm: None,
1413 detect: vec![DetectionType::Language],
1414 ..Default::default()
1415 },
1416 })],
1417 ..Default::default()
1418 };
1419 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1420 let result = processor.process_input("Hello").await.unwrap();
1421
1422 assert!(
1424 !result
1425 .metadata
1426 .stages_executed
1427 .contains(&"should_skip".to_string()),
1428 "Stage should have been skipped"
1429 );
1430 }
1431
1432 #[tokio::test]
1433 async fn test_stage_skipped_when_condition_false() {
1434 let config = ProcessConfig {
1435 input: vec![ProcessStage::Extract(ExtractStage {
1436 id: Some("skip_me".to_string()),
1437 condition: Some(ConditionExpr::Simple({
1438 let mut map = std::collections::HashMap::new();
1439 map.insert(
1440 "session.user".to_string(),
1441 serde_json::json!({ "exists": false }),
1442 );
1443 map
1444 })),
1445 config: ExtractConfig::default(),
1446 })],
1447 settings: ProcessSettings {
1448 debug: ProcessDebugConfig {
1449 log_stages: true,
1450 ..Default::default()
1451 },
1452 ..Default::default()
1453 },
1454 ..Default::default()
1455 };
1456 let processor = ProcessProcessor::new(config);
1457
1458 let mut data = ProcessData::new("test");
1459 data.context.insert(
1460 "session".to_string(),
1461 serde_json::json!({ "user": "Alice" }),
1462 );
1463
1464 let result = processor.process_input("test").await.unwrap();
1465 assert!(
1466 !result
1467 .metadata
1468 .stages_executed
1469 .contains(&"skip_me".to_string())
1470 );
1471 }
1472}