1use std::convert::Infallible;
4use std::fmt;
5use std::str::FromStr;
6use std::time::Duration;
7
8use axum::extract::{Query, State};
9use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
10use futures_util::stream::{Stream, StreamExt};
11use serde::Deserialize;
12use serde::de::{self, Deserializer};
13use tokio_stream::wrappers::BroadcastStream;
14use uuid::Uuid;
15
16use ironflow_auth::extractor::Authenticated;
17use ironflow_engine::notify::Event;
18
19use crate::state::AppState;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[cfg_attr(feature = "openapi", derive(serde::Serialize, utoipa::ToSchema))]
37#[cfg_attr(feature = "openapi", serde(rename_all = "snake_case"))]
38pub enum EventKind {
39 RunCreated,
41 RunStatusChanged,
43 RunFailed,
45 StepCompleted,
47 StepFailed,
49 ApprovalRequested,
51 ApprovalGranted,
53 ApprovalRejected,
55 UserSignedIn,
57 UserSignedUp,
59 UserSignedOut,
61}
62
63impl EventKind {
64 pub fn as_str(self) -> &'static str {
66 match self {
67 Self::RunCreated => Event::RUN_CREATED,
68 Self::RunStatusChanged => Event::RUN_STATUS_CHANGED,
69 Self::RunFailed => Event::RUN_FAILED,
70 Self::StepCompleted => Event::STEP_COMPLETED,
71 Self::StepFailed => Event::STEP_FAILED,
72 Self::ApprovalRequested => Event::APPROVAL_REQUESTED,
73 Self::ApprovalGranted => Event::APPROVAL_GRANTED,
74 Self::ApprovalRejected => Event::APPROVAL_REJECTED,
75 Self::UserSignedIn => Event::USER_SIGNED_IN,
76 Self::UserSignedUp => Event::USER_SIGNED_UP,
77 Self::UserSignedOut => Event::USER_SIGNED_OUT,
78 }
79 }
80}
81
82impl fmt::Display for EventKind {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 f.write_str(self.as_str())
85 }
86}
87
88impl FromStr for EventKind {
89 type Err = InvalidEventKind;
90
91 fn from_str(s: &str) -> Result<Self, Self::Err> {
92 match s {
93 "run_created" => Ok(Self::RunCreated),
94 "run_status_changed" => Ok(Self::RunStatusChanged),
95 "run_failed" => Ok(Self::RunFailed),
96 "step_completed" => Ok(Self::StepCompleted),
97 "step_failed" => Ok(Self::StepFailed),
98 "approval_requested" => Ok(Self::ApprovalRequested),
99 "approval_granted" => Ok(Self::ApprovalGranted),
100 "approval_rejected" => Ok(Self::ApprovalRejected),
101 "user_signed_in" => Ok(Self::UserSignedIn),
102 "user_signed_up" => Ok(Self::UserSignedUp),
103 "user_signed_out" => Ok(Self::UserSignedOut),
104 _ => Err(InvalidEventKind(s.to_string())),
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct InvalidEventKind(pub String);
112
113impl fmt::Display for InvalidEventKind {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 write!(f, "unknown event kind: {}", self.0)
116 }
117}
118
119impl std::error::Error for InvalidEventKind {}
120
121fn deserialize_comma_event_kinds<'de, D>(
123 deserializer: D,
124) -> Result<Option<Vec<EventKind>>, D::Error>
125where
126 D: Deserializer<'de>,
127{
128 let opt: Option<String> = Option::deserialize(deserializer)?;
129 match opt {
130 None => Ok(None),
131 Some(raw) => {
132 let kinds: Result<Vec<EventKind>, _> = raw
133 .split(',')
134 .map(|s| s.trim())
135 .filter(|s| !s.is_empty())
136 .map(EventKind::from_str)
137 .collect();
138 kinds.map(Some).map_err(de::Error::custom)
139 }
140 }
141}
142
143#[derive(Debug, Deserialize)]
158pub struct EventsQuery {
159 pub run_id: Option<Uuid>,
161 #[serde(default, deserialize_with = "deserialize_comma_event_kinds")]
163 pub types: Option<Vec<EventKind>>,
164}
165
166fn event_run_id(event: &Event) -> Option<Uuid> {
168 match event {
169 Event::RunCreated { run_id, .. }
170 | Event::RunStatusChanged { run_id, .. }
171 | Event::RunFailed { run_id, .. }
172 | Event::StepCompleted { run_id, .. }
173 | Event::StepFailed { run_id, .. }
174 | Event::ApprovalRequested { run_id, .. }
175 | Event::ApprovalGranted { run_id, .. }
176 | Event::ApprovalRejected { run_id, .. } => Some(*run_id),
177 Event::UserSignedIn { .. } | Event::UserSignedUp { .. } | Event::UserSignedOut { .. } => {
178 None
179 }
180 }
181}
182
183pub async fn events(
199 _auth: Authenticated,
200 State(state): State<AppState>,
201 Query(query): Query<EventsQuery>,
202) -> Sse<impl Stream<Item = Result<SseEvent, Infallible>>> {
203 let receiver = state.event_sender.subscribe();
204 let type_filter = query.types;
205
206 let stream = BroadcastStream::new(receiver).filter_map(move |result: Result<Event, _>| {
207 let type_filter = type_filter.clone();
208 let run_id_filter = query.run_id;
209 async move {
210 let event = result.ok()?;
211
212 if let Some(ref rid) = run_id_filter
213 && event_run_id(&event) != Some(*rid)
214 {
215 return None;
216 }
217
218 if let Some(ref kinds) = type_filter {
219 let event_type = event.event_type();
220 if !kinds.iter().any(|k| k.as_str() == event_type) {
221 return None;
222 }
223 }
224
225 let data = serde_json::to_string(&event).ok()?;
226 let sse_event = SseEvent::default().event(event.event_type()).data(data);
227
228 Some(Ok::<_, Infallible>(sse_event))
229 }
230 });
231
232 Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(30)))
233}
234
235#[cfg(test)]
236mod tests {
237 use std::sync::Arc;
238 use std::time::Duration;
239
240 use axum::Router;
241 use axum::routing::get;
242 use chrono::Utc;
243 use ironflow_auth::jwt::AccessToken;
244 use ironflow_core::providers::claude::ClaudeCodeProvider;
245 use ironflow_engine::engine::Engine;
246 use ironflow_engine::notify::Event;
247 use ironflow_store::memory::InMemoryStore;
248 use ironflow_store::models::RunStatus;
249 use rust_decimal::Decimal;
250 use tokio::io::AsyncBufReadExt;
251 use tokio::io::BufReader;
252 use tokio::net::TcpListener;
253 use tokio::sync::broadcast;
254 use tokio::time::{sleep, timeout};
255 use uuid::Uuid;
256
257 use super::events;
258 use crate::state::AppState;
259
260 fn test_state() -> AppState {
261 let store = Arc::new(InMemoryStore::new());
262 let provider = Arc::new(ClaudeCodeProvider::new());
263 let engine = Arc::new(Engine::new(store.clone(), provider));
264 let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
265 secret: "test-secret".to_string(),
266 access_token_ttl_secs: 900,
267 refresh_token_ttl_secs: 604800,
268 cookie_domain: None,
269 cookie_secure: false,
270 });
271 let (event_sender, _) = broadcast::channel::<Event>(16);
272 AppState::new(
273 store,
274 engine,
275 jwt_config,
276 "test-worker-token".to_string(),
277 event_sender,
278 )
279 }
280
281 fn sample_run_event(run_id: Uuid) -> Event {
282 Event::RunStatusChanged {
283 run_id,
284 workflow_name: "deploy".to_string(),
285 from: RunStatus::Running,
286 to: RunStatus::Completed,
287 error: None,
288 cost_usd: Decimal::ZERO,
289 duration_ms: 1000,
290 at: Utc::now(),
291 }
292 }
293
294 fn sample_user_event() -> Event {
295 Event::UserSignedIn {
296 user_id: Uuid::now_v7(),
297 username: "alice".to_string(),
298 at: Utc::now(),
299 }
300 }
301
302 fn make_auth_token(state: &AppState) -> String {
303 let user_id = Uuid::now_v7();
304 let token = AccessToken::for_user(user_id, "testuser", false, &state.jwt_config).unwrap();
305 format!("Bearer {}", token.0)
306 }
307
308 async fn start_sse_server(state: AppState) -> (String, broadcast::Sender<Event>, String) {
310 let sender = state.event_sender.clone();
311 let auth = make_auth_token(&state);
312 let app = Router::new()
313 .route("/events", get(events))
314 .with_state(state);
315
316 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
317 let addr = listener.local_addr().unwrap().to_string();
318 tokio::spawn(async move {
319 axum::serve(listener, app).await.unwrap();
320 });
321 (addr, sender, auth)
322 }
323
324 async fn connect_sse(addr: &str, query: &str, auth: &str) -> BufReader<tokio::net::TcpStream> {
326 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
327 let (reader, mut writer) = stream.into_split();
328
329 use tokio::io::AsyncWriteExt;
330 writer
331 .write_all(
332 format!(
333 "GET /events{query} HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\nAuthorization: {auth}\r\n\r\n"
334 )
335 .as_bytes(),
336 )
337 .await
338 .unwrap();
339
340 BufReader::new(reader.reunite(writer).unwrap())
341 }
342
343 async fn read_until_contains(
346 reader: &mut BufReader<tokio::net::TcpStream>,
347 needle: &str,
348 dur: Duration,
349 ) -> String {
350 let mut accumulated = String::new();
351 let result = timeout(dur, async {
352 loop {
353 let mut line = String::new();
354 let n = reader.read_line(&mut line).await.unwrap();
355 if n == 0 {
356 break;
357 }
358 accumulated.push_str(&line);
359 if accumulated.contains(needle) {
360 break;
361 }
362 }
363 })
364 .await;
365 if result.is_err() {
366 panic!("timeout waiting for '{needle}' in SSE stream. Data so far:\n{accumulated}");
367 }
368 accumulated
369 }
370
371 #[tokio::test]
372 async fn sse_stream_receives_events() {
373 let state = test_state();
374 let (addr, sender, auth) = start_sse_server(state).await;
375 let mut reader = connect_sse(&addr, "", &auth).await;
376
377 sleep(Duration::from_millis(50)).await;
378
379 let run_id = Uuid::now_v7();
380 sender.send(sample_run_event(run_id)).unwrap();
381
382 let text =
383 read_until_contains(&mut reader, &run_id.to_string(), Duration::from_secs(5)).await;
384
385 assert!(text.contains("run_status_changed"));
386 assert!(text.contains(&run_id.to_string()));
387 }
388
389 #[tokio::test]
390 async fn sse_filters_by_run_id() {
391 let state = test_state();
392 let (addr, sender, auth) = start_sse_server(state).await;
393
394 let target_run = Uuid::now_v7();
395 let other_run = Uuid::now_v7();
396
397 let mut reader = connect_sse(&addr, &format!("?run_id={target_run}"), &auth).await;
398 sleep(Duration::from_millis(50)).await;
399
400 sender.send(sample_run_event(other_run)).unwrap();
401 sender.send(sample_run_event(target_run)).unwrap();
402
403 let text =
404 read_until_contains(&mut reader, &target_run.to_string(), Duration::from_secs(5)).await;
405
406 assert!(text.contains(&target_run.to_string()));
407 assert!(!text.contains(&other_run.to_string()));
408 }
409
410 #[tokio::test]
411 async fn sse_filters_by_event_type() {
412 let state = test_state();
413 let (addr, sender, auth) = start_sse_server(state).await;
414
415 let mut reader = connect_sse(&addr, "?types=user_signed_in", &auth).await;
416 sleep(Duration::from_millis(50)).await;
417
418 let run_id = Uuid::now_v7();
419 sender.send(sample_run_event(run_id)).unwrap();
420 sender.send(sample_user_event()).unwrap();
421
422 let text = read_until_contains(&mut reader, "user_signed_in", Duration::from_secs(5)).await;
423
424 assert!(text.contains("user_signed_in"));
425 assert!(!text.contains("run_status_changed"));
426 }
427
428 #[tokio::test]
429 async fn sse_returns_correct_content_type() {
430 let state = test_state();
431 let (addr, _sender, auth) = start_sse_server(state).await;
432 let mut reader = connect_sse(&addr, "", &auth).await;
433
434 let text =
435 read_until_contains(&mut reader, "text/event-stream", Duration::from_secs(5)).await;
436
437 assert!(text.contains("text/event-stream"));
438 }
439
440 #[tokio::test]
441 async fn sse_rejects_unauthenticated() {
442 let state = test_state();
443 let (addr, _sender, _auth) = start_sse_server(state).await;
444 let stream = tokio::net::TcpStream::connect(&addr).await.unwrap();
446 let (reader, mut writer) = stream.into_split();
447
448 use tokio::io::AsyncWriteExt;
449 writer
450 .write_all(
451 format!(
452 "GET /events HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\n\r\n"
453 )
454 .as_bytes(),
455 )
456 .await
457 .unwrap();
458
459 let mut buf_reader = BufReader::new(reader.reunite(writer).unwrap());
460 let text = read_until_contains(&mut buf_reader, "401", Duration::from_secs(5)).await;
461
462 assert!(text.contains("401"));
463 }
464}