rig_compose/
coordinator.rs1use std::collections::HashMap;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use crate::{Agent, AgentId, AgentStepResult, InvestigationContext, KernelError};
16
17#[derive(Debug, Clone)]
20pub struct RoutingRule {
21 pub agent_name: String,
22 pub signals: Vec<String>,
23}
24
25impl RoutingRule {
26 pub fn new(
27 agent_name: impl Into<String>,
28 signals: impl IntoIterator<Item = impl Into<String>>,
29 ) -> Self {
30 Self {
31 agent_name: agent_name.into(),
32 signals: signals.into_iter().map(Into::into).collect(),
33 }
34 }
35
36 fn matches(&self, ctx: &InvestigationContext) -> bool {
37 self.signals.iter().any(|s| ctx.has_signal(s))
38 }
39}
40
41pub struct CoordinatorAgent {
43 id: AgentId,
44 name: String,
45 rules: Vec<RoutingRule>,
46 specialists: HashMap<String, Arc<dyn Agent>>,
47 fallback: Option<String>,
49}
50
51impl CoordinatorAgent {
52 pub fn builder(name: impl Into<String>) -> CoordinatorBuilder {
53 CoordinatorBuilder {
54 name: name.into(),
55 rules: Vec::new(),
56 specialists: HashMap::new(),
57 fallback: None,
58 }
59 }
60
61 pub fn route<'a>(&'a self, ctx: &InvestigationContext) -> Option<&'a Arc<dyn Agent>> {
63 for rule in &self.rules {
64 if !rule.matches(ctx) {
65 continue;
66 }
67 if let Some(agent) = self.specialists.get(&rule.agent_name) {
68 return Some(agent);
69 }
70 }
71 self.fallback.as_ref().and_then(|n| self.specialists.get(n))
72 }
73}
74
75#[async_trait]
76impl Agent for CoordinatorAgent {
77 fn id(&self) -> AgentId {
78 self.id
79 }
80 fn name(&self) -> &str {
81 &self.name
82 }
83 async fn step(&self, ctx: &mut InvestigationContext) -> Result<AgentStepResult, KernelError> {
84 let span = tracing::debug_span!(
85 "azreal.coordinator.route",
86 entity = %ctx.entity_id,
87 signals = ctx.signals.len(),
88 );
89 let _e = span.enter();
90 let routed = self.route(ctx).cloned();
91 drop(_e);
92 match routed {
93 Some(agent) => agent.step(ctx).await,
94 None => Ok(AgentStepResult {
95 skills_run: Vec::new(),
96 skills_skipped: Vec::new(),
97 confidence: ctx.confidence,
98 concluded: false,
99 }),
100 }
101 }
102}
103
104pub struct CoordinatorBuilder {
105 name: String,
106 rules: Vec<RoutingRule>,
107 specialists: HashMap<String, Arc<dyn Agent>>,
108 fallback: Option<String>,
109}
110
111impl CoordinatorBuilder {
112 pub fn route(mut self, rule: RoutingRule) -> Self {
113 self.rules.push(rule);
114 self
115 }
116
117 pub fn with_specialist(mut self, agent: Arc<dyn Agent>) -> Self {
118 self.specialists.insert(agent.name().to_string(), agent);
119 self
120 }
121
122 pub fn fallback(mut self, name: impl Into<String>) -> Self {
123 self.fallback = Some(name.into());
124 self
125 }
126
127 pub fn build(self) -> CoordinatorAgent {
128 CoordinatorAgent {
129 id: AgentId::new(),
130 name: self.name,
131 rules: self.rules,
132 specialists: self.specialists,
133 fallback: self.fallback,
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 use async_trait::async_trait;
143
144 use crate::skill::{Skill, SkillOutcome};
145 use crate::{GenericAgent, SkillRegistry, ToolRegistry};
146
147 struct TriggerSkill {
149 id: &'static str,
150 trigger: &'static str,
151 }
152
153 #[async_trait]
154 impl Skill for TriggerSkill {
155 fn id(&self) -> &str {
156 self.id
157 }
158 fn applies(&self, ctx: &InvestigationContext) -> bool {
159 ctx.has_signal(self.trigger)
160 }
161 async fn execute(
162 &self,
163 _ctx: &mut InvestigationContext,
164 _tools: &ToolRegistry,
165 ) -> Result<SkillOutcome, KernelError> {
166 Ok(SkillOutcome::default().with_delta(0.3))
167 }
168 }
169
170 fn build_specialist(name: &str, skill_ids: &[&str], skills: &SkillRegistry) -> Arc<dyn Agent> {
171 let tools = ToolRegistry::new();
172 let agent = GenericAgent::builder(name)
173 .with_skills(skill_ids.iter().copied())
174 .build(skills, &tools)
175 .unwrap();
176 Arc::new(agent)
177 }
178
179 fn shared_registry() -> SkillRegistry {
180 let r = SkillRegistry::new();
181 r.register(Arc::new(TriggerSkill {
182 id: "test.fanout",
183 trigger: "fanout.high",
184 }));
185 r.register(Arc::new(TriggerSkill {
186 id: "test.spray",
187 trigger: "auth.failure.burst",
188 }));
189 r
190 }
191
192 #[tokio::test]
193 async fn routes_to_first_matching_specialist() {
194 let skills = shared_registry();
195 let recon = build_specialist("recon", &["test.fanout"], &skills);
196 let credential = build_specialist("credential", &["test.spray"], &skills);
197 let coord = CoordinatorAgent::builder("coord")
198 .with_specialist(recon)
199 .with_specialist(credential)
200 .route(RoutingRule::new("recon", ["fanout.high"]))
201 .route(RoutingRule::new("credential", ["auth.failure.burst"]))
202 .build();
203 let mut ctx = InvestigationContext::new("e", "p").with_signal("fanout.high");
204 let r = coord.step(&mut ctx).await.unwrap();
205 assert!(r.skills_run.iter().any(|s| s == "test.fanout"));
206 assert!(ctx.confidence > 0.0);
207 }
208
209 #[tokio::test]
210 async fn falls_back_when_no_rule_matches() {
211 let skills = shared_registry();
212 let general = build_specialist("general", &["test.fanout"], &skills);
213 let coord = CoordinatorAgent::builder("coord")
214 .with_specialist(general)
215 .route(RoutingRule::new("nope", ["never.fires"]))
216 .fallback("general")
217 .build();
218 let mut ctx = InvestigationContext::new("e", "p");
219 let r = coord.step(&mut ctx).await.unwrap();
220 assert!(!r.concluded);
221 }
222
223 #[tokio::test]
224 async fn unmatched_with_no_fallback_is_noop() {
225 let coord = CoordinatorAgent::builder("coord").build();
226 let mut ctx = InvestigationContext::new("e", "p");
227 let r = coord.step(&mut ctx).await.unwrap();
228 assert!(r.skills_run.is_empty());
229 assert!(!r.concluded);
230 }
231
232 #[tokio::test]
233 async fn same_skill_instance_works_for_two_agents() {
234 let skills = shared_registry();
235 let a = build_specialist("a", &["test.fanout"], &skills);
236 let b = build_specialist("b", &["test.fanout"], &skills);
237 let mut ctx_a = InvestigationContext::new("x", "p").with_signal("fanout.high");
238 let mut ctx_b = InvestigationContext::new("y", "p").with_signal("fanout.high");
239 let ra = a.step(&mut ctx_a).await.unwrap();
240 let rb = b.step(&mut ctx_b).await.unwrap();
241 assert_eq!(ra.skills_run, rb.skills_run);
242 assert!(ctx_a.confidence > 0.0);
243 assert!(ctx_b.confidence > 0.0);
244 }
245}