Skip to main content

ai_agents_process/
config.rs

1//! Process configuration types for input/output transformation
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize, Default)]
7pub struct ProcessConfig {
8    #[serde(default)]
9    pub input: Vec<ProcessStage>,
10    #[serde(default)]
11    pub output: Vec<ProcessStage>,
12    #[serde(default)]
13    pub settings: ProcessSettings,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(tag = "type", rename_all = "snake_case")]
18pub enum ProcessStage {
19    Normalize(NormalizeStage),
20    Detect(DetectStage),
21    Extract(ExtractStage),
22    Sanitize(SanitizeStage),
23    Transform(TransformStage),
24    Validate(ValidateStage),
25    Format(FormatStage),
26    Enrich(EnrichStage),
27    Conditional(ConditionalStage),
28}
29
30impl ProcessStage {
31    pub fn condition(&self) -> Option<&ConditionExpr> {
32        match self {
33            ProcessStage::Normalize(s) => s.condition.as_ref(),
34            ProcessStage::Detect(s) => s.condition.as_ref(),
35            ProcessStage::Extract(s) => s.condition.as_ref(),
36            ProcessStage::Sanitize(s) => s.condition.as_ref(),
37            ProcessStage::Transform(s) => s.condition.as_ref(),
38            ProcessStage::Validate(s) => s.condition.as_ref(),
39            ProcessStage::Format(s) => s.condition.as_ref(),
40            ProcessStage::Enrich(s) => s.condition.as_ref(),
41            ProcessStage::Conditional(_) => None,
42        }
43    }
44
45    pub fn id(&self) -> Option<&str> {
46        match self {
47            ProcessStage::Normalize(s) => s.id.as_deref(),
48            ProcessStage::Detect(s) => s.id.as_deref(),
49            ProcessStage::Extract(s) => s.id.as_deref(),
50            ProcessStage::Sanitize(s) => s.id.as_deref(),
51            ProcessStage::Transform(s) => s.id.as_deref(),
52            ProcessStage::Validate(s) => s.id.as_deref(),
53            ProcessStage::Format(s) => s.id.as_deref(),
54            ProcessStage::Enrich(s) => s.id.as_deref(),
55            ProcessStage::Conditional(s) => s.id.as_deref(),
56        }
57    }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, Default)]
61pub struct NormalizeStage {
62    #[serde(default)]
63    pub id: Option<String>,
64    #[serde(default)]
65    pub condition: Option<ConditionExpr>,
66    #[serde(default)]
67    pub config: NormalizeConfig,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct NormalizeConfig {
72    #[serde(default = "default_true")]
73    pub trim: bool,
74    #[serde(default)]
75    pub unicode: Option<UnicodeNormalization>,
76    #[serde(default)]
77    pub collapse_whitespace: bool,
78    #[serde(default)]
79    pub lowercase: bool,
80}
81
82impl Default for NormalizeConfig {
83    fn default() -> Self {
84        Self {
85            trim: true,
86            unicode: None,
87            collapse_whitespace: false,
88            lowercase: false,
89        }
90    }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(rename_all = "lowercase")]
95pub enum UnicodeNormalization {
96    Nfc,
97    Nfd,
98    Nfkc,
99    Nfkd,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize, Default)]
103pub struct DetectStage {
104    #[serde(default)]
105    pub id: Option<String>,
106    #[serde(default)]
107    pub condition: Option<ConditionExpr>,
108    #[serde(default)]
109    pub config: DetectConfig,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize, Default)]
113pub struct DetectConfig {
114    #[serde(default)]
115    pub llm: Option<String>,
116    #[serde(default)]
117    pub detect: Vec<DetectionType>,
118    #[serde(default)]
119    pub intents: Vec<IntentDefinition>,
120    #[serde(default)]
121    pub store_in_context: HashMap<String, String>,
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
125#[serde(rename_all = "snake_case")]
126pub enum DetectionType {
127    Language,
128    Sentiment,
129    Intent,
130    Topic,
131    Formality,
132    Urgency,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct IntentDefinition {
137    pub id: String,
138    pub description: String,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize, Default)]
142pub struct ExtractStage {
143    #[serde(default)]
144    pub id: Option<String>,
145    #[serde(default)]
146    pub condition: Option<ConditionExpr>,
147    #[serde(default)]
148    pub config: ExtractConfig,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize, Default)]
152pub struct ExtractConfig {
153    #[serde(default)]
154    pub llm: Option<String>,
155    #[serde(default)]
156    pub schema: HashMap<String, FieldSchema>,
157    #[serde(default)]
158    pub store_in_context: Option<String>,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize, Default)]
162pub struct FieldSchema {
163    #[serde(rename = "type", default)]
164    pub field_type: FieldType,
165    #[serde(default)]
166    pub description: Option<String>,
167    #[serde(default)]
168    pub required: bool,
169    #[serde(default)]
170    pub values: Vec<String>,
171}
172
173#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
174#[serde(rename_all = "snake_case")]
175pub enum FieldType {
176    #[default]
177    String,
178    Number,
179    Integer,
180    Boolean,
181    Date,
182    Enum,
183    Array,
184    Object,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize, Default)]
188pub struct SanitizeStage {
189    #[serde(default)]
190    pub id: Option<String>,
191    #[serde(default)]
192    pub condition: Option<ConditionExpr>,
193    #[serde(default)]
194    pub config: SanitizeConfig,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize, Default)]
198pub struct SanitizeConfig {
199    #[serde(default)]
200    pub llm: Option<String>,
201    #[serde(default)]
202    pub pii: Option<PIISanitizeConfig>,
203    #[serde(default)]
204    pub harmful: Option<HarmfulContentConfig>,
205    #[serde(default)]
206    pub remove: Vec<String>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct PIISanitizeConfig {
211    #[serde(default)]
212    pub action: PIIAction,
213    #[serde(default)]
214    pub types: Vec<PIIType>,
215    #[serde(default = "default_mask_char")]
216    pub mask_char: String,
217}
218
219impl Default for PIISanitizeConfig {
220    fn default() -> Self {
221        Self {
222            action: PIIAction::Mask,
223            types: Vec::new(),
224            mask_char: default_mask_char(),
225        }
226    }
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
230#[serde(rename_all = "snake_case")]
231pub enum PIIAction {
232    #[default]
233    Mask,
234    Remove,
235    Flag,
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
239#[serde(rename_all = "snake_case")]
240pub enum PIIType {
241    Email,
242    Phone,
243    CreditCard,
244    Ssn,
245    IpAddress,
246    Name,
247    Address,
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize, Default)]
251pub struct HarmfulContentConfig {
252    #[serde(default)]
253    pub detect: Vec<HarmfulContentType>,
254    #[serde(default)]
255    pub action: HarmfulAction,
256}
257
258#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
259#[serde(rename_all = "snake_case")]
260pub enum HarmfulContentType {
261    HateSpeech,
262    Violence,
263    SexualContent,
264    Harassment,
265    SelfHarm,
266    IllegalActivity,
267}
268
269#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
270#[serde(rename_all = "snake_case")]
271pub enum HarmfulAction {
272    #[default]
273    Flag,
274    Block,
275    Remove,
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize, Default)]
279pub struct TransformStage {
280    #[serde(default)]
281    pub id: Option<String>,
282    #[serde(default)]
283    pub condition: Option<ConditionExpr>,
284    #[serde(default)]
285    pub config: TransformConfig,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize, Default)]
289pub struct TransformConfig {
290    #[serde(default)]
291    pub llm: Option<String>,
292    #[serde(default)]
293    pub prompt: Option<String>,
294    #[serde(default)]
295    pub max_output_tokens: Option<u32>,
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize, Default)]
299pub struct ValidateStage {
300    #[serde(default)]
301    pub id: Option<String>,
302    #[serde(default)]
303    pub condition: Option<ConditionExpr>,
304    #[serde(default)]
305    pub config: ValidateConfig,
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize, Default)]
309pub struct ValidateConfig {
310    #[serde(default)]
311    pub rules: Vec<ValidationRule>,
312    #[serde(default)]
313    pub llm: Option<String>,
314    #[serde(default)]
315    pub criteria: Vec<String>,
316    #[serde(default = "default_threshold")]
317    pub threshold: f32,
318    #[serde(default)]
319    pub on_fail: ValidationFailAction,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323#[serde(untagged)]
324pub enum ValidationRule {
325    MinLength {
326        min_length: usize,
327        #[serde(default)]
328        on_fail: ValidationAction,
329    },
330    MaxLength {
331        max_length: usize,
332        #[serde(default)]
333        on_fail: ValidationAction,
334    },
335    Pattern {
336        pattern: String,
337        #[serde(default)]
338        on_fail: ValidationAction,
339    },
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize, Default)]
343pub struct ValidationAction {
344    #[serde(default)]
345    pub action: ValidationActionType,
346    #[serde(default)]
347    pub message: Option<HashMap<String, String>>,
348}
349
350#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
351#[serde(rename_all = "snake_case")]
352pub enum ValidationActionType {
353    #[default]
354    Reject,
355    Truncate,
356    Warn,
357}
358
359#[derive(Debug, Clone, Serialize, Deserialize, Default)]
360pub struct ValidationFailAction {
361    #[serde(default)]
362    pub action: ValidationFailType,
363    #[serde(default)]
364    pub max_retries: Option<u32>,
365    #[serde(default)]
366    pub feedback_to_agent: bool,
367}
368
369#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
370#[serde(rename_all = "snake_case")]
371pub enum ValidationFailType {
372    #[default]
373    Reject,
374    Regenerate,
375    Warn,
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize, Default)]
379pub struct FormatStage {
380    #[serde(default)]
381    pub id: Option<String>,
382    #[serde(default)]
383    pub condition: Option<ConditionExpr>,
384    #[serde(default)]
385    pub config: FormatConfig,
386}
387
388#[derive(Debug, Clone, Serialize, Deserialize, Default)]
389pub struct FormatConfig {
390    #[serde(default)]
391    pub template: Option<String>,
392    #[serde(default)]
393    pub channels: HashMap<String, ChannelFormat>,
394    #[serde(default)]
395    pub channel: Option<String>,
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize, Default)]
399pub struct ChannelFormat {
400    #[serde(default)]
401    pub template: Option<String>,
402    #[serde(default)]
403    pub format: Option<OutputFormat>,
404    #[serde(default)]
405    pub max_length: Option<usize>,
406    #[serde(default)]
407    pub markdown: bool,
408}
409
410#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
411#[serde(rename_all = "snake_case")]
412pub enum OutputFormat {
413    #[default]
414    Text,
415    Html,
416    Json,
417    Markdown,
418}
419
420#[derive(Debug, Clone, Serialize, Deserialize, Default)]
421pub struct EnrichStage {
422    #[serde(default)]
423    pub id: Option<String>,
424    #[serde(default)]
425    pub condition: Option<ConditionExpr>,
426    #[serde(default)]
427    pub config: EnrichConfig,
428}
429
430#[derive(Debug, Clone, Serialize, Deserialize, Default)]
431pub struct EnrichConfig {
432    #[serde(default)]
433    pub source: EnrichSource,
434    #[serde(default)]
435    pub store_in_context: Option<String>,
436    #[serde(default)]
437    pub on_error: EnrichErrorAction,
438}
439
440#[derive(Debug, Clone, Serialize, Deserialize, Default)]
441#[serde(tag = "source", rename_all = "snake_case")]
442pub enum EnrichSource {
443    #[default]
444    None,
445    Api {
446        url: String,
447        #[serde(default = "default_method")]
448        method: String,
449        #[serde(default)]
450        headers: HashMap<String, String>,
451        #[serde(default)]
452        body: Option<serde_json::Value>,
453        #[serde(default)]
454        extract: HashMap<String, String>,
455    },
456    File {
457        path: String,
458        #[serde(default)]
459        format: Option<String>,
460    },
461    Tool {
462        tool: String,
463        #[serde(default)]
464        args: serde_json::Value,
465    },
466}
467
468#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
469#[serde(rename_all = "snake_case")]
470pub enum EnrichErrorAction {
471    #[default]
472    Continue,
473    Stop,
474    Warn,
475}
476
477#[derive(Debug, Clone, Serialize, Deserialize, Default)]
478pub struct ConditionalStage {
479    #[serde(default)]
480    pub id: Option<String>,
481    #[serde(default)]
482    pub config: ConditionalConfig,
483}
484
485#[derive(Debug, Clone, Serialize, Deserialize, Default)]
486pub struct ConditionalConfig {
487    #[serde(default)]
488    pub condition: Option<ConditionExpr>,
489    #[serde(default, rename = "then")]
490    pub then_stages: Vec<ProcessStage>,
491    #[serde(default, rename = "else")]
492    pub else_stages: Vec<ProcessStage>,
493}
494
495#[derive(Debug, Clone, Serialize, Deserialize)]
496#[serde(untagged)]
497pub enum ConditionExpr {
498    All { all: Vec<ConditionExpr> },
499    Any { any: Vec<ConditionExpr> },
500    Simple(HashMap<String, serde_json::Value>),
501}
502
503impl Default for ConditionExpr {
504    fn default() -> Self {
505        ConditionExpr::Simple(HashMap::new())
506    }
507}
508
509#[derive(Debug, Clone, Serialize, Deserialize)]
510pub struct ProcessSettings {
511    #[serde(default)]
512    pub on_stage_error: StageErrorConfig,
513    #[serde(default = "default_timeout")]
514    pub timeout_ms: u64,
515    #[serde(default)]
516    pub cache: ProcessCacheConfig,
517    #[serde(default)]
518    pub debug: ProcessDebugConfig,
519}
520
521impl Default for ProcessSettings {
522    fn default() -> Self {
523        Self {
524            on_stage_error: StageErrorConfig::default(),
525            timeout_ms: default_timeout(),
526            cache: ProcessCacheConfig::default(),
527            debug: ProcessDebugConfig::default(),
528        }
529    }
530}
531
532#[derive(Debug, Clone, Serialize, Deserialize, Default)]
533pub struct StageErrorConfig {
534    #[serde(default)]
535    pub default: StageErrorAction,
536    #[serde(default)]
537    pub retry: Option<StageRetryConfig>,
538}
539
540#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
541#[serde(rename_all = "snake_case")]
542pub enum StageErrorAction {
543    #[default]
544    Continue,
545    Stop,
546    Retry,
547}
548
549#[derive(Debug, Clone, Serialize, Deserialize)]
550pub struct StageRetryConfig {
551    #[serde(default = "default_retry")]
552    pub max_retries: u32,
553    #[serde(default = "default_backoff")]
554    pub backoff_ms: u64,
555}
556
557impl Default for StageRetryConfig {
558    fn default() -> Self {
559        Self {
560            max_retries: default_retry(),
561            backoff_ms: default_backoff(),
562        }
563    }
564}
565
566#[derive(Debug, Clone, Serialize, Deserialize, Default)]
567pub struct ProcessCacheConfig {
568    #[serde(default)]
569    pub enabled: bool,
570    #[serde(default)]
571    pub stages: Vec<String>,
572    #[serde(default = "default_cache_ttl")]
573    pub ttl_seconds: u64,
574}
575
576#[derive(Debug, Clone, Serialize, Deserialize, Default)]
577pub struct ProcessDebugConfig {
578    #[serde(default)]
579    pub log_stages: bool,
580    #[serde(default)]
581    pub include_timing: bool,
582}
583
584fn default_true() -> bool {
585    true
586}
587
588fn default_mask_char() -> String {
589    "*".to_string()
590}
591
592fn default_threshold() -> f32 {
593    0.7
594}
595
596fn default_timeout() -> u64 {
597    5000
598}
599
600fn default_retry() -> u32 {
601    2
602}
603
604fn default_backoff() -> u64 {
605    100
606}
607
608fn default_cache_ttl() -> u64 {
609    300
610}
611
612fn default_method() -> String {
613    "GET".to_string()
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619
620    #[test]
621    fn test_default_config() {
622        let config = ProcessConfig::default();
623        assert!(config.input.is_empty());
624        assert!(config.output.is_empty());
625        assert_eq!(config.settings.timeout_ms, 5000);
626    }
627
628    #[test]
629    fn test_yaml_parsing() {
630        let yaml = r#"
631input:
632  - type: normalize
633    id: basic_normalize
634    config:
635      trim: true
636      collapse_whitespace: true
637  - type: detect
638    id: detect_language
639    config:
640      llm: fast
641      detect:
642        - language
643        - sentiment
644      intents:
645        - id: greeting
646          description: "User is saying hello"
647  - type: extract
648    config:
649      llm: fast
650      schema:
651        user_name:
652          type: string
653          description: "User's name if mentioned"
654output:
655  - type: validate
656    config:
657      llm: fast
658      criteria:
659        - "Response is helpful"
660      threshold: 0.8
661settings:
662  timeout_ms: 3000
663"#;
664        let config: ProcessConfig = serde_yaml::from_str(yaml).unwrap();
665        assert_eq!(config.input.len(), 3);
666        assert_eq!(config.output.len(), 1);
667        assert_eq!(config.settings.timeout_ms, 3000);
668    }
669
670    #[test]
671    fn test_normalize_config() {
672        let config = NormalizeConfig::default();
673        assert!(config.trim);
674        assert!(!config.collapse_whitespace);
675    }
676
677    #[test]
678    fn test_field_type() {
679        let yaml = r#"
680type: enum
681values:
682  - low
683  - medium
684  - high
685description: "Priority level"
686"#;
687        let schema: FieldSchema = serde_yaml::from_str(yaml).unwrap();
688        assert_eq!(schema.field_type, FieldType::Enum);
689        assert_eq!(schema.values.len(), 3);
690    }
691
692    #[test]
693    fn test_condition_parsing_simple() {
694        let yaml = r#"
695type: extract
696condition:
697  context.session.user_name:
698    exists: false
699config:
700  schema:
701    user_name:
702      type: string
703"#;
704        let stage: ExtractStage = serde_yaml::from_str(yaml).unwrap();
705        assert!(stage.condition.is_some());
706    }
707
708    #[test]
709    fn test_condition_parsing_all() {
710        let yaml = r#"
711type: enrich
712condition:
713  all:
714    - context.input.extracted.user_name:
715        exists: true
716    - context.session.user_profile:
717        exists: false
718config: {}
719"#;
720        let stage: EnrichStage = serde_yaml::from_str(yaml).unwrap();
721        assert!(stage.condition.is_some());
722        match stage.condition.unwrap() {
723            ConditionExpr::All { all } => assert_eq!(all.len(), 2),
724            _ => panic!("Expected All condition"),
725        }
726    }
727
728    #[test]
729    fn test_condition_parsing_any() {
730        let yaml = r#"
731type: detect
732condition:
733  any:
734    - context.session.language:
735        exists: false
736    - context.session.force_detect: true
737config:
738  detect:
739    - language
740"#;
741        let stage: DetectStage = serde_yaml::from_str(yaml).unwrap();
742        assert!(stage.condition.is_some());
743        match stage.condition.unwrap() {
744            ConditionExpr::Any { any } => assert_eq!(any.len(), 2),
745            _ => panic!("Expected Any condition"),
746        }
747    }
748
749    #[test]
750    fn test_process_stage_condition_accessor() {
751        let stage = ProcessStage::Extract(ExtractStage {
752            id: Some("test".to_string()),
753            condition: Some(ConditionExpr::Simple(HashMap::new())),
754            config: ExtractConfig::default(),
755        });
756        assert!(stage.condition().is_some());
757        assert_eq!(stage.id(), Some("test"));
758    }
759
760    #[test]
761    fn test_process_stage_no_condition() {
762        let stage = ProcessStage::Normalize(NormalizeStage::default());
763        assert!(stage.condition().is_none());
764    }
765}