1use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::time::Instant;
8
9use ai_agents_core::{AgentError, ChatMessage, LLMProvider, Result};
10use ai_agents_llm::LLMRegistry;
11
12use super::config::*;
13
14#[derive(Debug, Clone, Default)]
15pub struct ProcessData {
16 pub content: String,
17 pub original: String,
18 pub context: HashMap<String, serde_json::Value>,
19 pub metadata: ProcessMetadata,
20}
21
22#[derive(Debug, Clone, Default)]
23pub struct ProcessMetadata {
24 pub stages_executed: Vec<String>,
25 pub timing: HashMap<String, u64>,
26 pub warnings: Vec<String>,
27 pub rejected: bool,
28 pub rejection_reason: Option<String>,
29}
30
31impl ProcessData {
32 pub fn new(content: impl Into<String>) -> Self {
33 let content = content.into();
34 Self {
35 original: content.clone(),
36 content,
37 context: HashMap::new(),
38 metadata: ProcessMetadata::default(),
39 }
40 }
41
42 pub fn with_context(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
43 self.context.insert(key.into(), value);
44 self
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum ProcessPurposeHint {
51 Detect,
52 Extract,
53 Validate,
54 Transform,
55 Other,
56}
57
58pub type ProcessStageFuture<'a> = Pin<Box<dyn Future<Output = Result<ProcessData>> + Send + 'a>>;
60
61pub trait ProcessStageObserver: Send + Sync {
63 fn observe<'a>(
65 &'a self,
66 hint: ProcessPurposeHint,
67 future: ProcessStageFuture<'a>,
68 ) -> ProcessStageFuture<'a>;
69}
70
71pub struct ProcessProcessor {
73 config: ProcessConfig,
74 llm_registry: Option<Arc<LLMRegistry>>,
75 stage_observer: Option<Arc<dyn ProcessStageObserver>>,
76}
77
78impl std::fmt::Debug for ProcessProcessor {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("ProcessProcessor")
81 .field("config", &self.config)
82 .field("has_llm_registry", &self.llm_registry.is_some())
83 .field("has_stage_observer", &self.stage_observer.is_some())
84 .finish()
85 }
86}
87
88impl Default for ProcessProcessor {
89 fn default() -> Self {
90 Self::new(ProcessConfig::default())
91 }
92}
93
94impl ProcessProcessor {
95 pub fn new(config: ProcessConfig) -> Self {
96 Self {
97 config,
98 llm_registry: None,
99 stage_observer: None,
100 }
101 }
102
103 pub fn with_llm_registry(mut self, registry: Arc<LLMRegistry>) -> Self {
104 self.llm_registry = Some(registry);
105 self
106 }
107
108 pub fn with_stage_observer(mut self, observer: Arc<dyn ProcessStageObserver>) -> Self {
110 self.stage_observer = Some(observer);
111 self
112 }
113
114 pub fn input_purpose_hint(&self) -> ProcessPurposeHint {
116 purpose_hint_for_stages(&self.config.input)
117 }
118
119 pub fn output_purpose_hint(&self) -> ProcessPurposeHint {
121 purpose_hint_for_stages(&self.config.output)
122 }
123
124 pub async fn process_input(&self, input: &str) -> Result<ProcessData> {
125 let mut data = ProcessData::new(input);
126
127 for stage in &self.config.input {
128 data = self.execute_stage(stage, data).await?;
129 if data.metadata.rejected {
130 break;
131 }
132 }
133
134 Ok(data)
135 }
136
137 pub async fn process_output(
138 &self,
139 output: &str,
140 input_context: &HashMap<String, serde_json::Value>,
141 ) -> Result<ProcessData> {
142 let mut data = ProcessData::new(output);
143 data.context = input_context.clone();
144
145 for stage in &self.config.output {
146 data = self.execute_stage(stage, data).await?;
147 if data.metadata.rejected {
148 break;
149 }
150 }
151
152 Ok(data)
153 }
154
155 fn execute_stage<'a>(
156 &'a self,
157 stage: &'a ProcessStage,
158 data: ProcessData,
159 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ProcessData>> + Send + 'a>> {
160 Box::pin(async move {
161 let start = Instant::now();
162 let stage_name = stage
163 .id()
164 .map(String::from)
165 .unwrap_or_else(|| self.get_stage_type_name(stage));
166
167 if let Some(condition) = stage.condition() {
169 if !self.evaluate_condition_expr(condition, &data) {
170 if self.config.settings.debug.log_stages {
171 tracing::debug!(
172 "[Process] Stage skipped (condition not met): {}",
173 stage_name
174 );
175 }
176 return Ok(data);
177 }
178 }
179
180 if self.config.settings.debug.log_stages {
181 tracing::debug!("[Process] Executing stage: {}", stage_name);
182 }
183
184 let data_clone = data.clone();
185 let hint = process_purpose_hint_for_stage(stage);
186 let run_stage: ProcessStageFuture<'a> = Box::pin(async move {
187 match stage {
188 ProcessStage::Normalize(s) => self.execute_normalize(&s.config, data).await,
189 ProcessStage::Detect(s) => self.execute_detect(&s.config, data).await,
190 ProcessStage::Extract(s) => self.execute_extract(&s.config, data).await,
191 ProcessStage::Sanitize(s) => self.execute_sanitize(&s.config, data).await,
192 ProcessStage::Transform(s) => self.execute_transform(&s.config, data).await,
193 ProcessStage::Validate(s) => self.execute_validate(&s.config, data).await,
194 ProcessStage::Format(s) => self.execute_format(&s.config, data).await,
195 ProcessStage::Enrich(s) => self.execute_enrich(&s.config, data).await,
196 ProcessStage::Conditional(s) => self.execute_conditional(&s.config, data).await,
197 }
198 });
199 let result = if let Some(observer) = self.stage_observer.as_ref() {
200 observer.observe(hint, run_stage).await
201 } else {
202 run_stage.await
203 };
204
205 match result {
206 Ok(mut d) => {
207 d.metadata.stages_executed.push(stage_name.clone());
208 if self.config.settings.debug.include_timing {
209 d.metadata
210 .timing
211 .insert(stage_name, start.elapsed().as_millis() as u64);
212 }
213 Ok(d)
214 }
215 Err(e) => {
216 let mut fallback_data = data_clone;
217 match self.config.settings.on_stage_error.default {
218 StageErrorAction::Stop => Err(e),
219 StageErrorAction::Continue => {
220 fallback_data
221 .metadata
222 .warnings
223 .push(format!("Stage {} failed: {}", stage_name, e));
224 Ok(fallback_data)
225 }
226 StageErrorAction::Retry => {
227 if let Some(retry_config) = &self.config.settings.on_stage_error.retry {
228 for _ in 0..retry_config.max_retries {
229 tokio::time::sleep(std::time::Duration::from_millis(
230 retry_config.backoff_ms,
231 ))
232 .await;
233 }
234 }
235 fallback_data
236 .metadata
237 .warnings
238 .push(format!("Stage {} failed after retries: {}", stage_name, e));
239 Ok(fallback_data)
240 }
241 }
242 }
243 }
244 })
245 }
246
247 fn get_stage_type_name(&self, stage: &ProcessStage) -> String {
248 match stage {
249 ProcessStage::Normalize(_) => "normalize".to_string(),
250 ProcessStage::Detect(_) => "detect".to_string(),
251 ProcessStage::Extract(_) => "extract".to_string(),
252 ProcessStage::Sanitize(_) => "sanitize".to_string(),
253 ProcessStage::Transform(_) => "transform".to_string(),
254 ProcessStage::Validate(_) => "validate".to_string(),
255 ProcessStage::Format(_) => "format".to_string(),
256 ProcessStage::Enrich(_) => "enrich".to_string(),
257 ProcessStage::Conditional(_) => "conditional".to_string(),
258 }
259 }
260
261 async fn execute_normalize(
262 &self,
263 config: &NormalizeConfig,
264 mut data: ProcessData,
265 ) -> Result<ProcessData> {
266 let mut content = data.content.clone();
267
268 if config.trim {
269 content = content.trim().to_string();
270 }
271
272 if config.collapse_whitespace {
273 content = content.split_whitespace().collect::<Vec<_>>().join(" ");
274 }
275
276 if config.lowercase {
277 content = content.to_lowercase();
278 }
279
280 data.content = content;
284 Ok(data)
285 }
286
287 async fn execute_detect(
288 &self,
289 config: &DetectConfig,
290 mut data: ProcessData,
291 ) -> Result<ProcessData> {
292 let llm = self.get_llm(config.llm.as_deref())?;
293
294 let detection_types: Vec<&str> = config
295 .detect
296 .iter()
297 .map(|d| match d {
298 DetectionType::Language => "language (ISO 639-1 code)",
299 DetectionType::Sentiment => "sentiment (positive, negative, neutral)",
300 DetectionType::Intent => "intent",
301 DetectionType::Topic => "topic",
302 DetectionType::Formality => "formality (formal, informal)",
303 DetectionType::Urgency => "urgency (low, medium, high, critical)",
304 })
305 .collect();
306
307 let intents_desc = if !config.intents.is_empty() {
308 let intents: Vec<String> = config
309 .intents
310 .iter()
311 .map(|i| format!("- {}: {}", i.id, i.description))
312 .collect();
313 format!("\n\nAvailable intents:\n{}", intents.join("\n"))
314 } else {
315 String::new()
316 };
317
318 let prompt = format!(
319 "Analyze the following text and detect: {}\n{}\n\n\
320 Respond with JSON only: {{\"language\": \"...\", \"sentiment\": \"...\", \"intent\": \"...\", ...}}\n\n\
321 Text: {}",
322 detection_types.join(", "),
323 intents_desc,
324 data.content
325 );
326
327 let messages = vec![ChatMessage::user(&prompt)];
328 let response = llm
329 .complete(&messages, None)
330 .await
331 .map_err(|e| AgentError::LLM(e.to_string()))?;
332
333 if let Ok(result) =
334 serde_json::from_str::<serde_json::Value>(&extract_json(&response.content))
335 {
336 for (key, context_path) in &config.store_in_context {
337 if let Some(value) = result.get(key) {
338 data.context.insert(context_path.clone(), value.clone());
339 }
340 }
341 data.context.insert("detection".to_string(), result);
342 }
343
344 Ok(data)
345 }
346
347 async fn execute_extract(
348 &self,
349 config: &ExtractConfig,
350 mut data: ProcessData,
351 ) -> Result<ProcessData> {
352 let llm = self.get_llm(config.llm.as_deref())?;
353
354 let schema_desc: Vec<String> = config
355 .schema
356 .iter()
357 .map(|(name, schema)| {
358 let type_str = format!("{:?}", schema.field_type).to_lowercase();
359 let desc = schema.description.as_deref().unwrap_or("");
360 let values = if !schema.values.is_empty() {
361 format!(" (values: {})", schema.values.join(", "))
362 } else {
363 String::new()
364 };
365 let required = if schema.required { " [required]" } else { "" };
366 format!("- {}: {} - {}{}{}", name, type_str, desc, values, required)
367 })
368 .collect();
369
370 let prompt = format!(
371 "Extract the following fields from the text:\n{}\n\n\
372 Respond with JSON only. Use null for fields not found.\n\n\
373 Text: {}",
374 schema_desc.join("\n"),
375 data.content
376 );
377
378 let messages = vec![ChatMessage::user(&prompt)];
379 let response = llm
380 .complete(&messages, None)
381 .await
382 .map_err(|e| AgentError::LLM(e.to_string()))?;
383
384 if let Ok(result) =
385 serde_json::from_str::<serde_json::Value>(&extract_json(&response.content))
386 {
387 if let Some(context_path) = &config.store_in_context {
388 data.context.insert(context_path.clone(), result.clone());
389 }
390 data.context.insert("extracted".to_string(), result);
391 }
392
393 Ok(data)
394 }
395
396 async fn execute_sanitize(
397 &self,
398 config: &SanitizeConfig,
399 mut data: ProcessData,
400 ) -> Result<ProcessData> {
401 let llm = self.get_llm(config.llm.as_deref())?;
402
403 let mut instructions = Vec::new();
404
405 if let Some(pii_config) = &config.pii {
406 if !pii_config.types.is_empty() {
407 let pii_types: Vec<String> = pii_config
408 .types
409 .iter()
410 .map(|t| format!("{:?}", t).to_lowercase())
411 .collect();
412 let action = match pii_config.action {
413 PIIAction::Mask => format!("replace with '{}'", pii_config.mask_char.repeat(4)),
414 PIIAction::Remove => "remove completely".to_string(),
415 PIIAction::Flag => "wrap with [PII: type]".to_string(),
416 };
417 instructions.push(format!("PII types to {}: {}", action, pii_types.join(", ")));
418 }
419 }
420
421 if let Some(harmful_config) = &config.harmful {
422 if !harmful_config.detect.is_empty() {
423 let types: Vec<String> = harmful_config
424 .detect
425 .iter()
426 .map(|t| format!("{:?}", t).to_lowercase())
427 .collect();
428 instructions.push(format!("Detect harmful content: {}", types.join(", ")));
429 }
430 }
431
432 if !config.remove.is_empty() {
433 instructions.push(format!(
434 "Remove any mentions of: {}",
435 config.remove.join(", ")
436 ));
437 }
438
439 if instructions.is_empty() {
440 return Ok(data);
441 }
442
443 let prompt = format!(
444 "Sanitize the following text according to these rules:\n{}\n\n\
445 Return only the sanitized text, nothing else.\n\n\
446 Text: {}",
447 instructions.join("\n"),
448 data.content
449 );
450
451 let messages = vec![ChatMessage::user(&prompt)];
452 let response = llm
453 .complete(&messages, None)
454 .await
455 .map_err(|e| AgentError::LLM(e.to_string()))?;
456
457 data.content = response.content.trim().to_string();
458 Ok(data)
459 }
460
461 async fn execute_transform(
462 &self,
463 config: &TransformConfig,
464 mut data: ProcessData,
465 ) -> Result<ProcessData> {
466 let prompt = match &config.prompt {
467 Some(p) => p.clone(),
468 None => return Ok(data),
469 };
470
471 let llm = self.get_llm(config.llm.as_deref())?;
472
473 let full_prompt = format!("{}\n\nOriginal text:\n{}", prompt, data.content);
474
475 let messages = vec![ChatMessage::user(&full_prompt)];
476 let response = llm
477 .complete(&messages, None)
478 .await
479 .map_err(|e| AgentError::LLM(e.to_string()))?;
480
481 data.content = response.content.trim().to_string();
482 Ok(data)
483 }
484
485 async fn execute_validate(
486 &self,
487 config: &ValidateConfig,
488 mut data: ProcessData,
489 ) -> Result<ProcessData> {
490 for rule in &config.rules {
492 match rule {
493 ValidationRule::MinLength {
494 min_length,
495 on_fail,
496 } => {
497 if data.content.len() < *min_length {
498 match on_fail.action {
499 ValidationActionType::Reject => {
500 data.metadata.rejected = true;
501 data.metadata.rejection_reason = Some(format!(
502 "Content too short: {} < {} characters",
503 data.content.len(),
504 min_length
505 ));
506 return Ok(data);
507 }
508 ValidationActionType::Warn => {
509 data.metadata.warnings.push(format!(
510 "Content shorter than {} characters",
511 min_length
512 ));
513 }
514 ValidationActionType::Truncate => {} }
516 }
517 }
518 ValidationRule::MaxLength {
519 max_length,
520 on_fail,
521 } => {
522 if data.content.len() > *max_length {
523 match on_fail.action {
524 ValidationActionType::Truncate => {
525 data.content = data.content.chars().take(*max_length).collect();
526 }
527 ValidationActionType::Reject => {
528 data.metadata.rejected = true;
529 data.metadata.rejection_reason = Some(format!(
530 "Content too long: {} > {} characters",
531 data.content.len(),
532 max_length
533 ));
534 return Ok(data);
535 }
536 ValidationActionType::Warn => {
537 data.metadata
538 .warnings
539 .push(format!("Content longer than {} characters", max_length));
540 }
541 }
542 }
543 }
544 ValidationRule::Pattern { pattern, on_fail } => {
545 if let Ok(re) = regex::Regex::new(pattern) {
546 if !re.is_match(&data.content) {
547 match on_fail.action {
548 ValidationActionType::Reject => {
549 data.metadata.rejected = true;
550 data.metadata.rejection_reason =
551 Some("Content does not match required pattern".to_string());
552 return Ok(data);
553 }
554 ValidationActionType::Warn => {
555 data.metadata.warnings.push(
556 "Content does not match expected pattern".to_string(),
557 );
558 }
559 ValidationActionType::Truncate => {} }
561 }
562 }
563 }
564 }
565 }
566
567 if !config.criteria.is_empty() {
569 let llm = self.get_llm(config.llm.as_deref())?;
570
571 let criteria_list = config
572 .criteria
573 .iter()
574 .enumerate()
575 .map(|(i, c)| format!("{}. {}", i + 1, c))
576 .collect::<Vec<_>>()
577 .join("\n");
578
579 let prompt = format!(
580 "Evaluate if the following content meets these criteria:\n{}\n\n\
581 Respond with JSON: {{\"passes\": true/false, \"score\": 0.0-1.0, \"issues\": [\"...\"]}}\n\n\
582 Content: {}",
583 criteria_list, data.content
584 );
585
586 let messages = vec![ChatMessage::user(&prompt)];
587 let response = llm
588 .complete(&messages, None)
589 .await
590 .map_err(|e| AgentError::LLM(e.to_string()))?;
591
592 if let Ok(result) =
593 serde_json::from_str::<serde_json::Value>(&extract_json(&response.content))
594 {
595 let score = result.get("score").and_then(|s| s.as_f64()).unwrap_or(1.0) as f32;
596 let passes = result
597 .get("passes")
598 .and_then(|p| p.as_bool())
599 .unwrap_or(true);
600
601 if !passes || score < config.threshold {
602 match config.on_fail.action {
603 ValidationFailType::Reject => {
604 data.metadata.rejected = true;
605 let issues = result
606 .get("issues")
607 .and_then(|i| i.as_array())
608 .map(|arr| {
609 arr.iter()
610 .filter_map(|v| v.as_str())
611 .collect::<Vec<_>>()
612 .join(", ")
613 })
614 .unwrap_or_else(|| "Validation failed".to_string());
615 data.metadata.rejection_reason = Some(issues);
616 return Ok(data);
617 }
618 ValidationFailType::Regenerate => {
619 data.metadata
620 .warnings
621 .push("Content may need regeneration".to_string());
622 }
623 ValidationFailType::Warn => {
624 if let Some(issues) = result.get("issues").and_then(|i| i.as_array()) {
625 for issue in issues {
626 if let Some(s) = issue.as_str() {
627 data.metadata.warnings.push(s.to_string());
628 }
629 }
630 }
631 }
632 }
633 }
634 }
635 }
636
637 Ok(data)
638 }
639
640 async fn execute_format(
641 &self,
642 config: &FormatConfig,
643 mut data: ProcessData,
644 ) -> Result<ProcessData> {
645 let template = if let Some(channel) = &config.channel {
646 config
647 .channels
648 .get(channel)
649 .and_then(|c| c.template.as_ref())
650 .or(config.template.as_ref())
651 } else {
652 config.template.as_ref()
653 };
654
655 if let Some(tmpl) = template {
656 let mut result = tmpl.clone();
658 result = result.replace("{{ response }}", &data.content);
659 result = result.replace("{{response}}", &data.content);
660
661 for (key, value) in &data.context {
663 let placeholder = format!("{{{{ context.{} }}}}", key);
664 let placeholder_no_space = format!("{{{{context.{}}}}}", key);
665 let value_str = match value {
666 serde_json::Value::String(s) => s.clone(),
667 _ => value.to_string(),
668 };
669 result = result.replace(&placeholder, &value_str);
670 result = result.replace(&placeholder_no_space, &value_str);
671 }
672
673 data.content = result;
674 }
675
676 if let Some(channel) = &config.channel {
678 if let Some(channel_config) = config.channels.get(channel) {
679 if let Some(max_len) = channel_config.max_length {
680 if data.content.len() > max_len {
681 data.content = data.content.chars().take(max_len).collect();
682 }
683 }
684 }
685 }
686
687 Ok(data)
688 }
689
690 async fn execute_enrich(
691 &self,
692 config: &EnrichConfig,
693 mut data: ProcessData,
694 ) -> Result<ProcessData> {
695 let result = match &config.source {
696 EnrichSource::None => return Ok(data),
697 EnrichSource::Api {
698 url,
699 method: _,
700 headers: _,
701 body: _,
702 extract: _,
703 } => {
704 data.metadata
707 .warnings
708 .push(format!("API enrichment not yet implemented: {}", url));
709 return Ok(data);
710 }
711 EnrichSource::File { path, format } => {
712 match std::fs::read_to_string(path) {
714 Ok(content) => match format.as_deref() {
715 Some("json") => serde_json::from_str(&content).ok(),
716 Some("yaml") => serde_yaml::from_str(&content).ok(),
717 _ => Some(serde_json::Value::String(content)),
718 },
719 Err(e) => match config.on_error {
720 EnrichErrorAction::Stop => return Err(AgentError::IoError(e)),
721 EnrichErrorAction::Continue | EnrichErrorAction::Warn => {
722 data.metadata
723 .warnings
724 .push(format!("File read failed: {}", e));
725 return Ok(data);
726 }
727 },
728 }
729 }
730 EnrichSource::Tool { tool, args: _ } => {
731 data.metadata
733 .warnings
734 .push(format!("Tool enrichment not yet implemented: {}", tool));
735 return Ok(data);
736 }
737 };
738
739 if let Some(value) = result {
740 if let Some(context_path) = &config.store_in_context {
741 data.context.insert(context_path.clone(), value);
742 }
743 }
744
745 Ok(data)
746 }
747
748 async fn execute_conditional(
749 &self,
750 config: &ConditionalConfig,
751 data: ProcessData,
752 ) -> Result<ProcessData> {
753 let condition_met = self.evaluate_condition(&config.condition, &data);
754
755 let stages = if condition_met {
756 &config.then_stages
757 } else {
758 &config.else_stages
759 };
760
761 let mut result = data;
762 for stage in stages {
763 result = self.execute_stage(stage, result).await?;
764 if result.metadata.rejected {
765 break;
766 }
767 }
768
769 Ok(result)
770 }
771
772 fn evaluate_condition(&self, condition: &Option<ConditionExpr>, data: &ProcessData) -> bool {
773 match condition {
774 None => true,
775 Some(expr) => self.evaluate_condition_expr(expr, data),
776 }
777 }
778
779 fn evaluate_condition_expr(&self, condition: &ConditionExpr, data: &ProcessData) -> bool {
780 match condition {
781 ConditionExpr::All { all } => all.iter().all(|c| self.evaluate_condition_expr(c, data)),
782 ConditionExpr::Any { any } => any.iter().any(|c| self.evaluate_condition_expr(c, data)),
783 ConditionExpr::Simple(map) => self.evaluate_simple_condition(map, data),
784 }
785 }
786
787 fn evaluate_simple_condition(
788 &self,
789 map: &std::collections::HashMap<String, serde_json::Value>,
790 data: &ProcessData,
791 ) -> bool {
792 for (path, expected) in map {
793 let actual = self.get_nested_value(&data.context, path);
794
795 if let Some(obj) = expected.as_object() {
797 if let Some(exists_val) = obj.get("exists") {
798 let should_exist = exists_val.as_bool().unwrap_or(true);
799 let does_exist =
800 actual.is_some() && !matches!(actual, Some(serde_json::Value::Null));
801 if does_exist != should_exist {
802 return false;
803 }
804 continue;
805 }
806 }
807
808 match (actual, expected) {
810 (Some(a), e) if a == e => continue,
811 (None, serde_json::Value::Null) => continue,
812 _ => return false,
813 }
814 }
815 true
816 }
817
818 fn get_nested_value<'a>(
819 &self,
820 context: &'a std::collections::HashMap<String, serde_json::Value>,
821 path: &str,
822 ) -> Option<&'a serde_json::Value> {
823 let parts: Vec<&str> = path.split('.').collect();
824 if parts.is_empty() {
825 return None;
826 }
827
828 let mut current: Option<&serde_json::Value> = context.get(parts[0]);
829
830 for part in &parts[1..] {
831 current = current.and_then(|v| {
832 if let serde_json::Value::Object(obj) = v {
833 obj.get(*part)
834 } else {
835 None
836 }
837 });
838 }
839
840 current
841 }
842
843 fn get_llm(&self, alias: Option<&str>) -> Result<Arc<dyn LLMProvider>> {
844 let registry = self
845 .llm_registry
846 .as_ref()
847 .ok_or_else(|| AgentError::Config("LLM registry not configured for process".into()))?;
848
849 match alias {
850 Some(name) => registry
851 .get(name)
852 .map_err(|e| AgentError::LLM(e.to_string())),
853 None => registry
854 .router()
855 .or_else(|_| registry.default())
856 .map_err(|e| AgentError::LLM(e.to_string())),
857 }
858 }
859}
860
861fn purpose_hint_for_stages(stages: &[ProcessStage]) -> ProcessPurposeHint {
862 for stage in stages {
863 let hint = process_purpose_hint_for_stage(stage);
864 if hint != ProcessPurposeHint::Other {
865 return hint;
866 }
867 }
868 ProcessPurposeHint::Other
869}
870
871fn process_purpose_hint_for_stage(stage: &ProcessStage) -> ProcessPurposeHint {
872 match stage {
873 ProcessStage::Detect(_) => ProcessPurposeHint::Detect,
874 ProcessStage::Extract(_) => ProcessPurposeHint::Extract,
875 ProcessStage::Validate(_) => ProcessPurposeHint::Validate,
876 ProcessStage::Sanitize(_) | ProcessStage::Transform(_) => ProcessPurposeHint::Transform,
877 ProcessStage::Conditional(config) => {
878 let then_hint = purpose_hint_for_stages(&config.config.then_stages);
879 if then_hint != ProcessPurposeHint::Other {
880 then_hint
881 } else {
882 purpose_hint_for_stages(&config.config.else_stages)
883 }
884 }
885 _ => ProcessPurposeHint::Other,
886 }
887}
888
889fn extract_json(response: &str) -> String {
890 let trimmed = response.trim();
891
892 if trimmed.starts_with("```json") {
893 if let Some(end) = trimmed[7..].find("```") {
894 return trimmed[7..7 + end].trim().to_string();
895 }
896 }
897
898 if trimmed.starts_with("```") {
899 if let Some(end) = trimmed[3..].find("```") {
900 return trimmed[3..3 + end].trim().to_string();
901 }
902 }
903
904 if let Some(start) = trimmed.find('{') {
905 if let Some(end) = trimmed.rfind('}') {
906 return trimmed[start..=end].to_string();
907 }
908 }
909
910 trimmed.to_string()
911}
912
913#[cfg(test)]
914mod tests {
915 use super::*;
916
917 #[test]
918 fn test_process_data_new() {
919 let data = ProcessData::new("test content");
920 assert_eq!(data.content, "test content");
921 assert_eq!(data.original, "test content");
922 assert!(data.context.is_empty());
923 }
924
925 #[test]
926 fn test_process_data_with_context() {
927 let data = ProcessData::new("test").with_context("key", serde_json::json!("value"));
928 assert!(data.context.contains_key("key"));
929 }
930
931 #[tokio::test]
932 async fn test_normalize_trim() {
933 let processor = ProcessProcessor::default();
934 let config = NormalizeConfig {
935 trim: true,
936 ..Default::default()
937 };
938 let data = ProcessData::new(" hello world ");
939 let result = processor.execute_normalize(&config, data).await.unwrap();
940 assert_eq!(result.content, "hello world");
941 }
942
943 #[tokio::test]
944 async fn test_normalize_collapse_whitespace() {
945 let processor = ProcessProcessor::default();
946 let config = NormalizeConfig {
947 trim: true,
948 collapse_whitespace: true,
949 ..Default::default()
950 };
951 let data = ProcessData::new("hello world\n\ntest");
952 let result = processor.execute_normalize(&config, data).await.unwrap();
953 assert_eq!(result.content, "hello world test");
954 }
955
956 #[tokio::test]
957 async fn test_normalize_lowercase() {
958 let processor = ProcessProcessor::default();
959 let config = NormalizeConfig {
960 lowercase: true,
961 ..Default::default()
962 };
963 let data = ProcessData::new("Hello World");
964 let result = processor.execute_normalize(&config, data).await.unwrap();
965 assert_eq!(result.content, "hello world");
966 }
967
968 #[tokio::test]
969 async fn test_validate_min_length_reject() {
970 let processor = ProcessProcessor::default();
971 let config = ValidateConfig {
972 rules: vec![ValidationRule::MinLength {
973 min_length: 10,
974 on_fail: ValidationAction {
975 action: ValidationActionType::Reject,
976 message: None,
977 },
978 }],
979 ..Default::default()
980 };
981 let data = ProcessData::new("short");
982 let result = processor.execute_validate(&config, data).await.unwrap();
983 assert!(result.metadata.rejected);
984 }
985
986 #[tokio::test]
987 async fn test_validate_max_length_truncate() {
988 let processor = ProcessProcessor::default();
989 let config = ValidateConfig {
990 rules: vec![ValidationRule::MaxLength {
991 max_length: 5,
992 on_fail: ValidationAction {
993 action: ValidationActionType::Truncate,
994 message: None,
995 },
996 }],
997 ..Default::default()
998 };
999 let data = ProcessData::new("hello world");
1000 let result = processor.execute_validate(&config, data).await.unwrap();
1001 assert_eq!(result.content, "hello");
1002 assert!(!result.metadata.rejected);
1003 }
1004
1005 #[tokio::test]
1006 async fn test_format_simple_template() {
1007 let processor = ProcessProcessor::default();
1008 let config = FormatConfig {
1009 template: Some("Response: {{ response }}".to_string()),
1010 ..Default::default()
1011 };
1012 let data = ProcessData::new("Hello!");
1013 let result = processor.execute_format(&config, data).await.unwrap();
1014 assert_eq!(result.content, "Response: Hello!");
1015 }
1016
1017 #[test]
1018 fn test_extract_json() {
1019 assert_eq!(extract_json(r#"{"key": 1}"#), r#"{"key": 1}"#);
1020 assert_eq!(extract_json("```json\n{\"key\": 1}\n```"), r#"{"key": 1}"#);
1021 assert_eq!(extract_json("Some text {\"key\": 1} more"), r#"{"key": 1}"#);
1022 }
1023
1024 #[test]
1025 fn test_evaluate_condition_empty() {
1026 let processor = ProcessProcessor::default();
1027 let data = ProcessData::new("test");
1028 assert!(processor.evaluate_condition(&None, &data));
1029 }
1030
1031 #[test]
1032 fn test_evaluate_condition_simple_exists_true() {
1033 let processor = ProcessProcessor::default();
1034 let mut data = ProcessData::new("test");
1035 data.context.insert(
1036 "session".to_string(),
1037 serde_json::json!({ "user_name": "Alice" }),
1038 );
1039
1040 let mut map = std::collections::HashMap::new();
1041 map.insert(
1042 "session.user_name".to_string(),
1043 serde_json::json!({ "exists": true }),
1044 );
1045 let condition = ConditionExpr::Simple(map);
1046
1047 assert!(processor.evaluate_condition_expr(&condition, &data));
1048 }
1049
1050 #[test]
1051 fn test_evaluate_condition_simple_exists_false() {
1052 let processor = ProcessProcessor::default();
1053 let data = ProcessData::new("test");
1054
1055 let mut map = std::collections::HashMap::new();
1056 map.insert(
1057 "session.user_name".to_string(),
1058 serde_json::json!({ "exists": false }),
1059 );
1060 let condition = ConditionExpr::Simple(map);
1061
1062 assert!(processor.evaluate_condition_expr(&condition, &data));
1063 }
1064
1065 #[test]
1066 fn test_evaluate_condition_all() {
1067 let processor = ProcessProcessor::default();
1068 let mut data = ProcessData::new("test");
1069 data.context.insert(
1070 "session".to_string(),
1071 serde_json::json!({ "user_name": "Alice", "language": "en" }),
1072 );
1073
1074 let mut map1 = std::collections::HashMap::new();
1075 map1.insert(
1076 "session.user_name".to_string(),
1077 serde_json::json!({ "exists": true }),
1078 );
1079 let mut map2 = std::collections::HashMap::new();
1080 map2.insert(
1081 "session.language".to_string(),
1082 serde_json::json!({ "exists": true }),
1083 );
1084
1085 let condition = ConditionExpr::All {
1086 all: vec![ConditionExpr::Simple(map1), ConditionExpr::Simple(map2)],
1087 };
1088
1089 assert!(processor.evaluate_condition_expr(&condition, &data));
1090 }
1091
1092 #[test]
1093 fn test_evaluate_condition_any() {
1094 let processor = ProcessProcessor::default();
1095 let mut data = ProcessData::new("test");
1096 data.context.insert(
1097 "session".to_string(),
1098 serde_json::json!({ "tier": "premium" }),
1099 );
1100
1101 let mut map1 = std::collections::HashMap::new();
1102 map1.insert("session.tier".to_string(), serde_json::json!("premium"));
1103 let mut map2 = std::collections::HashMap::new();
1104 map2.insert("session.tier".to_string(), serde_json::json!("enterprise"));
1105
1106 let condition = ConditionExpr::Any {
1107 any: vec![ConditionExpr::Simple(map1), ConditionExpr::Simple(map2)],
1108 };
1109
1110 assert!(processor.evaluate_condition_expr(&condition, &data));
1111 }
1112
1113 #[test]
1114 fn test_evaluate_condition_value_match() {
1115 let processor = ProcessProcessor::default();
1116 let mut data = ProcessData::new("test");
1117 data.context.insert(
1118 "input".to_string(),
1119 serde_json::json!({ "sentiment": "negative" }),
1120 );
1121
1122 let mut map = std::collections::HashMap::new();
1123 map.insert("input.sentiment".to_string(), serde_json::json!("negative"));
1124 let condition = ConditionExpr::Simple(map);
1125
1126 assert!(processor.evaluate_condition_expr(&condition, &data));
1127 }
1128
1129 #[test]
1130 fn test_get_nested_value() {
1131 let processor = ProcessProcessor::default();
1132 let mut context = std::collections::HashMap::new();
1133 context.insert(
1134 "session".to_string(),
1135 serde_json::json!({ "user": { "name": "Alice" } }),
1136 );
1137
1138 let result = processor.get_nested_value(&context, "session.user.name");
1139 assert_eq!(result, Some(&serde_json::json!("Alice")));
1140
1141 let result = processor.get_nested_value(&context, "session.nonexistent");
1142 assert!(result.is_none());
1143 }
1144
1145 fn create_mock_registry(response: &str) -> Arc<ai_agents_llm::LLMRegistry> {
1149 use ai_agents_llm::mock::MockLLMProvider;
1150 let mut mock = MockLLMProvider::new("test");
1151 mock.set_response(response);
1152 let mut registry = ai_agents_llm::LLMRegistry::new();
1153 registry.register("default", std::sync::Arc::new(mock));
1154 registry.set_default("default");
1155 std::sync::Arc::new(registry)
1156 }
1157
1158 fn create_mock_registry_multi(responses: Vec<&str>) -> Arc<ai_agents_llm::LLMRegistry> {
1159 use ai_agents_llm::mock::MockLLMProvider;
1160 let mut mock = MockLLMProvider::new("test");
1161 mock.set_responses(responses.into_iter().map(String::from).collect(), true);
1162 let mut registry = ai_agents_llm::LLMRegistry::new();
1163 registry.register("default", std::sync::Arc::new(mock));
1164 registry.set_default("default");
1165 std::sync::Arc::new(registry)
1166 }
1167
1168 #[tokio::test]
1169 async fn test_detect_stage_language_sentiment() {
1170 let registry = create_mock_registry(
1171 r#"{"language": "ko", "sentiment": "positive", "intent": "greeting"}"#,
1172 );
1173 let config = ProcessConfig {
1174 input: vec![ProcessStage::Detect(DetectStage {
1175 id: Some("detect_test".to_string()),
1176 condition: None,
1177 config: DetectConfig {
1178 llm: None,
1179 detect: vec![DetectionType::Language, DetectionType::Sentiment],
1180 intents: vec![IntentDefinition {
1181 id: "greeting".to_string(),
1182 description: "User says hello".to_string(),
1183 }],
1184 store_in_context: {
1185 let mut m = std::collections::HashMap::new();
1186 m.insert("language".to_string(), "input.language".to_string());
1187 m.insert("sentiment".to_string(), "input.sentiment".to_string());
1188 m
1189 },
1190 },
1191 })],
1192 ..Default::default()
1193 };
1194 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1195 let result = processor.process_input("안녕하세요!").await.unwrap();
1196
1197 assert_eq!(
1198 result.context.get("input.language"),
1199 Some(&serde_json::json!("ko"))
1200 );
1201 assert_eq!(
1202 result.context.get("input.sentiment"),
1203 Some(&serde_json::json!("positive"))
1204 );
1205 assert!(
1206 result
1207 .metadata
1208 .stages_executed
1209 .contains(&"detect_test".to_string())
1210 );
1211 }
1212
1213 #[tokio::test]
1214 async fn test_extract_stage_entities() {
1215 let registry = create_mock_registry(r#"{"order_number": "ORD-12345", "urgency": "high"}"#);
1216 let config = ProcessConfig {
1217 input: vec![ProcessStage::Extract(ExtractStage {
1218 id: Some("extract_test".to_string()),
1219 condition: None,
1220 config: ExtractConfig {
1221 llm: None,
1222 schema: {
1223 let mut m = std::collections::HashMap::new();
1224 m.insert(
1225 "order_number".to_string(),
1226 FieldSchema {
1227 field_type: FieldType::String,
1228 description: Some("Order number".to_string()),
1229 required: true,
1230 values: vec![],
1231 },
1232 );
1233 m.insert(
1234 "urgency".to_string(),
1235 FieldSchema {
1236 field_type: FieldType::Enum,
1237 description: Some("Urgency level".to_string()),
1238 required: false,
1239 values: vec![
1240 "low".to_string(),
1241 "medium".to_string(),
1242 "high".to_string(),
1243 ],
1244 },
1245 );
1246 m
1247 },
1248 store_in_context: Some("extracted".to_string()),
1249 },
1250 })],
1251 ..Default::default()
1252 };
1253 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1254 let result = processor
1255 .process_input("My order ORD-12345 is urgent!")
1256 .await
1257 .unwrap();
1258
1259 let extracted = result.context.get("extracted").unwrap();
1260 assert_eq!(extracted["order_number"], "ORD-12345");
1261 assert_eq!(extracted["urgency"], "high");
1262 }
1263
1264 #[tokio::test]
1265 async fn test_sanitize_stage_pii_masking() {
1266 let registry = create_mock_registry("Call me at ****-****-**** or email at ****@****.com");
1267 let config = ProcessConfig {
1268 input: vec![ProcessStage::Sanitize(SanitizeStage {
1269 id: Some("sanitize_test".to_string()),
1270 condition: None,
1271 config: SanitizeConfig {
1272 llm: None,
1273 pii: Some(PIISanitizeConfig {
1274 action: PIIAction::Mask,
1275 types: vec![PIIType::Phone, PIIType::Email],
1276 mask_char: "*".to_string(),
1277 }),
1278 harmful: None,
1279 remove: vec![],
1280 },
1281 })],
1282 ..Default::default()
1283 };
1284 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1285 let result = processor
1286 .process_input("Call me at 010-1234-5678 or email at user@example.com")
1287 .await
1288 .unwrap();
1289
1290 assert!(result.content.contains("****"));
1292 assert!(!result.content.contains("010-1234-5678"));
1293 assert!(!result.content.contains("user@example.com"));
1294 }
1295
1296 #[tokio::test]
1297 async fn test_transform_stage_tone_adjustment() {
1298 let registry = create_mock_registry(
1299 "I understand your frustration. Let me help you resolve this issue right away.",
1300 );
1301 let config = ProcessConfig {
1302 output: vec![ProcessStage::Transform(TransformStage {
1303 id: Some("tone_test".to_string()),
1304 condition: None,
1305 config: TransformConfig {
1306 llm: None,
1307 prompt: Some("Rewrite to be more empathetic.".to_string()),
1308 max_output_tokens: None,
1309 },
1310 })],
1311 ..Default::default()
1312 };
1313 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1314
1315 let input_context = std::collections::HashMap::new();
1316 let result = processor
1317 .process_output("Your request is being processed.", &input_context)
1318 .await
1319 .unwrap();
1320
1321 assert!(result.content.contains("understand"));
1322 }
1323
1324 #[tokio::test]
1325 async fn test_validate_stage_llm_criteria() {
1326 let registry = create_mock_registry(
1327 r#"{"passes": false, "score": 0.3, "issues": ["Response is too vague"]}"#,
1328 );
1329 let config = ProcessConfig {
1330 output: vec![ProcessStage::Validate(ValidateStage {
1331 id: Some("quality_test".to_string()),
1332 condition: None,
1333 config: ValidateConfig {
1334 rules: vec![],
1335 llm: None,
1336 criteria: vec!["Response is specific and actionable".to_string()],
1337 threshold: 0.7,
1338 on_fail: ValidationFailAction {
1339 action: ValidationFailType::Warn,
1340 ..Default::default()
1341 },
1342 },
1343 })],
1344 ..Default::default()
1345 };
1346 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1347
1348 let input_context = std::collections::HashMap::new();
1349 let result = processor
1350 .process_output("It depends.", &input_context)
1351 .await
1352 .unwrap();
1353
1354 assert!(
1356 result.metadata.warnings.iter().any(|w| w.contains("vague")),
1357 "Expected warning about vague response, got: {:?}",
1358 result.metadata.warnings
1359 );
1360 }
1361
1362 #[tokio::test]
1363 async fn test_validate_stage_llm_criteria_reject() {
1364 let registry = create_mock_registry(
1365 r#"{"passes": false, "score": 0.2, "issues": ["Contains harmful content"]}"#,
1366 );
1367 let config = ProcessConfig {
1368 output: vec![ProcessStage::Validate(ValidateStage {
1369 id: Some("reject_test".to_string()),
1370 condition: None,
1371 config: ValidateConfig {
1372 rules: vec![],
1373 llm: None,
1374 criteria: vec!["Response is safe".to_string()],
1375 threshold: 0.7,
1376 on_fail: ValidationFailAction {
1377 action: ValidationFailType::Reject,
1378 ..Default::default()
1379 },
1380 },
1381 })],
1382 ..Default::default()
1383 };
1384 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1385
1386 let input_context = std::collections::HashMap::new();
1387 let result = processor
1388 .process_output("Dangerous content here.", &input_context)
1389 .await
1390 .unwrap();
1391
1392 assert!(result.metadata.rejected);
1393 assert!(
1394 result
1395 .metadata
1396 .rejection_reason
1397 .as_ref()
1398 .unwrap()
1399 .contains("harmful")
1400 );
1401 }
1402
1403 #[tokio::test]
1404 async fn test_full_input_pipeline_chain() {
1405 let registry = create_mock_registry_multi(vec![
1407 r#"{"language": "en", "sentiment": "neutral"}"#,
1409 r#"{"user_name": "Alice", "topic": "billing"}"#,
1411 ]);
1412 let config = ProcessConfig {
1413 input: vec![
1414 ProcessStage::Normalize(NormalizeStage {
1415 id: Some("norm".to_string()),
1416 condition: None,
1417 config: NormalizeConfig {
1418 trim: true,
1419 collapse_whitespace: true,
1420 ..Default::default()
1421 },
1422 }),
1423 ProcessStage::Detect(DetectStage {
1424 id: Some("detect".to_string()),
1425 condition: None,
1426 config: DetectConfig {
1427 llm: None,
1428 detect: vec![DetectionType::Language, DetectionType::Sentiment],
1429 intents: vec![],
1430 store_in_context: {
1431 let mut m = std::collections::HashMap::new();
1432 m.insert("language".to_string(), "input.language".to_string());
1433 m
1434 },
1435 },
1436 }),
1437 ProcessStage::Extract(ExtractStage {
1438 id: Some("extract".to_string()),
1439 condition: None,
1440 config: ExtractConfig {
1441 llm: None,
1442 schema: {
1443 let mut m = std::collections::HashMap::new();
1444 m.insert(
1445 "user_name".to_string(),
1446 FieldSchema {
1447 field_type: FieldType::String,
1448 description: Some("User name".to_string()),
1449 ..Default::default()
1450 },
1451 );
1452 m
1453 },
1454 store_in_context: Some("entities".to_string()),
1455 },
1456 }),
1457 ],
1458 ..Default::default()
1459 };
1460 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1461 let result = processor
1462 .process_input(" Hi, I'm Alice and I have a billing question ")
1463 .await
1464 .unwrap();
1465
1466 assert_eq!(
1468 result.content,
1469 "Hi, I'm Alice and I have a billing question"
1470 );
1471
1472 assert_eq!(
1474 result.context.get("input.language"),
1475 Some(&serde_json::json!("en"))
1476 );
1477
1478 let entities = result.context.get("entities").unwrap();
1480 assert_eq!(entities["user_name"], "Alice");
1481
1482 assert_eq!(
1484 result.metadata.stages_executed,
1485 vec!["norm", "detect", "extract"]
1486 );
1487 }
1488
1489 #[tokio::test]
1490 async fn test_conditional_stage_skips_on_false() {
1491 let registry = create_mock_registry(r#"{"language": "en"}"#);
1492 let config = ProcessConfig {
1493 input: vec![ProcessStage::Detect(DetectStage {
1494 id: Some("should_skip".to_string()),
1495 condition: Some(ConditionExpr::Simple({
1496 let mut map = std::collections::HashMap::new();
1497 map.insert("needs_detection".to_string(), serde_json::json!(true));
1498 map
1499 })),
1500 config: DetectConfig {
1501 llm: None,
1502 detect: vec![DetectionType::Language],
1503 ..Default::default()
1504 },
1505 })],
1506 ..Default::default()
1507 };
1508 let processor = ProcessProcessor::new(config).with_llm_registry(registry);
1509 let result = processor.process_input("Hello").await.unwrap();
1510
1511 assert!(
1513 !result
1514 .metadata
1515 .stages_executed
1516 .contains(&"should_skip".to_string()),
1517 "Stage should have been skipped"
1518 );
1519 }
1520
1521 #[tokio::test]
1522 async fn test_stage_skipped_when_condition_false() {
1523 let config = ProcessConfig {
1524 input: vec![ProcessStage::Extract(ExtractStage {
1525 id: Some("skip_me".to_string()),
1526 condition: Some(ConditionExpr::Simple({
1527 let mut map = std::collections::HashMap::new();
1528 map.insert(
1529 "session.user".to_string(),
1530 serde_json::json!({ "exists": false }),
1531 );
1532 map
1533 })),
1534 config: ExtractConfig::default(),
1535 })],
1536 settings: ProcessSettings {
1537 debug: ProcessDebugConfig {
1538 log_stages: true,
1539 ..Default::default()
1540 },
1541 ..Default::default()
1542 },
1543 ..Default::default()
1544 };
1545 let processor = ProcessProcessor::new(config);
1546
1547 let mut data = ProcessData::new("test");
1548 data.context.insert(
1549 "session".to_string(),
1550 serde_json::json!({ "user": "Alice" }),
1551 );
1552
1553 let result = processor.process_input("test").await.unwrap();
1554 assert!(
1555 !result
1556 .metadata
1557 .stages_executed
1558 .contains(&"skip_me".to_string())
1559 );
1560 }
1561}