1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::{Datelike, Local, Timelike, Utc};
6use serde::Deserialize;
7use serde_json::Value;
8
9use ai_agents_core::{ChatMessage, LLMProvider, Result};
10use ai_agents_state::{CompareOp, ContextMatcher, StateMatcher, TimeMatcher, ToolCondition};
11
12#[derive(Debug, Clone)]
13pub struct ToolCallRecord {
14 pub tool_id: String,
15 pub result: Value,
16 pub timestamp: chrono::DateTime<chrono::Utc>,
17}
18
19#[derive(Debug, Clone)]
20pub struct EvaluationContext {
21 pub context: HashMap<String, Value>,
22 pub state_name: Option<String>,
23 pub state_turn_count: u32,
24 pub previous_state: Option<String>,
25 pub called_tools: Vec<ToolCallRecord>,
26 pub recent_messages: Vec<ChatMessage>,
27}
28
29impl Default for EvaluationContext {
30 fn default() -> Self {
31 Self {
32 context: HashMap::new(),
33 state_name: None,
34 state_turn_count: 0,
35 previous_state: None,
36 called_tools: Vec::new(),
37 recent_messages: Vec::new(),
38 }
39 }
40}
41
42impl EvaluationContext {
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn with_context(mut self, context: HashMap<String, Value>) -> Self {
48 self.context = context;
49 self
50 }
51
52 pub fn with_state(
53 mut self,
54 name: Option<String>,
55 turn_count: u32,
56 previous: Option<String>,
57 ) -> Self {
58 self.state_name = name;
59 self.state_turn_count = turn_count;
60 self.previous_state = previous;
61 self
62 }
63
64 pub fn with_called_tools(mut self, tools: Vec<ToolCallRecord>) -> Self {
65 self.called_tools = tools;
66 self
67 }
68
69 pub fn with_messages(mut self, messages: Vec<ChatMessage>) -> Self {
70 self.recent_messages = messages;
71 self
72 }
73}
74
75#[async_trait]
76pub trait LLMGetter: Send + Sync {
77 fn get_llm(&self, alias: &str) -> Option<Arc<dyn LLMProvider>>;
78}
79
80pub struct ConditionEvaluator<G: LLMGetter> {
81 llm_getter: G,
82}
83
84impl<G: LLMGetter> ConditionEvaluator<G> {
85 pub fn new(llm_getter: G) -> Self {
86 Self { llm_getter }
87 }
88
89 pub async fn evaluate(
90 &self,
91 condition: &ToolCondition,
92 ctx: &EvaluationContext,
93 ) -> Result<bool> {
94 match condition {
95 ToolCondition::Context(matchers) => Ok(self.evaluate_context(matchers, &ctx.context)),
96 ToolCondition::State(matcher) => Ok(self.evaluate_state(matcher, ctx)),
97 ToolCondition::AfterTool(tool_id) => {
98 Ok(ctx.called_tools.iter().any(|t| &t.tool_id == tool_id))
99 }
100 ToolCondition::ToolResult { tool, result } => {
101 Ok(self.evaluate_tool_result(tool, result, &ctx.called_tools))
102 }
103 ToolCondition::Semantic {
104 when,
105 llm,
106 threshold,
107 } => self.evaluate_semantic(when, llm, *threshold, ctx).await,
108 ToolCondition::Time(matcher) => Ok(self.evaluate_time(matcher)),
109 ToolCondition::All(conditions) => {
110 for cond in conditions {
111 if !Box::pin(self.evaluate(cond, ctx)).await? {
112 return Ok(false);
113 }
114 }
115 Ok(true)
116 }
117 ToolCondition::Any(conditions) => {
118 for cond in conditions {
119 if Box::pin(self.evaluate(cond, ctx)).await? {
120 return Ok(true);
121 }
122 }
123 Ok(false)
124 }
125 ToolCondition::Not(inner) => Ok(!Box::pin(self.evaluate(inner, ctx)).await?),
126 }
127 }
128
129 fn evaluate_context(
130 &self,
131 matchers: &HashMap<String, ContextMatcher>,
132 context: &HashMap<String, Value>,
133 ) -> bool {
134 for (path, matcher) in matchers {
135 let value = self.get_context_value(path, context);
136 if !self.match_value(value.as_ref(), matcher) {
137 return false;
138 }
139 }
140 true
141 }
142
143 fn get_context_value(&self, path: &str, context: &HashMap<String, Value>) -> Option<Value> {
144 ai_agents_core::get_dot_path_from_map(context, path)
145 }
146
147 fn match_value(&self, value: Option<&Value>, matcher: &ContextMatcher) -> bool {
148 match matcher {
149 ContextMatcher::Exact(expected) => value.map(|v| v == expected).unwrap_or(false),
150 ContextMatcher::Exists { exists } => {
151 let has_value = value.is_some() && value != Some(&Value::Null);
152 *exists == has_value
153 }
154 ContextMatcher::Compare(op) => {
155 let Some(val) = value else {
156 return false;
157 };
158 self.compare_value(val, op)
159 }
160 }
161 }
162
163 fn compare_value(&self, value: &Value, op: &CompareOp) -> bool {
164 match op {
165 CompareOp::Eq(expected) => value == expected,
166 CompareOp::Neq(expected) => value != expected,
167 CompareOp::Gt(n) => value.as_f64().map(|v| v > *n).unwrap_or(false),
168 CompareOp::Gte(n) => value.as_f64().map(|v| v >= *n).unwrap_or(false),
169 CompareOp::Lt(n) => value.as_f64().map(|v| v < *n).unwrap_or(false),
170 CompareOp::Lte(n) => value.as_f64().map(|v| v <= *n).unwrap_or(false),
171 CompareOp::In(values) => values.contains(value),
172 CompareOp::Contains(s) => value
173 .as_str()
174 .map(|v| v.contains(s))
175 .or_else(|| {
176 value
177 .as_array()
178 .map(|arr| arr.iter().any(|v| v.as_str() == Some(s)))
179 })
180 .unwrap_or(false),
181 }
182 }
183
184 fn evaluate_state(&self, matcher: &StateMatcher, ctx: &EvaluationContext) -> bool {
185 if let Some(ref expected_name) = matcher.name {
186 if ctx.state_name.as_ref() != Some(expected_name) {
187 return false;
188 }
189 }
190
191 if let Some(ref turn_op) = matcher.turn_count {
192 let turn_count = ctx.state_turn_count as f64;
193 if !self.compare_value(
194 &Value::Number(serde_json::Number::from_f64(turn_count).unwrap_or(0.into())),
195 turn_op,
196 ) {
197 return false;
198 }
199 }
200
201 if let Some(ref expected_prev) = matcher.previous {
202 if ctx.previous_state.as_ref() != Some(expected_prev) {
203 return false;
204 }
205 }
206
207 true
208 }
209
210 fn evaluate_tool_result(
211 &self,
212 tool: &str,
213 expected: &HashMap<String, Value>,
214 called_tools: &[ToolCallRecord],
215 ) -> bool {
216 let tool_record = called_tools.iter().rev().find(|t| t.tool_id == tool);
217
218 let Some(record) = tool_record else {
219 return false;
220 };
221
222 let result_obj = match &record.result {
223 Value::Object(obj) => obj,
224 _ => return false,
225 };
226
227 for (key, expected_value) in expected {
228 match result_obj.get(key) {
229 Some(actual) if actual == expected_value => continue,
230 _ => return false,
231 }
232 }
233
234 true
235 }
236
237 fn evaluate_time(&self, matcher: &TimeMatcher) -> bool {
238 let now = if let Some(ref tz) = matcher.timezone {
239 if tz == "utc" || tz == "UTC" {
240 Utc::now().with_timezone(&Utc).naive_local()
241 } else {
242 Local::now().naive_local()
243 }
244 } else {
245 Local::now().naive_local()
246 };
247
248 if let Some(ref hours_op) = matcher.hours {
249 let hour = now.hour() as f64;
250 if !self.compare_value(&serde_json::json!(hour), hours_op) {
251 return false;
252 }
253 }
254
255 if let Some(ref days) = matcher.day_of_week {
256 let day_name = match now.weekday() {
257 chrono::Weekday::Mon => "monday",
258 chrono::Weekday::Tue => "tuesday",
259 chrono::Weekday::Wed => "wednesday",
260 chrono::Weekday::Thu => "thursday",
261 chrono::Weekday::Fri => "friday",
262 chrono::Weekday::Sat => "saturday",
263 chrono::Weekday::Sun => "sunday",
264 };
265
266 if !days.iter().any(|d| d.to_lowercase() == day_name) {
267 return false;
268 }
269 }
270
271 true
272 }
273
274 async fn evaluate_semantic(
275 &self,
276 condition: &str,
277 llm_alias: &str,
278 threshold: f32,
279 ctx: &EvaluationContext,
280 ) -> Result<bool> {
281 let llm = match self.llm_getter.get_llm(llm_alias) {
282 Some(l) => l,
283 None => {
284 tracing::warn!(llm = llm_alias, "LLM not found for semantic evaluation");
285 return Ok(false);
286 }
287 };
288
289 let conversation_summary = ctx
290 .recent_messages
291 .iter()
292 .take(10)
293 .map(|m| format!("{:?}: {}", m.role, m.content))
294 .collect::<Vec<_>>()
295 .join("\n");
296
297 let prompt = format!(
298 r#"Based on the conversation below, evaluate if this condition is TRUE or FALSE.
299
300Condition to evaluate: "{}"
301
302Recent conversation:
303{}
304
305Respond with ONLY a JSON object:
306{{"result": true, "confidence": 0.9, "reason": "brief explanation"}}
307or
308{{"result": false, "confidence": 0.9, "reason": "brief explanation"}}"#,
309 condition, conversation_summary
310 );
311
312 let messages = vec![ChatMessage::user(&prompt)];
313 let response = llm.complete(&messages, None).await?;
314
315 let parsed: SemanticEvalResult =
316 serde_json::from_str(&response.content).unwrap_or(SemanticEvalResult {
317 result: false,
318 confidence: 0.0,
319 reason: "Failed to parse response".to_string(),
320 });
321
322 tracing::debug!(
323 condition = condition,
324 result = parsed.result,
325 confidence = parsed.confidence,
326 threshold = threshold,
327 reason = %parsed.reason,
328 "Semantic evaluation"
329 );
330
331 Ok(parsed.result && parsed.confidence >= threshold)
332 }
333}
334
335#[derive(Debug, Deserialize)]
336struct SemanticEvalResult {
337 result: bool,
338 confidence: f32,
339 reason: String,
340}
341
342pub struct SimpleLLMGetter {
343 llms: HashMap<String, Arc<dyn LLMProvider>>,
344}
345
346impl SimpleLLMGetter {
347 pub fn new() -> Self {
348 Self {
349 llms: HashMap::new(),
350 }
351 }
352
353 pub fn with_llm(mut self, alias: &str, llm: Arc<dyn LLMProvider>) -> Self {
354 self.llms.insert(alias.to_string(), llm);
355 self
356 }
357}
358
359impl Default for SimpleLLMGetter {
360 fn default() -> Self {
361 Self::new()
362 }
363}
364
365impl LLMGetter for SimpleLLMGetter {
366 fn get_llm(&self, alias: &str) -> Option<Arc<dyn LLMProvider>> {
367 self.llms.get(alias).cloned()
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 struct NoOpLLMGetter;
376
377 impl LLMGetter for NoOpLLMGetter {
378 fn get_llm(&self, _alias: &str) -> Option<Arc<dyn LLMProvider>> {
379 None
380 }
381 }
382
383 #[tokio::test]
384 async fn test_context_condition_exact() {
385 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
386
387 let mut context = HashMap::new();
388 context.insert(
389 "user".to_string(),
390 serde_json::json!({
391 "verified": true,
392 "tier": "premium"
393 }),
394 );
395
396 let ctx = EvaluationContext::new().with_context(context);
397
398 let mut matchers = HashMap::new();
399 matchers.insert(
400 "user.verified".to_string(),
401 ContextMatcher::Exact(Value::Bool(true)),
402 );
403
404 let condition = ToolCondition::Context(matchers);
405 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
406 }
407
408 #[tokio::test]
409 async fn test_context_condition_exists() {
410 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
411
412 let mut context = HashMap::new();
413 context.insert("name".to_string(), Value::String("Alice".into()));
414
415 let ctx = EvaluationContext::new().with_context(context);
416
417 let mut matchers = HashMap::new();
418 matchers.insert("name".to_string(), ContextMatcher::Exists { exists: true });
419 matchers.insert(
420 "email".to_string(),
421 ContextMatcher::Exists { exists: false },
422 );
423
424 let condition = ToolCondition::Context(matchers);
425 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
426 }
427
428 #[tokio::test]
429 async fn test_context_condition_compare() {
430 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
431
432 let mut context = HashMap::new();
433 context.insert("balance".to_string(), serde_json::json!(150.0));
434
435 let ctx = EvaluationContext::new().with_context(context);
436
437 let mut matchers = HashMap::new();
438 matchers.insert(
439 "balance".to_string(),
440 ContextMatcher::Compare(CompareOp::Gte(100.0)),
441 );
442
443 let condition = ToolCondition::Context(matchers);
444 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
445 }
446
447 #[tokio::test]
448 async fn test_state_condition() {
449 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
450
451 let ctx = EvaluationContext::new().with_state(
452 Some("checkout".to_string()),
453 5,
454 Some("browsing".to_string()),
455 );
456
457 let condition = ToolCondition::State(StateMatcher {
458 name: Some("checkout".to_string()),
459 turn_count: Some(CompareOp::Gte(3.0)),
460 previous: Some("browsing".to_string()),
461 });
462
463 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
464 }
465
466 #[tokio::test]
467 async fn test_after_tool_condition() {
468 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
469
470 let ctx = EvaluationContext::new().with_called_tools(vec![ToolCallRecord {
471 tool_id: "search".to_string(),
472 result: serde_json::json!({"found": true}),
473 timestamp: Utc::now(),
474 }]);
475
476 let condition = ToolCondition::AfterTool("search".to_string());
477 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
478
479 let condition2 = ToolCondition::AfterTool("calculate".to_string());
480 assert!(!evaluator.evaluate(&condition2, &ctx).await.unwrap());
481 }
482
483 #[tokio::test]
484 async fn test_tool_result_condition() {
485 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
486
487 let ctx = EvaluationContext::new().with_called_tools(vec![ToolCallRecord {
488 tool_id: "verify_purchase".to_string(),
489 result: serde_json::json!({
490 "valid": true,
491 "refundable": true
492 }),
493 timestamp: Utc::now(),
494 }]);
495
496 let mut expected = HashMap::new();
497 expected.insert("valid".to_string(), Value::Bool(true));
498 expected.insert("refundable".to_string(), Value::Bool(true));
499
500 let condition = ToolCondition::ToolResult {
501 tool: "verify_purchase".to_string(),
502 result: expected,
503 };
504
505 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
506 }
507
508 #[tokio::test]
509 async fn test_all_condition() {
510 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
511
512 let mut context = HashMap::new();
513 context.insert("verified".to_string(), Value::Bool(true));
514 context.insert("balance".to_string(), serde_json::json!(100.0));
515
516 let ctx = EvaluationContext::new().with_context(context);
517
518 let mut m1 = HashMap::new();
519 m1.insert(
520 "verified".to_string(),
521 ContextMatcher::Exact(Value::Bool(true)),
522 );
523
524 let mut m2 = HashMap::new();
525 m2.insert(
526 "balance".to_string(),
527 ContextMatcher::Compare(CompareOp::Gte(50.0)),
528 );
529
530 let condition =
531 ToolCondition::All(vec![ToolCondition::Context(m1), ToolCondition::Context(m2)]);
532
533 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
534 }
535
536 #[tokio::test]
537 async fn test_any_condition() {
538 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
539
540 let mut context = HashMap::new();
541 context.insert("tier".to_string(), Value::String("basic".into()));
542
543 let ctx = EvaluationContext::new().with_context(context);
544
545 let mut m1 = HashMap::new();
546 m1.insert(
547 "tier".to_string(),
548 ContextMatcher::Exact(Value::String("premium".into())),
549 );
550
551 let mut m2 = HashMap::new();
552 m2.insert(
553 "tier".to_string(),
554 ContextMatcher::Exact(Value::String("basic".into())),
555 );
556
557 let condition =
558 ToolCondition::Any(vec![ToolCondition::Context(m1), ToolCondition::Context(m2)]);
559
560 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
561 }
562
563 #[tokio::test]
564 async fn test_not_condition() {
565 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
566
567 let mut context = HashMap::new();
568 context.insert("blocked".to_string(), Value::Bool(false));
569
570 let ctx = EvaluationContext::new().with_context(context);
571
572 let mut matchers = HashMap::new();
573 matchers.insert(
574 "blocked".to_string(),
575 ContextMatcher::Exact(Value::Bool(true)),
576 );
577
578 let condition = ToolCondition::Not(Box::new(ToolCondition::Context(matchers)));
579 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
580 }
581
582 #[tokio::test]
583 async fn test_time_condition_day_of_week() {
584 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
585 let ctx = EvaluationContext::new();
586
587 let all_days = vec![
588 "monday".to_string(),
589 "tuesday".to_string(),
590 "wednesday".to_string(),
591 "thursday".to_string(),
592 "friday".to_string(),
593 "saturday".to_string(),
594 "sunday".to_string(),
595 ];
596
597 let condition = ToolCondition::Time(TimeMatcher {
598 hours: None,
599 day_of_week: Some(all_days),
600 timezone: None,
601 });
602
603 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
604 }
605
606 #[tokio::test]
607 async fn test_compare_in() {
608 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
609
610 let mut context = HashMap::new();
611 context.insert("status".to_string(), Value::String("active".into()));
612
613 let ctx = EvaluationContext::new().with_context(context);
614
615 let mut matchers = HashMap::new();
616 matchers.insert(
617 "status".to_string(),
618 ContextMatcher::Compare(CompareOp::In(vec![
619 Value::String("active".into()),
620 Value::String("pending".into()),
621 ])),
622 );
623
624 let condition = ToolCondition::Context(matchers);
625 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
626 }
627
628 #[tokio::test]
629 async fn test_compare_contains_string() {
630 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
631
632 let mut context = HashMap::new();
633 context.insert(
634 "email".to_string(),
635 Value::String("user@example.com".into()),
636 );
637
638 let ctx = EvaluationContext::new().with_context(context);
639
640 let mut matchers = HashMap::new();
641 matchers.insert(
642 "email".to_string(),
643 ContextMatcher::Compare(CompareOp::Contains("@example.com".into())),
644 );
645
646 let condition = ToolCondition::Context(matchers);
647 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
648 }
649
650 #[tokio::test]
651 async fn test_nested_context_path() {
652 let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
653
654 let mut context = HashMap::new();
655 context.insert(
656 "user".to_string(),
657 serde_json::json!({
658 "profile": {
659 "settings": {
660 "notifications": true
661 }
662 }
663 }),
664 );
665
666 let ctx = EvaluationContext::new().with_context(context);
667
668 let mut matchers = HashMap::new();
669 matchers.insert(
670 "user.profile.settings.notifications".to_string(),
671 ContextMatcher::Exact(Value::Bool(true)),
672 );
673
674 let condition = ToolCondition::Context(matchers);
675 assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
676 }
677
678 #[test]
679 fn test_simple_llm_getter() {
680 let getter = SimpleLLMGetter::new();
681 assert!(getter.get_llm("nonexistent").is_none());
682 }
683}