1use std::convert::Infallible;
4use std::str::FromStr;
5use std::time::Duration;
6
7use axum::extract::{Query, State};
8use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
9use futures_util::stream::{Stream, StreamExt};
10use serde::Deserialize;
11use serde::de::{self, Deserializer};
12use tokio_stream::wrappers::BroadcastStream;
13use uuid::Uuid;
14
15use crate::state::AppState;
16use ironflow_auth::extractor::Authenticated;
17use ironflow_engine::notify::Event;
18
19pub use ironflow_store::entities::EventKind;
20
21fn deserialize_comma_event_kinds<'de, D>(
23 deserializer: D,
24) -> Result<Option<Vec<EventKind>>, D::Error>
25where
26 D: Deserializer<'de>,
27{
28 let opt: Option<String> = Option::deserialize(deserializer)?;
29 match opt {
30 None => Ok(None),
31 Some(raw) => {
32 let kinds: Result<Vec<EventKind>, _> = raw
33 .split(',')
34 .map(|s| s.trim())
35 .filter(|s| !s.is_empty())
36 .map(EventKind::from_str)
37 .collect();
38 kinds.map(Some).map_err(de::Error::custom)
39 }
40 }
41}
42
43#[derive(Debug, Deserialize)]
58pub struct EventsQuery {
59 pub run_id: Option<Uuid>,
61 #[serde(default, deserialize_with = "deserialize_comma_event_kinds")]
63 pub types: Option<Vec<EventKind>>,
64}
65
66fn event_run_id(event: &Event) -> Option<Uuid> {
68 match event {
69 Event::RunCreated { run_id, .. }
70 | Event::RunStatusChanged { run_id, .. }
71 | Event::RunFailed { run_id, .. }
72 | Event::StepCompleted { run_id, .. }
73 | Event::StepFailed { run_id, .. }
74 | Event::ApprovalRequested { run_id, .. }
75 | Event::ApprovalGranted { run_id, .. }
76 | Event::ApprovalRejected { run_id, .. }
77 | Event::LogLine { run_id, .. } => Some(*run_id),
78 Event::UserSignedIn { .. } | Event::UserSignedUp { .. } | Event::UserSignedOut { .. } => {
79 None
80 }
81 }
82}
83
84pub async fn events(
100 _auth: Authenticated,
101 State(state): State<AppState>,
102 Query(query): Query<EventsQuery>,
103) -> Sse<impl Stream<Item = Result<SseEvent, Infallible>>> {
104 let receiver = state.event_sender.subscribe();
105 let type_filter = query.types;
106
107 let stream = BroadcastStream::new(receiver).filter_map(move |result: Result<Event, _>| {
108 let type_filter = type_filter.clone();
109 let run_id_filter = query.run_id;
110 async move {
111 let event = result.ok()?;
112
113 if let Some(ref rid) = run_id_filter
114 && event_run_id(&event) != Some(*rid)
115 {
116 return None;
117 }
118
119 if let Some(ref kinds) = type_filter {
120 let event_type = event.event_type();
121 if !kinds.iter().any(|k| k.as_str() == event_type) {
122 return None;
123 }
124 }
125
126 let data = serde_json::to_string(&event).ok()?;
127 let sse_event = SseEvent::default().event(event.event_type()).data(data);
128
129 Some(Ok::<_, Infallible>(sse_event))
130 }
131 });
132
133 Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(30)))
134}
135
136#[cfg(test)]
137mod tests {
138 use std::sync::Arc;
139 use std::time::Duration;
140
141 use axum::Router;
142 use axum::routing::get;
143 use chrono::Utc;
144 use ironflow_auth::jwt::AccessToken;
145 use ironflow_core::providers::claude::ClaudeCodeProvider;
146 use ironflow_engine::engine::Engine;
147 use ironflow_engine::notify::Event;
148 use ironflow_store::memory::InMemoryStore;
149 use ironflow_store::models::RunStatus;
150 use rust_decimal::Decimal;
151 use tokio::io::AsyncBufReadExt;
152 use tokio::io::BufReader;
153 use tokio::net::TcpListener;
154 use tokio::sync::broadcast;
155 use tokio::time::{sleep, timeout};
156 use uuid::Uuid;
157
158 use super::events;
159 use crate::state::AppState;
160
161 fn test_state() -> AppState {
162 let store = Arc::new(InMemoryStore::new());
163 let provider = Arc::new(ClaudeCodeProvider::new());
164 let engine = Arc::new(Engine::new(store.clone(), provider));
165 let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
166 secret: "test-secret".to_string(),
167 access_token_ttl_secs: 900,
168 refresh_token_ttl_secs: 604800,
169 cookie_domain: None,
170 cookie_secure: false,
171 });
172 let (event_sender, _) = broadcast::channel::<Event>(16);
173 AppState::new(
174 store,
175 engine,
176 jwt_config,
177 "test-worker-token".to_string(),
178 event_sender,
179 )
180 }
181
182 fn sample_run_event(run_id: Uuid) -> Event {
183 Event::RunStatusChanged {
184 run_id,
185 workflow_name: "deploy".to_string(),
186 from: RunStatus::Running,
187 to: RunStatus::Completed,
188 error: None,
189 cost_usd: Decimal::ZERO,
190 duration_ms: 1000,
191 at: Utc::now(),
192 }
193 }
194
195 fn sample_user_event() -> Event {
196 Event::UserSignedIn {
197 user_id: Uuid::now_v7(),
198 username: "alice".to_string(),
199 at: Utc::now(),
200 }
201 }
202
203 fn make_auth_token(state: &AppState) -> String {
204 let user_id = Uuid::now_v7();
205 let token = AccessToken::for_user(user_id, "testuser", false, &state.jwt_config).unwrap();
206 format!("Bearer {}", token.0)
207 }
208
209 async fn start_sse_server(state: AppState) -> (String, broadcast::Sender<Event>, String) {
211 let sender = state.event_sender.clone();
212 let auth = make_auth_token(&state);
213 let app = Router::new()
214 .route("/events", get(events))
215 .with_state(state);
216
217 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
218 let addr = listener.local_addr().unwrap().to_string();
219 tokio::spawn(async move {
220 axum::serve(listener, app).await.unwrap();
221 });
222 (addr, sender, auth)
223 }
224
225 async fn connect_sse(addr: &str, query: &str, auth: &str) -> BufReader<tokio::net::TcpStream> {
227 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
228 let (reader, mut writer) = stream.into_split();
229
230 use tokio::io::AsyncWriteExt;
231 writer
232 .write_all(
233 format!(
234 "GET /events{query} HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\nAuthorization: {auth}\r\n\r\n"
235 )
236 .as_bytes(),
237 )
238 .await
239 .unwrap();
240
241 BufReader::new(reader.reunite(writer).unwrap())
242 }
243
244 async fn read_until_contains(
247 reader: &mut BufReader<tokio::net::TcpStream>,
248 needle: &str,
249 dur: Duration,
250 ) -> String {
251 let mut accumulated = String::new();
252 let result = timeout(dur, async {
253 loop {
254 let mut line = String::new();
255 let n = reader.read_line(&mut line).await.unwrap();
256 if n == 0 {
257 break;
258 }
259 accumulated.push_str(&line);
260 if accumulated.contains(needle) {
261 break;
262 }
263 }
264 })
265 .await;
266 if result.is_err() {
267 panic!("timeout waiting for '{needle}' in SSE stream. Data so far:\n{accumulated}");
268 }
269 accumulated
270 }
271
272 #[tokio::test]
273 async fn sse_stream_receives_events() {
274 let state = test_state();
275 let (addr, sender, auth) = start_sse_server(state).await;
276 let mut reader = connect_sse(&addr, "", &auth).await;
277
278 sleep(Duration::from_millis(50)).await;
279
280 let run_id = Uuid::now_v7();
281 sender.send(sample_run_event(run_id)).unwrap();
282
283 let text =
284 read_until_contains(&mut reader, &run_id.to_string(), Duration::from_secs(5)).await;
285
286 assert!(text.contains("run_status_changed"));
287 assert!(text.contains(&run_id.to_string()));
288 }
289
290 #[tokio::test]
291 async fn sse_filters_by_run_id() {
292 let state = test_state();
293 let (addr, sender, auth) = start_sse_server(state).await;
294
295 let target_run = Uuid::now_v7();
296 let other_run = Uuid::now_v7();
297
298 let mut reader = connect_sse(&addr, &format!("?run_id={target_run}"), &auth).await;
299 sleep(Duration::from_millis(50)).await;
300
301 sender.send(sample_run_event(other_run)).unwrap();
302 sender.send(sample_run_event(target_run)).unwrap();
303
304 let text =
305 read_until_contains(&mut reader, &target_run.to_string(), Duration::from_secs(5)).await;
306
307 assert!(text.contains(&target_run.to_string()));
308 assert!(!text.contains(&other_run.to_string()));
309 }
310
311 #[tokio::test]
312 async fn sse_filters_by_event_type() {
313 let state = test_state();
314 let (addr, sender, auth) = start_sse_server(state).await;
315
316 let mut reader = connect_sse(&addr, "?types=user_signed_in", &auth).await;
317 sleep(Duration::from_millis(50)).await;
318
319 let run_id = Uuid::now_v7();
320 sender.send(sample_run_event(run_id)).unwrap();
321 sender.send(sample_user_event()).unwrap();
322
323 let text = read_until_contains(&mut reader, "user_signed_in", Duration::from_secs(5)).await;
324
325 assert!(text.contains("user_signed_in"));
326 assert!(!text.contains("run_status_changed"));
327 }
328
329 #[tokio::test]
330 async fn sse_returns_correct_content_type() {
331 let state = test_state();
332 let (addr, _sender, auth) = start_sse_server(state).await;
333 let mut reader = connect_sse(&addr, "", &auth).await;
334
335 let text =
336 read_until_contains(&mut reader, "text/event-stream", Duration::from_secs(5)).await;
337
338 assert!(text.contains("text/event-stream"));
339 }
340
341 #[tokio::test]
342 async fn sse_rejects_unauthenticated() {
343 let state = test_state();
344 let (addr, _sender, _auth) = start_sse_server(state).await;
345 let stream = tokio::net::TcpStream::connect(&addr).await.unwrap();
347 let (reader, mut writer) = stream.into_split();
348
349 use tokio::io::AsyncWriteExt;
350 writer
351 .write_all(
352 format!(
353 "GET /events HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\n\r\n"
354 )
355 .as_bytes(),
356 )
357 .await
358 .unwrap();
359
360 let mut buf_reader = BufReader::new(reader.reunite(writer).unwrap());
361 let text = read_until_contains(&mut buf_reader, "401", Duration::from_secs(5)).await;
362
363 assert!(text.contains("401"));
364 }
365}