Skip to main content

ai_agents_tools/
condition.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::{Datelike, Local, Timelike, Utc};
6use serde::Deserialize;
7use serde_json::Value;
8
9use ai_agents_core::{ChatMessage, LLMProvider, Result};
10use ai_agents_state::{CompareOp, ContextMatcher, StateMatcher, TimeMatcher, ToolCondition};
11
12#[derive(Debug, Clone)]
13pub struct ToolCallRecord {
14    pub tool_id: String,
15    pub result: Value,
16    pub timestamp: chrono::DateTime<chrono::Utc>,
17}
18
19#[derive(Debug, Clone)]
20pub struct EvaluationContext {
21    pub context: HashMap<String, Value>,
22    pub state_name: Option<String>,
23    pub state_turn_count: u32,
24    pub previous_state: Option<String>,
25    pub called_tools: Vec<ToolCallRecord>,
26    pub recent_messages: Vec<ChatMessage>,
27}
28
29impl Default for EvaluationContext {
30    fn default() -> Self {
31        Self {
32            context: HashMap::new(),
33            state_name: None,
34            state_turn_count: 0,
35            previous_state: None,
36            called_tools: Vec::new(),
37            recent_messages: Vec::new(),
38        }
39    }
40}
41
42impl EvaluationContext {
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    pub fn with_context(mut self, context: HashMap<String, Value>) -> Self {
48        self.context = context;
49        self
50    }
51
52    pub fn with_state(
53        mut self,
54        name: Option<String>,
55        turn_count: u32,
56        previous: Option<String>,
57    ) -> Self {
58        self.state_name = name;
59        self.state_turn_count = turn_count;
60        self.previous_state = previous;
61        self
62    }
63
64    pub fn with_called_tools(mut self, tools: Vec<ToolCallRecord>) -> Self {
65        self.called_tools = tools;
66        self
67    }
68
69    pub fn with_messages(mut self, messages: Vec<ChatMessage>) -> Self {
70        self.recent_messages = messages;
71        self
72    }
73}
74
75#[async_trait]
76pub trait LLMGetter: Send + Sync {
77    fn get_llm(&self, alias: &str) -> Option<Arc<dyn LLMProvider>>;
78}
79
80pub struct ConditionEvaluator<G: LLMGetter> {
81    llm_getter: G,
82}
83
84impl<G: LLMGetter> ConditionEvaluator<G> {
85    pub fn new(llm_getter: G) -> Self {
86        Self { llm_getter }
87    }
88
89    pub async fn evaluate(
90        &self,
91        condition: &ToolCondition,
92        ctx: &EvaluationContext,
93    ) -> Result<bool> {
94        match condition {
95            ToolCondition::Context(matchers) => Ok(self.evaluate_context(matchers, &ctx.context)),
96            ToolCondition::State(matcher) => Ok(self.evaluate_state(matcher, ctx)),
97            ToolCondition::AfterTool(tool_id) => {
98                Ok(ctx.called_tools.iter().any(|t| &t.tool_id == tool_id))
99            }
100            ToolCondition::ToolResult { tool, result } => {
101                Ok(self.evaluate_tool_result(tool, result, &ctx.called_tools))
102            }
103            ToolCondition::Semantic {
104                when,
105                llm,
106                threshold,
107            } => self.evaluate_semantic(when, llm, *threshold, ctx).await,
108            ToolCondition::Time(matcher) => Ok(self.evaluate_time(matcher)),
109            ToolCondition::All(conditions) => {
110                for cond in conditions {
111                    if !Box::pin(self.evaluate(cond, ctx)).await? {
112                        return Ok(false);
113                    }
114                }
115                Ok(true)
116            }
117            ToolCondition::Any(conditions) => {
118                for cond in conditions {
119                    if Box::pin(self.evaluate(cond, ctx)).await? {
120                        return Ok(true);
121                    }
122                }
123                Ok(false)
124            }
125            ToolCondition::Not(inner) => Ok(!Box::pin(self.evaluate(inner, ctx)).await?),
126        }
127    }
128
129    fn evaluate_context(
130        &self,
131        matchers: &HashMap<String, ContextMatcher>,
132        context: &HashMap<String, Value>,
133    ) -> bool {
134        for (path, matcher) in matchers {
135            let value = self.get_context_value(path, context);
136            if !self.match_value(value.as_ref(), matcher) {
137                return false;
138            }
139        }
140        true
141    }
142
143    fn get_context_value(&self, path: &str, context: &HashMap<String, Value>) -> Option<Value> {
144        ai_agents_core::get_dot_path_from_map(context, path)
145    }
146
147    fn match_value(&self, value: Option<&Value>, matcher: &ContextMatcher) -> bool {
148        match matcher {
149            ContextMatcher::Exact(expected) => value.map(|v| v == expected).unwrap_or(false),
150            ContextMatcher::Exists { exists } => {
151                let has_value = value.is_some() && value != Some(&Value::Null);
152                *exists == has_value
153            }
154            ContextMatcher::Compare(op) => {
155                let Some(val) = value else {
156                    return false;
157                };
158                self.compare_value(val, op)
159            }
160        }
161    }
162
163    fn compare_value(&self, value: &Value, op: &CompareOp) -> bool {
164        match op {
165            CompareOp::Eq(expected) => value == expected,
166            CompareOp::Neq(expected) => value != expected,
167            CompareOp::Gt(n) => value.as_f64().map(|v| v > *n).unwrap_or(false),
168            CompareOp::Gte(n) => value.as_f64().map(|v| v >= *n).unwrap_or(false),
169            CompareOp::Lt(n) => value.as_f64().map(|v| v < *n).unwrap_or(false),
170            CompareOp::Lte(n) => value.as_f64().map(|v| v <= *n).unwrap_or(false),
171            CompareOp::In(values) => values.contains(value),
172            CompareOp::Contains(s) => value
173                .as_str()
174                .map(|v| v.contains(s))
175                .or_else(|| {
176                    value
177                        .as_array()
178                        .map(|arr| arr.iter().any(|v| v.as_str() == Some(s)))
179                })
180                .unwrap_or(false),
181        }
182    }
183
184    fn evaluate_state(&self, matcher: &StateMatcher, ctx: &EvaluationContext) -> bool {
185        if let Some(ref expected_name) = matcher.name {
186            if ctx.state_name.as_ref() != Some(expected_name) {
187                return false;
188            }
189        }
190
191        if let Some(ref turn_op) = matcher.turn_count {
192            let turn_count = ctx.state_turn_count as f64;
193            if !self.compare_value(
194                &Value::Number(serde_json::Number::from_f64(turn_count).unwrap_or(0.into())),
195                turn_op,
196            ) {
197                return false;
198            }
199        }
200
201        if let Some(ref expected_prev) = matcher.previous {
202            if ctx.previous_state.as_ref() != Some(expected_prev) {
203                return false;
204            }
205        }
206
207        true
208    }
209
210    fn evaluate_tool_result(
211        &self,
212        tool: &str,
213        expected: &HashMap<String, Value>,
214        called_tools: &[ToolCallRecord],
215    ) -> bool {
216        let tool_record = called_tools.iter().rev().find(|t| t.tool_id == tool);
217
218        let Some(record) = tool_record else {
219            return false;
220        };
221
222        let result_obj = match &record.result {
223            Value::Object(obj) => obj,
224            _ => return false,
225        };
226
227        for (key, expected_value) in expected {
228            match result_obj.get(key) {
229                Some(actual) if actual == expected_value => continue,
230                _ => return false,
231            }
232        }
233
234        true
235    }
236
237    fn evaluate_time(&self, matcher: &TimeMatcher) -> bool {
238        let now = if let Some(ref tz) = matcher.timezone {
239            if tz == "utc" || tz == "UTC" {
240                Utc::now().with_timezone(&Utc).naive_local()
241            } else {
242                Local::now().naive_local()
243            }
244        } else {
245            Local::now().naive_local()
246        };
247
248        if let Some(ref hours_op) = matcher.hours {
249            let hour = now.hour() as f64;
250            if !self.compare_value(&serde_json::json!(hour), hours_op) {
251                return false;
252            }
253        }
254
255        if let Some(ref days) = matcher.day_of_week {
256            let day_name = match now.weekday() {
257                chrono::Weekday::Mon => "monday",
258                chrono::Weekday::Tue => "tuesday",
259                chrono::Weekday::Wed => "wednesday",
260                chrono::Weekday::Thu => "thursday",
261                chrono::Weekday::Fri => "friday",
262                chrono::Weekday::Sat => "saturday",
263                chrono::Weekday::Sun => "sunday",
264            };
265
266            if !days.iter().any(|d| d.to_lowercase() == day_name) {
267                return false;
268            }
269        }
270
271        true
272    }
273
274    async fn evaluate_semantic(
275        &self,
276        condition: &str,
277        llm_alias: &str,
278        threshold: f32,
279        ctx: &EvaluationContext,
280    ) -> Result<bool> {
281        let llm = match self.llm_getter.get_llm(llm_alias) {
282            Some(l) => l,
283            None => {
284                tracing::warn!(llm = llm_alias, "LLM not found for semantic evaluation");
285                return Ok(false);
286            }
287        };
288
289        let conversation_summary = ctx
290            .recent_messages
291            .iter()
292            .take(10)
293            .map(|m| format!("{:?}: {}", m.role, m.content))
294            .collect::<Vec<_>>()
295            .join("\n");
296
297        let prompt = format!(
298            r#"Based on the conversation below, evaluate if this condition is TRUE or FALSE.
299
300Condition to evaluate: "{}"
301
302Recent conversation:
303{}
304
305Respond with ONLY a JSON object:
306{{"result": true, "confidence": 0.9, "reason": "brief explanation"}}
307or
308{{"result": false, "confidence": 0.9, "reason": "brief explanation"}}"#,
309            condition, conversation_summary
310        );
311
312        let messages = vec![ChatMessage::user(&prompt)];
313        let response = llm.complete(&messages, None).await?;
314
315        let parsed: SemanticEvalResult =
316            serde_json::from_str(&response.content).unwrap_or(SemanticEvalResult {
317                result: false,
318                confidence: 0.0,
319                reason: "Failed to parse response".to_string(),
320            });
321
322        tracing::debug!(
323            condition = condition,
324            result = parsed.result,
325            confidence = parsed.confidence,
326            threshold = threshold,
327            reason = %parsed.reason,
328            "Semantic evaluation"
329        );
330
331        Ok(parsed.result && parsed.confidence >= threshold)
332    }
333}
334
335#[derive(Debug, Deserialize)]
336struct SemanticEvalResult {
337    result: bool,
338    confidence: f32,
339    reason: String,
340}
341
342pub struct SimpleLLMGetter {
343    llms: HashMap<String, Arc<dyn LLMProvider>>,
344}
345
346impl SimpleLLMGetter {
347    pub fn new() -> Self {
348        Self {
349            llms: HashMap::new(),
350        }
351    }
352
353    pub fn with_llm(mut self, alias: &str, llm: Arc<dyn LLMProvider>) -> Self {
354        self.llms.insert(alias.to_string(), llm);
355        self
356    }
357}
358
359impl Default for SimpleLLMGetter {
360    fn default() -> Self {
361        Self::new()
362    }
363}
364
365impl LLMGetter for SimpleLLMGetter {
366    fn get_llm(&self, alias: &str) -> Option<Arc<dyn LLMProvider>> {
367        self.llms.get(alias).cloned()
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    struct NoOpLLMGetter;
376
377    impl LLMGetter for NoOpLLMGetter {
378        fn get_llm(&self, _alias: &str) -> Option<Arc<dyn LLMProvider>> {
379            None
380        }
381    }
382
383    #[tokio::test]
384    async fn test_context_condition_exact() {
385        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
386
387        let mut context = HashMap::new();
388        context.insert(
389            "user".to_string(),
390            serde_json::json!({
391                "verified": true,
392                "tier": "premium"
393            }),
394        );
395
396        let ctx = EvaluationContext::new().with_context(context);
397
398        let mut matchers = HashMap::new();
399        matchers.insert(
400            "user.verified".to_string(),
401            ContextMatcher::Exact(Value::Bool(true)),
402        );
403
404        let condition = ToolCondition::Context(matchers);
405        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
406    }
407
408    #[tokio::test]
409    async fn test_context_condition_exists() {
410        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
411
412        let mut context = HashMap::new();
413        context.insert("name".to_string(), Value::String("Alice".into()));
414
415        let ctx = EvaluationContext::new().with_context(context);
416
417        let mut matchers = HashMap::new();
418        matchers.insert("name".to_string(), ContextMatcher::Exists { exists: true });
419        matchers.insert(
420            "email".to_string(),
421            ContextMatcher::Exists { exists: false },
422        );
423
424        let condition = ToolCondition::Context(matchers);
425        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
426    }
427
428    #[tokio::test]
429    async fn test_context_condition_compare() {
430        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
431
432        let mut context = HashMap::new();
433        context.insert("balance".to_string(), serde_json::json!(150.0));
434
435        let ctx = EvaluationContext::new().with_context(context);
436
437        let mut matchers = HashMap::new();
438        matchers.insert(
439            "balance".to_string(),
440            ContextMatcher::Compare(CompareOp::Gte(100.0)),
441        );
442
443        let condition = ToolCondition::Context(matchers);
444        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
445    }
446
447    #[tokio::test]
448    async fn test_state_condition() {
449        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
450
451        let ctx = EvaluationContext::new().with_state(
452            Some("checkout".to_string()),
453            5,
454            Some("browsing".to_string()),
455        );
456
457        let condition = ToolCondition::State(StateMatcher {
458            name: Some("checkout".to_string()),
459            turn_count: Some(CompareOp::Gte(3.0)),
460            previous: Some("browsing".to_string()),
461        });
462
463        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
464    }
465
466    #[tokio::test]
467    async fn test_after_tool_condition() {
468        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
469
470        let ctx = EvaluationContext::new().with_called_tools(vec![ToolCallRecord {
471            tool_id: "search".to_string(),
472            result: serde_json::json!({"found": true}),
473            timestamp: Utc::now(),
474        }]);
475
476        let condition = ToolCondition::AfterTool("search".to_string());
477        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
478
479        let condition2 = ToolCondition::AfterTool("calculate".to_string());
480        assert!(!evaluator.evaluate(&condition2, &ctx).await.unwrap());
481    }
482
483    #[tokio::test]
484    async fn test_tool_result_condition() {
485        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
486
487        let ctx = EvaluationContext::new().with_called_tools(vec![ToolCallRecord {
488            tool_id: "verify_purchase".to_string(),
489            result: serde_json::json!({
490                "valid": true,
491                "refundable": true
492            }),
493            timestamp: Utc::now(),
494        }]);
495
496        let mut expected = HashMap::new();
497        expected.insert("valid".to_string(), Value::Bool(true));
498        expected.insert("refundable".to_string(), Value::Bool(true));
499
500        let condition = ToolCondition::ToolResult {
501            tool: "verify_purchase".to_string(),
502            result: expected,
503        };
504
505        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
506    }
507
508    #[tokio::test]
509    async fn test_all_condition() {
510        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
511
512        let mut context = HashMap::new();
513        context.insert("verified".to_string(), Value::Bool(true));
514        context.insert("balance".to_string(), serde_json::json!(100.0));
515
516        let ctx = EvaluationContext::new().with_context(context);
517
518        let mut m1 = HashMap::new();
519        m1.insert(
520            "verified".to_string(),
521            ContextMatcher::Exact(Value::Bool(true)),
522        );
523
524        let mut m2 = HashMap::new();
525        m2.insert(
526            "balance".to_string(),
527            ContextMatcher::Compare(CompareOp::Gte(50.0)),
528        );
529
530        let condition =
531            ToolCondition::All(vec![ToolCondition::Context(m1), ToolCondition::Context(m2)]);
532
533        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
534    }
535
536    #[tokio::test]
537    async fn test_any_condition() {
538        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
539
540        let mut context = HashMap::new();
541        context.insert("tier".to_string(), Value::String("basic".into()));
542
543        let ctx = EvaluationContext::new().with_context(context);
544
545        let mut m1 = HashMap::new();
546        m1.insert(
547            "tier".to_string(),
548            ContextMatcher::Exact(Value::String("premium".into())),
549        );
550
551        let mut m2 = HashMap::new();
552        m2.insert(
553            "tier".to_string(),
554            ContextMatcher::Exact(Value::String("basic".into())),
555        );
556
557        let condition =
558            ToolCondition::Any(vec![ToolCondition::Context(m1), ToolCondition::Context(m2)]);
559
560        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
561    }
562
563    #[tokio::test]
564    async fn test_not_condition() {
565        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
566
567        let mut context = HashMap::new();
568        context.insert("blocked".to_string(), Value::Bool(false));
569
570        let ctx = EvaluationContext::new().with_context(context);
571
572        let mut matchers = HashMap::new();
573        matchers.insert(
574            "blocked".to_string(),
575            ContextMatcher::Exact(Value::Bool(true)),
576        );
577
578        let condition = ToolCondition::Not(Box::new(ToolCondition::Context(matchers)));
579        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
580    }
581
582    #[tokio::test]
583    async fn test_time_condition_day_of_week() {
584        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
585        let ctx = EvaluationContext::new();
586
587        let all_days = vec![
588            "monday".to_string(),
589            "tuesday".to_string(),
590            "wednesday".to_string(),
591            "thursday".to_string(),
592            "friday".to_string(),
593            "saturday".to_string(),
594            "sunday".to_string(),
595        ];
596
597        let condition = ToolCondition::Time(TimeMatcher {
598            hours: None,
599            day_of_week: Some(all_days),
600            timezone: None,
601        });
602
603        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
604    }
605
606    #[tokio::test]
607    async fn test_compare_in() {
608        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
609
610        let mut context = HashMap::new();
611        context.insert("status".to_string(), Value::String("active".into()));
612
613        let ctx = EvaluationContext::new().with_context(context);
614
615        let mut matchers = HashMap::new();
616        matchers.insert(
617            "status".to_string(),
618            ContextMatcher::Compare(CompareOp::In(vec![
619                Value::String("active".into()),
620                Value::String("pending".into()),
621            ])),
622        );
623
624        let condition = ToolCondition::Context(matchers);
625        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
626    }
627
628    #[tokio::test]
629    async fn test_compare_contains_string() {
630        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
631
632        let mut context = HashMap::new();
633        context.insert(
634            "email".to_string(),
635            Value::String("user@example.com".into()),
636        );
637
638        let ctx = EvaluationContext::new().with_context(context);
639
640        let mut matchers = HashMap::new();
641        matchers.insert(
642            "email".to_string(),
643            ContextMatcher::Compare(CompareOp::Contains("@example.com".into())),
644        );
645
646        let condition = ToolCondition::Context(matchers);
647        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
648    }
649
650    #[tokio::test]
651    async fn test_nested_context_path() {
652        let evaluator = ConditionEvaluator::new(NoOpLLMGetter);
653
654        let mut context = HashMap::new();
655        context.insert(
656            "user".to_string(),
657            serde_json::json!({
658                "profile": {
659                    "settings": {
660                        "notifications": true
661                    }
662                }
663            }),
664        );
665
666        let ctx = EvaluationContext::new().with_context(context);
667
668        let mut matchers = HashMap::new();
669        matchers.insert(
670            "user.profile.settings.notifications".to_string(),
671            ContextMatcher::Exact(Value::Bool(true)),
672        );
673
674        let condition = ToolCondition::Context(matchers);
675        assert!(evaluator.evaluate(&condition, &ctx).await.unwrap());
676    }
677
678    #[test]
679    fn test_simple_llm_getter() {
680        let getter = SimpleLLMGetter::new();
681        assert!(getter.get_llm("nonexistent").is_none());
682    }
683}