Skip to main content

rig_compose/
coordinator.rs

1//! [`CoordinatorAgent`] — routes an investigation to the first
2//! registered specialist whose signal tag matches the context.
3//!
4//! The coordinator is itself an [`Agent`] so workflows can dispatch to it
5//! uniformly. It owns no skills of its own; on `step` it inspects
6//! `ctx.signals`, picks the first matching specialist by name, and
7//! delegates. Specialists are stored as `Arc<dyn Agent>`, so any kernel
8//! agent (including future custom impls) can be registered.
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use crate::{Agent, AgentId, AgentStepResult, InvestigationContext, KernelError};
16
17/// One routing rule: any of `signals` matching the context routes to
18/// `agent_name`. First-match wins.
19#[derive(Debug, Clone)]
20pub struct RoutingRule {
21    pub agent_name: String,
22    pub signals: Vec<String>,
23}
24
25impl RoutingRule {
26    pub fn new(
27        agent_name: impl Into<String>,
28        signals: impl IntoIterator<Item = impl Into<String>>,
29    ) -> Self {
30        Self {
31            agent_name: agent_name.into(),
32            signals: signals.into_iter().map(Into::into).collect(),
33        }
34    }
35
36    fn matches(&self, ctx: &InvestigationContext) -> bool {
37        self.signals.iter().any(|s| ctx.has_signal(s))
38    }
39}
40
41/// Routes investigations across a fixed set of specialist agents.
42pub struct CoordinatorAgent {
43    id: AgentId,
44    name: String,
45    rules: Vec<RoutingRule>,
46    specialists: HashMap<String, Arc<dyn Agent>>,
47    /// Optional fallback agent name. Used when no rule matches.
48    fallback: Option<String>,
49}
50
51impl CoordinatorAgent {
52    pub fn builder(name: impl Into<String>) -> CoordinatorBuilder {
53        CoordinatorBuilder {
54            name: name.into(),
55            rules: Vec::new(),
56            specialists: HashMap::new(),
57            fallback: None,
58        }
59    }
60
61    /// Resolve the specialist that should handle `ctx`, if any.
62    pub fn route<'a>(&'a self, ctx: &InvestigationContext) -> Option<&'a Arc<dyn Agent>> {
63        for rule in &self.rules {
64            if !rule.matches(ctx) {
65                continue;
66            }
67            if let Some(agent) = self.specialists.get(&rule.agent_name) {
68                return Some(agent);
69            }
70        }
71        self.fallback.as_ref().and_then(|n| self.specialists.get(n))
72    }
73}
74
75#[async_trait]
76impl Agent for CoordinatorAgent {
77    fn id(&self) -> AgentId {
78        self.id
79    }
80    fn name(&self) -> &str {
81        &self.name
82    }
83    async fn step(&self, ctx: &mut InvestigationContext) -> Result<AgentStepResult, KernelError> {
84        let span = tracing::debug_span!(
85            "azreal.coordinator.route",
86            entity = %ctx.entity_id,
87            signals = ctx.signals.len(),
88        );
89        let _e = span.enter();
90        let routed = self.route(ctx).cloned();
91        drop(_e);
92        match routed {
93            Some(agent) => agent.step(ctx).await,
94            None => Ok(AgentStepResult {
95                skills_run: Vec::new(),
96                skills_skipped: Vec::new(),
97                confidence: ctx.confidence,
98                concluded: false,
99            }),
100        }
101    }
102}
103
104pub struct CoordinatorBuilder {
105    name: String,
106    rules: Vec<RoutingRule>,
107    specialists: HashMap<String, Arc<dyn Agent>>,
108    fallback: Option<String>,
109}
110
111impl CoordinatorBuilder {
112    pub fn route(mut self, rule: RoutingRule) -> Self {
113        self.rules.push(rule);
114        self
115    }
116
117    pub fn with_specialist(mut self, agent: Arc<dyn Agent>) -> Self {
118        self.specialists.insert(agent.name().to_string(), agent);
119        self
120    }
121
122    pub fn fallback(mut self, name: impl Into<String>) -> Self {
123        self.fallback = Some(name.into());
124        self
125    }
126
127    pub fn build(self) -> CoordinatorAgent {
128        CoordinatorAgent {
129            id: AgentId::new(),
130            name: self.name,
131            rules: self.rules,
132            specialists: self.specialists,
133            fallback: self.fallback,
134        }
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    use async_trait::async_trait;
143
144    use crate::skill::{Skill, SkillOutcome};
145    use crate::{GenericAgent, SkillRegistry, ToolRegistry};
146
147    /// Tiny test skill: lifts confidence by 0.3 when its trigger signal is present.
148    struct TriggerSkill {
149        id: &'static str,
150        trigger: &'static str,
151    }
152
153    #[async_trait]
154    impl Skill for TriggerSkill {
155        fn id(&self) -> &str {
156            self.id
157        }
158        fn applies(&self, ctx: &InvestigationContext) -> bool {
159            ctx.has_signal(self.trigger)
160        }
161        async fn execute(
162            &self,
163            _ctx: &mut InvestigationContext,
164            _tools: &ToolRegistry,
165        ) -> Result<SkillOutcome, KernelError> {
166            Ok(SkillOutcome::default().with_delta(0.3))
167        }
168    }
169
170    fn build_specialist(name: &str, skill_ids: &[&str], skills: &SkillRegistry) -> Arc<dyn Agent> {
171        let tools = ToolRegistry::new();
172        let agent = GenericAgent::builder(name)
173            .with_skills(skill_ids.iter().copied())
174            .build(skills, &tools)
175            .unwrap();
176        Arc::new(agent)
177    }
178
179    fn shared_registry() -> SkillRegistry {
180        let r = SkillRegistry::new();
181        r.register(Arc::new(TriggerSkill {
182            id: "test.fanout",
183            trigger: "fanout.high",
184        }));
185        r.register(Arc::new(TriggerSkill {
186            id: "test.spray",
187            trigger: "auth.failure.burst",
188        }));
189        r
190    }
191
192    #[tokio::test]
193    async fn routes_to_first_matching_specialist() {
194        let skills = shared_registry();
195        let recon = build_specialist("recon", &["test.fanout"], &skills);
196        let credential = build_specialist("credential", &["test.spray"], &skills);
197        let coord = CoordinatorAgent::builder("coord")
198            .with_specialist(recon)
199            .with_specialist(credential)
200            .route(RoutingRule::new("recon", ["fanout.high"]))
201            .route(RoutingRule::new("credential", ["auth.failure.burst"]))
202            .build();
203        let mut ctx = InvestigationContext::new("e", "p").with_signal("fanout.high");
204        let r = coord.step(&mut ctx).await.unwrap();
205        assert!(r.skills_run.iter().any(|s| s == "test.fanout"));
206        assert!(ctx.confidence > 0.0);
207    }
208
209    #[tokio::test]
210    async fn falls_back_when_no_rule_matches() {
211        let skills = shared_registry();
212        let general = build_specialist("general", &["test.fanout"], &skills);
213        let coord = CoordinatorAgent::builder("coord")
214            .with_specialist(general)
215            .route(RoutingRule::new("nope", ["never.fires"]))
216            .fallback("general")
217            .build();
218        let mut ctx = InvestigationContext::new("e", "p");
219        let r = coord.step(&mut ctx).await.unwrap();
220        assert!(!r.concluded);
221    }
222
223    #[tokio::test]
224    async fn unmatched_with_no_fallback_is_noop() {
225        let coord = CoordinatorAgent::builder("coord").build();
226        let mut ctx = InvestigationContext::new("e", "p");
227        let r = coord.step(&mut ctx).await.unwrap();
228        assert!(r.skills_run.is_empty());
229        assert!(!r.concluded);
230    }
231
232    #[tokio::test]
233    async fn same_skill_instance_works_for_two_agents() {
234        let skills = shared_registry();
235        let a = build_specialist("a", &["test.fanout"], &skills);
236        let b = build_specialist("b", &["test.fanout"], &skills);
237        let mut ctx_a = InvestigationContext::new("x", "p").with_signal("fanout.high");
238        let mut ctx_b = InvestigationContext::new("y", "p").with_signal("fanout.high");
239        let ra = a.step(&mut ctx_a).await.unwrap();
240        let rb = b.step(&mut ctx_b).await.unwrap();
241        assert_eq!(ra.skills_run, rb.skills_run);
242        assert!(ctx_a.confidence > 0.0);
243        assert!(ctx_b.confidence > 0.0);
244    }
245}