Skip to main content

ironflow_api/routes/
events.rs

1//! SSE endpoint for real-time event streaming.
2
3use std::convert::Infallible;
4use std::str::FromStr;
5use std::time::Duration;
6
7use axum::extract::{Query, State};
8use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
9use futures_util::stream::{Stream, StreamExt};
10use serde::Deserialize;
11use serde::de::{self, Deserializer};
12use tokio_stream::wrappers::BroadcastStream;
13use uuid::Uuid;
14
15use crate::state::AppState;
16use ironflow_auth::extractor::Authenticated;
17use ironflow_engine::notify::Event;
18
19pub use ironflow_store::entities::EventKind;
20
21/// Deserialize a comma-separated string into `Option<Vec<EventKind>>`.
22fn deserialize_comma_event_kinds<'de, D>(
23    deserializer: D,
24) -> Result<Option<Vec<EventKind>>, D::Error>
25where
26    D: Deserializer<'de>,
27{
28    let opt: Option<String> = Option::deserialize(deserializer)?;
29    match opt {
30        None => Ok(None),
31        Some(raw) => {
32            let kinds: Result<Vec<EventKind>, _> = raw
33                .split(',')
34                .map(|s| s.trim())
35                .filter(|s| !s.is_empty())
36                .map(EventKind::from_str)
37                .collect();
38            kinds.map(Some).map_err(de::Error::custom)
39        }
40    }
41}
42
43/// Query parameters for the SSE events endpoint.
44///
45/// Both fields are optional. When set, only matching events are streamed.
46///
47/// # Examples
48///
49/// ```
50/// use ironflow_api::routes::events::EventsQuery;
51///
52/// let query = EventsQuery {
53///     run_id: None,
54///     types: None,
55/// };
56/// ```
57#[derive(Debug, Deserialize)]
58pub struct EventsQuery {
59    /// Only stream events related to this run.
60    pub run_id: Option<Uuid>,
61    /// Comma-separated list of event types to include (e.g. `?types=run_status_changed,step_completed`).
62    #[serde(default, deserialize_with = "deserialize_comma_event_kinds")]
63    pub types: Option<Vec<EventKind>>,
64}
65
66/// Extract the `run_id` from an event, if the variant carries one.
67fn event_run_id(event: &Event) -> Option<Uuid> {
68    match event {
69        Event::RunCreated { run_id, .. }
70        | Event::RunStatusChanged { run_id, .. }
71        | Event::RunFailed { run_id, .. }
72        | Event::StepCompleted { run_id, .. }
73        | Event::StepFailed { run_id, .. }
74        | Event::ApprovalRequested { run_id, .. }
75        | Event::ApprovalGranted { run_id, .. }
76        | Event::ApprovalRejected { run_id, .. }
77        | Event::LogLine { run_id, .. } => Some(*run_id),
78        Event::UserSignedIn { .. } | Event::UserSignedUp { .. } | Event::UserSignedOut { .. } => {
79            None
80        }
81    }
82}
83
84/// `GET /api/v1/events` -- Server-Sent Events stream.
85///
86/// Streams domain events in real time. Supports optional filtering:
87/// - `?run_id=<uuid>` -- only events for that run
88/// - `?types=run_status_changed,step_completed` -- only those event types
89///
90/// Each SSE message has:
91/// - `event:` set to the event type (e.g. `run_status_changed`)
92/// - `data:` JSON-serialized event payload
93///
94/// A keep-alive comment is sent every 30 seconds.
95///
96/// # Errors
97///
98/// Returns 401 if the request is not authenticated.
99pub async fn events(
100    _auth: Authenticated,
101    State(state): State<AppState>,
102    Query(query): Query<EventsQuery>,
103) -> Sse<impl Stream<Item = Result<SseEvent, Infallible>>> {
104    let receiver = state.event_sender.subscribe();
105    let type_filter = query.types;
106
107    let stream = BroadcastStream::new(receiver).filter_map(move |result: Result<Event, _>| {
108        let type_filter = type_filter.clone();
109        let run_id_filter = query.run_id;
110        async move {
111            let event = result.ok()?;
112
113            if let Some(ref rid) = run_id_filter
114                && event_run_id(&event) != Some(*rid)
115            {
116                return None;
117            }
118
119            if let Some(ref kinds) = type_filter {
120                let event_type = event.event_type();
121                if !kinds.iter().any(|k| k.as_str() == event_type) {
122                    return None;
123                }
124            }
125
126            let data = serde_json::to_string(&event).ok()?;
127            let sse_event = SseEvent::default().event(event.event_type()).data(data);
128
129            Some(Ok::<_, Infallible>(sse_event))
130        }
131    });
132
133    Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(30)))
134}
135
136#[cfg(test)]
137mod tests {
138    use std::sync::Arc;
139    use std::time::Duration;
140
141    use axum::Router;
142    use axum::routing::get;
143    use chrono::Utc;
144    use ironflow_auth::jwt::AccessToken;
145    use ironflow_core::providers::claude::ClaudeCodeProvider;
146    use ironflow_engine::engine::Engine;
147    use ironflow_engine::notify::Event;
148    use ironflow_store::memory::InMemoryStore;
149    use ironflow_store::models::RunStatus;
150    use rust_decimal::Decimal;
151    use tokio::io::AsyncBufReadExt;
152    use tokio::io::BufReader;
153    use tokio::net::TcpListener;
154    use tokio::sync::broadcast;
155    use tokio::time::{sleep, timeout};
156    use uuid::Uuid;
157
158    use super::events;
159    use crate::state::AppState;
160
161    fn test_state() -> AppState {
162        let store = Arc::new(InMemoryStore::new());
163        let provider = Arc::new(ClaudeCodeProvider::new());
164        let engine = Arc::new(Engine::new(store.clone(), provider));
165        let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
166            secret: "test-secret".to_string(),
167            access_token_ttl_secs: 900,
168            refresh_token_ttl_secs: 604800,
169            cookie_domain: None,
170            cookie_secure: false,
171        });
172        let (event_sender, _) = broadcast::channel::<Event>(16);
173        AppState::new(
174            store,
175            engine,
176            jwt_config,
177            "test-worker-token".to_string(),
178            event_sender,
179        )
180    }
181
182    fn sample_run_event(run_id: Uuid) -> Event {
183        Event::RunStatusChanged {
184            run_id,
185            workflow_name: "deploy".to_string(),
186            from: RunStatus::Running,
187            to: RunStatus::Completed,
188            error: None,
189            cost_usd: Decimal::ZERO,
190            duration_ms: 1000,
191            at: Utc::now(),
192        }
193    }
194
195    fn sample_user_event() -> Event {
196        Event::UserSignedIn {
197            user_id: Uuid::now_v7(),
198            username: "alice".to_string(),
199            at: Utc::now(),
200        }
201    }
202
203    fn make_auth_token(state: &AppState) -> String {
204        let user_id = Uuid::now_v7();
205        let token = AccessToken::for_user(user_id, "testuser", false, &state.jwt_config).unwrap();
206        format!("Bearer {}", token.0)
207    }
208
209    /// Start a real TCP server and return (address, sender, auth header).
210    async fn start_sse_server(state: AppState) -> (String, broadcast::Sender<Event>, String) {
211        let sender = state.event_sender.clone();
212        let auth = make_auth_token(&state);
213        let app = Router::new()
214            .route("/events", get(events))
215            .with_state(state);
216
217        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
218        let addr = listener.local_addr().unwrap().to_string();
219        tokio::spawn(async move {
220            axum::serve(listener, app).await.unwrap();
221        });
222        (addr, sender, auth)
223    }
224
225    /// Connect to the SSE endpoint with auth and return a line reader.
226    async fn connect_sse(addr: &str, query: &str, auth: &str) -> BufReader<tokio::net::TcpStream> {
227        let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
228        let (reader, mut writer) = stream.into_split();
229
230        use tokio::io::AsyncWriteExt;
231        writer
232            .write_all(
233                format!(
234                    "GET /events{query} HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\nAuthorization: {auth}\r\n\r\n"
235                )
236                .as_bytes(),
237            )
238            .await
239            .unwrap();
240
241        BufReader::new(reader.reunite(writer).unwrap())
242    }
243
244    /// Read all available data from the SSE stream until `needle` is found
245    /// in the accumulated text, or timeout.
246    async fn read_until_contains(
247        reader: &mut BufReader<tokio::net::TcpStream>,
248        needle: &str,
249        dur: Duration,
250    ) -> String {
251        let mut accumulated = String::new();
252        let result = timeout(dur, async {
253            loop {
254                let mut line = String::new();
255                let n = reader.read_line(&mut line).await.unwrap();
256                if n == 0 {
257                    break;
258                }
259                accumulated.push_str(&line);
260                if accumulated.contains(needle) {
261                    break;
262                }
263            }
264        })
265        .await;
266        if result.is_err() {
267            panic!("timeout waiting for '{needle}' in SSE stream. Data so far:\n{accumulated}");
268        }
269        accumulated
270    }
271
272    #[tokio::test]
273    async fn sse_stream_receives_events() {
274        let state = test_state();
275        let (addr, sender, auth) = start_sse_server(state).await;
276        let mut reader = connect_sse(&addr, "", &auth).await;
277
278        sleep(Duration::from_millis(50)).await;
279
280        let run_id = Uuid::now_v7();
281        sender.send(sample_run_event(run_id)).unwrap();
282
283        let text =
284            read_until_contains(&mut reader, &run_id.to_string(), Duration::from_secs(5)).await;
285
286        assert!(text.contains("run_status_changed"));
287        assert!(text.contains(&run_id.to_string()));
288    }
289
290    #[tokio::test]
291    async fn sse_filters_by_run_id() {
292        let state = test_state();
293        let (addr, sender, auth) = start_sse_server(state).await;
294
295        let target_run = Uuid::now_v7();
296        let other_run = Uuid::now_v7();
297
298        let mut reader = connect_sse(&addr, &format!("?run_id={target_run}"), &auth).await;
299        sleep(Duration::from_millis(50)).await;
300
301        sender.send(sample_run_event(other_run)).unwrap();
302        sender.send(sample_run_event(target_run)).unwrap();
303
304        let text =
305            read_until_contains(&mut reader, &target_run.to_string(), Duration::from_secs(5)).await;
306
307        assert!(text.contains(&target_run.to_string()));
308        assert!(!text.contains(&other_run.to_string()));
309    }
310
311    #[tokio::test]
312    async fn sse_filters_by_event_type() {
313        let state = test_state();
314        let (addr, sender, auth) = start_sse_server(state).await;
315
316        let mut reader = connect_sse(&addr, "?types=user_signed_in", &auth).await;
317        sleep(Duration::from_millis(50)).await;
318
319        let run_id = Uuid::now_v7();
320        sender.send(sample_run_event(run_id)).unwrap();
321        sender.send(sample_user_event()).unwrap();
322
323        let text = read_until_contains(&mut reader, "user_signed_in", Duration::from_secs(5)).await;
324
325        assert!(text.contains("user_signed_in"));
326        assert!(!text.contains("run_status_changed"));
327    }
328
329    #[tokio::test]
330    async fn sse_returns_correct_content_type() {
331        let state = test_state();
332        let (addr, _sender, auth) = start_sse_server(state).await;
333        let mut reader = connect_sse(&addr, "", &auth).await;
334
335        let text =
336            read_until_contains(&mut reader, "text/event-stream", Duration::from_secs(5)).await;
337
338        assert!(text.contains("text/event-stream"));
339    }
340
341    #[tokio::test]
342    async fn sse_rejects_unauthenticated() {
343        let state = test_state();
344        let (addr, _sender, _auth) = start_sse_server(state).await;
345        // Connect without auth header
346        let stream = tokio::net::TcpStream::connect(&addr).await.unwrap();
347        let (reader, mut writer) = stream.into_split();
348
349        use tokio::io::AsyncWriteExt;
350        writer
351            .write_all(
352                format!(
353                    "GET /events HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\n\r\n"
354                )
355                .as_bytes(),
356            )
357            .await
358            .unwrap();
359
360        let mut buf_reader = BufReader::new(reader.reunite(writer).unwrap());
361        let text = read_until_contains(&mut buf_reader, "401", Duration::from_secs(5)).await;
362
363        assert!(text.contains("401"));
364    }
365}