1use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
20pub enum AgentEventType {
21 AgentCreated,
22 AgentActivated,
23 AgentCompleted,
24 AgentError,
25 TaskAssigned,
26 TaskCompleted,
27 TaskFailed,
28 TaskStatusChanged,
29 MessageSent,
30 ReportSubmitted,
31 WorkspaceUpdated,
32}
33
34impl AgentEventType {
35 pub fn as_str(&self) -> &'static str {
36 match self {
37 Self::AgentCreated => "AGENT_CREATED",
38 Self::AgentActivated => "AGENT_ACTIVATED",
39 Self::AgentCompleted => "AGENT_COMPLETED",
40 Self::AgentError => "AGENT_ERROR",
41 Self::TaskAssigned => "TASK_ASSIGNED",
42 Self::TaskCompleted => "TASK_COMPLETED",
43 Self::TaskFailed => "TASK_FAILED",
44 Self::TaskStatusChanged => "TASK_STATUS_CHANGED",
45 Self::MessageSent => "MESSAGE_SENT",
46 Self::ReportSubmitted => "REPORT_SUBMITTED",
47 Self::WorkspaceUpdated => "WORKSPACE_UPDATED",
48 }
49 }
50
51 #[allow(clippy::should_implement_trait)]
52 pub fn from_str(s: &str) -> Option<Self> {
53 match s.to_uppercase().as_str() {
54 "AGENT_CREATED" => Some(Self::AgentCreated),
55 "AGENT_ACTIVATED" => Some(Self::AgentActivated),
56 "AGENT_COMPLETED" => Some(Self::AgentCompleted),
57 "AGENT_ERROR" => Some(Self::AgentError),
58 "TASK_ASSIGNED" => Some(Self::TaskAssigned),
59 "TASK_COMPLETED" => Some(Self::TaskCompleted),
60 "TASK_FAILED" => Some(Self::TaskFailed),
61 "TASK_STATUS_CHANGED" => Some(Self::TaskStatusChanged),
62 "MESSAGE_SENT" => Some(Self::MessageSent),
63 "REPORT_SUBMITTED" => Some(Self::ReportSubmitted),
64 "WORKSPACE_UPDATED" => Some(Self::WorkspaceUpdated),
65 _ => None,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72#[serde(rename_all = "camelCase")]
73pub struct AgentEvent {
74 #[serde(rename = "type")]
75 pub event_type: AgentEventType,
76 pub agent_id: String,
77 pub workspace_id: String,
78 pub data: serde_json::Value,
79 pub timestamp: DateTime<Utc>,
80}
81
82#[derive(Debug, Clone)]
84pub struct EventSubscription {
85 pub id: String,
86 pub agent_id: String,
87 pub agent_name: String,
88 pub event_types: Vec<AgentEventType>,
89 pub exclude_self: bool,
90 pub one_shot: bool,
92 pub wait_group_id: Option<String>,
94 pub priority: i32,
96}
97
98#[derive(Debug, Clone)]
100pub struct WaitGroup {
101 pub id: String,
102 pub parent_agent_id: String,
103 pub expected_agent_ids: Vec<String>,
104 pub completed_agent_ids: HashSet<String>,
105}
106
107type EventHandler = Arc<dyn Fn(AgentEvent) + Send + Sync>;
108
109struct EventBusInner {
111 handlers: HashMap<String, EventHandler>,
112 subscriptions: HashMap<String, EventSubscription>,
113 pending_events: HashMap<String, Vec<AgentEvent>>,
114 wait_groups: HashMap<String, WaitGroup>,
115}
116
117#[derive(Clone)]
119pub struct EventBus {
120 inner: Arc<RwLock<EventBusInner>>,
121}
122
123impl Default for EventBus {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl EventBus {
130 pub fn new() -> Self {
131 Self {
132 inner: Arc::new(RwLock::new(EventBusInner {
133 handlers: HashMap::new(),
134 subscriptions: HashMap::new(),
135 pending_events: HashMap::new(),
136 wait_groups: HashMap::new(),
137 })),
138 }
139 }
140
141 pub async fn on<F>(&self, key: &str, handler: F)
145 where
146 F: Fn(AgentEvent) + Send + Sync + 'static,
147 {
148 let mut inner = self.inner.write().await;
149 inner.handlers.insert(key.to_string(), Arc::new(handler));
150 }
151
152 pub async fn off(&self, key: &str) {
154 let mut inner = self.inner.write().await;
155 inner.handlers.remove(key);
156 }
157
158 pub async fn emit(&self, event: AgentEvent) {
162 let mut inner = self.inner.write().await;
163
164 for handler in inner.handlers.values() {
166 let handler = handler.clone();
167 let event = event.clone();
168 tokio::spawn(async move {
170 handler(event);
171 });
172 }
173
174 let mut sorted_subs: Vec<_> = inner.subscriptions.values().cloned().collect();
176 sorted_subs.sort_by(|a, b| b.priority.cmp(&a.priority));
177
178 let mut one_shot_to_remove: Vec<String> = Vec::new();
179
180 for sub in &sorted_subs {
181 if sub.exclude_self && event.agent_id == sub.agent_id {
182 continue;
183 }
184 if !sub.event_types.contains(&event.event_type) {
185 continue;
186 }
187
188 let pending = inner
189 .pending_events
190 .entry(sub.agent_id.clone())
191 .or_default();
192 pending.push(event.clone());
193
194 if sub.one_shot {
196 one_shot_to_remove.push(sub.id.clone());
197 }
198 }
199
200 for sub_id in one_shot_to_remove {
202 inner.subscriptions.remove(&sub_id);
203 }
204
205 if matches!(
207 event.event_type,
208 AgentEventType::AgentCompleted | AgentEventType::ReportSubmitted
209 ) {
210 Self::check_wait_groups_inner(&mut inner, &event.agent_id);
211 }
212 }
213
214 pub async fn subscribe(&self, subscription: EventSubscription) {
218 let mut inner = self.inner.write().await;
219 inner
220 .subscriptions
221 .insert(subscription.id.clone(), subscription);
222 }
223
224 pub async fn unsubscribe(&self, subscription_id: &str) -> bool {
226 let mut inner = self.inner.write().await;
227 inner.subscriptions.remove(subscription_id).is_some()
228 }
229
230 pub async fn drain_pending_events(&self, agent_id: &str) -> Vec<AgentEvent> {
232 let mut inner = self.inner.write().await;
233 inner.pending_events.remove(agent_id).unwrap_or_default()
234 }
235
236 pub async fn create_wait_group(
240 &self,
241 id: String,
242 parent_agent_id: String,
243 expected_agent_ids: Vec<String>,
244 ) {
245 let mut inner = self.inner.write().await;
246 inner.wait_groups.insert(
247 id.clone(),
248 WaitGroup {
249 id,
250 parent_agent_id,
251 expected_agent_ids,
252 completed_agent_ids: HashSet::new(),
253 },
254 );
255 }
256
257 pub async fn add_to_wait_group(&self, group_id: &str, agent_id: &str) {
259 let mut inner = self.inner.write().await;
260 if let Some(group) = inner.wait_groups.get_mut(group_id) {
261 if !group.expected_agent_ids.contains(&agent_id.to_string()) {
262 group.expected_agent_ids.push(agent_id.to_string());
263 }
264 }
265 }
266
267 pub async fn get_wait_group(&self, group_id: &str) -> Option<WaitGroup> {
269 let inner = self.inner.read().await;
270 inner.wait_groups.get(group_id).cloned()
271 }
272
273 pub async fn remove_wait_group(&self, group_id: &str) {
275 let mut inner = self.inner.write().await;
276 inner.wait_groups.remove(group_id);
277 }
278
279 fn check_wait_groups_inner(inner: &mut EventBusInner, completed_agent_id: &str) {
281 let mut completed_groups: Vec<String> = Vec::new();
282
283 for (group_id, group) in inner.wait_groups.iter_mut() {
284 if group
285 .expected_agent_ids
286 .contains(&completed_agent_id.to_string())
287 {
288 group
289 .completed_agent_ids
290 .insert(completed_agent_id.to_string());
291
292 tracing::info!(
293 "[EventBus] Wait group {}: {}/{} completed",
294 group_id,
295 group.completed_agent_ids.len(),
296 group.expected_agent_ids.len()
297 );
298
299 if group.completed_agent_ids.len() >= group.expected_agent_ids.len() {
300 tracing::info!("[EventBus] Wait group {} complete", group_id);
301 completed_groups.push(group_id.clone());
302 }
303 }
304 }
305
306 for group_id in completed_groups {
308 inner.wait_groups.remove(&group_id);
309 }
310 }
311
312 pub fn all_event_types() -> Vec<&'static str> {
314 vec![
315 "AGENT_CREATED",
316 "AGENT_ACTIVATED",
317 "AGENT_COMPLETED",
318 "AGENT_ERROR",
319 "TASK_ASSIGNED",
320 "TASK_COMPLETED",
321 "TASK_FAILED",
322 "TASK_STATUS_CHANGED",
323 "MESSAGE_SENT",
324 "REPORT_SUBMITTED",
325 "WORKSPACE_UPDATED",
326 ]
327 }
328}