1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6
7use ai_agents_core::{ChatMessage, LLMProvider, Result};
8
9use super::config::{CompareOp, ContextMatcher, GuardConditions, Transition, TransitionGuard};
10
11pub struct TransitionContext {
12 pub user_message: String,
13 pub assistant_response: String,
14 pub current_state: String,
15 pub context: HashMap<String, Value>,
16}
17
18impl TransitionContext {
19 pub fn new(user_message: &str, assistant_response: &str, current_state: &str) -> Self {
20 Self {
21 user_message: user_message.to_string(),
22 assistant_response: assistant_response.to_string(),
23 current_state: current_state.to_string(),
24 context: HashMap::new(),
25 }
26 }
27
28 pub fn with_context(mut self, context: HashMap<String, Value>) -> Self {
29 self.context = context;
30 self
31 }
32}
33
34#[async_trait]
35pub trait TransitionEvaluator: Send + Sync {
36 async fn select_transition(
37 &self,
38 transitions: &[Transition],
39 context: &TransitionContext,
40 ) -> Result<Option<usize>>;
41}
42
43pub struct LLMTransitionEvaluator {
44 llm: Arc<dyn LLMProvider>,
45}
46
47impl LLMTransitionEvaluator {
48 pub fn new(llm: Arc<dyn LLMProvider>) -> Self {
49 Self { llm }
50 }
51}
52
53pub fn evaluate_guard(guard: &TransitionGuard, ctx: &TransitionContext) -> bool {
56 match guard {
57 TransitionGuard::Expression(expr) => evaluate_expression(expr, ctx),
58 TransitionGuard::Conditions(conditions) => evaluate_conditions(conditions, ctx),
59 }
60}
61
62pub fn evaluate_expression(expr: &str, ctx: &TransitionContext) -> bool {
63 let expr = expr.trim();
64
65 if !expr.contains("{{") {
66 return !expr.is_empty();
67 }
68
69 let inner = expr.trim_start_matches("{{").trim_end_matches("}}").trim();
70
71 evaluate_simple_expression(inner, ctx)
72}
73
74fn evaluate_simple_expression(expr: &str, ctx: &TransitionContext) -> bool {
75 if expr.starts_with("context.") {
76 let path = &expr[8..];
77 return get_context_value(path, &ctx.context).is_some();
78 }
79
80 if expr.starts_with("state.") {
81 let field = &expr[6..];
82 return evaluate_state_expression(field, ctx);
83 }
84
85 if let Some(idx) = expr.find('>') {
86 let (left, right) = expr.split_at(idx);
87 let op = if right.starts_with(">=") { ">=" } else { ">" };
88 let right = right.trim_start_matches(op).trim();
89 let left = left.trim();
90
91 if let (Some(left_val), Ok(right_val)) = (resolve_value(left, ctx), right.parse::<f64>()) {
92 if let Some(left_num) = left_val.as_f64() {
93 return if op == ">=" {
94 left_num >= right_val
95 } else {
96 left_num > right_val
97 };
98 }
99 }
100 }
101
102 if let Some(idx) = expr.find('<') {
103 let (left, right) = expr.split_at(idx);
104 let op = if right.starts_with("<=") { "<=" } else { "<" };
105 let right = right.trim_start_matches(op).trim();
106 let left = left.trim();
107
108 if let (Some(left_val), Ok(right_val)) = (resolve_value(left, ctx), right.parse::<f64>()) {
109 if let Some(left_num) = left_val.as_f64() {
110 return if op == "<=" {
111 left_num <= right_val
112 } else {
113 left_num < right_val
114 };
115 }
116 }
117 }
118
119 if let Some(idx) = expr.find("==") {
120 let (left, right) = expr.split_at(idx);
121 let right = &right[2..].trim();
122 let left = left.trim();
123
124 if let Some(left_val) = resolve_value(left, ctx) {
125 let right_val: Value = if right.starts_with('"') && right.ends_with('"') {
126 Value::String(right[1..right.len() - 1].to_string())
127 } else if *right == "true" {
128 Value::Bool(true)
129 } else if *right == "false" {
130 Value::Bool(false)
131 } else if let Ok(n) = right.parse::<f64>() {
132 serde_json::json!(n)
133 } else {
134 Value::String(right.to_string())
135 };
136
137 return left_val == right_val;
138 }
139 }
140
141 if let Some(idx) = expr.find("!=") {
142 let (left, right) = expr.split_at(idx);
143 let right = &right[2..].trim();
144 let left = left.trim();
145
146 if let Some(left_val) = resolve_value(left, ctx) {
147 let right_val: Value = if right.starts_with('"') && right.ends_with('"') {
148 Value::String(right[1..right.len() - 1].to_string())
149 } else if *right == "true" {
150 Value::Bool(true)
151 } else if *right == "false" {
152 Value::Bool(false)
153 } else if let Ok(n) = right.parse::<f64>() {
154 serde_json::json!(n)
155 } else {
156 Value::String(right.to_string())
157 };
158
159 return left_val != right_val;
160 }
161 }
162
163 false
164}
165
166fn resolve_value(expr: &str, ctx: &TransitionContext) -> Option<Value> {
167 let expr = expr.trim();
168 if expr.starts_with("context.") {
169 let path = &expr[8..];
170 return get_context_value(path, &ctx.context);
171 }
172 if expr.starts_with("state.") {
173 let field = &expr[6..];
174 return get_state_value(field, ctx);
175 }
176 None
177}
178
179fn evaluate_state_expression(field: &str, _ctx: &TransitionContext) -> bool {
180 match field {
181 "turn_count" => true,
182 _ => false,
183 }
184}
185
186fn get_state_value(field: &str, ctx: &TransitionContext) -> Option<Value> {
187 match field {
188 "current" => Some(Value::String(ctx.current_state.clone())),
189 _ => None,
190 }
191}
192
193pub fn evaluate_conditions(conditions: &GuardConditions, ctx: &TransitionContext) -> bool {
194 match conditions {
195 GuardConditions::All(exprs) => exprs.iter().all(|e| evaluate_expression(e, ctx)),
196 GuardConditions::Any(exprs) => exprs.iter().any(|e| evaluate_expression(e, ctx)),
197 GuardConditions::Context(matchers) => evaluate_context_matchers(matchers, &ctx.context),
198 }
199}
200
201pub fn evaluate_context_matchers(
202 matchers: &HashMap<String, ContextMatcher>,
203 context: &HashMap<String, Value>,
204) -> bool {
205 for (path, matcher) in matchers {
206 let value = get_context_value(path, context);
207 if !match_value(value.as_ref(), matcher) {
208 return false;
209 }
210 }
211 true
212}
213
214pub fn get_context_value(path: &str, context: &HashMap<String, Value>) -> Option<Value> {
215 ai_agents_core::get_dot_path_from_map(context, path)
216}
217
218pub fn match_value(value: Option<&Value>, matcher: &ContextMatcher) -> bool {
219 match matcher {
220 ContextMatcher::Exact(expected) => value.map(|v| v == expected).unwrap_or(false),
221 ContextMatcher::Exists { exists } => {
222 let has_value = value.is_some() && value != Some(&Value::Null);
223 *exists == has_value
224 }
225 ContextMatcher::Compare(op) => {
226 let Some(val) = value else {
227 return false;
228 };
229 compare_value(val, op)
230 }
231 }
232}
233
234fn values_equal_coerced(value: &Value, expected: &Value) -> bool {
237 if value == expected {
238 return true;
239 }
240 if let Some(s) = value.as_str() {
242 match expected {
243 Value::Bool(b) => match s {
244 "true" => return *b,
245 "false" => return !*b,
246 _ => {}
247 },
248 Value::Number(n) => {
249 if let Ok(parsed) = s.parse::<f64>() {
250 if let Some(expected_f) = n.as_f64() {
251 return (parsed - expected_f).abs() < f64::EPSILON;
252 }
253 }
254 }
255 _ => {}
256 }
257 }
258 if let Some(s) = expected.as_str() {
260 match value {
261 Value::Bool(b) => match s {
262 "true" => return *b,
263 "false" => return !*b,
264 _ => {}
265 },
266 Value::Number(n) => {
267 if let Ok(parsed) = s.parse::<f64>() {
268 if let Some(val_f) = n.as_f64() {
269 return (parsed - val_f).abs() < f64::EPSILON;
270 }
271 }
272 }
273 _ => {}
274 }
275 }
276 false
277}
278
279pub fn compare_value(value: &Value, op: &CompareOp) -> bool {
280 match op {
281 CompareOp::Eq(expected) => values_equal_coerced(value, expected),
282 CompareOp::Neq(expected) => !values_equal_coerced(value, expected),
283 CompareOp::Gt(n) => value.as_f64().map(|v| v > *n).unwrap_or(false),
284 CompareOp::Gte(n) => value.as_f64().map(|v| v >= *n).unwrap_or(false),
285 CompareOp::Lt(n) => value.as_f64().map(|v| v < *n).unwrap_or(false),
286 CompareOp::Lte(n) => value.as_f64().map(|v| v <= *n).unwrap_or(false),
287 CompareOp::In(values) => values.contains(value),
288 CompareOp::Contains(s) => value
289 .as_str()
290 .map(|v| v.contains(s))
291 .or_else(|| {
292 value
293 .as_array()
294 .map(|arr| arr.iter().any(|v| v.as_str() == Some(s)))
295 })
296 .unwrap_or(false),
297 }
298}
299
300#[async_trait]
301impl TransitionEvaluator for LLMTransitionEvaluator {
302 async fn select_transition(
303 &self,
304 transitions: &[Transition],
305 context: &TransitionContext,
306 ) -> Result<Option<usize>> {
307 if transitions.is_empty() {
308 return Ok(None);
309 }
310
311 for (i, transition) in transitions.iter().enumerate() {
313 if let Some(ref guard) = transition.guard {
314 if evaluate_guard(guard, context) {
315 return Ok(Some(i));
316 }
317 }
318 }
319
320 if let Some(resolved) = context.context.get("resolved_intent") {
325 if let Some(resolved_str) = resolved.as_str() {
326 if !resolved_str.is_empty() {
328 for (i, transition) in transitions.iter().enumerate() {
329 if let Some(ref intent) = transition.intent {
330 if intent == resolved_str {
331 tracing::debug!(
332 resolved_intent = resolved_str,
333 target = %transition.to,
334 "Deterministic routing via resolved_intent"
335 );
336 return Ok(Some(i));
337 }
338 }
339 }
340 }
341 }
342 }
343
344 let llm_transitions: Vec<(usize, &Transition)> = transitions
346 .iter()
347 .enumerate()
348 .filter(|(_, t)| !t.when.is_empty() && t.guard.is_none())
349 .collect();
350
351 if llm_transitions.is_empty() {
352 return Ok(None);
353 }
354
355 let conditions: Vec<String> = llm_transitions
356 .iter()
357 .enumerate()
358 .map(|(display_idx, (_, t))| format!("{}. {}", display_idx + 1, t.when))
359 .collect();
360
361 let prompt = format!(
362 r#"Based on the conversation, which condition is met?
363
364Current state: {}
365User message: {}
366Assistant response: {}
367
368Conditions:
369{}
3700. None of the above
371
372Reply with ONLY the number (0-{})."#,
373 context.current_state,
374 context.user_message,
375 context.assistant_response,
376 conditions.join("\n"),
377 llm_transitions.len()
378 );
379
380 let messages = vec![ChatMessage::user(&prompt)];
381 let response = self.llm.complete(&messages, None).await?;
382
383 let choice: usize = response.content.trim().parse().unwrap_or(0);
384
385 if choice == 0 || choice > llm_transitions.len() {
386 Ok(None)
387 } else {
388 Ok(Some(llm_transitions[choice - 1].0))
389 }
390 }
391}
392
393pub struct GuardOnlyEvaluator;
394
395impl GuardOnlyEvaluator {
396 pub fn new() -> Self {
397 Self
398 }
399
400 pub fn evaluate_guard(&self, guard: &TransitionGuard, ctx: &TransitionContext) -> bool {
401 evaluate_guard(guard, ctx)
402 }
403
404 pub fn evaluate_guards(
405 &self,
406 transitions: &[Transition],
407 ctx: &TransitionContext,
408 ) -> Option<usize> {
409 for (i, transition) in transitions.iter().enumerate() {
410 if let Some(ref guard) = transition.guard {
411 if evaluate_guard(guard, ctx) {
412 return Some(i);
413 }
414 }
415 }
416 None
417 }
418}
419
420impl Default for GuardOnlyEvaluator {
421 fn default() -> Self {
422 Self::new()
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::super::config::TransitionTiming;
429 use super::*;
430 use ai_agents_core::{FinishReason, LLMResponse};
431 use ai_agents_llm::mock::MockLLMProvider;
432
433 #[tokio::test]
434 async fn test_select_transition_none() {
435 let mut mock = MockLLMProvider::new("evaluator_test");
436 mock.add_response(LLMResponse::new("0", FinishReason::Stop));
437 let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
438
439 let transitions = vec![Transition {
440 to: "next".into(),
441 when: "user says goodbye".into(),
442 guard: None,
443 intent: None,
444 auto: true,
445 priority: 0,
446 cooldown_turns: None,
447 timing: TransitionTiming::PostResponse,
448 requires_response: false,
449 run_extractors: false,
450 }];
451
452 let context = TransitionContext::new("hello", "hi there", "greeting");
453
454 let result = evaluator.select_transition(&transitions, &context).await;
455 assert!(result.is_ok());
456 assert!(result.unwrap().is_none());
457 }
458
459 #[tokio::test]
460 async fn test_select_transition_match() {
461 let mut mock = MockLLMProvider::new("evaluator_test");
462 mock.add_response(LLMResponse::new("1", FinishReason::Stop));
463 let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
464
465 let transitions = vec![
466 Transition {
467 to: "support".into(),
468 when: "user needs help".into(),
469 guard: None,
470 intent: None,
471 auto: true,
472 priority: 10,
473 cooldown_turns: None,
474 timing: TransitionTiming::PostResponse,
475 requires_response: false,
476 run_extractors: false,
477 },
478 Transition {
479 to: "sales".into(),
480 when: "user wants to buy".into(),
481 guard: None,
482 intent: None,
483 auto: true,
484 priority: 5,
485 cooldown_turns: None,
486 timing: TransitionTiming::PostResponse,
487 requires_response: false,
488 run_extractors: false,
489 },
490 ];
491
492 let context = TransitionContext::new("I need help", "Sure!", "greeting");
493
494 let result = evaluator.select_transition(&transitions, &context).await;
495 assert!(result.is_ok());
496 assert_eq!(result.unwrap(), Some(0));
497 }
498
499 #[tokio::test]
500 async fn test_empty_transitions() {
501 let mock = MockLLMProvider::new("evaluator_test");
502 let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
503
504 let context = TransitionContext::new("hi", "hello", "start");
505
506 let result = evaluator.select_transition(&[], &context).await;
507 assert!(result.is_ok());
508 assert!(result.unwrap().is_none());
509 }
510
511 #[test]
512 fn test_guard_expression_simple() {
513 let mut context_map = HashMap::new();
514 context_map.insert("has_data".to_string(), Value::Bool(true));
515
516 let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
517
518 let guard = TransitionGuard::Expression("{{ context.has_data }}".into());
519 assert!(evaluate_guard(&guard, &ctx));
520 }
521
522 #[test]
523 fn test_guard_expression_missing() {
524 let ctx = TransitionContext::new("hi", "hello", "start").with_context(HashMap::new());
525
526 let guard = TransitionGuard::Expression("{{ context.has_data }}".into());
527 assert!(!evaluate_guard(&guard, &ctx));
528 }
529
530 #[test]
531 fn test_guard_with_nested_context() {
532 let mut context_map = HashMap::new();
533 context_map.insert(
534 "user".to_string(),
535 serde_json::json!({
536 "name": "Alice",
537 "verified": true
538 }),
539 );
540
541 let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
542
543 let guard = TransitionGuard::Expression("{{ context.user.verified }}".into());
544 assert!(evaluate_guard(&guard, &ctx));
545 }
546
547 #[test]
548 fn test_guard_conditions_all() {
549 let mut context_map = HashMap::new();
550 context_map.insert("has_name".to_string(), Value::Bool(true));
551 context_map.insert("has_email".to_string(), Value::Bool(true));
552
553 let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
554
555 let guard = TransitionGuard::Conditions(GuardConditions::All(vec![
556 "{{ context.has_name }}".into(),
557 "{{ context.has_email }}".into(),
558 ]));
559 assert!(evaluate_guard(&guard, &ctx));
560 }
561
562 #[test]
563 fn test_guard_conditions_any() {
564 let mut context_map = HashMap::new();
565 context_map.insert("is_vip".to_string(), Value::Bool(true));
566
567 let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
568
569 let guard = TransitionGuard::Conditions(GuardConditions::Any(vec![
570 "{{ context.is_admin }}".into(),
571 "{{ context.is_vip }}".into(),
572 ]));
573 assert!(evaluate_guard(&guard, &ctx));
574 }
575
576 #[test]
577 fn test_guard_context_matchers() {
578 let mut context_map = HashMap::new();
579 context_map.insert(
580 "user".to_string(),
581 serde_json::json!({
582 "tier": "premium",
583 "balance": 100.0
584 }),
585 );
586
587 let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
588
589 let mut matchers = HashMap::new();
590 matchers.insert(
591 "user.tier".to_string(),
592 ContextMatcher::Exact(Value::String("premium".into())),
593 );
594 matchers.insert(
595 "user.balance".to_string(),
596 ContextMatcher::Compare(CompareOp::Gte(50.0)),
597 );
598
599 let guard = TransitionGuard::Conditions(GuardConditions::Context(matchers));
600 assert!(evaluate_guard(&guard, &ctx));
601 }
602
603 #[tokio::test]
604 async fn test_guard_priority_over_llm() {
605 let mock = MockLLMProvider::new("guard_test");
606 let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
607
608 let mut context_map = HashMap::new();
609 context_map.insert("ready".to_string(), Value::Bool(true));
610
611 let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
612
613 let transitions = vec![
614 Transition {
615 to: "llm_based".into(),
616 when: "user wants to proceed".into(),
617 guard: None,
618 intent: None,
619 auto: true,
620 priority: 10,
621 cooldown_turns: None,
622 timing: TransitionTiming::PostResponse,
623 requires_response: false,
624 run_extractors: false,
625 },
626 Transition {
627 to: "guard_based".into(),
628 when: String::new(),
629 guard: Some(TransitionGuard::Expression("{{ context.ready }}".into())),
630 intent: None,
631 auto: true,
632 priority: 5,
633 cooldown_turns: None,
634 timing: TransitionTiming::PostResponse,
635 requires_response: false,
636 run_extractors: false,
637 },
638 ];
639
640 let result = evaluator.select_transition(&transitions, &ctx).await;
641 assert_eq!(result.unwrap(), Some(1));
642 }
643
644 #[test]
645 fn test_guard_only_evaluator() {
646 let evaluator = GuardOnlyEvaluator::new();
647
648 let mut context_map = HashMap::new();
649 context_map.insert("ready".to_string(), Value::Bool(true));
650
651 let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
652
653 let transitions = vec![
654 Transition {
655 to: "no_guard".into(),
656 when: "some condition".into(),
657 guard: None,
658 intent: None,
659 auto: true,
660 priority: 10,
661 cooldown_turns: None,
662 timing: TransitionTiming::PostResponse,
663 requires_response: false,
664 run_extractors: false,
665 },
666 Transition {
667 to: "with_guard".into(),
668 when: String::new(),
669 guard: Some(TransitionGuard::Expression("{{ context.ready }}".into())),
670 intent: None,
671 auto: true,
672 priority: 5,
673 cooldown_turns: None,
674 timing: TransitionTiming::PostResponse,
675 requires_response: false,
676 run_extractors: false,
677 },
678 ];
679
680 let result = evaluator.evaluate_guards(&transitions, &ctx);
681 assert_eq!(result, Some(1));
682 }
683
684 #[test]
685 fn test_context_matcher_exists() {
686 let mut context_map = HashMap::new();
687 context_map.insert("name".to_string(), Value::String("Alice".into()));
688
689 let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
690
691 let mut matchers = HashMap::new();
692 matchers.insert("name".to_string(), ContextMatcher::Exists { exists: true });
693 matchers.insert(
694 "email".to_string(),
695 ContextMatcher::Exists { exists: false },
696 );
697
698 let guard = TransitionGuard::Conditions(GuardConditions::Context(matchers));
699 assert!(evaluate_guard(&guard, &ctx));
700 }
701
702 #[test]
703 fn test_compare_contains() {
704 let mut context_map = HashMap::new();
705 context_map.insert("message".to_string(), Value::String("hello world".into()));
706 context_map.insert("tags".to_string(), serde_json::json!(["urgent", "support"]));
707
708 let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
709
710 let mut matchers1 = HashMap::new();
711 matchers1.insert(
712 "message".to_string(),
713 ContextMatcher::Compare(CompareOp::Contains("world".into())),
714 );
715 let guard1 = TransitionGuard::Conditions(GuardConditions::Context(matchers1));
716 assert!(evaluate_guard(&guard1, &ctx));
717
718 let mut matchers2 = HashMap::new();
719 matchers2.insert(
720 "tags".to_string(),
721 ContextMatcher::Compare(CompareOp::Contains("urgent".into())),
722 );
723 let guard2 = TransitionGuard::Conditions(GuardConditions::Context(matchers2));
724 assert!(evaluate_guard(&guard2, &ctx));
725 }
726
727 #[test]
728 fn test_compare_in() {
729 let mut context_map = HashMap::new();
730 context_map.insert("tier".to_string(), Value::String("premium".into()));
731
732 let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
733
734 let mut matchers = HashMap::new();
735 matchers.insert(
736 "tier".to_string(),
737 ContextMatcher::Compare(CompareOp::In(vec![
738 Value::String("premium".into()),
739 Value::String("enterprise".into()),
740 ])),
741 );
742
743 let guard = TransitionGuard::Conditions(GuardConditions::Context(matchers));
744 assert!(evaluate_guard(&guard, &ctx));
745 }
746
747 #[tokio::test]
749 async fn test_intent_based_routing_deterministic() {
750 let mock = MockLLMProvider::new("intent_test");
752 let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
753
754 let transitions = vec![
755 Transition {
756 to: "cancel_order".into(),
757 when: "User wants to cancel an order".into(),
758 guard: None,
759 intent: Some("cancel_order".into()),
760 auto: true,
761 priority: 10,
762 cooldown_turns: None,
763 timing: TransitionTiming::PostResponse,
764 requires_response: false,
765 run_extractors: false,
766 },
767 Transition {
768 to: "cancel_reservation".into(),
769 when: "User wants to cancel a reservation".into(),
770 guard: None,
771 intent: Some("cancel_reservation".into()),
772 auto: true,
773 priority: 10,
774 cooldown_turns: None,
775 timing: TransitionTiming::PostResponse,
776 requires_response: false,
777 run_extractors: false,
778 },
779 Transition {
780 to: "cancel_subscription".into(),
781 when: "User wants to cancel a subscription".into(),
782 guard: None,
783 intent: Some("cancel_subscription".into()),
784 auto: true,
785 priority: 10,
786 cooldown_turns: None,
787 timing: TransitionTiming::PostResponse,
788 requires_response: false,
789 run_extractors: false,
790 },
791 ];
792
793 let mut context_map = HashMap::new();
795 context_map.insert(
796 "resolved_intent".to_string(),
797 Value::String("cancel_reservation".into()),
798 );
799
800 let ctx =
801 TransitionContext::new("あれキャンセルして", "", "greeting").with_context(context_map);
802
803 let result = evaluator
804 .select_transition(&transitions, &ctx)
805 .await
806 .unwrap();
807 assert_eq!(result, Some(1));
809 }
810
811 #[tokio::test]
813 async fn test_intent_routing_falls_back_to_llm_when_no_resolved_intent() {
814 let mut mock = MockLLMProvider::new("intent_fallback_test");
815 mock.add_response(LLMResponse::new("1", FinishReason::Stop));
817 let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
818
819 let transitions = vec![
820 Transition {
821 to: "cancel_order".into(),
822 when: "User wants to cancel an order".into(),
823 guard: None,
824 intent: Some("cancel_order".into()),
825 auto: true,
826 priority: 10,
827 cooldown_turns: None,
828 timing: TransitionTiming::PostResponse,
829 requires_response: false,
830 run_extractors: false,
831 },
832 Transition {
833 to: "cancel_reservation".into(),
834 when: "User wants to cancel a reservation".into(),
835 guard: None,
836 intent: Some("cancel_reservation".into()),
837 auto: true,
838 priority: 10,
839 cooldown_turns: None,
840 timing: TransitionTiming::PostResponse,
841 requires_response: false,
842 run_extractors: false,
843 },
844 ];
845
846 let ctx = TransitionContext::new("Cancel order ORD-1042", "", "greeting")
848 .with_context(HashMap::new());
849
850 let result = evaluator
851 .select_transition(&transitions, &ctx)
852 .await
853 .unwrap();
854 assert_eq!(result, Some(0));
856 }
857
858 #[tokio::test]
860 async fn test_no_routing_when_resolved_intent_doesnt_match() {
861 let mut mock = MockLLMProvider::new("intent_nomatch_test");
862 mock.add_response(LLMResponse::new("0", FinishReason::Stop));
864 let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
865
866 let transitions = vec![
867 Transition {
868 to: "cancel_order".into(),
869 when: "User wants to cancel an order".into(),
870 guard: None,
871 intent: Some("cancel_order".into()),
872 auto: true,
873 priority: 10,
874 cooldown_turns: None,
875 timing: TransitionTiming::PostResponse,
876 requires_response: false,
877 run_extractors: false,
878 },
879 Transition {
880 to: "cancel_reservation".into(),
881 when: "User wants to cancel a reservation".into(),
882 guard: None,
883 intent: Some("cancel_reservation".into()),
884 auto: true,
885 priority: 10,
886 cooldown_turns: None,
887 timing: TransitionTiming::PostResponse,
888 requires_response: false,
889 run_extractors: false,
890 },
891 ];
892
893 let mut context_map = HashMap::new();
895 context_map.insert(
896 "resolved_intent".to_string(),
897 Value::String("something_else".into()),
898 );
899
900 let ctx = TransitionContext::new("do something", "", "greeting").with_context(context_map);
901
902 let result = evaluator
903 .select_transition(&transitions, &ctx)
904 .await
905 .unwrap();
906 assert_eq!(result, None);
908 }
909
910 #[tokio::test]
912 async fn test_null_resolved_intent_is_ignored() {
913 let mut mock = MockLLMProvider::new("intent_null_test");
914 mock.add_response(LLMResponse::new("1", FinishReason::Stop));
915 let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
916
917 let transitions = vec![Transition {
918 to: "cancel_order".into(),
919 when: "User wants to cancel an order".into(),
920 guard: None,
921 intent: Some("cancel_order".into()),
922 auto: true,
923 priority: 10,
924 cooldown_turns: None,
925 timing: TransitionTiming::PostResponse,
926 requires_response: false,
927 run_extractors: false,
928 }];
929
930 let mut context_map = HashMap::new();
932 context_map.insert("resolved_intent".to_string(), Value::Null);
933
934 let ctx =
935 TransitionContext::new("Cancel my order", "", "greeting").with_context(context_map);
936
937 let result = evaluator
938 .select_transition(&transitions, &ctx)
939 .await
940 .unwrap();
941 assert_eq!(result, Some(0));
943 }
944}