1use crate::types::WorkflowDefinition;
16use distri_types::WorkflowTrigger;
17use std::collections::HashMap;
18
19#[derive(Debug, Clone, PartialEq)]
26pub struct TriggerBinding {
27 pub agent_id: String,
28 pub workspace_id: Option<String>,
29 pub entry_point_id: String,
30 pub trigger: WorkflowTrigger,
31}
32
33#[async_trait::async_trait]
35pub trait WorkflowTriggerRegistry: Send + Sync {
36 async fn register(
41 &self,
42 agent_id: &str,
43 workspace_id: Option<&str>,
44 def: &WorkflowDefinition,
45 ) -> anyhow::Result<()>;
46
47 async fn unregister(&self, agent_id: &str) -> anyhow::Result<()>;
49
50 async fn find_webhook(&self, path: &str) -> anyhow::Result<Option<TriggerBinding>>;
54
55 async fn find_tool(&self, tool_name: &str) -> anyhow::Result<Option<TriggerBinding>>;
58
59 async fn find_event(&self, topic: &str) -> anyhow::Result<Vec<TriggerBinding>>;
62
63 async fn list_schedules(&self) -> anyhow::Result<Vec<TriggerBinding>>;
67}
68
69#[derive(Default)]
73pub struct InMemoryWorkflowTriggerRegistry {
74 bindings: std::sync::Mutex<HashMap<String, Vec<TriggerBinding>>>,
75}
76
77impl InMemoryWorkflowTriggerRegistry {
78 pub fn new() -> Self {
79 Self::default()
80 }
81
82 fn collect_bindings(
83 agent_id: &str,
84 workspace_id: Option<&str>,
85 def: &WorkflowDefinition,
86 ) -> Vec<TriggerBinding> {
87 let mut out = Vec::new();
88 for ep in &def.entry_points {
89 for trigger in &ep.triggers {
90 out.push(TriggerBinding {
91 agent_id: agent_id.to_string(),
92 workspace_id: workspace_id.map(|s| s.to_string()),
93 entry_point_id: ep.id.clone(),
94 trigger: trigger.clone(),
95 });
96 }
97 }
98 out
99 }
100}
101
102#[async_trait::async_trait]
103impl WorkflowTriggerRegistry for InMemoryWorkflowTriggerRegistry {
104 async fn register(
105 &self,
106 agent_id: &str,
107 workspace_id: Option<&str>,
108 def: &WorkflowDefinition,
109 ) -> anyhow::Result<()> {
110 let mut guard = self
111 .bindings
112 .lock()
113 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
114 guard.insert(
115 agent_id.to_string(),
116 Self::collect_bindings(agent_id, workspace_id, def),
117 );
118 Ok(())
119 }
120
121 async fn unregister(&self, agent_id: &str) -> anyhow::Result<()> {
122 let mut guard = self
123 .bindings
124 .lock()
125 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
126 guard.remove(agent_id);
127 Ok(())
128 }
129
130 async fn find_webhook(&self, path: &str) -> anyhow::Result<Option<TriggerBinding>> {
131 let guard = self
132 .bindings
133 .lock()
134 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
135 for entries in guard.values() {
136 for binding in entries {
137 if let WorkflowTrigger::Webhook { path: p, .. } = &binding.trigger {
138 if p == path {
139 return Ok(Some(binding.clone()));
140 }
141 }
142 }
143 }
144 Ok(None)
145 }
146
147 async fn find_tool(&self, tool_name: &str) -> anyhow::Result<Option<TriggerBinding>> {
148 let guard = self
149 .bindings
150 .lock()
151 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
152 for entries in guard.values() {
153 for binding in entries {
154 if let WorkflowTrigger::Tool { name, .. } = &binding.trigger {
155 if name == tool_name {
156 return Ok(Some(binding.clone()));
157 }
158 }
159 }
160 }
161 Ok(None)
162 }
163
164 async fn find_event(&self, topic: &str) -> anyhow::Result<Vec<TriggerBinding>> {
165 let guard = self
166 .bindings
167 .lock()
168 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
169 let mut out = Vec::new();
170 for entries in guard.values() {
171 for binding in entries {
172 if let WorkflowTrigger::Event { topic: t, .. } = &binding.trigger {
173 if t == topic {
174 out.push(binding.clone());
175 }
176 }
177 }
178 }
179 Ok(out)
180 }
181
182 async fn list_schedules(&self) -> anyhow::Result<Vec<TriggerBinding>> {
183 let guard = self
184 .bindings
185 .lock()
186 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
187 let mut out = Vec::new();
188 for entries in guard.values() {
189 for binding in entries {
190 if matches!(&binding.trigger, WorkflowTrigger::Schedule { .. }) {
191 out.push(binding.clone());
192 }
193 }
194 }
195 Ok(out)
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use crate::types::{EntryPoint, WorkflowDefinition, WorkflowStep};
203 use distri_types::workflow_triggers::WebhookAuth;
204
205 fn def_with(triggers: Vec<WorkflowTrigger>) -> WorkflowDefinition {
206 WorkflowDefinition::new(vec![WorkflowStep::checkpoint("s", "S", "ok")]).with_entry_points(
207 vec![EntryPoint {
208 id: "main".into(),
209 label: "Main".into(),
210 description: None,
211 starts_at: "s".into(),
212 preset_results: Default::default(),
213 required_inputs: vec![],
214 triggers,
215 }],
216 )
217 }
218
219 #[tokio::test]
220 async fn register_then_find_webhook() {
221 let reg = InMemoryWorkflowTriggerRegistry::new();
222 let def = def_with(vec![WorkflowTrigger::Webhook {
223 path: "github".into(),
224 methods: vec!["POST".into()],
225 auth: WebhookAuth::None,
226 response: Default::default(),
227 }]);
228 reg.register("agent-1", None, &def).await.unwrap();
229
230 let hit = reg.find_webhook("github").await.unwrap().unwrap();
231 assert_eq!(hit.agent_id, "agent-1");
232 assert_eq!(hit.entry_point_id, "main");
233
234 assert!(reg.find_webhook("missing").await.unwrap().is_none());
235 }
236
237 #[tokio::test]
238 async fn register_then_find_tool() {
239 let reg = InMemoryWorkflowTriggerRegistry::new();
240 let def = def_with(vec![WorkflowTrigger::Tool {
241 name: "summarize".into(),
242 description: "summarize a document".into(),
243 input_schema: None,
244 }]);
245 reg.register("wf-summarize", None, &def).await.unwrap();
246
247 let hit = reg.find_tool("summarize").await.unwrap().unwrap();
248 assert_eq!(hit.agent_id, "wf-summarize");
249
250 assert!(reg.find_tool("nope").await.unwrap().is_none());
251 }
252
253 #[tokio::test]
254 async fn find_event_fans_out() {
255 let reg = InMemoryWorkflowTriggerRegistry::new();
256 let def_a = def_with(vec![WorkflowTrigger::Event {
257 topic: "user.signup".into(),
258 filter: None,
259 }]);
260 let def_b = def_with(vec![WorkflowTrigger::Event {
261 topic: "user.signup".into(),
262 filter: None,
263 }]);
264 reg.register("agent-a", None, &def_a).await.unwrap();
265 reg.register("agent-b", None, &def_b).await.unwrap();
266
267 let hits = reg.find_event("user.signup").await.unwrap();
268 assert_eq!(hits.len(), 2);
269 }
270
271 #[tokio::test]
272 async fn list_schedules_returns_only_schedule_triggers() {
273 let reg = InMemoryWorkflowTriggerRegistry::new();
274 let def = def_with(vec![
275 WorkflowTrigger::Schedule {
276 cron: "0 * * * *".into(),
277 timezone: None,
278 enabled: true,
279 input: None,
280 },
281 WorkflowTrigger::Manual,
282 ]);
283 reg.register("nightly", None, &def).await.unwrap();
284
285 let sched = reg.list_schedules().await.unwrap();
286 assert_eq!(sched.len(), 1);
287 assert!(matches!(sched[0].trigger, WorkflowTrigger::Schedule { .. }));
288 }
289
290 #[tokio::test]
291 async fn unregister_clears_bindings() {
292 let reg = InMemoryWorkflowTriggerRegistry::new();
293 let def = def_with(vec![WorkflowTrigger::Webhook {
294 path: "stripe".into(),
295 methods: vec![],
296 auth: WebhookAuth::None,
297 response: Default::default(),
298 }]);
299 reg.register("billing", None, &def).await.unwrap();
300 assert!(reg.find_webhook("stripe").await.unwrap().is_some());
301
302 reg.unregister("billing").await.unwrap();
303 assert!(reg.find_webhook("stripe").await.unwrap().is_none());
304 }
305
306 #[tokio::test]
307 async fn register_overwrites_previous_bindings_for_agent() {
308 let reg = InMemoryWorkflowTriggerRegistry::new();
309 let def_v1 = def_with(vec![WorkflowTrigger::Webhook {
310 path: "v1".into(),
311 methods: vec![],
312 auth: WebhookAuth::None,
313 response: Default::default(),
314 }]);
315 reg.register("api", None, &def_v1).await.unwrap();
316 assert!(reg.find_webhook("v1").await.unwrap().is_some());
317
318 let def_v2 = def_with(vec![WorkflowTrigger::Webhook {
320 path: "v2".into(),
321 methods: vec![],
322 auth: WebhookAuth::None,
323 response: Default::default(),
324 }]);
325 reg.register("api", None, &def_v2).await.unwrap();
326 assert!(reg.find_webhook("v1").await.unwrap().is_none());
327 assert!(reg.find_webhook("v2").await.unwrap().is_some());
328 }
329}