Skip to main content

ai_agents_state/
evaluator.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6
7use ai_agents_core::{ChatMessage, LLMProvider, Result};
8
9use super::config::{CompareOp, ContextMatcher, GuardConditions, Transition, TransitionGuard};
10
11pub struct TransitionContext {
12    pub user_message: String,
13    pub assistant_response: String,
14    pub current_state: String,
15    pub context: HashMap<String, Value>,
16}
17
18impl TransitionContext {
19    pub fn new(user_message: &str, assistant_response: &str, current_state: &str) -> Self {
20        Self {
21            user_message: user_message.to_string(),
22            assistant_response: assistant_response.to_string(),
23            current_state: current_state.to_string(),
24            context: HashMap::new(),
25        }
26    }
27
28    pub fn with_context(mut self, context: HashMap<String, Value>) -> Self {
29        self.context = context;
30        self
31    }
32}
33
34#[async_trait]
35pub trait TransitionEvaluator: Send + Sync {
36    async fn select_transition(
37        &self,
38        transitions: &[Transition],
39        context: &TransitionContext,
40    ) -> Result<Option<usize>>;
41}
42
43pub struct LLMTransitionEvaluator {
44    llm: Arc<dyn LLMProvider>,
45}
46
47impl LLMTransitionEvaluator {
48    pub fn new(llm: Arc<dyn LLMProvider>) -> Self {
49        Self { llm }
50    }
51}
52
53// ── Standalone guard evaluation functions ──
54
55pub fn evaluate_guard(guard: &TransitionGuard, ctx: &TransitionContext) -> bool {
56    match guard {
57        TransitionGuard::Expression(expr) => evaluate_expression(expr, ctx),
58        TransitionGuard::Conditions(conditions) => evaluate_conditions(conditions, ctx),
59    }
60}
61
62pub fn evaluate_expression(expr: &str, ctx: &TransitionContext) -> bool {
63    let expr = expr.trim();
64
65    if !expr.contains("{{") {
66        return !expr.is_empty();
67    }
68
69    let inner = expr.trim_start_matches("{{").trim_end_matches("}}").trim();
70
71    evaluate_simple_expression(inner, ctx)
72}
73
74fn evaluate_simple_expression(expr: &str, ctx: &TransitionContext) -> bool {
75    if expr.starts_with("context.") {
76        let path = &expr[8..];
77        return get_context_value(path, &ctx.context).is_some();
78    }
79
80    if expr.starts_with("state.") {
81        let field = &expr[6..];
82        return evaluate_state_expression(field, ctx);
83    }
84
85    if let Some(idx) = expr.find('>') {
86        let (left, right) = expr.split_at(idx);
87        let op = if right.starts_with(">=") { ">=" } else { ">" };
88        let right = right.trim_start_matches(op).trim();
89        let left = left.trim();
90
91        if let (Some(left_val), Ok(right_val)) = (resolve_value(left, ctx), right.parse::<f64>()) {
92            if let Some(left_num) = left_val.as_f64() {
93                return if op == ">=" {
94                    left_num >= right_val
95                } else {
96                    left_num > right_val
97                };
98            }
99        }
100    }
101
102    if let Some(idx) = expr.find('<') {
103        let (left, right) = expr.split_at(idx);
104        let op = if right.starts_with("<=") { "<=" } else { "<" };
105        let right = right.trim_start_matches(op).trim();
106        let left = left.trim();
107
108        if let (Some(left_val), Ok(right_val)) = (resolve_value(left, ctx), right.parse::<f64>()) {
109            if let Some(left_num) = left_val.as_f64() {
110                return if op == "<=" {
111                    left_num <= right_val
112                } else {
113                    left_num < right_val
114                };
115            }
116        }
117    }
118
119    if let Some(idx) = expr.find("==") {
120        let (left, right) = expr.split_at(idx);
121        let right = &right[2..].trim();
122        let left = left.trim();
123
124        if let Some(left_val) = resolve_value(left, ctx) {
125            let right_val: Value = if right.starts_with('"') && right.ends_with('"') {
126                Value::String(right[1..right.len() - 1].to_string())
127            } else if *right == "true" {
128                Value::Bool(true)
129            } else if *right == "false" {
130                Value::Bool(false)
131            } else if let Ok(n) = right.parse::<f64>() {
132                serde_json::json!(n)
133            } else {
134                Value::String(right.to_string())
135            };
136
137            return left_val == right_val;
138        }
139    }
140
141    if let Some(idx) = expr.find("!=") {
142        let (left, right) = expr.split_at(idx);
143        let right = &right[2..].trim();
144        let left = left.trim();
145
146        if let Some(left_val) = resolve_value(left, ctx) {
147            let right_val: Value = if right.starts_with('"') && right.ends_with('"') {
148                Value::String(right[1..right.len() - 1].to_string())
149            } else if *right == "true" {
150                Value::Bool(true)
151            } else if *right == "false" {
152                Value::Bool(false)
153            } else if let Ok(n) = right.parse::<f64>() {
154                serde_json::json!(n)
155            } else {
156                Value::String(right.to_string())
157            };
158
159            return left_val != right_val;
160        }
161    }
162
163    false
164}
165
166fn resolve_value(expr: &str, ctx: &TransitionContext) -> Option<Value> {
167    let expr = expr.trim();
168    if expr.starts_with("context.") {
169        let path = &expr[8..];
170        return get_context_value(path, &ctx.context);
171    }
172    if expr.starts_with("state.") {
173        let field = &expr[6..];
174        return get_state_value(field, ctx);
175    }
176    None
177}
178
179fn evaluate_state_expression(field: &str, _ctx: &TransitionContext) -> bool {
180    match field {
181        "turn_count" => true,
182        _ => false,
183    }
184}
185
186fn get_state_value(field: &str, ctx: &TransitionContext) -> Option<Value> {
187    match field {
188        "current" => Some(Value::String(ctx.current_state.clone())),
189        _ => None,
190    }
191}
192
193pub fn evaluate_conditions(conditions: &GuardConditions, ctx: &TransitionContext) -> bool {
194    match conditions {
195        GuardConditions::All(exprs) => exprs.iter().all(|e| evaluate_expression(e, ctx)),
196        GuardConditions::Any(exprs) => exprs.iter().any(|e| evaluate_expression(e, ctx)),
197        GuardConditions::Context(matchers) => evaluate_context_matchers(matchers, &ctx.context),
198    }
199}
200
201pub fn evaluate_context_matchers(
202    matchers: &HashMap<String, ContextMatcher>,
203    context: &HashMap<String, Value>,
204) -> bool {
205    for (path, matcher) in matchers {
206        let value = get_context_value(path, context);
207        if !match_value(value.as_ref(), matcher) {
208            return false;
209        }
210    }
211    true
212}
213
214pub fn get_context_value(path: &str, context: &HashMap<String, Value>) -> Option<Value> {
215    ai_agents_core::get_dot_path_from_map(context, path)
216}
217
218pub fn match_value(value: Option<&Value>, matcher: &ContextMatcher) -> bool {
219    match matcher {
220        ContextMatcher::Exact(expected) => value.map(|v| v == expected).unwrap_or(false),
221        ContextMatcher::Exists { exists } => {
222            let has_value = value.is_some() && value != Some(&Value::Null);
223            *exists == has_value
224        }
225        ContextMatcher::Compare(op) => {
226            let Some(val) = value else {
227                return false;
228            };
229            compare_value(val, op)
230        }
231    }
232}
233
234/// Compare with type coercion: context extractors store all values as strings, but YAML guards may specify booleans (`eq: true`) or numbers (`eq: 42`).
235/// Coerce the string to the expected type before comparing.
236fn values_equal_coerced(value: &Value, expected: &Value) -> bool {
237    if value == expected {
238        return true;
239    }
240    // String value vs non-string expected: coerce the string
241    if let Some(s) = value.as_str() {
242        match expected {
243            Value::Bool(b) => match s {
244                "true" => return *b,
245                "false" => return !*b,
246                _ => {}
247            },
248            Value::Number(n) => {
249                if let Ok(parsed) = s.parse::<f64>() {
250                    if let Some(expected_f) = n.as_f64() {
251                        return (parsed - expected_f).abs() < f64::EPSILON;
252                    }
253                }
254            }
255            _ => {}
256        }
257    }
258    // Non-string value vs string expected: coerce the other way
259    if let Some(s) = expected.as_str() {
260        match value {
261            Value::Bool(b) => match s {
262                "true" => return *b,
263                "false" => return !*b,
264                _ => {}
265            },
266            Value::Number(n) => {
267                if let Ok(parsed) = s.parse::<f64>() {
268                    if let Some(val_f) = n.as_f64() {
269                        return (parsed - val_f).abs() < f64::EPSILON;
270                    }
271                }
272            }
273            _ => {}
274        }
275    }
276    false
277}
278
279pub fn compare_value(value: &Value, op: &CompareOp) -> bool {
280    match op {
281        CompareOp::Eq(expected) => values_equal_coerced(value, expected),
282        CompareOp::Neq(expected) => !values_equal_coerced(value, expected),
283        CompareOp::Gt(n) => value.as_f64().map(|v| v > *n).unwrap_or(false),
284        CompareOp::Gte(n) => value.as_f64().map(|v| v >= *n).unwrap_or(false),
285        CompareOp::Lt(n) => value.as_f64().map(|v| v < *n).unwrap_or(false),
286        CompareOp::Lte(n) => value.as_f64().map(|v| v <= *n).unwrap_or(false),
287        CompareOp::In(values) => values.contains(value),
288        CompareOp::Contains(s) => value
289            .as_str()
290            .map(|v| v.contains(s))
291            .or_else(|| {
292                value
293                    .as_array()
294                    .map(|arr| arr.iter().any(|v| v.as_str() == Some(s)))
295            })
296            .unwrap_or(false),
297    }
298}
299
300#[async_trait]
301impl TransitionEvaluator for LLMTransitionEvaluator {
302    async fn select_transition(
303        &self,
304        transitions: &[Transition],
305        context: &TransitionContext,
306    ) -> Result<Option<usize>> {
307        if transitions.is_empty() {
308            return Ok(None);
309        }
310
311        // Guard-based transitions (existing, no LLM)
312        for (i, transition) in transitions.iter().enumerate() {
313            if let Some(ref guard) = transition.guard {
314                if evaluate_guard(guard, context) {
315                    return Ok(Some(i));
316                }
317            }
318        }
319
320        // !!NOTE: Resolved-intent short-circuit
321        //
322        // If disambiguation has resolved an intent, try to match it against transitions that declare an `intent` field.
323        // This is DETERMINISTIC - no LLM call.
324        if let Some(resolved) = context.context.get("resolved_intent") {
325            if let Some(resolved_str) = resolved.as_str() {
326                // Skip null values (used to clear stale context)
327                if !resolved_str.is_empty() {
328                    for (i, transition) in transitions.iter().enumerate() {
329                        if let Some(ref intent) = transition.intent {
330                            if intent == resolved_str {
331                                tracing::debug!(
332                                    resolved_intent = resolved_str,
333                                    target = %transition.to,
334                                    "Deterministic routing via resolved_intent"
335                                );
336                                return Ok(Some(i));
337                            }
338                        }
339                    }
340                }
341            }
342        }
343
344        // ── Phase 1: LLM-based evaluation (existing) ──
345        let llm_transitions: Vec<(usize, &Transition)> = transitions
346            .iter()
347            .enumerate()
348            .filter(|(_, t)| !t.when.is_empty() && t.guard.is_none())
349            .collect();
350
351        if llm_transitions.is_empty() {
352            return Ok(None);
353        }
354
355        let conditions: Vec<String> = llm_transitions
356            .iter()
357            .enumerate()
358            .map(|(display_idx, (_, t))| format!("{}. {}", display_idx + 1, t.when))
359            .collect();
360
361        let prompt = format!(
362            r#"Based on the conversation, which condition is met?
363
364Current state: {}
365User message: {}
366Assistant response: {}
367
368Conditions:
369{}
3700. None of the above
371
372Reply with ONLY the number (0-{})."#,
373            context.current_state,
374            context.user_message,
375            context.assistant_response,
376            conditions.join("\n"),
377            llm_transitions.len()
378        );
379
380        let messages = vec![ChatMessage::user(&prompt)];
381        let response = self.llm.complete(&messages, None).await?;
382
383        let choice: usize = response.content.trim().parse().unwrap_or(0);
384
385        if choice == 0 || choice > llm_transitions.len() {
386            Ok(None)
387        } else {
388            Ok(Some(llm_transitions[choice - 1].0))
389        }
390    }
391}
392
393pub struct GuardOnlyEvaluator;
394
395impl GuardOnlyEvaluator {
396    pub fn new() -> Self {
397        Self
398    }
399
400    pub fn evaluate_guard(&self, guard: &TransitionGuard, ctx: &TransitionContext) -> bool {
401        evaluate_guard(guard, ctx)
402    }
403
404    pub fn evaluate_guards(
405        &self,
406        transitions: &[Transition],
407        ctx: &TransitionContext,
408    ) -> Option<usize> {
409        for (i, transition) in transitions.iter().enumerate() {
410            if let Some(ref guard) = transition.guard {
411                if evaluate_guard(guard, ctx) {
412                    return Some(i);
413                }
414            }
415        }
416        None
417    }
418}
419
420impl Default for GuardOnlyEvaluator {
421    fn default() -> Self {
422        Self::new()
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use ai_agents_core::{FinishReason, LLMResponse};
430    use ai_agents_llm::mock::MockLLMProvider;
431
432    #[tokio::test]
433    async fn test_select_transition_none() {
434        let mut mock = MockLLMProvider::new("evaluator_test");
435        mock.add_response(LLMResponse::new("0", FinishReason::Stop));
436        let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
437
438        let transitions = vec![Transition {
439            to: "next".into(),
440            when: "user says goodbye".into(),
441            guard: None,
442            intent: None,
443            auto: true,
444            priority: 0,
445            cooldown_turns: None,
446        }];
447
448        let context = TransitionContext::new("hello", "hi there", "greeting");
449
450        let result = evaluator.select_transition(&transitions, &context).await;
451        assert!(result.is_ok());
452        assert!(result.unwrap().is_none());
453    }
454
455    #[tokio::test]
456    async fn test_select_transition_match() {
457        let mut mock = MockLLMProvider::new("evaluator_test");
458        mock.add_response(LLMResponse::new("1", FinishReason::Stop));
459        let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
460
461        let transitions = vec![
462            Transition {
463                to: "support".into(),
464                when: "user needs help".into(),
465                guard: None,
466                intent: None,
467                auto: true,
468                priority: 10,
469                cooldown_turns: None,
470            },
471            Transition {
472                to: "sales".into(),
473                when: "user wants to buy".into(),
474                guard: None,
475                intent: None,
476                auto: true,
477                priority: 5,
478                cooldown_turns: None,
479            },
480        ];
481
482        let context = TransitionContext::new("I need help", "Sure!", "greeting");
483
484        let result = evaluator.select_transition(&transitions, &context).await;
485        assert!(result.is_ok());
486        assert_eq!(result.unwrap(), Some(0));
487    }
488
489    #[tokio::test]
490    async fn test_empty_transitions() {
491        let mock = MockLLMProvider::new("evaluator_test");
492        let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
493
494        let context = TransitionContext::new("hi", "hello", "start");
495
496        let result = evaluator.select_transition(&[], &context).await;
497        assert!(result.is_ok());
498        assert!(result.unwrap().is_none());
499    }
500
501    #[test]
502    fn test_guard_expression_simple() {
503        let mut context_map = HashMap::new();
504        context_map.insert("has_data".to_string(), Value::Bool(true));
505
506        let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
507
508        let guard = TransitionGuard::Expression("{{ context.has_data }}".into());
509        assert!(evaluate_guard(&guard, &ctx));
510    }
511
512    #[test]
513    fn test_guard_expression_missing() {
514        let ctx = TransitionContext::new("hi", "hello", "start").with_context(HashMap::new());
515
516        let guard = TransitionGuard::Expression("{{ context.has_data }}".into());
517        assert!(!evaluate_guard(&guard, &ctx));
518    }
519
520    #[test]
521    fn test_guard_with_nested_context() {
522        let mut context_map = HashMap::new();
523        context_map.insert(
524            "user".to_string(),
525            serde_json::json!({
526                "name": "Alice",
527                "verified": true
528            }),
529        );
530
531        let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
532
533        let guard = TransitionGuard::Expression("{{ context.user.verified }}".into());
534        assert!(evaluate_guard(&guard, &ctx));
535    }
536
537    #[test]
538    fn test_guard_conditions_all() {
539        let mut context_map = HashMap::new();
540        context_map.insert("has_name".to_string(), Value::Bool(true));
541        context_map.insert("has_email".to_string(), Value::Bool(true));
542
543        let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
544
545        let guard = TransitionGuard::Conditions(GuardConditions::All(vec![
546            "{{ context.has_name }}".into(),
547            "{{ context.has_email }}".into(),
548        ]));
549        assert!(evaluate_guard(&guard, &ctx));
550    }
551
552    #[test]
553    fn test_guard_conditions_any() {
554        let mut context_map = HashMap::new();
555        context_map.insert("is_vip".to_string(), Value::Bool(true));
556
557        let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
558
559        let guard = TransitionGuard::Conditions(GuardConditions::Any(vec![
560            "{{ context.is_admin }}".into(),
561            "{{ context.is_vip }}".into(),
562        ]));
563        assert!(evaluate_guard(&guard, &ctx));
564    }
565
566    #[test]
567    fn test_guard_context_matchers() {
568        let mut context_map = HashMap::new();
569        context_map.insert(
570            "user".to_string(),
571            serde_json::json!({
572                "tier": "premium",
573                "balance": 100.0
574            }),
575        );
576
577        let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
578
579        let mut matchers = HashMap::new();
580        matchers.insert(
581            "user.tier".to_string(),
582            ContextMatcher::Exact(Value::String("premium".into())),
583        );
584        matchers.insert(
585            "user.balance".to_string(),
586            ContextMatcher::Compare(CompareOp::Gte(50.0)),
587        );
588
589        let guard = TransitionGuard::Conditions(GuardConditions::Context(matchers));
590        assert!(evaluate_guard(&guard, &ctx));
591    }
592
593    #[tokio::test]
594    async fn test_guard_priority_over_llm() {
595        let mock = MockLLMProvider::new("guard_test");
596        let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
597
598        let mut context_map = HashMap::new();
599        context_map.insert("ready".to_string(), Value::Bool(true));
600
601        let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
602
603        let transitions = vec![
604            Transition {
605                to: "llm_based".into(),
606                when: "user wants to proceed".into(),
607                guard: None,
608                intent: None,
609                auto: true,
610                priority: 10,
611                cooldown_turns: None,
612            },
613            Transition {
614                to: "guard_based".into(),
615                when: String::new(),
616                guard: Some(TransitionGuard::Expression("{{ context.ready }}".into())),
617                intent: None,
618                auto: true,
619                priority: 5,
620                cooldown_turns: None,
621            },
622        ];
623
624        let result = evaluator.select_transition(&transitions, &ctx).await;
625        assert_eq!(result.unwrap(), Some(1));
626    }
627
628    #[test]
629    fn test_guard_only_evaluator() {
630        let evaluator = GuardOnlyEvaluator::new();
631
632        let mut context_map = HashMap::new();
633        context_map.insert("ready".to_string(), Value::Bool(true));
634
635        let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
636
637        let transitions = vec![
638            Transition {
639                to: "no_guard".into(),
640                when: "some condition".into(),
641                guard: None,
642                intent: None,
643                auto: true,
644                priority: 10,
645                cooldown_turns: None,
646            },
647            Transition {
648                to: "with_guard".into(),
649                when: String::new(),
650                guard: Some(TransitionGuard::Expression("{{ context.ready }}".into())),
651                intent: None,
652                auto: true,
653                priority: 5,
654                cooldown_turns: None,
655            },
656        ];
657
658        let result = evaluator.evaluate_guards(&transitions, &ctx);
659        assert_eq!(result, Some(1));
660    }
661
662    #[test]
663    fn test_context_matcher_exists() {
664        let mut context_map = HashMap::new();
665        context_map.insert("name".to_string(), Value::String("Alice".into()));
666
667        let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
668
669        let mut matchers = HashMap::new();
670        matchers.insert("name".to_string(), ContextMatcher::Exists { exists: true });
671        matchers.insert(
672            "email".to_string(),
673            ContextMatcher::Exists { exists: false },
674        );
675
676        let guard = TransitionGuard::Conditions(GuardConditions::Context(matchers));
677        assert!(evaluate_guard(&guard, &ctx));
678    }
679
680    #[test]
681    fn test_compare_contains() {
682        let mut context_map = HashMap::new();
683        context_map.insert("message".to_string(), Value::String("hello world".into()));
684        context_map.insert("tags".to_string(), serde_json::json!(["urgent", "support"]));
685
686        let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
687
688        let mut matchers1 = HashMap::new();
689        matchers1.insert(
690            "message".to_string(),
691            ContextMatcher::Compare(CompareOp::Contains("world".into())),
692        );
693        let guard1 = TransitionGuard::Conditions(GuardConditions::Context(matchers1));
694        assert!(evaluate_guard(&guard1, &ctx));
695
696        let mut matchers2 = HashMap::new();
697        matchers2.insert(
698            "tags".to_string(),
699            ContextMatcher::Compare(CompareOp::Contains("urgent".into())),
700        );
701        let guard2 = TransitionGuard::Conditions(GuardConditions::Context(matchers2));
702        assert!(evaluate_guard(&guard2, &ctx));
703    }
704
705    #[test]
706    fn test_compare_in() {
707        let mut context_map = HashMap::new();
708        context_map.insert("tier".to_string(), Value::String("premium".into()));
709
710        let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
711
712        let mut matchers = HashMap::new();
713        matchers.insert(
714            "tier".to_string(),
715            ContextMatcher::Compare(CompareOp::In(vec![
716                Value::String("premium".into()),
717                Value::String("enterprise".into()),
718            ])),
719        );
720
721        let guard = TransitionGuard::Conditions(GuardConditions::Context(matchers));
722        assert!(evaluate_guard(&guard, &ctx));
723    }
724
725    /// a transition's `intent` field, routing is deterministic - no LLM call.
726    #[tokio::test]
727    async fn test_intent_based_routing_deterministic() {
728        // The mock has NO responses queued — if the LLM were called it would panic/fail.
729        let mock = MockLLMProvider::new("intent_test");
730        let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
731
732        let transitions = vec![
733            Transition {
734                to: "cancel_order".into(),
735                when: "User wants to cancel an order".into(),
736                guard: None,
737                intent: Some("cancel_order".into()),
738                auto: true,
739                priority: 10,
740                cooldown_turns: None,
741            },
742            Transition {
743                to: "cancel_reservation".into(),
744                when: "User wants to cancel a reservation".into(),
745                guard: None,
746                intent: Some("cancel_reservation".into()),
747                auto: true,
748                priority: 10,
749                cooldown_turns: None,
750            },
751            Transition {
752                to: "cancel_subscription".into(),
753                when: "User wants to cancel a subscription".into(),
754                guard: None,
755                intent: Some("cancel_subscription".into()),
756                auto: true,
757                priority: 10,
758                cooldown_turns: None,
759            },
760        ];
761
762        // Simulate disambiguation having resolved intent to "cancel_reservation"
763        let mut context_map = HashMap::new();
764        context_map.insert(
765            "resolved_intent".to_string(),
766            Value::String("cancel_reservation".into()),
767        );
768
769        let ctx =
770            TransitionContext::new("あれキャンセルして", "", "greeting").with_context(context_map);
771
772        let result = evaluator
773            .select_transition(&transitions, &ctx)
774            .await
775            .unwrap();
776        // Should pick index 1 (cancel_reservation) deterministically
777        assert_eq!(result, Some(1));
778    }
779
780    /// the evaluator falls back to LLM-based `when` evaluation.
781    #[tokio::test]
782    async fn test_intent_routing_falls_back_to_llm_when_no_resolved_intent() {
783        let mut mock = MockLLMProvider::new("intent_fallback_test");
784        // LLM returns "1" → first LLM transition (cancel_order)
785        mock.add_response(LLMResponse::new("1", FinishReason::Stop));
786        let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
787
788        let transitions = vec![
789            Transition {
790                to: "cancel_order".into(),
791                when: "User wants to cancel an order".into(),
792                guard: None,
793                intent: Some("cancel_order".into()),
794                auto: true,
795                priority: 10,
796                cooldown_turns: None,
797            },
798            Transition {
799                to: "cancel_reservation".into(),
800                when: "User wants to cancel a reservation".into(),
801                guard: None,
802                intent: Some("cancel_reservation".into()),
803                auto: true,
804                priority: 10,
805                cooldown_turns: None,
806            },
807        ];
808
809        // No resolved_intent in context -> LLM evaluation fires
810        let ctx = TransitionContext::new("Cancel order ORD-1042", "", "greeting")
811            .with_context(HashMap::new());
812
813        let result = evaluator
814            .select_transition(&transitions, &ctx)
815            .await
816            .unwrap();
817        // LLM said "1" → index 0 (first LLM transition)
818        assert_eq!(result, Some(0));
819    }
820
821    /// transition's `intent` field, it falls through to LLM evaluation.
822    #[tokio::test]
823    async fn test_no_routing_when_resolved_intent_doesnt_match() {
824        let mut mock = MockLLMProvider::new("intent_nomatch_test");
825        // LLM returns "0" (none of the above)
826        mock.add_response(LLMResponse::new("0", FinishReason::Stop));
827        let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
828
829        let transitions = vec![
830            Transition {
831                to: "cancel_order".into(),
832                when: "User wants to cancel an order".into(),
833                guard: None,
834                intent: Some("cancel_order".into()),
835                auto: true,
836                priority: 10,
837                cooldown_turns: None,
838            },
839            Transition {
840                to: "cancel_reservation".into(),
841                when: "User wants to cancel a reservation".into(),
842                guard: None,
843                intent: Some("cancel_reservation".into()),
844                auto: true,
845                priority: 10,
846                cooldown_turns: None,
847            },
848        ];
849
850        // resolved_intent is "something_else" which matches no transition
851        let mut context_map = HashMap::new();
852        context_map.insert(
853            "resolved_intent".to_string(),
854            Value::String("something_else".into()),
855        );
856
857        let ctx = TransitionContext::new("do something", "", "greeting").with_context(context_map);
858
859        let result = evaluator
860            .select_transition(&transitions, &ctx)
861            .await
862            .unwrap();
863        // LLM said "0" → None
864        assert_eq!(result, None);
865    }
866
867    /// should be ignored - not treated as a valid intent.
868    #[tokio::test]
869    async fn test_null_resolved_intent_is_ignored() {
870        let mut mock = MockLLMProvider::new("intent_null_test");
871        mock.add_response(LLMResponse::new("1", FinishReason::Stop));
872        let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
873
874        let transitions = vec![Transition {
875            to: "cancel_order".into(),
876            when: "User wants to cancel an order".into(),
877            guard: None,
878            intent: Some("cancel_order".into()),
879            auto: true,
880            priority: 10,
881            cooldown_turns: None,
882        }];
883
884        // resolved_intent is Null (simulating clear_disambiguation_context)
885        let mut context_map = HashMap::new();
886        context_map.insert("resolved_intent".to_string(), Value::Null);
887
888        let ctx =
889            TransitionContext::new("Cancel my order", "", "greeting").with_context(context_map);
890
891        let result = evaluator
892            .select_transition(&transitions, &ctx)
893            .await
894            .unwrap();
895        // Null resolved_intent should be ignored; LLM said "1" -> index 0
896        assert_eq!(result, Some(0));
897    }
898}