1use std::convert::Infallible;
4use std::fmt;
5use std::str::FromStr;
6use std::time::Duration;
7
8use axum::extract::{Query, State};
9use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
10use futures_util::stream::{Stream, StreamExt};
11use serde::Deserialize;
12use serde::de::{self, Deserializer};
13use tokio_stream::wrappers::BroadcastStream;
14use uuid::Uuid;
15
16use ironflow_auth::extractor::Authenticated;
17use ironflow_engine::notify::Event;
18
19use crate::state::AppState;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[cfg_attr(feature = "openapi", derive(serde::Serialize, utoipa::ToSchema))]
37#[cfg_attr(feature = "openapi", serde(rename_all = "snake_case"))]
38pub enum EventKind {
39 RunCreated,
41 RunStatusChanged,
43 RunFailed,
45 StepCompleted,
47 StepFailed,
49 ApprovalRequested,
51 ApprovalGranted,
53 ApprovalRejected,
55 UserSignedIn,
57 UserSignedUp,
59 UserSignedOut,
61}
62
63impl EventKind {
64 pub fn as_str(self) -> &'static str {
66 match self {
67 Self::RunCreated => Event::RUN_CREATED,
68 Self::RunStatusChanged => Event::RUN_STATUS_CHANGED,
69 Self::RunFailed => Event::RUN_FAILED,
70 Self::StepCompleted => Event::STEP_COMPLETED,
71 Self::StepFailed => Event::STEP_FAILED,
72 Self::ApprovalRequested => Event::APPROVAL_REQUESTED,
73 Self::ApprovalGranted => Event::APPROVAL_GRANTED,
74 Self::ApprovalRejected => Event::APPROVAL_REJECTED,
75 Self::UserSignedIn => Event::USER_SIGNED_IN,
76 Self::UserSignedUp => Event::USER_SIGNED_UP,
77 Self::UserSignedOut => Event::USER_SIGNED_OUT,
78 }
79 }
80}
81
82impl fmt::Display for EventKind {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 f.write_str(self.as_str())
85 }
86}
87
88impl FromStr for EventKind {
89 type Err = InvalidEventKind;
90
91 fn from_str(s: &str) -> Result<Self, Self::Err> {
92 match s {
93 "run_created" => Ok(Self::RunCreated),
94 "run_status_changed" => Ok(Self::RunStatusChanged),
95 "run_failed" => Ok(Self::RunFailed),
96 "step_completed" => Ok(Self::StepCompleted),
97 "step_failed" => Ok(Self::StepFailed),
98 "approval_requested" => Ok(Self::ApprovalRequested),
99 "approval_granted" => Ok(Self::ApprovalGranted),
100 "approval_rejected" => Ok(Self::ApprovalRejected),
101 "user_signed_in" => Ok(Self::UserSignedIn),
102 "user_signed_up" => Ok(Self::UserSignedUp),
103 "user_signed_out" => Ok(Self::UserSignedOut),
104 _ => Err(InvalidEventKind(s.to_string())),
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct InvalidEventKind(pub String);
112
113impl fmt::Display for InvalidEventKind {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 write!(f, "unknown event kind: {}", self.0)
116 }
117}
118
119impl std::error::Error for InvalidEventKind {}
120
121fn deserialize_comma_event_kinds<'de, D>(
123 deserializer: D,
124) -> Result<Option<Vec<EventKind>>, D::Error>
125where
126 D: Deserializer<'de>,
127{
128 let opt: Option<String> = Option::deserialize(deserializer)?;
129 match opt {
130 None => Ok(None),
131 Some(raw) => {
132 let kinds: Result<Vec<EventKind>, _> = raw
133 .split(',')
134 .map(|s| s.trim())
135 .filter(|s| !s.is_empty())
136 .map(EventKind::from_str)
137 .collect();
138 kinds.map(Some).map_err(de::Error::custom)
139 }
140 }
141}
142
143#[derive(Debug, Deserialize)]
158pub struct EventsQuery {
159 pub run_id: Option<Uuid>,
161 #[serde(default, deserialize_with = "deserialize_comma_event_kinds")]
163 pub types: Option<Vec<EventKind>>,
164}
165
166fn event_run_id(event: &Event) -> Option<Uuid> {
168 match event {
169 Event::RunCreated { run_id, .. }
170 | Event::RunStatusChanged { run_id, .. }
171 | Event::RunFailed { run_id, .. }
172 | Event::StepCompleted { run_id, .. }
173 | Event::StepFailed { run_id, .. }
174 | Event::ApprovalRequested { run_id, .. }
175 | Event::ApprovalGranted { run_id, .. }
176 | Event::ApprovalRejected { run_id, .. } => Some(*run_id),
177 Event::UserSignedIn { .. } | Event::UserSignedUp { .. } | Event::UserSignedOut { .. } => {
178 None
179 }
180 }
181}
182
183pub async fn events(
199 _auth: Authenticated,
200 State(state): State<AppState>,
201 Query(query): Query<EventsQuery>,
202) -> Sse<impl Stream<Item = Result<SseEvent, Infallible>>> {
203 let receiver = state.event_sender.subscribe();
204 let type_filter = query.types;
205
206 let stream = BroadcastStream::new(receiver).filter_map(move |result: Result<Event, _>| {
207 let type_filter = type_filter.clone();
208 let run_id_filter = query.run_id;
209 async move {
210 let event = result.ok()?;
211
212 if let Some(ref rid) = run_id_filter
213 && event_run_id(&event) != Some(*rid)
214 {
215 return None;
216 }
217
218 if let Some(ref kinds) = type_filter {
219 let event_type = event.event_type();
220 if !kinds.iter().any(|k| k.as_str() == event_type) {
221 return None;
222 }
223 }
224
225 let data = serde_json::to_string(&event).ok()?;
226 let sse_event = SseEvent::default().event(event.event_type()).data(data);
227
228 Some(Ok::<_, Infallible>(sse_event))
229 }
230 });
231
232 Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(30)))
233}
234
235#[cfg(test)]
236mod tests {
237 use std::sync::Arc;
238 use std::time::Duration;
239
240 use axum::Router;
241 use axum::routing::get;
242 use chrono::Utc;
243 use ironflow_auth::jwt::AccessToken;
244 use ironflow_core::providers::claude::ClaudeCodeProvider;
245 use ironflow_engine::engine::Engine;
246 use ironflow_engine::notify::Event;
247 use ironflow_store::api_key_store::ApiKeyStore;
248 use ironflow_store::memory::InMemoryStore;
249 use ironflow_store::models::RunStatus;
250 use ironflow_store::user_store::UserStore;
251 use rust_decimal::Decimal;
252 use tokio::io::AsyncBufReadExt;
253 use tokio::io::BufReader;
254 use tokio::net::TcpListener;
255 use tokio::sync::broadcast;
256 use tokio::time::{sleep, timeout};
257 use uuid::Uuid;
258
259 use super::events;
260 use crate::state::AppState;
261
262 fn test_state() -> AppState {
263 let store = Arc::new(InMemoryStore::new());
264 let user_store: Arc<dyn UserStore> = Arc::new(InMemoryStore::new());
265 let api_key_store: Arc<dyn ApiKeyStore> = Arc::new(InMemoryStore::new());
266 let provider = Arc::new(ClaudeCodeProvider::new());
267 let engine = Arc::new(Engine::new(store.clone(), provider));
268 let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
269 secret: "test-secret".to_string(),
270 access_token_ttl_secs: 900,
271 refresh_token_ttl_secs: 604800,
272 cookie_domain: None,
273 cookie_secure: false,
274 });
275 let (event_sender, _) = broadcast::channel::<Event>(16);
276 AppState::new(
277 store,
278 user_store,
279 api_key_store,
280 engine,
281 jwt_config,
282 "test-worker-token".to_string(),
283 event_sender,
284 )
285 }
286
287 fn sample_run_event(run_id: Uuid) -> Event {
288 Event::RunStatusChanged {
289 run_id,
290 workflow_name: "deploy".to_string(),
291 from: RunStatus::Running,
292 to: RunStatus::Completed,
293 error: None,
294 cost_usd: Decimal::ZERO,
295 duration_ms: 1000,
296 at: Utc::now(),
297 }
298 }
299
300 fn sample_user_event() -> Event {
301 Event::UserSignedIn {
302 user_id: Uuid::now_v7(),
303 username: "alice".to_string(),
304 at: Utc::now(),
305 }
306 }
307
308 fn make_auth_token(state: &AppState) -> String {
309 let user_id = Uuid::now_v7();
310 let token = AccessToken::for_user(user_id, "testuser", false, &state.jwt_config).unwrap();
311 format!("Bearer {}", token.0)
312 }
313
314 async fn start_sse_server(state: AppState) -> (String, broadcast::Sender<Event>, String) {
316 let sender = state.event_sender.clone();
317 let auth = make_auth_token(&state);
318 let app = Router::new()
319 .route("/events", get(events))
320 .with_state(state);
321
322 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
323 let addr = listener.local_addr().unwrap().to_string();
324 tokio::spawn(async move {
325 axum::serve(listener, app).await.unwrap();
326 });
327 (addr, sender, auth)
328 }
329
330 async fn connect_sse(addr: &str, query: &str, auth: &str) -> BufReader<tokio::net::TcpStream> {
332 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
333 let (reader, mut writer) = stream.into_split();
334
335 use tokio::io::AsyncWriteExt;
336 writer
337 .write_all(
338 format!(
339 "GET /events{query} HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\nAuthorization: {auth}\r\n\r\n"
340 )
341 .as_bytes(),
342 )
343 .await
344 .unwrap();
345
346 BufReader::new(reader.reunite(writer).unwrap())
347 }
348
349 async fn read_until_contains(
352 reader: &mut BufReader<tokio::net::TcpStream>,
353 needle: &str,
354 dur: Duration,
355 ) -> String {
356 let mut accumulated = String::new();
357 let result = timeout(dur, async {
358 loop {
359 let mut line = String::new();
360 let n = reader.read_line(&mut line).await.unwrap();
361 if n == 0 {
362 break;
363 }
364 accumulated.push_str(&line);
365 if accumulated.contains(needle) {
366 break;
367 }
368 }
369 })
370 .await;
371 if result.is_err() {
372 panic!("timeout waiting for '{needle}' in SSE stream. Data so far:\n{accumulated}");
373 }
374 accumulated
375 }
376
377 #[tokio::test]
378 async fn sse_stream_receives_events() {
379 let state = test_state();
380 let (addr, sender, auth) = start_sse_server(state).await;
381 let mut reader = connect_sse(&addr, "", &auth).await;
382
383 sleep(Duration::from_millis(50)).await;
384
385 let run_id = Uuid::now_v7();
386 sender.send(sample_run_event(run_id)).unwrap();
387
388 let text =
389 read_until_contains(&mut reader, &run_id.to_string(), Duration::from_secs(5)).await;
390
391 assert!(text.contains("run_status_changed"));
392 assert!(text.contains(&run_id.to_string()));
393 }
394
395 #[tokio::test]
396 async fn sse_filters_by_run_id() {
397 let state = test_state();
398 let (addr, sender, auth) = start_sse_server(state).await;
399
400 let target_run = Uuid::now_v7();
401 let other_run = Uuid::now_v7();
402
403 let mut reader = connect_sse(&addr, &format!("?run_id={target_run}"), &auth).await;
404 sleep(Duration::from_millis(50)).await;
405
406 sender.send(sample_run_event(other_run)).unwrap();
407 sender.send(sample_run_event(target_run)).unwrap();
408
409 let text =
410 read_until_contains(&mut reader, &target_run.to_string(), Duration::from_secs(5)).await;
411
412 assert!(text.contains(&target_run.to_string()));
413 assert!(!text.contains(&other_run.to_string()));
414 }
415
416 #[tokio::test]
417 async fn sse_filters_by_event_type() {
418 let state = test_state();
419 let (addr, sender, auth) = start_sse_server(state).await;
420
421 let mut reader = connect_sse(&addr, "?types=user_signed_in", &auth).await;
422 sleep(Duration::from_millis(50)).await;
423
424 let run_id = Uuid::now_v7();
425 sender.send(sample_run_event(run_id)).unwrap();
426 sender.send(sample_user_event()).unwrap();
427
428 let text = read_until_contains(&mut reader, "user_signed_in", Duration::from_secs(5)).await;
429
430 assert!(text.contains("user_signed_in"));
431 assert!(!text.contains("run_status_changed"));
432 }
433
434 #[tokio::test]
435 async fn sse_returns_correct_content_type() {
436 let state = test_state();
437 let (addr, _sender, auth) = start_sse_server(state).await;
438 let mut reader = connect_sse(&addr, "", &auth).await;
439
440 let text =
441 read_until_contains(&mut reader, "text/event-stream", Duration::from_secs(5)).await;
442
443 assert!(text.contains("text/event-stream"));
444 }
445
446 #[tokio::test]
447 async fn sse_rejects_unauthenticated() {
448 let state = test_state();
449 let (addr, _sender, _auth) = start_sse_server(state).await;
450 let stream = tokio::net::TcpStream::connect(&addr).await.unwrap();
452 let (reader, mut writer) = stream.into_split();
453
454 use tokio::io::AsyncWriteExt;
455 writer
456 .write_all(
457 format!(
458 "GET /events HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\n\r\n"
459 )
460 .as_bytes(),
461 )
462 .await
463 .unwrap();
464
465 let mut buf_reader = BufReader::new(reader.reunite(writer).unwrap());
466 let text = read_until_contains(&mut buf_reader, "401", Duration::from_secs(5)).await;
467
468 assert!(text.contains("401"));
469 }
470}