use std::sync::Arc;
use async_trait::async_trait;
use super::TextAgent;
use crate::error::AgentError;
use crate::state::State;
pub struct RouteRule {
predicate: Box<dyn Fn(&State) -> bool + Send + Sync>,
agent: Arc<dyn TextAgent>,
}
impl RouteRule {
pub fn new(
predicate: impl Fn(&State) -> bool + Send + Sync + 'static,
agent: Arc<dyn TextAgent>,
) -> Self {
Self {
predicate: Box::new(predicate),
agent,
}
}
}
pub struct RouteTextAgent {
name: String,
rules: Vec<RouteRule>,
default: Arc<dyn TextAgent>,
}
impl RouteTextAgent {
pub fn new(
name: impl Into<String>,
rules: Vec<RouteRule>,
default: Arc<dyn TextAgent>,
) -> Self {
Self {
name: name.into(),
rules,
default,
}
}
}
#[async_trait]
impl TextAgent for RouteTextAgent {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, state: &State) -> Result<String, AgentError> {
for rule in &self.rules {
if (rule.predicate)(state) {
return rule.agent.run(state).await;
}
}
self.default.run(state).await
}
}