1use regex::Regex;
30use serde::Deserialize;
31
32use crate::error::CoreError;
33
34const DEFAULT_SIGNAL_TOOL: &str = "nika:complete";
40
41const DEFAULT_CONFIDENCE_THRESHOLD: f64 = 0.7;
43
44const DEFAULT_MAX_RETRIES: u32 = 2;
46
47#[derive(Debug, Clone, Default, Deserialize)]
56pub struct CompletionConfig {
57 #[serde(default)]
59 pub mode: CompletionMode,
60
61 #[serde(default)]
63 pub signal: Option<SignalConfig>,
64
65 #[serde(default)]
67 pub patterns: Vec<PatternConfig>,
68
69 #[serde(default)]
71 pub confidence: Option<ConfidenceConfig>,
72
73 #[serde(default)]
75 pub instruction: Option<InstructionConfig>,
76}
77
78impl CompletionConfig {
79 pub fn generate_system_instruction(&self) -> String {
84 match self.mode {
85 CompletionMode::Explicit => self.generate_explicit_instruction(),
86 CompletionMode::Natural => String::new(), CompletionMode::Pattern => self.generate_pattern_instruction(),
88 }
89 }
90
91 fn generate_explicit_instruction(&self) -> String {
92 let signal = self
93 .signal
94 .as_ref()
95 .map(|s| &s.tool)
96 .map(String::as_str)
97 .unwrap_or(DEFAULT_SIGNAL_TOOL);
98
99 let fields = self.signal.as_ref().map(|s| &s.fields);
100
101 let tone = self
102 .instruction
103 .as_ref()
104 .map(|i| &i.tone)
105 .unwrap_or(&InstructionTone::Concise);
106
107 let lang = self
108 .instruction
109 .as_ref()
110 .and_then(|i| i.lang.as_ref())
111 .map(String::as_str)
112 .unwrap_or("en");
113
114 match (tone, lang) {
115 (InstructionTone::Concise, "fr") => {
116 let mut instruction =
117 format!("Quand tu as terminé, appelle l'outil {} avec:\n", signal);
118 if let Some(f) = fields {
119 for field in &f.required {
120 instruction.push_str(&format!("• {} (requis)\n", field));
121 }
122 for field in &f.optional {
123 instruction.push_str(&format!("• {} (optionnel)\n", field));
124 }
125 } else {
126 instruction.push_str("• result (requis)\n");
127 }
128 if let Some(conf) = &self.confidence {
129 instruction.push_str(&format!(
130 "\nConfidence minimum acceptée: {}\n",
131 conf.threshold
132 ));
133 }
134 instruction
135 }
136 (InstructionTone::Concise, _) => {
137 let mut instruction = format!("When complete, call {} with:\n", signal);
138 if let Some(f) = fields {
139 for field in &f.required {
140 instruction.push_str(&format!("• {} (required)\n", field));
141 }
142 for field in &f.optional {
143 instruction.push_str(&format!("• {} (optional)\n", field));
144 }
145 } else {
146 instruction.push_str("• result (required)\n");
147 }
148 if let Some(conf) = &self.confidence {
149 instruction.push_str(&format!(
150 "\nMinimum accepted confidence: {}\n",
151 conf.threshold
152 ));
153 }
154 instruction
155 }
156 (InstructionTone::Detailed, "fr") => {
157 let mut instruction = format!(
158 "INSTRUCTIONS DE COMPLÉTION:\n\
159 Quand vous avez terminé votre tâche, vous DEVEZ appeler l'outil {} \
160 pour signaler la complétion.\n\n\
161 Paramètres:\n",
162 signal
163 );
164 if let Some(f) = fields {
165 for field in &f.required {
166 instruction
167 .push_str(&format!("• {} (REQUIS): Valeur obligatoire\n", field));
168 }
169 for field in &f.optional {
170 instruction
171 .push_str(&format!("• {} (optionnel): Valeur recommandée\n", field));
172 }
173 }
174 instruction
175 }
176 (InstructionTone::Detailed, _) => {
177 let mut instruction = format!(
178 "COMPLETION INSTRUCTIONS:\n\
179 When you have completed your task, you MUST call the {} tool \
180 to signal completion.\n\n\
181 Parameters:\n",
182 signal
183 );
184 if let Some(f) = fields {
185 for field in &f.required {
186 instruction.push_str(&format!("• {} (REQUIRED): Mandatory value\n", field));
187 }
188 for field in &f.optional {
189 instruction
190 .push_str(&format!("• {} (optional): Recommended value\n", field));
191 }
192 }
193 instruction
194 }
195 }
196 }
197
198 fn generate_pattern_instruction(&self) -> String {
199 if self.patterns.is_empty() {
200 return String::new();
201 }
202
203 let lang = self
204 .instruction
205 .as_ref()
206 .and_then(|i| i.lang.as_ref())
207 .map(String::as_str)
208 .unwrap_or("en");
209
210 let patterns: Vec<&str> = self
211 .patterns
212 .iter()
213 .filter(|p| p.pattern_type != PatternType::Regex)
214 .map(|p| p.value.as_str())
215 .collect();
216
217 if patterns.is_empty() {
218 return String::new();
219 }
220
221 match lang {
222 "fr" => format!(
223 "Quand tu as terminé, termine ta réponse avec: {}\n",
224 patterns.join(" ou ")
225 ),
226 _ => format!(
227 "When complete, end your response with: {}\n",
228 patterns.join(" or ")
229 ),
230 }
231 }
232
233 pub fn check_pattern_match(&self, output: &str) -> bool {
238 if self.mode != CompletionMode::Pattern {
239 return false;
240 }
241
242 for pattern in &self.patterns {
243 if pattern.matches(output) {
244 return true;
245 }
246 }
247 false
248 }
249
250 pub fn effective_mode(&self) -> CompletionMode {
252 self.mode.clone()
253 }
254
255 pub fn validate(&self) -> Result<(), CoreError> {
257 if self.mode == CompletionMode::Pattern && self.patterns.is_empty() {
259 return Err(CoreError::ValidationError {
260 reason: "completion.mode: pattern requires at least one pattern definition".into(),
261 });
262 }
263
264 if let Some(conf) = &self.confidence {
266 if conf.threshold < 0.0 || conf.threshold > 1.0 {
267 return Err(CoreError::ValidationError {
268 reason: format!(
269 "confidence.threshold must be between 0.0 and 1.0, got {}",
270 conf.threshold
271 ),
272 });
273 }
274 }
275
276 for pattern in &self.patterns {
278 if pattern.pattern_type == PatternType::Regex && Regex::new(&pattern.value).is_err() {
279 return Err(CoreError::ValidationError {
280 reason: format!("Invalid regex pattern: {}", pattern.value),
281 });
282 }
283 }
284
285 Ok(())
286 }
287}
288
289#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize)]
295#[serde(rename_all = "lowercase")]
296pub enum CompletionMode {
297 #[default]
299 Explicit,
300
301 Natural,
303
304 Pattern,
306}
307
308#[derive(Debug, Clone, Deserialize)]
314pub struct SignalConfig {
315 #[serde(default = "default_signal_tool")]
317 pub tool: String,
318
319 #[serde(default)]
321 pub fields: SignalFields,
322}
323
324impl Default for SignalConfig {
325 fn default() -> Self {
326 Self {
327 tool: DEFAULT_SIGNAL_TOOL.to_string(),
328 fields: SignalFields::default(),
329 }
330 }
331}
332
333fn default_signal_tool() -> String {
334 DEFAULT_SIGNAL_TOOL.to_string()
335}
336
337#[derive(Debug, Clone, Deserialize)]
339pub struct SignalFields {
340 #[serde(default = "default_required_fields")]
342 pub required: Vec<String>,
343
344 #[serde(default)]
346 pub optional: Vec<String>,
347}
348
349impl Default for SignalFields {
350 fn default() -> Self {
351 Self {
352 required: default_required_fields(),
353 optional: Vec::new(),
354 }
355 }
356}
357
358fn default_required_fields() -> Vec<String> {
359 vec!["result".to_string()]
360}
361
362#[derive(Debug, Clone, Deserialize)]
368pub struct PatternConfig {
369 pub value: String,
371
372 #[serde(default, rename = "type")]
374 pub pattern_type: PatternType,
375
376 #[serde(skip)]
378 compiled_regex: std::sync::OnceLock<Option<Regex>>,
379}
380
381impl PatternConfig {
382 pub fn new(value: impl Into<String>, pattern_type: PatternType) -> Self {
384 Self {
385 value: value.into(),
386 pattern_type,
387 compiled_regex: std::sync::OnceLock::new(),
388 }
389 }
390
391 pub fn matches(&self, output: &str) -> bool {
393 match self.pattern_type {
394 PatternType::Exact => output == self.value,
395 PatternType::Contains => output.contains(&self.value),
396 PatternType::Regex => {
397 let regex = self
398 .compiled_regex
399 .get_or_init(|| Regex::new(&self.value).ok());
400 regex
401 .as_ref()
402 .map(|re| re.is_match(output))
403 .unwrap_or(false)
404 }
405 }
406 }
407}
408
409#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize)]
411#[serde(rename_all = "lowercase")]
412pub enum PatternType {
413 Exact,
415
416 #[default]
418 Contains,
419
420 Regex,
422}
423
424#[derive(Debug, Clone, Deserialize)]
430pub struct ConfidenceConfig {
431 #[serde(default = "default_confidence_threshold")]
433 pub threshold: f64,
434
435 #[serde(default)]
437 pub on_low: OnLowConfidenceConfig,
438
439 #[serde(default)]
441 pub routing: Option<ConfidenceRouting>,
442}
443
444impl Default for ConfidenceConfig {
445 fn default() -> Self {
446 Self {
447 threshold: DEFAULT_CONFIDENCE_THRESHOLD,
448 on_low: OnLowConfidenceConfig::default(),
449 routing: None,
450 }
451 }
452}
453
454fn default_confidence_threshold() -> f64 {
455 DEFAULT_CONFIDENCE_THRESHOLD
456}
457
458#[derive(Debug, Clone, Default, Deserialize)]
460pub struct OnLowConfidenceConfig {
461 #[serde(default)]
463 pub action: LowConfidenceAction,
464
465 #[serde(default = "default_max_retries")]
467 pub max_retries: u32,
468
469 #[serde(default)]
471 pub feedback: Option<String>,
472}
473
474fn default_max_retries() -> u32 {
475 DEFAULT_MAX_RETRIES
476}
477
478#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize)]
480#[serde(rename_all = "lowercase")]
481pub enum LowConfidenceAction {
482 #[default]
484 Retry,
485
486 Escalate,
488
489 Accept,
491}
492
493#[derive(Debug, Clone, Deserialize)]
495pub struct ConfidenceRouting {
496 pub high: ConfidenceRoute,
498
499 pub medium: ConfidenceRoute,
501
502 pub low: ConfidenceRoute,
504}
505
506#[derive(Debug, Clone, Deserialize)]
508pub struct ConfidenceRoute {
509 #[serde(default)]
511 pub min: Option<f64>,
512
513 pub action: RouteAction,
515
516 #[serde(default)]
518 pub escalate_to: Option<String>,
519}
520
521#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
523#[serde(rename_all = "snake_case")]
524pub enum RouteAction {
525 Accept,
527
528 AcceptWithFlag,
530
531 Retry,
533
534 Escalate,
536}
537
538#[derive(Debug, Clone, Default, Deserialize)]
544pub struct InstructionConfig {
545 #[serde(default)]
547 pub tone: InstructionTone,
548
549 #[serde(default)]
551 pub lang: Option<String>,
552}
553
554#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize)]
556#[serde(rename_all = "lowercase")]
557pub enum InstructionTone {
558 #[default]
560 Concise,
561
562 Detailed,
564}
565
566#[cfg(test)]
571mod tests {
572 use super::*;
573 use crate::serde_yaml;
574
575 #[test]
580 fn parse_completion_mode_explicit() {
581 let yaml = r#"
582mode: explicit
583"#;
584 let config: CompletionConfig = serde_yaml::from_str(yaml).unwrap();
585 assert_eq!(config.mode, CompletionMode::Explicit);
586 }
587
588 #[test]
589 fn parse_completion_mode_natural() {
590 let yaml = r#"
591mode: natural
592"#;
593 let config: CompletionConfig = serde_yaml::from_str(yaml).unwrap();
594 assert_eq!(config.mode, CompletionMode::Natural);
595 }
596
597 #[test]
598 fn parse_completion_mode_pattern() {
599 let yaml = r#"
600mode: pattern
601patterns:
602 - value: "COMPLETE"
603 type: exact
604 - value: "DONE"
605 type: contains
606"#;
607 let config: CompletionConfig = serde_yaml::from_str(yaml).unwrap();
608 assert_eq!(config.mode, CompletionMode::Pattern);
609 assert_eq!(config.patterns.len(), 2);
610 assert_eq!(config.patterns[0].value, "COMPLETE");
611 assert_eq!(config.patterns[0].pattern_type, PatternType::Exact);
612 assert_eq!(config.patterns[1].pattern_type, PatternType::Contains);
613 }
614
615 #[test]
616 fn parse_completion_mode_default_is_explicit() {
617 let yaml = "";
618 let config: CompletionConfig = serde_yaml::from_str(yaml).unwrap();
619 assert_eq!(config.mode, CompletionMode::Explicit);
620 }
621
622 #[test]
627 fn parse_signal_config_full() {
628 let yaml = r#"
629mode: explicit
630signal:
631 tool: nika:complete
632 fields:
633 required:
634 - result
635 optional:
636 - confidence
637 - reason
638 - sources
639"#;
640 let config: CompletionConfig = serde_yaml::from_str(yaml).unwrap();
641 let signal = config.signal.unwrap();
642 assert_eq!(signal.tool, "nika:complete");
643 assert_eq!(signal.fields.required, vec!["result"]);
644 assert_eq!(
645 signal.fields.optional,
646 vec!["confidence", "reason", "sources"]
647 );
648 }
649
650 #[test]
651 fn parse_signal_config_defaults() {
652 let yaml = r#"
653mode: explicit
654signal: {}
655"#;
656 let config: CompletionConfig = serde_yaml::from_str(yaml).unwrap();
657 let signal = config.signal.unwrap();
658 assert_eq!(signal.tool, "nika:complete");
659 assert_eq!(signal.fields.required, vec!["result"]);
660 }
661
662 #[test]
667 fn pattern_matches_exact() {
668 let pattern = PatternConfig::new("DONE", PatternType::Exact);
669 assert!(pattern.matches("DONE"));
670 assert!(!pattern.matches("DONE!"));
671 assert!(!pattern.matches("Task is DONE"));
672 }
673
674 #[test]
675 fn pattern_matches_contains() {
676 let pattern = PatternConfig::new("DONE", PatternType::Contains);
677 assert!(pattern.matches("DONE"));
678 assert!(pattern.matches("Task is DONE!"));
679 assert!(!pattern.matches("Task is complete"));
680 }
681
682 #[test]
683 fn pattern_matches_regex() {
684 let pattern = PatternConfig::new(r"\[DONE:\w+\]", PatternType::Regex);
685 assert!(pattern.matches("[DONE:SUCCESS]"));
686 assert!(pattern.matches("Result: [DONE:COMPLETE]"));
687 assert!(!pattern.matches("[DONE:]"));
688 assert!(!pattern.matches("DONE"));
689 }
690
691 #[test]
696 fn parse_confidence_config() {
697 let yaml = r#"
698mode: explicit
699confidence:
700 threshold: 0.8
701 on_low:
702 action: retry
703 max_retries: 3
704 feedback: "Please verify your sources"
705"#;
706 let config: CompletionConfig = serde_yaml::from_str(yaml).unwrap();
707 let conf = config.confidence.unwrap();
708 assert_eq!(conf.threshold, 0.8);
709 assert_eq!(conf.on_low.action, LowConfidenceAction::Retry);
710 assert_eq!(conf.on_low.max_retries, 3);
711 assert_eq!(
712 conf.on_low.feedback,
713 Some("Please verify your sources".to_string())
714 );
715 }
716
717 #[test]
718 fn parse_confidence_routing() {
719 let yaml = r#"
720confidence:
721 threshold: 0.7
722 routing:
723 high:
724 min: 0.85
725 action: accept
726 medium:
727 min: 0.7
728 action: accept_with_flag
729 low:
730 action: escalate
731 escalate_to: human
732"#;
733 let config: CompletionConfig = serde_yaml::from_str(yaml).unwrap();
734 let routing = config.confidence.unwrap().routing.unwrap();
735 assert_eq!(routing.high.min, Some(0.85));
736 assert_eq!(routing.high.action, RouteAction::Accept);
737 assert_eq!(routing.medium.action, RouteAction::AcceptWithFlag);
738 assert_eq!(routing.low.action, RouteAction::Escalate);
739 assert_eq!(routing.low.escalate_to, Some("human".to_string()));
740 }
741
742 #[test]
747 fn parse_instruction_config() {
748 let yaml = r#"
749mode: explicit
750instruction:
751 tone: detailed
752 lang: fr
753"#;
754 let config: CompletionConfig = serde_yaml::from_str(yaml).unwrap();
755 let instr = config.instruction.unwrap();
756 assert_eq!(instr.tone, InstructionTone::Detailed);
757 assert_eq!(instr.lang, Some("fr".to_string()));
758 }
759
760 #[test]
765 fn generate_instruction_explicit_concise_en() {
766 let config = CompletionConfig {
767 mode: CompletionMode::Explicit,
768 signal: Some(SignalConfig {
769 tool: "nika:complete".to_string(),
770 fields: SignalFields {
771 required: vec!["result".to_string()],
772 optional: vec!["confidence".to_string()],
773 },
774 }),
775 instruction: Some(InstructionConfig {
776 tone: InstructionTone::Concise,
777 lang: Some("en".to_string()),
778 }),
779 ..Default::default()
780 };
781
782 let instruction = config.generate_system_instruction();
783 assert!(instruction.contains("nika:complete"));
784 assert!(instruction.contains("result"));
785 assert!(instruction.contains("required"));
786 assert!(instruction.contains("confidence"));
787 assert!(instruction.contains("optional"));
788 }
789
790 #[test]
791 fn generate_instruction_explicit_concise_fr() {
792 let config = CompletionConfig {
793 mode: CompletionMode::Explicit,
794 signal: Some(SignalConfig::default()),
795 instruction: Some(InstructionConfig {
796 tone: InstructionTone::Concise,
797 lang: Some("fr".to_string()),
798 }),
799 ..Default::default()
800 };
801
802 let instruction = config.generate_system_instruction();
803 assert!(instruction.contains("Quand tu as terminé"));
804 assert!(instruction.contains("nika:complete"));
805 assert!(instruction.contains("requis"));
806 }
807
808 #[test]
809 fn generate_instruction_natural_is_empty() {
810 let config = CompletionConfig {
811 mode: CompletionMode::Natural,
812 ..Default::default()
813 };
814
815 let instruction = config.generate_system_instruction();
816 assert!(instruction.is_empty());
817 }
818
819 #[test]
820 fn generate_instruction_pattern() {
821 let config = CompletionConfig {
822 mode: CompletionMode::Pattern,
823 patterns: vec![
824 PatternConfig::new("COMPLETE", PatternType::Contains),
825 PatternConfig::new("DONE", PatternType::Contains),
826 ],
827 ..Default::default()
828 };
829
830 let instruction = config.generate_system_instruction();
831 assert!(instruction.contains("COMPLETE"));
832 assert!(instruction.contains("DONE"));
833 }
834
835 #[test]
840 fn validate_confidence_threshold_valid() {
841 let config = CompletionConfig {
842 confidence: Some(ConfidenceConfig {
843 threshold: 0.7,
844 ..Default::default()
845 }),
846 ..Default::default()
847 };
848 assert!(config.validate().is_ok());
849 }
850
851 #[test]
852 fn validate_confidence_threshold_too_high() {
853 let config = CompletionConfig {
854 confidence: Some(ConfidenceConfig {
855 threshold: 1.5,
856 ..Default::default()
857 }),
858 ..Default::default()
859 };
860 let err = config.validate().unwrap_err();
861 assert!(err.to_string().contains("confidence.threshold"));
862 }
863
864 #[test]
865 fn validate_confidence_threshold_negative() {
866 let config = CompletionConfig {
867 confidence: Some(ConfidenceConfig {
868 threshold: -0.1,
869 ..Default::default()
870 }),
871 ..Default::default()
872 };
873 assert!(config.validate().is_err());
874 }
875
876 #[test]
877 fn validate_invalid_regex() {
878 let config = CompletionConfig {
879 mode: CompletionMode::Pattern,
880 patterns: vec![PatternConfig::new("[invalid(", PatternType::Regex)],
881 ..Default::default()
882 };
883 let err = config.validate().unwrap_err();
884 assert!(err.to_string().contains("Invalid regex"));
885 }
886
887 #[test]
892 fn check_pattern_match_explicit_mode_always_false() {
893 let config = CompletionConfig {
894 mode: CompletionMode::Explicit,
895 patterns: vec![PatternConfig::new("DONE", PatternType::Contains)],
896 ..Default::default()
897 };
898 assert!(!config.check_pattern_match("DONE"));
900 }
901
902 #[test]
903 fn check_pattern_match_pattern_mode() {
904 let config = CompletionConfig {
905 mode: CompletionMode::Pattern,
906 patterns: vec![
907 PatternConfig::new("DONE", PatternType::Contains),
908 PatternConfig::new(r"\[COMPLETE\]", PatternType::Regex),
909 ],
910 ..Default::default()
911 };
912 assert!(config.check_pattern_match("Task is DONE"));
913 assert!(config.check_pattern_match("[COMPLETE]"));
914 assert!(!config.check_pattern_match("Still working"));
915 }
916
917 #[test]
922 fn parse_full_completion_config() {
923 let yaml = r#"
924mode: explicit
925signal:
926 tool: nika:complete
927 fields:
928 required: [result]
929 optional: [confidence, reason, sources]
930confidence:
931 threshold: 0.7
932 on_low:
933 action: retry
934 max_retries: 2
935 feedback: "Confidence too low"
936instruction:
937 tone: concise
938 lang: en
939"#;
940 let config: CompletionConfig = serde_yaml::from_str(yaml).unwrap();
941
942 assert_eq!(config.mode, CompletionMode::Explicit);
943
944 let signal = config.signal.clone().unwrap();
945 assert_eq!(signal.tool, "nika:complete");
946 assert_eq!(signal.fields.required, vec!["result"]);
947 assert_eq!(signal.fields.optional.len(), 3);
948
949 let conf = config.confidence.clone().unwrap();
950 assert_eq!(conf.threshold, 0.7);
951 assert_eq!(conf.on_low.action, LowConfidenceAction::Retry);
952
953 let instr = config.instruction.clone().unwrap();
954 assert_eq!(instr.tone, InstructionTone::Concise);
955
956 assert!(config.validate().is_ok());
957 }
958}