1use crate::error::{Result, TextError};
16use std::collections::HashMap;
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub enum DialogAct {
25 Greet,
27 Request,
29 Inform,
31 Confirm,
33 Reject,
35 Goodbye,
37 Unknown,
39}
40
41impl std::fmt::Display for DialogAct {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 let label = match self {
44 Self::Greet => "GREET",
45 Self::Request => "REQUEST",
46 Self::Inform => "INFORM",
47 Self::Confirm => "CONFIRM",
48 Self::Reject => "REJECT",
49 Self::Goodbye => "GOODBYE",
50 Self::Unknown => "UNKNOWN",
51 };
52 write!(f, "{label}")
53 }
54}
55
56#[derive(Debug, Clone, Default)]
76pub struct DialogState {
77 pub context: Vec<String>,
79 pub entities: HashMap<String, String>,
81 pub slots: HashMap<String, String>,
83 pub current_act: Option<DialogAct>,
85 pub turn_count: usize,
87}
88
89impl DialogState {
90 pub fn new() -> Self {
92 Self::default()
93 }
94
95 pub fn add_utterance(&mut self, utterance: &str) {
97 self.context.push(utterance.to_string());
98 self.turn_count += 1;
99 }
100
101 pub fn set_slot(&mut self, slot: &str, value: &str) {
103 self.slots.insert(slot.to_string(), value.to_string());
104 }
105
106 pub fn get_slot(&self, slot: &str) -> Option<&str> {
108 self.slots.get(slot).map(|s| s.as_str())
109 }
110
111 pub fn set_entity(&mut self, entity_type: &str, value: &str) {
113 self.entities
114 .insert(entity_type.to_string(), value.to_string());
115 }
116
117 pub fn get_entity(&self, entity_type: &str) -> Option<&str> {
119 self.entities.get(entity_type).map(|s| s.as_str())
120 }
121
122 pub fn reset(&mut self) {
124 *self = Self::default();
125 }
126
127 pub fn last_utterance(&self) -> Option<&str> {
129 self.context.last().map(|s| s.as_str())
130 }
131
132 pub fn slots_filled(&self, required: &[&str]) -> bool {
134 required.iter().all(|s| self.slots.contains_key(*s))
135 }
136}
137
138#[derive(Debug, Clone, Default)]
162pub struct IntentClassifier {
163 pub intents: Vec<String>,
165 pub patterns: Vec<Vec<String>>,
169}
170
171impl IntentClassifier {
172 pub fn new() -> Self {
174 Self::default()
175 }
176
177 pub fn add_intent(&mut self, name: &str, patterns: Vec<&str>) {
182 self.intents.push(name.to_string());
183 self.patterns
184 .push(patterns.into_iter().map(|p| p.to_lowercase()).collect());
185 }
186
187 pub fn len(&self) -> usize {
189 self.intents.len()
190 }
191
192 pub fn is_empty(&self) -> bool {
194 self.intents.is_empty()
195 }
196}
197
198pub fn classify_intent(utterance: &str, classifier: &IntentClassifier) -> (String, f64) {
205 if classifier.intents.is_empty() {
206 return ("unknown".to_string(), 0.0);
207 }
208
209 let utt_lower = utterance.to_lowercase();
210 let utt_tokens: Vec<&str> = utt_lower
211 .split(|c: char| !c.is_alphanumeric())
212 .filter(|t| !t.is_empty())
213 .collect();
214
215 let mut best_intent = "unknown".to_string();
216 let mut best_score = 0.0_f64;
217 let mut best_matches = 0usize;
218
219 for (intent_idx, patterns) in classifier.patterns.iter().enumerate() {
220 if patterns.is_empty() {
221 continue;
222 }
223 let total = patterns.len();
224 let matches = patterns
225 .iter()
226 .filter(|pat| {
227 utt_tokens.iter().any(|tok| {
229 *tok == pat.as_str()
230 || tok.starts_with(pat.as_str())
231 || utt_lower.contains(pat.as_str())
232 })
233 })
234 .count();
235
236 let score = matches as f64 / total as f64;
237 if matches > best_matches || (matches == best_matches && score > best_score) {
238 best_matches = matches;
239 best_score = score;
240 best_intent = classifier.intents[intent_idx].clone();
241 }
242 }
243
244 if best_matches == 0 {
245 ("unknown".to_string(), 0.0)
246 } else {
247 (best_intent, best_score)
248 }
249}
250
251#[derive(Debug, Clone, PartialEq, Eq, Hash)]
257pub enum EntityKind {
258 Date,
260 Number,
262 Name,
264 Location,
266 Custom(String),
268}
269
270impl std::fmt::Display for EntityKind {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 match self {
273 Self::Date => write!(f, "DATE"),
274 Self::Number => write!(f, "NUMBER"),
275 Self::Name => write!(f, "NAME"),
276 Self::Location => write!(f, "LOCATION"),
277 Self::Custom(s) => write!(f, "CUSTOM({})", s),
278 }
279 }
280}
281
282#[derive(Debug, Clone)]
284pub struct ExtractedEntity {
285 pub text: String,
287 pub kind: EntityKind,
289 pub start: usize,
291 pub end: usize,
293}
294
295#[derive(Debug, Default)]
300pub struct EntityExtractor {
301 gazetteer: Vec<(String, EntityKind)>,
303}
304
305impl EntityExtractor {
306 pub fn new() -> Self {
308 Self::default()
309 }
310
311 pub fn add_gazetteer_entry(&mut self, term: &str, kind: EntityKind) {
313 self.gazetteer.push((term.to_lowercase(), kind));
314 }
315
316 pub fn extract(&self, utterance: &str) -> Vec<ExtractedEntity> {
322 let mut entities: Vec<ExtractedEntity> = Vec::new();
323
324 self.extract_gazetteer(utterance, &mut entities);
325 self.extract_dates(utterance, &mut entities);
326 self.extract_numbers(utterance, &mut entities);
327 self.extract_names(utterance, &mut entities);
328 self.extract_locations(utterance, &mut entities);
329
330 entities.sort_by_key(|e| e.start);
332 entities
333 }
334
335 fn extract_gazetteer(&self, text: &str, out: &mut Vec<ExtractedEntity>) {
337 let text_lower = text.to_lowercase();
338 for (term, kind) in &self.gazetteer {
339 let mut search_start = 0usize;
340 while let Some(offset) = text_lower[search_start..].find(term.as_str()) {
341 let abs_start = search_start + offset;
342 let abs_end = abs_start + term.len();
343 out.push(ExtractedEntity {
344 text: text[abs_start..abs_end].to_string(),
345 kind: kind.clone(),
346 start: abs_start,
347 end: abs_end,
348 });
349 search_start = abs_end;
350 }
351 }
352 }
353
354 fn extract_dates(&self, text: &str, out: &mut Vec<ExtractedEntity>) {
360 let mut i = 0;
362 let bytes = text.as_bytes();
363 let len = bytes.len();
364
365 while i < len {
366 if !bytes[i].is_ascii_digit() {
368 i += 1;
369 continue;
370 }
371 let start = i;
373 while i < len && bytes[i].is_ascii_digit() {
374 i += 1;
375 }
376 if i < len && bytes[i] == b'/' {
378 let slash1 = i;
379 i += 1;
380 let seg2_start = i;
381 while i < len && bytes[i].is_ascii_digit() {
382 i += 1;
383 }
384 if i > seg2_start {
385 let end = if i < len && bytes[i] == b'/' {
386 i += 1; let seg3_start = i;
388 while i < len && bytes[i].is_ascii_digit() {
389 i += 1;
390 }
391 if i > seg3_start {
392 i
393 } else {
394 i = slash1 + 1 + (i - slash1 - 1);
396 slash1
397 }
398 } else {
399 i
400 };
401 let matched = &text[start..end];
402 if matched.contains('/') {
403 out.push(ExtractedEntity {
404 text: matched.to_string(),
405 kind: EntityKind::Date,
406 start,
407 end,
408 });
409 }
410 }
411 continue;
412 }
413 }
414
415 let months = [
417 "january",
418 "february",
419 "march",
420 "april",
421 "may",
422 "june",
423 "july",
424 "august",
425 "september",
426 "october",
427 "november",
428 "december",
429 "jan",
430 "feb",
431 "mar",
432 "apr",
433 "jun",
434 "jul",
435 "aug",
436 "sep",
437 "oct",
438 "nov",
439 "dec",
440 ];
441 let text_lower = text.to_lowercase();
442 for month in &months {
443 let mut search_pos = 0usize;
444 while let Some(offset) = text_lower[search_pos..].find(month) {
445 let abs_start = search_pos + offset;
446 let abs_end = abs_start + month.len();
447
448 let before_ok =
450 abs_start == 0 || !text.as_bytes()[abs_start - 1].is_ascii_alphanumeric();
451 let after_ok =
452 abs_end >= text.len() || !text.as_bytes()[abs_end].is_ascii_alphanumeric();
453
454 if before_ok && after_ok {
455 let mut end = abs_end;
457 let rest = &text[abs_end..];
458 let after_space: &str = rest.trim_start_matches(' ');
459 let day_len: usize = after_space
460 .chars()
461 .take_while(|c| c.is_ascii_digit())
462 .map(|c| c.len_utf8())
463 .sum();
464 if day_len > 0 {
465 let spaces = rest.len() - after_space.len();
466 end += spaces + day_len;
467 }
468
469 let rest2 = &text[end..];
471 let after_space2: &str = rest2.trim_start_matches(' ');
472 let year_candidate: String = after_space2
473 .chars()
474 .take_while(|c| c.is_ascii_digit())
475 .collect();
476 if year_candidate.len() == 4 {
477 let spaces2 = rest2.len() - after_space2.len();
478 end += spaces2 + 4;
479 }
480
481 out.push(ExtractedEntity {
482 text: text[abs_start..end].to_string(),
483 kind: EntityKind::Date,
484 start: abs_start,
485 end,
486 });
487 }
488
489 search_pos = abs_end;
490 }
491 }
492 }
493
494 fn extract_numbers(&self, text: &str, out: &mut Vec<ExtractedEntity>) {
496 let mut i = 0;
497 let bytes = text.as_bytes();
498 let len = bytes.len();
499
500 while i < len {
501 if !bytes[i].is_ascii_digit() {
502 i += 1;
503 continue;
504 }
505 let start = i;
506 while i < len && bytes[i].is_ascii_digit() {
507 i += 1;
508 }
509 if i < len && (bytes[i] == b'.' || bytes[i] == b',') {
511 let sep = i;
512 i += 1;
513 let frac_start = i;
514 while i < len && bytes[i].is_ascii_digit() {
515 i += 1;
516 }
517 if i == frac_start {
518 i = sep;
520 }
521 }
522 let end = i;
525 let candidate = &text[start..end];
526 let already_date = out
527 .iter()
528 .any(|e| e.kind == EntityKind::Date && e.start <= start && e.end >= end);
529 if !already_date {
530 out.push(ExtractedEntity {
531 text: candidate.to_string(),
532 kind: EntityKind::Number,
533 start,
534 end,
535 });
536 }
537 }
538 }
539
540 fn extract_names(&self, text: &str, out: &mut Vec<ExtractedEntity>) {
544 let mut word_spans: Vec<(usize, usize, &str)> = Vec::new();
546 let mut pos = 0usize;
547 for word in text.split_ascii_whitespace() {
548 if let Some(offset) = text[pos..].find(word) {
550 let start = pos + offset;
551 let end = start + word.len();
552 word_spans.push((start, end, word));
553 pos = end;
554 }
555 }
556
557 let mut i = 0usize;
559 while i < word_spans.len() {
560 let (start, _, word) = word_spans[i];
561 let first_alpha = word.chars().find(|c| c.is_alphabetic());
563 let is_cap = first_alpha.map(|c| c.is_uppercase()).unwrap_or(false);
564
565 if !is_cap {
566 i += 1;
567 continue;
568 }
569
570 let run_start = start;
572 let mut j = i;
573 while j < word_spans.len() {
574 let (_, _, w) = word_spans[j];
575 let fc = w.chars().find(|c| c.is_alphabetic());
576 if fc.map(|c| c.is_uppercase()).unwrap_or(false) {
577 j += 1;
578 } else {
579 break;
580 }
581 }
582
583 if j - i >= 2 {
585 let (_, run_end, _) = word_spans[j - 1];
586 let name_text = &text[run_start..run_end];
587 out.push(ExtractedEntity {
588 text: name_text.to_string(),
589 kind: EntityKind::Name,
590 start: run_start,
591 end: run_end,
592 });
593 i = j;
594 } else {
595 i += 1;
596 }
597 }
598 }
599
600 fn extract_locations(&self, text: &str, out: &mut Vec<ExtractedEntity>) {
603 let location_triggers = ["in ", "to ", "from ", "at ", "near ", "between "];
604 for trigger in &location_triggers {
605 let text_lower = text.to_lowercase();
606 let mut search_pos = 0usize;
607 while let Some(offset) = text_lower[search_pos..].find(trigger) {
608 let abs_trigger_start = search_pos + offset;
609 let candidate_start = abs_trigger_start + trigger.len();
610 if candidate_start >= text.len() {
611 break;
612 }
613
614 let rest = &text[candidate_start..];
616 let mut loc_end = candidate_start;
617 for word in rest.split_ascii_whitespace() {
618 let first_char = word
619 .trim_matches(|c: char| !c.is_alphabetic())
620 .chars()
621 .next();
622 if first_char.map(|c| c.is_uppercase()).unwrap_or(false) {
623 loc_end += word.len() + 1; } else {
625 break;
626 }
627 }
628 let loc_end = loc_end.min(text.len());
630 if loc_end > candidate_start {
631 let loc_text = text[candidate_start..loc_end].trim().to_string();
632 if !loc_text.is_empty() {
633 let actual_end = candidate_start + loc_text.len();
634 out.push(ExtractedEntity {
635 text: loc_text,
636 kind: EntityKind::Location,
637 start: candidate_start,
638 end: actual_end,
639 });
640 }
641 }
642
643 search_pos = candidate_start;
644 }
645 }
646 }
647}
648
649#[derive(Debug, Default, Clone)]
671pub struct SlotFiller;
672
673impl SlotFiller {
674 pub fn new() -> Self {
676 Self
677 }
678
679 pub fn fill(&self, utterance: &str, template: &str) -> Result<HashMap<String, String>> {
686 let parts = parse_template(template)?;
688 let mut slots: HashMap<String, String> = HashMap::new();
689
690 let utt_lower = utterance.to_lowercase();
692 let mut search_pos = 0usize;
693
694 let n = parts.len();
695 let mut pi = 0usize;
696
697 while pi < n {
698 match &parts[pi] {
699 TemplatePart::Literal(lit) => {
700 let lit_lower = lit.to_lowercase();
701 if lit_lower.is_empty() {
702 pi += 1;
703 continue;
704 }
705 if let Some(offset) = utt_lower[search_pos..].find(lit_lower.as_str()) {
706 search_pos += offset + lit.len();
707 pi += 1;
708 } else {
709 break;
711 }
712 }
713 TemplatePart::Slot(slot_name) => {
714 let next_literal: Option<&str> = parts[pi + 1..].iter().find_map(|p| {
716 if let TemplatePart::Literal(s) = p {
717 if !s.is_empty() {
718 Some(s.as_str())
719 } else {
720 None
721 }
722 } else {
723 None
724 }
725 });
726
727 let value_end = if let Some(next_lit) = next_literal {
728 let next_lit_lower = next_lit.to_lowercase();
729 utt_lower[search_pos..]
730 .find(next_lit_lower.as_str())
731 .map(|o| search_pos + o)
732 .unwrap_or(utt_lower.len())
733 } else {
734 utt_lower.len()
735 };
736
737 let raw_value = utterance[search_pos..value_end].trim().to_string();
738 if !raw_value.is_empty() {
739 slots.insert(slot_name.clone(), raw_value);
740 }
741 search_pos = value_end;
742 pi += 1;
743 }
744 }
745 }
746
747 Ok(slots)
748 }
749}
750
751#[derive(Debug)]
753enum TemplatePart {
754 Literal(String),
755 Slot(String),
756}
757
758fn parse_template(template: &str) -> Result<Vec<TemplatePart>> {
760 let mut parts: Vec<TemplatePart> = Vec::new();
761 let mut chars = template.char_indices().peekable();
762 let mut buf = String::new();
763
764 while let Some((_, ch)) = chars.next() {
765 if ch == '{' {
766 if !buf.is_empty() {
768 parts.push(TemplatePart::Literal(std::mem::take(&mut buf)));
769 }
770 let mut slot_name = String::new();
772 let mut closed = false;
773 for (_, sc) in chars.by_ref() {
774 if sc == '}' {
775 closed = true;
776 break;
777 }
778 slot_name.push(sc);
779 }
780 if !closed {
781 return Err(TextError::InvalidInput(
782 "Unclosed '{' in slot template".to_string(),
783 ));
784 }
785 if slot_name.is_empty() {
786 return Err(TextError::InvalidInput(
787 "Empty slot name '{}' in template".to_string(),
788 ));
789 }
790 parts.push(TemplatePart::Slot(slot_name));
791 } else {
792 buf.push(ch);
793 }
794 }
795
796 if !buf.is_empty() {
797 parts.push(TemplatePart::Literal(buf));
798 }
799
800 Ok(parts)
801}
802
803#[derive(Debug, Clone, PartialEq, Eq, Hash)]
809pub enum PolicyState {
810 Initial,
812 Greeted,
814 SlotCollection,
816 PendingConfirmation,
818 Confirmed,
820 Ended,
822}
823
824impl std::fmt::Display for PolicyState {
825 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
826 let s = match self {
827 Self::Initial => "INITIAL",
828 Self::Greeted => "GREETED",
829 Self::SlotCollection => "SLOT_COLLECTION",
830 Self::PendingConfirmation => "PENDING_CONFIRMATION",
831 Self::Confirmed => "CONFIRMED",
832 Self::Ended => "ENDED",
833 };
834 write!(f, "{s}")
835 }
836}
837
838#[derive(Debug, Clone)]
840pub struct PolicyAction {
841 pub act: DialogAct,
843 pub request_slot: Option<String>,
845 pub confirm_slots: Vec<String>,
847}
848
849pub struct DialogPolicy {
866 required_slots: Vec<String>,
868 policy_state: PolicyState,
870}
871
872impl DialogPolicy {
873 pub fn new(required_slots: Vec<String>) -> Self {
875 Self {
876 required_slots,
877 policy_state: PolicyState::Initial,
878 }
879 }
880
881 pub fn state(&self) -> &PolicyState {
883 &self.policy_state
884 }
885
886 pub fn next_action(&mut self, dialog_state: &DialogState) -> PolicyAction {
890 match self.policy_state {
891 PolicyState::Initial => {
892 self.policy_state = PolicyState::Greeted;
893 PolicyAction {
894 act: DialogAct::Greet,
895 request_slot: None,
896 confirm_slots: Vec::new(),
897 }
898 }
899 PolicyState::Greeted | PolicyState::SlotCollection => {
900 let missing = self
902 .required_slots
903 .iter()
904 .find(|s| !dialog_state.slots.contains_key(*s))
905 .cloned();
906
907 if let Some(slot) = missing {
908 self.policy_state = PolicyState::SlotCollection;
909 PolicyAction {
910 act: DialogAct::Request,
911 request_slot: Some(slot),
912 confirm_slots: Vec::new(),
913 }
914 } else {
915 self.policy_state = PolicyState::PendingConfirmation;
917 PolicyAction {
918 act: DialogAct::Confirm,
919 request_slot: None,
920 confirm_slots: self.required_slots.clone(),
921 }
922 }
923 }
924 PolicyState::PendingConfirmation => {
925 let confirmed = dialog_state
927 .last_utterance()
928 .map(|u| {
929 let ul = u.to_lowercase();
930 ul.contains("yes")
931 || ul.contains("correct")
932 || ul.contains("right")
933 || ul.contains("confirm")
934 })
935 .unwrap_or(false);
936
937 if confirmed {
938 self.policy_state = PolicyState::Confirmed;
939 PolicyAction {
940 act: DialogAct::Inform,
941 request_slot: None,
942 confirm_slots: Vec::new(),
943 }
944 } else {
945 self.policy_state = PolicyState::SlotCollection;
947 PolicyAction {
948 act: DialogAct::Reject,
949 request_slot: None,
950 confirm_slots: Vec::new(),
951 }
952 }
953 }
954 PolicyState::Confirmed => {
955 self.policy_state = PolicyState::Ended;
956 PolicyAction {
957 act: DialogAct::Goodbye,
958 request_slot: None,
959 confirm_slots: Vec::new(),
960 }
961 }
962 PolicyState::Ended => PolicyAction {
963 act: DialogAct::Goodbye,
964 request_slot: None,
965 confirm_slots: Vec::new(),
966 },
967 }
968 }
969
970 pub fn reset(&mut self) {
972 self.policy_state = PolicyState::Initial;
973 }
974}
975
976pub fn response_template(act: DialogAct, slots: &HashMap<String, String>) -> String {
998 let template = match act {
999 DialogAct::Greet => "Hello! How can I help you today?".to_string(),
1000 DialogAct::Request => {
1001 let slot_hint = slots
1004 .keys()
1005 .next()
1006 .map(|s| s.as_str())
1007 .unwrap_or("information");
1008 format!("Could you please provide the {slot_hint}?")
1009 }
1010 DialogAct::Inform => {
1011 if slots.is_empty() {
1012 "I have processed your request successfully.".to_string()
1013 } else {
1014 let details: Vec<String> = slots.iter().map(|(k, v)| format!("{k}: {v}")).collect();
1015 format!("Here is the information: {}.", details.join(", "))
1016 }
1017 }
1018 DialogAct::Confirm => {
1019 if slots.is_empty() {
1020 "Can you please confirm your request?".to_string()
1021 } else {
1022 let details: Vec<String> =
1023 slots.iter().map(|(k, v)| format!("{k} = {v}")).collect();
1024 format!(
1025 "Just to confirm, you would like to proceed with {}. Is that correct?",
1026 details.join(", ")
1027 )
1028 }
1029 }
1030 DialogAct::Reject => {
1031 "I'm sorry, that does not match what we have. Let's try again.".to_string()
1032 }
1033 DialogAct::Goodbye => "Thank you for using our service. Goodbye!".to_string(),
1034 DialogAct::Unknown => {
1035 "I'm sorry, I didn't understand that. Could you rephrase?".to_string()
1036 }
1037 };
1038
1039 let mut result = template;
1041 for (key, value) in slots {
1042 let placeholder = format!("{{{key}}}");
1043 result = result.replace(&placeholder, value);
1044 }
1045 result
1046}
1047
1048#[cfg(test)]
1053mod tests {
1054 use super::*;
1055
1056 #[test]
1059 fn test_dialog_state_slots() {
1060 let mut state = DialogState::new();
1061 state.set_slot("destination", "Paris");
1062 assert_eq!(state.get_slot("destination"), Some("Paris"));
1063 assert_eq!(state.get_slot("origin"), None);
1064 }
1065
1066 #[test]
1067 fn test_dialog_state_entities() {
1068 let mut state = DialogState::new();
1069 state.set_entity("DATE", "January 15");
1070 assert_eq!(state.get_entity("DATE"), Some("January 15"));
1071 }
1072
1073 #[test]
1074 fn test_dialog_state_utterances() {
1075 let mut state = DialogState::new();
1076 assert!(state.last_utterance().is_none());
1077 state.add_utterance("Hello");
1078 assert_eq!(state.last_utterance(), Some("Hello"));
1079 state.add_utterance("Goodbye");
1080 assert_eq!(state.last_utterance(), Some("Goodbye"));
1081 assert_eq!(state.turn_count, 2);
1082 }
1083
1084 #[test]
1085 fn test_dialog_state_slots_filled() {
1086 let mut state = DialogState::new();
1087 state.set_slot("a", "1");
1088 state.set_slot("b", "2");
1089 assert!(state.slots_filled(&["a", "b"]));
1090 assert!(!state.slots_filled(&["a", "b", "c"]));
1091 }
1092
1093 #[test]
1094 fn test_dialog_state_reset() {
1095 let mut state = DialogState::new();
1096 state.set_slot("x", "y");
1097 state.add_utterance("hello");
1098 state.reset();
1099 assert!(state.slots.is_empty());
1100 assert!(state.context.is_empty());
1101 assert_eq!(state.turn_count, 0);
1102 }
1103
1104 #[test]
1107 fn test_classify_intent_basic() {
1108 let mut clf = IntentClassifier::new();
1109 clf.add_intent("book_flight", vec!["book", "flight", "fly", "ticket"]);
1110 clf.add_intent("cancel", vec!["cancel", "undo", "delete"]);
1111
1112 let (intent, conf) = classify_intent("I want to book a flight", &clf);
1113 assert_eq!(intent, "book_flight");
1114 assert!(conf > 0.0);
1115 }
1116
1117 #[test]
1118 fn test_classify_intent_unknown() {
1119 let clf = IntentClassifier::new();
1120 let (intent, conf) = classify_intent("hello", &clf);
1121 assert_eq!(intent, "unknown");
1122 assert_eq!(conf, 0.0);
1123 }
1124
1125 #[test]
1126 fn test_classify_intent_no_match() {
1127 let mut clf = IntentClassifier::new();
1128 clf.add_intent("book_flight", vec!["book", "flight"]);
1129 let (intent, conf) = classify_intent("tell me the weather", &clf);
1130 assert_eq!(intent, "unknown");
1131 assert_eq!(conf, 0.0);
1132 }
1133
1134 #[test]
1135 fn test_classify_intent_case_insensitive() {
1136 let mut clf = IntentClassifier::new();
1137 clf.add_intent("greet", vec!["hello", "hi", "hey"]);
1138 let (intent, _conf) = classify_intent("HELLO there", &clf);
1139 assert_eq!(intent, "greet");
1140 }
1141
1142 #[test]
1145 fn test_extract_numbers() {
1146 let ext = EntityExtractor::new();
1147 let entities = ext.extract("I need 3 tickets and 12.5 kg baggage");
1148 let numbers: Vec<&str> = entities
1149 .iter()
1150 .filter(|e| e.kind == EntityKind::Number)
1151 .map(|e| e.text.as_str())
1152 .collect();
1153 assert!(numbers.contains(&"3"), "Missing '3': {:?}", numbers);
1154 assert!(numbers.contains(&"12.5"), "Missing '12.5': {:?}", numbers);
1155 }
1156
1157 #[test]
1158 fn test_extract_date_month_name() {
1159 let ext = EntityExtractor::new();
1160 let entities = ext.extract("The flight is on January 15");
1161 let dates: Vec<&str> = entities
1162 .iter()
1163 .filter(|e| e.kind == EntityKind::Date)
1164 .map(|e| e.text.as_str())
1165 .collect();
1166 assert!(!dates.is_empty(), "Expected at least one date entity");
1167 assert!(
1168 dates.iter().any(|d| d.contains("January")),
1169 "Expected 'January' in dates: {:?}",
1170 dates
1171 );
1172 }
1173
1174 #[test]
1175 fn test_extract_gazetteer() {
1176 let mut ext = EntityExtractor::new();
1177 ext.add_gazetteer_entry("london", EntityKind::Location);
1178 let entities = ext.extract("I want to travel to London");
1179 let locs: Vec<&str> = entities
1180 .iter()
1181 .filter(|e| e.kind == EntityKind::Location)
1182 .map(|e| e.text.as_str())
1183 .collect();
1184 assert!(!locs.is_empty(), "Expected location entity");
1185 }
1186
1187 #[test]
1190 fn test_slot_filler_basic() {
1191 let sf = SlotFiller::new();
1192 let slots = sf
1193 .fill(
1194 "book a flight from London to Paris",
1195 "flight from {origin} to {destination}",
1196 )
1197 .expect("fill should succeed");
1198 assert_eq!(slots.get("origin").map(|s| s.as_str()), Some("London"));
1199 assert_eq!(slots.get("destination").map(|s| s.as_str()), Some("Paris"));
1200 }
1201
1202 #[test]
1203 fn test_slot_filler_single_slot() {
1204 let sf = SlotFiller::new();
1205 let slots = sf
1206 .fill("my name is Alice", "my name is {name}")
1207 .expect("fill should succeed");
1208 assert_eq!(slots.get("name").map(|s| s.as_str()), Some("Alice"));
1209 }
1210
1211 #[test]
1212 fn test_slot_filler_unclosed_brace_error() {
1213 let sf = SlotFiller::new();
1214 let result = sf.fill("hello world", "hello {world");
1215 assert!(result.is_err(), "Expected error for unclosed brace");
1216 }
1217
1218 #[test]
1219 fn test_slot_filler_no_match() {
1220 let sf = SlotFiller::new();
1221 let slots = sf
1222 .fill(
1223 "completely different text",
1224 "flight from {origin} to {destination}",
1225 )
1226 .expect("should not error");
1227 assert!(
1229 !slots.contains_key("origin") && !slots.contains_key("destination"),
1230 "Expected no slots when template does not match"
1231 );
1232 }
1233
1234 #[test]
1237 fn test_policy_initial_greet() {
1238 let mut policy = DialogPolicy::new(vec!["origin".to_string(), "destination".to_string()]);
1239 let state = DialogState::new();
1240 let action = policy.next_action(&state);
1241 assert_eq!(action.act, DialogAct::Greet);
1242 }
1243
1244 #[test]
1245 fn test_policy_requests_missing_slot() {
1246 let mut policy = DialogPolicy::new(vec!["origin".to_string(), "destination".to_string()]);
1247 let mut state = DialogState::new();
1248 policy.next_action(&state); let action = policy.next_action(&state);
1250 assert_eq!(action.act, DialogAct::Request);
1251 assert!(action.request_slot.is_some());
1252 }
1253
1254 #[test]
1255 fn test_policy_confirms_when_slots_filled() {
1256 let mut policy = DialogPolicy::new(vec!["origin".to_string(), "destination".to_string()]);
1257 let mut state = DialogState::new();
1258 policy.next_action(&state); state.set_slot("origin", "London");
1260 state.set_slot("destination", "Paris");
1261 let action = policy.next_action(&state);
1262 assert_eq!(action.act, DialogAct::Confirm);
1263 }
1264
1265 #[test]
1266 fn test_policy_informs_after_confirmation() {
1267 let mut policy = DialogPolicy::new(vec!["origin".to_string()]);
1268 let mut state = DialogState::new();
1269 policy.next_action(&state); state.set_slot("origin", "London");
1271 policy.next_action(&state); state.add_utterance("yes");
1273 let action = policy.next_action(&state);
1274 assert_eq!(action.act, DialogAct::Inform);
1275 }
1276
1277 #[test]
1278 fn test_policy_goodbye_at_end() {
1279 let mut policy = DialogPolicy::new(vec!["origin".to_string()]);
1280 let mut state = DialogState::new();
1281 policy.next_action(&state); state.set_slot("origin", "London");
1283 policy.next_action(&state); state.add_utterance("yes");
1285 policy.next_action(&state); let action = policy.next_action(&state); assert_eq!(action.act, DialogAct::Goodbye);
1288 }
1289
1290 #[test]
1291 fn test_policy_reset() {
1292 let mut policy = DialogPolicy::new(vec!["slot_a".to_string()]);
1293 let state = DialogState::new();
1294 policy.next_action(&state);
1295 assert_ne!(*policy.state(), PolicyState::Initial);
1296 policy.reset();
1297 assert_eq!(*policy.state(), PolicyState::Initial);
1298 }
1299
1300 #[test]
1303 fn test_response_greet() {
1304 let slots: HashMap<String, String> = HashMap::new();
1305 let response = response_template(DialogAct::Greet, &slots);
1306 assert!(!response.is_empty());
1307 let lower = response.to_lowercase();
1308 assert!(
1309 lower.contains("hello") || lower.contains("hi") || lower.contains("help"),
1310 "Greet response should be a greeting: '{response}'"
1311 );
1312 }
1313
1314 #[test]
1315 fn test_response_inform_with_slots() {
1316 let mut slots: HashMap<String, String> = HashMap::new();
1317 slots.insert("destination".to_string(), "Paris".to_string());
1318 let response = response_template(DialogAct::Inform, &slots);
1319 assert!(
1320 response.contains("Paris"),
1321 "Response should contain 'Paris': '{response}'"
1322 );
1323 }
1324
1325 #[test]
1326 fn test_response_goodbye() {
1327 let slots: HashMap<String, String> = HashMap::new();
1328 let response = response_template(DialogAct::Goodbye, &slots);
1329 let lower = response.to_lowercase();
1330 assert!(
1331 lower.contains("goodbye") || lower.contains("bye") || lower.contains("thank"),
1332 "Goodbye response unexpected: '{response}'"
1333 );
1334 }
1335
1336 #[test]
1337 fn test_response_confirm_with_slots() {
1338 let mut slots: HashMap<String, String> = HashMap::new();
1339 slots.insert("origin".to_string(), "London".to_string());
1340 slots.insert("destination".to_string(), "Tokyo".to_string());
1341 let response = response_template(DialogAct::Confirm, &slots);
1342 assert!(!response.is_empty());
1343 }
1344
1345 #[test]
1346 fn test_response_unknown() {
1347 let slots: HashMap<String, String> = HashMap::new();
1348 let response = response_template(DialogAct::Unknown, &slots);
1349 assert!(!response.is_empty());
1350 }
1351
1352 #[test]
1355 fn test_dialog_act_display() {
1356 assert_eq!(DialogAct::Greet.to_string(), "GREET");
1357 assert_eq!(DialogAct::Goodbye.to_string(), "GOODBYE");
1358 assert_eq!(DialogAct::Unknown.to_string(), "UNKNOWN");
1359 }
1360}