rig_compose/
coordinator.rs1use std::collections::HashMap;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22
23use crate::{Agent, AgentId, AgentStepResult, InvestigationContext, KernelError};
24
25#[derive(Debug, Clone)]
28pub struct RoutingRule {
29 pub agent_name: String,
30 pub signals: Vec<String>,
31}
32
33impl RoutingRule {
34 pub fn new(
35 agent_name: impl Into<String>,
36 signals: impl IntoIterator<Item = impl Into<String>>,
37 ) -> Self {
38 Self {
39 agent_name: agent_name.into(),
40 signals: signals.into_iter().map(Into::into).collect(),
41 }
42 }
43
44 fn matches(&self, ctx: &InvestigationContext) -> bool {
45 self.signals.iter().any(|s| ctx.has_signal(s))
46 }
47}
48
49pub struct CoordinatorAgent {
51 id: AgentId,
52 name: String,
53 rules: Vec<RoutingRule>,
54 specialists: HashMap<String, Arc<dyn Agent>>,
55 fallback: Option<String>,
57}
58
59impl CoordinatorAgent {
60 pub fn builder(name: impl Into<String>) -> CoordinatorBuilder {
61 CoordinatorBuilder {
62 name: name.into(),
63 rules: Vec::new(),
64 specialists: HashMap::new(),
65 fallback: None,
66 }
67 }
68
69 pub fn route<'a>(&'a self, ctx: &InvestigationContext) -> Option<&'a Arc<dyn Agent>> {
71 for rule in &self.rules {
72 if !rule.matches(ctx) {
73 continue;
74 }
75 if let Some(agent) = self.specialists.get(&rule.agent_name) {
76 return Some(agent);
77 }
78 }
79 self.fallback.as_ref().and_then(|n| self.specialists.get(n))
80 }
81}
82
83#[async_trait]
84impl Agent for CoordinatorAgent {
85 fn id(&self) -> AgentId {
86 self.id
87 }
88 fn name(&self) -> &str {
89 &self.name
90 }
91 async fn step(&self, ctx: &mut InvestigationContext) -> Result<AgentStepResult, KernelError> {
92 let span = tracing::debug_span!(
93 "rig_compose.coordinator.route",
94 entity = %ctx.entity_id,
95 signals = ctx.signals.len(),
96 );
97 let _e = span.enter();
98 let routed = self.route(ctx).cloned();
99 drop(_e);
100 match routed {
101 Some(agent) => agent.step(ctx).await,
102 None => Ok(AgentStepResult {
103 skills_run: Vec::new(),
104 skills_skipped: Vec::new(),
105 confidence: ctx.confidence,
106 concluded: false,
107 }),
108 }
109 }
110}
111
112pub struct CoordinatorBuilder {
113 name: String,
114 rules: Vec<RoutingRule>,
115 specialists: HashMap<String, Arc<dyn Agent>>,
116 fallback: Option<String>,
117}
118
119impl CoordinatorBuilder {
120 pub fn route(mut self, rule: RoutingRule) -> Self {
121 self.rules.push(rule);
122 self
123 }
124
125 pub fn with_specialist(mut self, agent: Arc<dyn Agent>) -> Self {
126 self.specialists.insert(agent.name().to_string(), agent);
127 self
128 }
129
130 pub fn fallback(mut self, name: impl Into<String>) -> Self {
131 self.fallback = Some(name.into());
132 self
133 }
134
135 pub fn build(self) -> CoordinatorAgent {
136 CoordinatorAgent {
137 id: AgentId::new(),
138 name: self.name,
139 rules: self.rules,
140 specialists: self.specialists,
141 fallback: self.fallback,
142 }
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 use async_trait::async_trait;
151
152 use crate::skill::{Skill, SkillOutcome};
153 use crate::{GenericAgent, SkillRegistry, ToolRegistry};
154
155 struct TriggerSkill {
157 id: &'static str,
158 trigger: &'static str,
159 }
160
161 #[async_trait]
162 impl Skill for TriggerSkill {
163 fn id(&self) -> &str {
164 self.id
165 }
166 fn applies(&self, ctx: &InvestigationContext) -> bool {
167 ctx.has_signal(self.trigger)
168 }
169 async fn execute(
170 &self,
171 _ctx: &mut InvestigationContext,
172 _tools: &ToolRegistry,
173 ) -> Result<SkillOutcome, KernelError> {
174 Ok(SkillOutcome::default().with_delta(0.3))
175 }
176 }
177
178 fn build_specialist(name: &str, skill_ids: &[&str], skills: &SkillRegistry) -> Arc<dyn Agent> {
179 let tools = ToolRegistry::new();
180 let agent = GenericAgent::builder(name)
181 .with_skills(skill_ids.iter().copied())
182 .build(skills, &tools)
183 .unwrap();
184 Arc::new(agent)
185 }
186
187 fn shared_registry() -> SkillRegistry {
188 let r = SkillRegistry::new();
189 r.register(Arc::new(TriggerSkill {
190 id: "test.fanout",
191 trigger: "fanout.high",
192 }));
193 r.register(Arc::new(TriggerSkill {
194 id: "test.spray",
195 trigger: "auth.failure.burst",
196 }));
197 r
198 }
199
200 #[tokio::test]
201 async fn routes_to_first_matching_specialist() {
202 let skills = shared_registry();
203 let recon = build_specialist("recon", &["test.fanout"], &skills);
204 let credential = build_specialist("credential", &["test.spray"], &skills);
205 let coord = CoordinatorAgent::builder("coord")
206 .with_specialist(recon)
207 .with_specialist(credential)
208 .route(RoutingRule::new("recon", ["fanout.high"]))
209 .route(RoutingRule::new("credential", ["auth.failure.burst"]))
210 .build();
211 let mut ctx = InvestigationContext::new("e", "p").with_signal("fanout.high");
212 let r = coord.step(&mut ctx).await.unwrap();
213 assert!(r.skills_run.iter().any(|s| s == "test.fanout"));
214 assert!(ctx.confidence > 0.0);
215 }
216
217 #[tokio::test]
218 async fn falls_back_when_no_rule_matches() {
219 let skills = shared_registry();
220 let general = build_specialist("general", &["test.fanout"], &skills);
221 let coord = CoordinatorAgent::builder("coord")
222 .with_specialist(general)
223 .route(RoutingRule::new("nope", ["never.fires"]))
224 .fallback("general")
225 .build();
226 let mut ctx = InvestigationContext::new("e", "p");
227 let r = coord.step(&mut ctx).await.unwrap();
228 assert!(!r.concluded);
229 }
230
231 #[tokio::test]
232 async fn unmatched_with_no_fallback_is_noop() {
233 let coord = CoordinatorAgent::builder("coord").build();
234 let mut ctx = InvestigationContext::new("e", "p");
235 let r = coord.step(&mut ctx).await.unwrap();
236 assert!(r.skills_run.is_empty());
237 assert!(!r.concluded);
238 }
239
240 #[tokio::test]
241 async fn same_skill_instance_works_for_two_agents() {
242 let skills = shared_registry();
243 let a = build_specialist("a", &["test.fanout"], &skills);
244 let b = build_specialist("b", &["test.fanout"], &skills);
245 let mut ctx_a = InvestigationContext::new("x", "p").with_signal("fanout.high");
246 let mut ctx_b = InvestigationContext::new("y", "p").with_signal("fanout.high");
247 let ra = a.step(&mut ctx_a).await.unwrap();
248 let rb = b.step(&mut ctx_b).await.unwrap();
249 assert_eq!(ra.skills_run, rb.skills_run);
250 assert!(ctx_a.confidence > 0.0);
251 assert!(ctx_b.confidence > 0.0);
252 }
253}