use anyhow::Result;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use tokio::time::{interval, Duration};
use tracing::{debug, error, info};
use crate::infrastructure::tool::ToolExecutor;
use crate::domain::tool::ToolCallContext;
use crate::domain::TriggerCondition;
#[derive(Debug, Clone)]
pub struct PollingTick {
pub rule_id: String,
pub target_agent_id: String,
pub tool_result: Value,
pub triggered_at: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct PollingRule {
pub id: String,
pub tool_id: String,
pub params: Value,
pub interval_secs: u64,
pub condition: TriggerCondition,
pub target_agent_id: String,
pub enabled: bool,
pub last_polled: Option<DateTime<Utc>>,
pub last_result: Option<Value>,
pub tags: Vec<String>,
}
impl PollingRule {
pub fn new(
id: impl Into<String>,
tool_id: impl Into<String>,
params: Value,
interval_secs: u64,
condition: TriggerCondition,
target_agent_id: impl Into<String>,
) -> Self {
Self {
id: id.into(),
tool_id: tool_id.into(),
params,
interval_secs,
condition,
target_agent_id: target_agent_id.into(),
enabled: true,
last_polled: None,
last_result: None,
tags: vec![],
}
}
pub fn with_tags(mut self, tags: Vec<String>) -> Self {
self.tags = tags;
self
}
pub fn should_poll(&self, now: DateTime<Utc>) -> bool {
if !self.enabled {
return false;
}
match self.last_polled {
Some(last) => {
let elapsed = now.signed_duration_since(last).num_seconds() as u64;
elapsed >= self.interval_secs
}
None => true, }
}
pub fn matches_condition(&self, result: &Value) -> bool {
match &self.condition {
TriggerCondition::NumericRange { min, max } => {
if let Some(num_val) = result.as_f64() {
num_val >= *min && num_val <= *max
} else {
false
}
}
TriggerCondition::StringContains { content } => {
if let Some(str_val) = result.as_str() {
return str_val.contains(content.as_str());
}
if let Value::Object(obj) = result {
for (_, val) in obj {
if let Some(str_val) = val.as_str() {
if str_val.contains(content) {
return true;
}
}
if let Value::Array(arr) = val {
for arr_val in arr {
if let Some(arr_str) = arr_val.as_str() {
if arr_str.contains(content) {
return true;
}
}
}
}
}
}
let json_str = result.to_string();
json_str.contains(content)
}
TriggerCondition::StatusMatches { expected_status } => {
if let Some(status_val) = result.as_str() {
status_val == expected_status
} else {
false
}
}
TriggerCondition::CustomExpression { .. }
| TriggerCondition::ScheduleInterval { .. }
| TriggerCondition::ScheduleCron { .. } => {
false
}
}
}
pub fn mark_polled(&mut self, result: Value) {
self.last_polled = Some(Utc::now());
self.last_result = Some(result);
}
}
pub struct PollingManager {
rules: DashMap<String, PollingRule>,
tool_executor: Arc<dyn ToolExecutor>,
tx: broadcast::Sender<PollingTick>,
running: Arc<RwLock<bool>>,
}
impl PollingManager {
pub fn new(tool_executor: Arc<dyn ToolExecutor>) -> Self {
let (tx, _) = broadcast::channel(100);
Self {
rules: DashMap::new(),
tool_executor,
tx,
running: Arc::new(RwLock::new(false)),
}
}
pub fn register_rule(&self, rule: PollingRule) -> Result<()> {
let rule_id = rule.id.clone();
self.rules.insert(rule.id.clone(), rule);
info!("Registered polling rule: {}", rule_id);
Ok(())
}
pub fn remove_rule(&self, rule_id: &str) -> Option<PollingRule> {
self.rules.remove(rule_id).map(|(_, rule)| rule)
}
pub fn get_rule(&self, rule_id: &str) -> Option<PollingRule> {
self.rules.get(rule_id).map(|r| r.clone())
}
pub fn has_rule(&self, rule_id: &str) -> bool {
self.rules.contains_key(rule_id)
}
pub fn list_rules(&self) -> Vec<PollingRule> {
self.rules.iter().map(|r| r.clone()).collect()
}
pub fn set_rule_enabled(&self, rule_id: &str, enabled: bool) -> bool {
if let Some(mut rule) = self.rules.get_mut(rule_id) {
rule.enabled = enabled;
info!(
"Polling rule {} is now {}",
rule_id,
if enabled { "enabled" } else { "disabled" }
);
true
} else {
false
}
}
pub fn subscribe(&self) -> broadcast::Receiver<PollingTick> {
self.tx.subscribe()
}
pub fn start(&self) {
let rules = self.rules.clone();
let tool_executor = self.tool_executor.clone();
let tx = self.tx.clone();
let running = self.running.clone();
tokio::spawn(async move {
*running.write().await = true;
info!("Polling manager started");
let mut check_interval = interval(Duration::from_secs(1));
loop {
check_interval.tick().await;
if !*running.read().await {
break;
}
let now = Utc::now();
let mut polled_count = 0;
for mut rule_ref in rules.iter_mut() {
let rule = rule_ref.value_mut();
if rule.should_poll(now) {
let context = ToolCallContext::new("polling_manager".to_string());
match tool_executor
.execute(&rule.tool_id, rule.params.clone(), &context)
.await
{
Ok(result) => {
debug!(
"Polling rule {} tool {} executed successfully",
rule.id, rule.tool_id
);
if rule.matches_condition(&result) {
let tick = PollingTick {
rule_id: rule.id.clone(),
target_agent_id: rule.target_agent_id.clone(),
tool_result: result.clone(),
triggered_at: now,
};
if let Err(e) = tx.send(tick) {
error!("Failed to send polling tick: {}", e);
} else {
debug!("Polling rule triggered: {}", rule.id);
polled_count += 1;
}
}
rule.mark_polled(result);
}
Err(e) => {
error!(
"Polling rule {} tool {} execution failed: {}",
rule.id, rule.tool_id, e
);
rule.last_polled = Some(Utc::now());
}
}
}
}
if polled_count > 0 {
debug!("Triggered {} polling rules", polled_count);
}
}
info!("Polling manager stopped");
});
}
pub async fn stop(&self) {
*self.running.write().await = false;
info!("Polling manager stopping...");
}
pub async fn is_running(&self) -> bool {
*self.running.read().await
}
}
impl Default for PollingManager {
fn default() -> Self {
unimplemented!("Default implementation not available. Use new() instead.")
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
struct MockToolExecutor {
return_value: Value,
}
impl MockToolExecutor {
fn new(return_value: Value) -> Self {
Self { return_value }
}
}
#[async_trait::async_trait]
impl ToolExecutor for MockToolExecutor {
async fn execute(
&self,
_tool_id: &str,
_params: Value,
_context: &ToolCallContext,
) -> Result<Value> {
Ok(self.return_value.clone())
}
fn can_execute(&self, _tool_id: &str) -> bool {
true
}
fn supported_tools(&self) -> Vec<String> {
vec!["mock_tool".to_string()]
}
}
#[test]
fn test_polling_rule_should_poll() {
let mut rule = PollingRule::new(
"test",
"mock_tool",
json!({}),
60,
TriggerCondition::NumericRange {
min: 0.0,
max: 100.0,
},
"agent1",
);
assert!(rule.should_poll(Utc::now()));
rule.last_polled = Some(Utc::now());
assert!(!rule.should_poll(Utc::now()));
rule.enabled = false;
assert!(!rule.should_poll(Utc::now()));
}
#[test]
fn test_polling_rule_matches_condition() {
let rule = PollingRule::new(
"test",
"mock_tool",
json!({}),
60,
TriggerCondition::NumericRange {
min: 20.0,
max: 30.0,
},
"agent1",
);
assert!(rule.matches_condition(&json!(25.0)));
assert!(rule.matches_condition(&json!(20.0)));
assert!(rule.matches_condition(&json!(30.0)));
assert!(!rule.matches_condition(&json!(10.0)));
assert!(!rule.matches_condition(&json!(40.0)));
}
#[test]
fn test_polling_rule_string_contains() {
let rule = PollingRule::new(
"test",
"mock_tool",
json!({}),
60,
TriggerCondition::StringContains {
content: "hello".to_string(),
},
"agent1",
);
assert!(rule.matches_condition(&json!("hello world")));
assert!(rule.matches_condition(&json!("say hello")));
assert!(!rule.matches_condition(&json!("goodbye")));
}
#[tokio::test]
async fn test_polling_manager() {
let executor = Arc::new(MockToolExecutor::new(json!(25.0)));
let manager = PollingManager::new(executor);
let rule = PollingRule::new(
"test_poll",
"mock_tool",
json!({}),
1, TriggerCondition::NumericRange {
min: 20.0,
max: 30.0,
},
"test_agent",
);
manager.register_rule(rule).unwrap();
let mut rx = manager.subscribe();
manager.start();
let tick = tokio::time::timeout(Duration::from_secs(3), rx.recv())
.await
.expect("Timeout waiting for polling tick")
.expect("Channel closed");
assert_eq!(tick.rule_id, "test_poll");
assert_eq!(tick.target_agent_id, "test_agent");
manager.stop().await;
}
}