Skip to main content

ironflow_api/routes/
events.rs

1//! SSE endpoint for real-time event streaming.
2
3use std::convert::Infallible;
4use std::fmt;
5use std::str::FromStr;
6use std::time::Duration;
7
8use axum::extract::{Query, State};
9use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
10use futures_util::stream::{Stream, StreamExt};
11use serde::Deserialize;
12use serde::de::{self, Deserializer};
13use tokio_stream::wrappers::BroadcastStream;
14use uuid::Uuid;
15
16use ironflow_auth::extractor::Authenticated;
17use ironflow_engine::notify::Event;
18
19use crate::state::AppState;
20
21/// Strongly-typed event kind matching [`Event`] variants.
22///
23/// Serializes to/from the same `snake_case` strings as
24/// [`Event::event_type()`].
25///
26/// # Examples
27///
28/// ```
29/// use ironflow_api::routes::events::EventKind;
30///
31/// let kind: EventKind = "run_status_changed".parse().unwrap();
32/// assert_eq!(kind, EventKind::RunStatusChanged);
33/// assert_eq!(kind.as_str(), "run_status_changed");
34/// ```
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[cfg_attr(feature = "openapi", derive(serde::Serialize, utoipa::ToSchema))]
37#[cfg_attr(feature = "openapi", serde(rename_all = "snake_case"))]
38pub enum EventKind {
39    /// [`Event::RunCreated`]
40    RunCreated,
41    /// [`Event::RunStatusChanged`]
42    RunStatusChanged,
43    /// [`Event::RunFailed`]
44    RunFailed,
45    /// [`Event::StepCompleted`]
46    StepCompleted,
47    /// [`Event::StepFailed`]
48    StepFailed,
49    /// [`Event::ApprovalRequested`]
50    ApprovalRequested,
51    /// [`Event::ApprovalGranted`]
52    ApprovalGranted,
53    /// [`Event::ApprovalRejected`]
54    ApprovalRejected,
55    /// [`Event::UserSignedIn`]
56    UserSignedIn,
57    /// [`Event::UserSignedUp`]
58    UserSignedUp,
59    /// [`Event::UserSignedOut`]
60    UserSignedOut,
61}
62
63impl EventKind {
64    /// Returns the wire-format string for this kind.
65    pub fn as_str(self) -> &'static str {
66        match self {
67            Self::RunCreated => Event::RUN_CREATED,
68            Self::RunStatusChanged => Event::RUN_STATUS_CHANGED,
69            Self::RunFailed => Event::RUN_FAILED,
70            Self::StepCompleted => Event::STEP_COMPLETED,
71            Self::StepFailed => Event::STEP_FAILED,
72            Self::ApprovalRequested => Event::APPROVAL_REQUESTED,
73            Self::ApprovalGranted => Event::APPROVAL_GRANTED,
74            Self::ApprovalRejected => Event::APPROVAL_REJECTED,
75            Self::UserSignedIn => Event::USER_SIGNED_IN,
76            Self::UserSignedUp => Event::USER_SIGNED_UP,
77            Self::UserSignedOut => Event::USER_SIGNED_OUT,
78        }
79    }
80}
81
82impl fmt::Display for EventKind {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        f.write_str(self.as_str())
85    }
86}
87
88impl FromStr for EventKind {
89    type Err = InvalidEventKind;
90
91    fn from_str(s: &str) -> Result<Self, Self::Err> {
92        match s {
93            "run_created" => Ok(Self::RunCreated),
94            "run_status_changed" => Ok(Self::RunStatusChanged),
95            "run_failed" => Ok(Self::RunFailed),
96            "step_completed" => Ok(Self::StepCompleted),
97            "step_failed" => Ok(Self::StepFailed),
98            "approval_requested" => Ok(Self::ApprovalRequested),
99            "approval_granted" => Ok(Self::ApprovalGranted),
100            "approval_rejected" => Ok(Self::ApprovalRejected),
101            "user_signed_in" => Ok(Self::UserSignedIn),
102            "user_signed_up" => Ok(Self::UserSignedUp),
103            "user_signed_out" => Ok(Self::UserSignedOut),
104            _ => Err(InvalidEventKind(s.to_string())),
105        }
106    }
107}
108
109/// Error returned when parsing an unknown event kind string.
110#[derive(Debug, Clone)]
111pub struct InvalidEventKind(pub String);
112
113impl fmt::Display for InvalidEventKind {
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        write!(f, "unknown event kind: {}", self.0)
116    }
117}
118
119impl std::error::Error for InvalidEventKind {}
120
121/// Deserialize a comma-separated string into `Option<Vec<EventKind>>`.
122fn deserialize_comma_event_kinds<'de, D>(
123    deserializer: D,
124) -> Result<Option<Vec<EventKind>>, D::Error>
125where
126    D: Deserializer<'de>,
127{
128    let opt: Option<String> = Option::deserialize(deserializer)?;
129    match opt {
130        None => Ok(None),
131        Some(raw) => {
132            let kinds: Result<Vec<EventKind>, _> = raw
133                .split(',')
134                .map(|s| s.trim())
135                .filter(|s| !s.is_empty())
136                .map(EventKind::from_str)
137                .collect();
138            kinds.map(Some).map_err(de::Error::custom)
139        }
140    }
141}
142
143/// Query parameters for the SSE events endpoint.
144///
145/// Both fields are optional. When set, only matching events are streamed.
146///
147/// # Examples
148///
149/// ```
150/// use ironflow_api::routes::events::EventsQuery;
151///
152/// let query = EventsQuery {
153///     run_id: None,
154///     types: None,
155/// };
156/// ```
157#[derive(Debug, Deserialize)]
158pub struct EventsQuery {
159    /// Only stream events related to this run.
160    pub run_id: Option<Uuid>,
161    /// Comma-separated list of event types to include (e.g. `?types=run_status_changed,step_completed`).
162    #[serde(default, deserialize_with = "deserialize_comma_event_kinds")]
163    pub types: Option<Vec<EventKind>>,
164}
165
166/// Extract the `run_id` from an event, if the variant carries one.
167fn event_run_id(event: &Event) -> Option<Uuid> {
168    match event {
169        Event::RunCreated { run_id, .. }
170        | Event::RunStatusChanged { run_id, .. }
171        | Event::RunFailed { run_id, .. }
172        | Event::StepCompleted { run_id, .. }
173        | Event::StepFailed { run_id, .. }
174        | Event::ApprovalRequested { run_id, .. }
175        | Event::ApprovalGranted { run_id, .. }
176        | Event::ApprovalRejected { run_id, .. } => Some(*run_id),
177        Event::UserSignedIn { .. } | Event::UserSignedUp { .. } | Event::UserSignedOut { .. } => {
178            None
179        }
180    }
181}
182
183/// `GET /api/v1/events` -- Server-Sent Events stream.
184///
185/// Streams domain events in real time. Supports optional filtering:
186/// - `?run_id=<uuid>` -- only events for that run
187/// - `?types=run_status_changed,step_completed` -- only those event types
188///
189/// Each SSE message has:
190/// - `event:` set to the event type (e.g. `run_status_changed`)
191/// - `data:` JSON-serialized event payload
192///
193/// A keep-alive comment is sent every 30 seconds.
194///
195/// # Errors
196///
197/// Returns 401 if the request is not authenticated.
198pub async fn events(
199    _auth: Authenticated,
200    State(state): State<AppState>,
201    Query(query): Query<EventsQuery>,
202) -> Sse<impl Stream<Item = Result<SseEvent, Infallible>>> {
203    let receiver = state.event_sender.subscribe();
204    let type_filter = query.types;
205
206    let stream = BroadcastStream::new(receiver).filter_map(move |result: Result<Event, _>| {
207        let type_filter = type_filter.clone();
208        let run_id_filter = query.run_id;
209        async move {
210            let event = result.ok()?;
211
212            if let Some(ref rid) = run_id_filter
213                && event_run_id(&event) != Some(*rid)
214            {
215                return None;
216            }
217
218            if let Some(ref kinds) = type_filter {
219                let event_type = event.event_type();
220                if !kinds.iter().any(|k| k.as_str() == event_type) {
221                    return None;
222                }
223            }
224
225            let data = serde_json::to_string(&event).ok()?;
226            let sse_event = SseEvent::default().event(event.event_type()).data(data);
227
228            Some(Ok::<_, Infallible>(sse_event))
229        }
230    });
231
232    Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(30)))
233}
234
235#[cfg(test)]
236mod tests {
237    use std::sync::Arc;
238    use std::time::Duration;
239
240    use axum::Router;
241    use axum::routing::get;
242    use chrono::Utc;
243    use ironflow_auth::jwt::AccessToken;
244    use ironflow_core::providers::claude::ClaudeCodeProvider;
245    use ironflow_engine::engine::Engine;
246    use ironflow_engine::notify::Event;
247    use ironflow_store::api_key_store::ApiKeyStore;
248    use ironflow_store::memory::InMemoryStore;
249    use ironflow_store::models::RunStatus;
250    use ironflow_store::user_store::UserStore;
251    use rust_decimal::Decimal;
252    use tokio::io::AsyncBufReadExt;
253    use tokio::io::BufReader;
254    use tokio::net::TcpListener;
255    use tokio::sync::broadcast;
256    use tokio::time::{sleep, timeout};
257    use uuid::Uuid;
258
259    use super::events;
260    use crate::state::AppState;
261
262    fn test_state() -> AppState {
263        let store = Arc::new(InMemoryStore::new());
264        let user_store: Arc<dyn UserStore> = Arc::new(InMemoryStore::new());
265        let api_key_store: Arc<dyn ApiKeyStore> = Arc::new(InMemoryStore::new());
266        let provider = Arc::new(ClaudeCodeProvider::new());
267        let engine = Arc::new(Engine::new(store.clone(), provider));
268        let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
269            secret: "test-secret".to_string(),
270            access_token_ttl_secs: 900,
271            refresh_token_ttl_secs: 604800,
272            cookie_domain: None,
273            cookie_secure: false,
274        });
275        let (event_sender, _) = broadcast::channel::<Event>(16);
276        AppState::new(
277            store,
278            user_store,
279            api_key_store,
280            engine,
281            jwt_config,
282            "test-worker-token".to_string(),
283            event_sender,
284        )
285    }
286
287    fn sample_run_event(run_id: Uuid) -> Event {
288        Event::RunStatusChanged {
289            run_id,
290            workflow_name: "deploy".to_string(),
291            from: RunStatus::Running,
292            to: RunStatus::Completed,
293            error: None,
294            cost_usd: Decimal::ZERO,
295            duration_ms: 1000,
296            at: Utc::now(),
297        }
298    }
299
300    fn sample_user_event() -> Event {
301        Event::UserSignedIn {
302            user_id: Uuid::now_v7(),
303            username: "alice".to_string(),
304            at: Utc::now(),
305        }
306    }
307
308    fn make_auth_token(state: &AppState) -> String {
309        let user_id = Uuid::now_v7();
310        let token = AccessToken::for_user(user_id, "testuser", false, &state.jwt_config).unwrap();
311        format!("Bearer {}", token.0)
312    }
313
314    /// Start a real TCP server and return (address, sender, auth header).
315    async fn start_sse_server(state: AppState) -> (String, broadcast::Sender<Event>, String) {
316        let sender = state.event_sender.clone();
317        let auth = make_auth_token(&state);
318        let app = Router::new()
319            .route("/events", get(events))
320            .with_state(state);
321
322        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
323        let addr = listener.local_addr().unwrap().to_string();
324        tokio::spawn(async move {
325            axum::serve(listener, app).await.unwrap();
326        });
327        (addr, sender, auth)
328    }
329
330    /// Connect to the SSE endpoint with auth and return a line reader.
331    async fn connect_sse(addr: &str, query: &str, auth: &str) -> BufReader<tokio::net::TcpStream> {
332        let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
333        let (reader, mut writer) = stream.into_split();
334
335        use tokio::io::AsyncWriteExt;
336        writer
337            .write_all(
338                format!(
339                    "GET /events{query} HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\nAuthorization: {auth}\r\n\r\n"
340                )
341                .as_bytes(),
342            )
343            .await
344            .unwrap();
345
346        BufReader::new(reader.reunite(writer).unwrap())
347    }
348
349    /// Read all available data from the SSE stream until `needle` is found
350    /// in the accumulated text, or timeout.
351    async fn read_until_contains(
352        reader: &mut BufReader<tokio::net::TcpStream>,
353        needle: &str,
354        dur: Duration,
355    ) -> String {
356        let mut accumulated = String::new();
357        let result = timeout(dur, async {
358            loop {
359                let mut line = String::new();
360                let n = reader.read_line(&mut line).await.unwrap();
361                if n == 0 {
362                    break;
363                }
364                accumulated.push_str(&line);
365                if accumulated.contains(needle) {
366                    break;
367                }
368            }
369        })
370        .await;
371        if result.is_err() {
372            panic!("timeout waiting for '{needle}' in SSE stream. Data so far:\n{accumulated}");
373        }
374        accumulated
375    }
376
377    #[tokio::test]
378    async fn sse_stream_receives_events() {
379        let state = test_state();
380        let (addr, sender, auth) = start_sse_server(state).await;
381        let mut reader = connect_sse(&addr, "", &auth).await;
382
383        sleep(Duration::from_millis(50)).await;
384
385        let run_id = Uuid::now_v7();
386        sender.send(sample_run_event(run_id)).unwrap();
387
388        let text =
389            read_until_contains(&mut reader, &run_id.to_string(), Duration::from_secs(5)).await;
390
391        assert!(text.contains("run_status_changed"));
392        assert!(text.contains(&run_id.to_string()));
393    }
394
395    #[tokio::test]
396    async fn sse_filters_by_run_id() {
397        let state = test_state();
398        let (addr, sender, auth) = start_sse_server(state).await;
399
400        let target_run = Uuid::now_v7();
401        let other_run = Uuid::now_v7();
402
403        let mut reader = connect_sse(&addr, &format!("?run_id={target_run}"), &auth).await;
404        sleep(Duration::from_millis(50)).await;
405
406        sender.send(sample_run_event(other_run)).unwrap();
407        sender.send(sample_run_event(target_run)).unwrap();
408
409        let text =
410            read_until_contains(&mut reader, &target_run.to_string(), Duration::from_secs(5)).await;
411
412        assert!(text.contains(&target_run.to_string()));
413        assert!(!text.contains(&other_run.to_string()));
414    }
415
416    #[tokio::test]
417    async fn sse_filters_by_event_type() {
418        let state = test_state();
419        let (addr, sender, auth) = start_sse_server(state).await;
420
421        let mut reader = connect_sse(&addr, "?types=user_signed_in", &auth).await;
422        sleep(Duration::from_millis(50)).await;
423
424        let run_id = Uuid::now_v7();
425        sender.send(sample_run_event(run_id)).unwrap();
426        sender.send(sample_user_event()).unwrap();
427
428        let text = read_until_contains(&mut reader, "user_signed_in", Duration::from_secs(5)).await;
429
430        assert!(text.contains("user_signed_in"));
431        assert!(!text.contains("run_status_changed"));
432    }
433
434    #[tokio::test]
435    async fn sse_returns_correct_content_type() {
436        let state = test_state();
437        let (addr, _sender, auth) = start_sse_server(state).await;
438        let mut reader = connect_sse(&addr, "", &auth).await;
439
440        let text =
441            read_until_contains(&mut reader, "text/event-stream", Duration::from_secs(5)).await;
442
443        assert!(text.contains("text/event-stream"));
444    }
445
446    #[tokio::test]
447    async fn sse_rejects_unauthenticated() {
448        let state = test_state();
449        let (addr, _sender, _auth) = start_sse_server(state).await;
450        // Connect without auth header
451        let stream = tokio::net::TcpStream::connect(&addr).await.unwrap();
452        let (reader, mut writer) = stream.into_split();
453
454        use tokio::io::AsyncWriteExt;
455        writer
456            .write_all(
457                format!(
458                    "GET /events HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\n\r\n"
459                )
460                .as_bytes(),
461            )
462            .await
463            .unwrap();
464
465        let mut buf_reader = BufReader::new(reader.reunite(writer).unwrap());
466        let text = read_until_contains(&mut buf_reader, "401", Duration::from_secs(5)).await;
467
468        assert!(text.contains("401"));
469    }
470}