use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use ai_agents_core::{ChatMessage, LLMProvider, Result};
use super::config::{CompareOp, ContextMatcher, GuardConditions, Transition, TransitionGuard};
pub struct TransitionContext {
pub user_message: String,
pub assistant_response: String,
pub current_state: String,
pub context: HashMap<String, Value>,
}
impl TransitionContext {
pub fn new(user_message: &str, assistant_response: &str, current_state: &str) -> Self {
Self {
user_message: user_message.to_string(),
assistant_response: assistant_response.to_string(),
current_state: current_state.to_string(),
context: HashMap::new(),
}
}
pub fn with_context(mut self, context: HashMap<String, Value>) -> Self {
self.context = context;
self
}
}
#[async_trait]
pub trait TransitionEvaluator: Send + Sync {
async fn select_transition(
&self,
transitions: &[Transition],
context: &TransitionContext,
) -> Result<Option<usize>>;
}
pub struct LLMTransitionEvaluator {
llm: Arc<dyn LLMProvider>,
}
impl LLMTransitionEvaluator {
pub fn new(llm: Arc<dyn LLMProvider>) -> Self {
Self { llm }
}
}
pub fn evaluate_guard(guard: &TransitionGuard, ctx: &TransitionContext) -> bool {
match guard {
TransitionGuard::Expression(expr) => evaluate_expression(expr, ctx),
TransitionGuard::Conditions(conditions) => evaluate_conditions(conditions, ctx),
}
}
pub fn evaluate_expression(expr: &str, ctx: &TransitionContext) -> bool {
let expr = expr.trim();
if !expr.contains("{{") {
return !expr.is_empty();
}
let inner = expr.trim_start_matches("{{").trim_end_matches("}}").trim();
evaluate_simple_expression(inner, ctx)
}
fn evaluate_simple_expression(expr: &str, ctx: &TransitionContext) -> bool {
if expr.starts_with("context.") {
let path = &expr[8..];
return get_context_value(path, &ctx.context).is_some();
}
if expr.starts_with("state.") {
let field = &expr[6..];
return evaluate_state_expression(field, ctx);
}
if let Some(idx) = expr.find('>') {
let (left, right) = expr.split_at(idx);
let op = if right.starts_with(">=") { ">=" } else { ">" };
let right = right.trim_start_matches(op).trim();
let left = left.trim();
if let (Some(left_val), Ok(right_val)) = (resolve_value(left, ctx), right.parse::<f64>()) {
if let Some(left_num) = left_val.as_f64() {
return if op == ">=" {
left_num >= right_val
} else {
left_num > right_val
};
}
}
}
if let Some(idx) = expr.find('<') {
let (left, right) = expr.split_at(idx);
let op = if right.starts_with("<=") { "<=" } else { "<" };
let right = right.trim_start_matches(op).trim();
let left = left.trim();
if let (Some(left_val), Ok(right_val)) = (resolve_value(left, ctx), right.parse::<f64>()) {
if let Some(left_num) = left_val.as_f64() {
return if op == "<=" {
left_num <= right_val
} else {
left_num < right_val
};
}
}
}
if let Some(idx) = expr.find("==") {
let (left, right) = expr.split_at(idx);
let right = &right[2..].trim();
let left = left.trim();
if let Some(left_val) = resolve_value(left, ctx) {
let right_val: Value = if right.starts_with('"') && right.ends_with('"') {
Value::String(right[1..right.len() - 1].to_string())
} else if *right == "true" {
Value::Bool(true)
} else if *right == "false" {
Value::Bool(false)
} else if let Ok(n) = right.parse::<f64>() {
serde_json::json!(n)
} else {
Value::String(right.to_string())
};
return left_val == right_val;
}
}
if let Some(idx) = expr.find("!=") {
let (left, right) = expr.split_at(idx);
let right = &right[2..].trim();
let left = left.trim();
if let Some(left_val) = resolve_value(left, ctx) {
let right_val: Value = if right.starts_with('"') && right.ends_with('"') {
Value::String(right[1..right.len() - 1].to_string())
} else if *right == "true" {
Value::Bool(true)
} else if *right == "false" {
Value::Bool(false)
} else if let Ok(n) = right.parse::<f64>() {
serde_json::json!(n)
} else {
Value::String(right.to_string())
};
return left_val != right_val;
}
}
false
}
fn resolve_value(expr: &str, ctx: &TransitionContext) -> Option<Value> {
let expr = expr.trim();
if expr.starts_with("context.") {
let path = &expr[8..];
return get_context_value(path, &ctx.context);
}
if expr.starts_with("state.") {
let field = &expr[6..];
return get_state_value(field, ctx);
}
None
}
fn evaluate_state_expression(field: &str, _ctx: &TransitionContext) -> bool {
match field {
"turn_count" => true,
_ => false,
}
}
fn get_state_value(field: &str, ctx: &TransitionContext) -> Option<Value> {
match field {
"current" => Some(Value::String(ctx.current_state.clone())),
_ => None,
}
}
pub fn evaluate_conditions(conditions: &GuardConditions, ctx: &TransitionContext) -> bool {
match conditions {
GuardConditions::All(exprs) => exprs.iter().all(|e| evaluate_expression(e, ctx)),
GuardConditions::Any(exprs) => exprs.iter().any(|e| evaluate_expression(e, ctx)),
GuardConditions::Context(matchers) => evaluate_context_matchers(matchers, &ctx.context),
}
}
pub fn evaluate_context_matchers(
matchers: &HashMap<String, ContextMatcher>,
context: &HashMap<String, Value>,
) -> bool {
for (path, matcher) in matchers {
let value = get_context_value(path, context);
if !match_value(value.as_ref(), matcher) {
return false;
}
}
true
}
pub fn get_context_value(path: &str, context: &HashMap<String, Value>) -> Option<Value> {
ai_agents_core::get_dot_path_from_map(context, path)
}
pub fn match_value(value: Option<&Value>, matcher: &ContextMatcher) -> bool {
match matcher {
ContextMatcher::Exact(expected) => value.map(|v| v == expected).unwrap_or(false),
ContextMatcher::Exists { exists } => {
let has_value = value.is_some() && value != Some(&Value::Null);
*exists == has_value
}
ContextMatcher::Compare(op) => {
let Some(val) = value else {
return false;
};
compare_value(val, op)
}
}
}
fn values_equal_coerced(value: &Value, expected: &Value) -> bool {
if value == expected {
return true;
}
if let Some(s) = value.as_str() {
match expected {
Value::Bool(b) => match s {
"true" => return *b,
"false" => return !*b,
_ => {}
},
Value::Number(n) => {
if let Ok(parsed) = s.parse::<f64>() {
if let Some(expected_f) = n.as_f64() {
return (parsed - expected_f).abs() < f64::EPSILON;
}
}
}
_ => {}
}
}
if let Some(s) = expected.as_str() {
match value {
Value::Bool(b) => match s {
"true" => return *b,
"false" => return !*b,
_ => {}
},
Value::Number(n) => {
if let Ok(parsed) = s.parse::<f64>() {
if let Some(val_f) = n.as_f64() {
return (parsed - val_f).abs() < f64::EPSILON;
}
}
}
_ => {}
}
}
false
}
pub fn compare_value(value: &Value, op: &CompareOp) -> bool {
match op {
CompareOp::Eq(expected) => values_equal_coerced(value, expected),
CompareOp::Neq(expected) => !values_equal_coerced(value, expected),
CompareOp::Gt(n) => value.as_f64().map(|v| v > *n).unwrap_or(false),
CompareOp::Gte(n) => value.as_f64().map(|v| v >= *n).unwrap_or(false),
CompareOp::Lt(n) => value.as_f64().map(|v| v < *n).unwrap_or(false),
CompareOp::Lte(n) => value.as_f64().map(|v| v <= *n).unwrap_or(false),
CompareOp::In(values) => values.contains(value),
CompareOp::Contains(s) => value
.as_str()
.map(|v| v.contains(s))
.or_else(|| {
value
.as_array()
.map(|arr| arr.iter().any(|v| v.as_str() == Some(s)))
})
.unwrap_or(false),
}
}
#[async_trait]
impl TransitionEvaluator for LLMTransitionEvaluator {
async fn select_transition(
&self,
transitions: &[Transition],
context: &TransitionContext,
) -> Result<Option<usize>> {
if transitions.is_empty() {
return Ok(None);
}
for (i, transition) in transitions.iter().enumerate() {
if let Some(ref guard) = transition.guard {
if evaluate_guard(guard, context) {
return Ok(Some(i));
}
}
}
if let Some(resolved) = context.context.get("resolved_intent") {
if let Some(resolved_str) = resolved.as_str() {
if !resolved_str.is_empty() {
for (i, transition) in transitions.iter().enumerate() {
if let Some(ref intent) = transition.intent {
if intent == resolved_str {
tracing::debug!(
resolved_intent = resolved_str,
target = %transition.to,
"Deterministic routing via resolved_intent"
);
return Ok(Some(i));
}
}
}
}
}
}
let llm_transitions: Vec<(usize, &Transition)> = transitions
.iter()
.enumerate()
.filter(|(_, t)| !t.when.is_empty() && t.guard.is_none())
.collect();
if llm_transitions.is_empty() {
return Ok(None);
}
let conditions: Vec<String> = llm_transitions
.iter()
.enumerate()
.map(|(display_idx, (_, t))| format!("{}. {}", display_idx + 1, t.when))
.collect();
let prompt = format!(
r#"Based on the conversation, which condition is met?
Current state: {}
User message: {}
Assistant response: {}
Conditions:
{}
0. None of the above
Reply with ONLY the number (0-{})."#,
context.current_state,
context.user_message,
context.assistant_response,
conditions.join("\n"),
llm_transitions.len()
);
let messages = vec![ChatMessage::user(&prompt)];
let response = self.llm.complete(&messages, None).await?;
let choice: usize = response.content.trim().parse().unwrap_or(0);
if choice == 0 || choice > llm_transitions.len() {
Ok(None)
} else {
Ok(Some(llm_transitions[choice - 1].0))
}
}
}
pub struct GuardOnlyEvaluator;
impl GuardOnlyEvaluator {
pub fn new() -> Self {
Self
}
pub fn evaluate_guard(&self, guard: &TransitionGuard, ctx: &TransitionContext) -> bool {
evaluate_guard(guard, ctx)
}
pub fn evaluate_guards(
&self,
transitions: &[Transition],
ctx: &TransitionContext,
) -> Option<usize> {
for (i, transition) in transitions.iter().enumerate() {
if let Some(ref guard) = transition.guard {
if evaluate_guard(guard, ctx) {
return Some(i);
}
}
}
None
}
}
impl Default for GuardOnlyEvaluator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::super::config::TransitionTiming;
use super::*;
use ai_agents_core::{FinishReason, LLMResponse};
use ai_agents_llm::mock::MockLLMProvider;
#[tokio::test]
async fn test_select_transition_none() {
let mut mock = MockLLMProvider::new("evaluator_test");
mock.add_response(LLMResponse::new("0", FinishReason::Stop));
let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
let transitions = vec![Transition {
to: "next".into(),
when: "user says goodbye".into(),
guard: None,
intent: None,
auto: true,
priority: 0,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
}];
let context = TransitionContext::new("hello", "hi there", "greeting");
let result = evaluator.select_transition(&transitions, &context).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn test_select_transition_match() {
let mut mock = MockLLMProvider::new("evaluator_test");
mock.add_response(LLMResponse::new("1", FinishReason::Stop));
let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
let transitions = vec![
Transition {
to: "support".into(),
when: "user needs help".into(),
guard: None,
intent: None,
auto: true,
priority: 10,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
Transition {
to: "sales".into(),
when: "user wants to buy".into(),
guard: None,
intent: None,
auto: true,
priority: 5,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
];
let context = TransitionContext::new("I need help", "Sure!", "greeting");
let result = evaluator.select_transition(&transitions, &context).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), Some(0));
}
#[tokio::test]
async fn test_empty_transitions() {
let mock = MockLLMProvider::new("evaluator_test");
let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
let context = TransitionContext::new("hi", "hello", "start");
let result = evaluator.select_transition(&[], &context).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_guard_expression_simple() {
let mut context_map = HashMap::new();
context_map.insert("has_data".to_string(), Value::Bool(true));
let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
let guard = TransitionGuard::Expression("{{ context.has_data }}".into());
assert!(evaluate_guard(&guard, &ctx));
}
#[test]
fn test_guard_expression_missing() {
let ctx = TransitionContext::new("hi", "hello", "start").with_context(HashMap::new());
let guard = TransitionGuard::Expression("{{ context.has_data }}".into());
assert!(!evaluate_guard(&guard, &ctx));
}
#[test]
fn test_guard_with_nested_context() {
let mut context_map = HashMap::new();
context_map.insert(
"user".to_string(),
serde_json::json!({
"name": "Alice",
"verified": true
}),
);
let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
let guard = TransitionGuard::Expression("{{ context.user.verified }}".into());
assert!(evaluate_guard(&guard, &ctx));
}
#[test]
fn test_guard_conditions_all() {
let mut context_map = HashMap::new();
context_map.insert("has_name".to_string(), Value::Bool(true));
context_map.insert("has_email".to_string(), Value::Bool(true));
let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
let guard = TransitionGuard::Conditions(GuardConditions::All(vec![
"{{ context.has_name }}".into(),
"{{ context.has_email }}".into(),
]));
assert!(evaluate_guard(&guard, &ctx));
}
#[test]
fn test_guard_conditions_any() {
let mut context_map = HashMap::new();
context_map.insert("is_vip".to_string(), Value::Bool(true));
let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
let guard = TransitionGuard::Conditions(GuardConditions::Any(vec![
"{{ context.is_admin }}".into(),
"{{ context.is_vip }}".into(),
]));
assert!(evaluate_guard(&guard, &ctx));
}
#[test]
fn test_guard_context_matchers() {
let mut context_map = HashMap::new();
context_map.insert(
"user".to_string(),
serde_json::json!({
"tier": "premium",
"balance": 100.0
}),
);
let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
let mut matchers = HashMap::new();
matchers.insert(
"user.tier".to_string(),
ContextMatcher::Exact(Value::String("premium".into())),
);
matchers.insert(
"user.balance".to_string(),
ContextMatcher::Compare(CompareOp::Gte(50.0)),
);
let guard = TransitionGuard::Conditions(GuardConditions::Context(matchers));
assert!(evaluate_guard(&guard, &ctx));
}
#[tokio::test]
async fn test_guard_priority_over_llm() {
let mock = MockLLMProvider::new("guard_test");
let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
let mut context_map = HashMap::new();
context_map.insert("ready".to_string(), Value::Bool(true));
let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
let transitions = vec![
Transition {
to: "llm_based".into(),
when: "user wants to proceed".into(),
guard: None,
intent: None,
auto: true,
priority: 10,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
Transition {
to: "guard_based".into(),
when: String::new(),
guard: Some(TransitionGuard::Expression("{{ context.ready }}".into())),
intent: None,
auto: true,
priority: 5,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
];
let result = evaluator.select_transition(&transitions, &ctx).await;
assert_eq!(result.unwrap(), Some(1));
}
#[test]
fn test_guard_only_evaluator() {
let evaluator = GuardOnlyEvaluator::new();
let mut context_map = HashMap::new();
context_map.insert("ready".to_string(), Value::Bool(true));
let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
let transitions = vec![
Transition {
to: "no_guard".into(),
when: "some condition".into(),
guard: None,
intent: None,
auto: true,
priority: 10,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
Transition {
to: "with_guard".into(),
when: String::new(),
guard: Some(TransitionGuard::Expression("{{ context.ready }}".into())),
intent: None,
auto: true,
priority: 5,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
];
let result = evaluator.evaluate_guards(&transitions, &ctx);
assert_eq!(result, Some(1));
}
#[test]
fn test_context_matcher_exists() {
let mut context_map = HashMap::new();
context_map.insert("name".to_string(), Value::String("Alice".into()));
let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
let mut matchers = HashMap::new();
matchers.insert("name".to_string(), ContextMatcher::Exists { exists: true });
matchers.insert(
"email".to_string(),
ContextMatcher::Exists { exists: false },
);
let guard = TransitionGuard::Conditions(GuardConditions::Context(matchers));
assert!(evaluate_guard(&guard, &ctx));
}
#[test]
fn test_compare_contains() {
let mut context_map = HashMap::new();
context_map.insert("message".to_string(), Value::String("hello world".into()));
context_map.insert("tags".to_string(), serde_json::json!(["urgent", "support"]));
let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
let mut matchers1 = HashMap::new();
matchers1.insert(
"message".to_string(),
ContextMatcher::Compare(CompareOp::Contains("world".into())),
);
let guard1 = TransitionGuard::Conditions(GuardConditions::Context(matchers1));
assert!(evaluate_guard(&guard1, &ctx));
let mut matchers2 = HashMap::new();
matchers2.insert(
"tags".to_string(),
ContextMatcher::Compare(CompareOp::Contains("urgent".into())),
);
let guard2 = TransitionGuard::Conditions(GuardConditions::Context(matchers2));
assert!(evaluate_guard(&guard2, &ctx));
}
#[test]
fn test_compare_in() {
let mut context_map = HashMap::new();
context_map.insert("tier".to_string(), Value::String("premium".into()));
let ctx = TransitionContext::new("hi", "hello", "start").with_context(context_map);
let mut matchers = HashMap::new();
matchers.insert(
"tier".to_string(),
ContextMatcher::Compare(CompareOp::In(vec![
Value::String("premium".into()),
Value::String("enterprise".into()),
])),
);
let guard = TransitionGuard::Conditions(GuardConditions::Context(matchers));
assert!(evaluate_guard(&guard, &ctx));
}
#[tokio::test]
async fn test_intent_based_routing_deterministic() {
let mock = MockLLMProvider::new("intent_test");
let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
let transitions = vec![
Transition {
to: "cancel_order".into(),
when: "User wants to cancel an order".into(),
guard: None,
intent: Some("cancel_order".into()),
auto: true,
priority: 10,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
Transition {
to: "cancel_reservation".into(),
when: "User wants to cancel a reservation".into(),
guard: None,
intent: Some("cancel_reservation".into()),
auto: true,
priority: 10,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
Transition {
to: "cancel_subscription".into(),
when: "User wants to cancel a subscription".into(),
guard: None,
intent: Some("cancel_subscription".into()),
auto: true,
priority: 10,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
];
let mut context_map = HashMap::new();
context_map.insert(
"resolved_intent".to_string(),
Value::String("cancel_reservation".into()),
);
let ctx =
TransitionContext::new("あれキャンセルして", "", "greeting").with_context(context_map);
let result = evaluator
.select_transition(&transitions, &ctx)
.await
.unwrap();
assert_eq!(result, Some(1));
}
#[tokio::test]
async fn test_intent_routing_falls_back_to_llm_when_no_resolved_intent() {
let mut mock = MockLLMProvider::new("intent_fallback_test");
mock.add_response(LLMResponse::new("1", FinishReason::Stop));
let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
let transitions = vec![
Transition {
to: "cancel_order".into(),
when: "User wants to cancel an order".into(),
guard: None,
intent: Some("cancel_order".into()),
auto: true,
priority: 10,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
Transition {
to: "cancel_reservation".into(),
when: "User wants to cancel a reservation".into(),
guard: None,
intent: Some("cancel_reservation".into()),
auto: true,
priority: 10,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
];
let ctx = TransitionContext::new("Cancel order ORD-1042", "", "greeting")
.with_context(HashMap::new());
let result = evaluator
.select_transition(&transitions, &ctx)
.await
.unwrap();
assert_eq!(result, Some(0));
}
#[tokio::test]
async fn test_no_routing_when_resolved_intent_doesnt_match() {
let mut mock = MockLLMProvider::new("intent_nomatch_test");
mock.add_response(LLMResponse::new("0", FinishReason::Stop));
let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
let transitions = vec![
Transition {
to: "cancel_order".into(),
when: "User wants to cancel an order".into(),
guard: None,
intent: Some("cancel_order".into()),
auto: true,
priority: 10,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
Transition {
to: "cancel_reservation".into(),
when: "User wants to cancel a reservation".into(),
guard: None,
intent: Some("cancel_reservation".into()),
auto: true,
priority: 10,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
},
];
let mut context_map = HashMap::new();
context_map.insert(
"resolved_intent".to_string(),
Value::String("something_else".into()),
);
let ctx = TransitionContext::new("do something", "", "greeting").with_context(context_map);
let result = evaluator
.select_transition(&transitions, &ctx)
.await
.unwrap();
assert_eq!(result, None);
}
#[tokio::test]
async fn test_null_resolved_intent_is_ignored() {
let mut mock = MockLLMProvider::new("intent_null_test");
mock.add_response(LLMResponse::new("1", FinishReason::Stop));
let evaluator = LLMTransitionEvaluator::new(Arc::new(mock));
let transitions = vec![Transition {
to: "cancel_order".into(),
when: "User wants to cancel an order".into(),
guard: None,
intent: Some("cancel_order".into()),
auto: true,
priority: 10,
cooldown_turns: None,
timing: TransitionTiming::PostResponse,
requires_response: false,
run_extractors: false,
}];
let mut context_map = HashMap::new();
context_map.insert("resolved_intent".to_string(), Value::Null);
let ctx =
TransitionContext::new("Cancel my order", "", "greeting").with_context(context_map);
let result = evaluator
.select_transition(&transitions, &ctx)
.await
.unwrap();
assert_eq!(result, Some(0));
}
}