Skip to main content

ironflow_api/routes/
events.rs

1//! SSE endpoint for real-time event streaming.
2
3use std::convert::Infallible;
4use std::fmt;
5use std::str::FromStr;
6use std::time::Duration;
7
8use axum::extract::{Query, State};
9use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
10use futures_util::stream::{Stream, StreamExt};
11use serde::Deserialize;
12use serde::de::{self, Deserializer};
13use tokio_stream::wrappers::BroadcastStream;
14use uuid::Uuid;
15
16use ironflow_auth::extractor::Authenticated;
17use ironflow_engine::notify::Event;
18
19use crate::state::AppState;
20
21/// Strongly-typed event kind matching [`Event`] variants.
22///
23/// Serializes to/from the same `snake_case` strings as
24/// [`Event::event_type()`].
25///
26/// # Examples
27///
28/// ```
29/// use ironflow_api::routes::events::EventKind;
30///
31/// let kind: EventKind = "run_status_changed".parse().unwrap();
32/// assert_eq!(kind, EventKind::RunStatusChanged);
33/// assert_eq!(kind.as_str(), "run_status_changed");
34/// ```
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[cfg_attr(feature = "openapi", derive(serde::Serialize, utoipa::ToSchema))]
37#[cfg_attr(feature = "openapi", serde(rename_all = "snake_case"))]
38pub enum EventKind {
39    /// [`Event::RunCreated`]
40    RunCreated,
41    /// [`Event::RunStatusChanged`]
42    RunStatusChanged,
43    /// [`Event::RunFailed`]
44    RunFailed,
45    /// [`Event::StepCompleted`]
46    StepCompleted,
47    /// [`Event::StepFailed`]
48    StepFailed,
49    /// [`Event::ApprovalRequested`]
50    ApprovalRequested,
51    /// [`Event::ApprovalGranted`]
52    ApprovalGranted,
53    /// [`Event::ApprovalRejected`]
54    ApprovalRejected,
55    /// [`Event::UserSignedIn`]
56    UserSignedIn,
57    /// [`Event::UserSignedUp`]
58    UserSignedUp,
59    /// [`Event::UserSignedOut`]
60    UserSignedOut,
61}
62
63impl EventKind {
64    /// Returns the wire-format string for this kind.
65    pub fn as_str(self) -> &'static str {
66        match self {
67            Self::RunCreated => Event::RUN_CREATED,
68            Self::RunStatusChanged => Event::RUN_STATUS_CHANGED,
69            Self::RunFailed => Event::RUN_FAILED,
70            Self::StepCompleted => Event::STEP_COMPLETED,
71            Self::StepFailed => Event::STEP_FAILED,
72            Self::ApprovalRequested => Event::APPROVAL_REQUESTED,
73            Self::ApprovalGranted => Event::APPROVAL_GRANTED,
74            Self::ApprovalRejected => Event::APPROVAL_REJECTED,
75            Self::UserSignedIn => Event::USER_SIGNED_IN,
76            Self::UserSignedUp => Event::USER_SIGNED_UP,
77            Self::UserSignedOut => Event::USER_SIGNED_OUT,
78        }
79    }
80}
81
82impl fmt::Display for EventKind {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        f.write_str(self.as_str())
85    }
86}
87
88impl FromStr for EventKind {
89    type Err = InvalidEventKind;
90
91    fn from_str(s: &str) -> Result<Self, Self::Err> {
92        match s {
93            "run_created" => Ok(Self::RunCreated),
94            "run_status_changed" => Ok(Self::RunStatusChanged),
95            "run_failed" => Ok(Self::RunFailed),
96            "step_completed" => Ok(Self::StepCompleted),
97            "step_failed" => Ok(Self::StepFailed),
98            "approval_requested" => Ok(Self::ApprovalRequested),
99            "approval_granted" => Ok(Self::ApprovalGranted),
100            "approval_rejected" => Ok(Self::ApprovalRejected),
101            "user_signed_in" => Ok(Self::UserSignedIn),
102            "user_signed_up" => Ok(Self::UserSignedUp),
103            "user_signed_out" => Ok(Self::UserSignedOut),
104            _ => Err(InvalidEventKind(s.to_string())),
105        }
106    }
107}
108
109/// Error returned when parsing an unknown event kind string.
110#[derive(Debug, Clone)]
111pub struct InvalidEventKind(pub String);
112
113impl fmt::Display for InvalidEventKind {
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        write!(f, "unknown event kind: {}", self.0)
116    }
117}
118
119impl std::error::Error for InvalidEventKind {}
120
121/// Deserialize a comma-separated string into `Option<Vec<EventKind>>`.
122fn deserialize_comma_event_kinds<'de, D>(
123    deserializer: D,
124) -> Result<Option<Vec<EventKind>>, D::Error>
125where
126    D: Deserializer<'de>,
127{
128    let opt: Option<String> = Option::deserialize(deserializer)?;
129    match opt {
130        None => Ok(None),
131        Some(raw) => {
132            let kinds: Result<Vec<EventKind>, _> = raw
133                .split(',')
134                .map(|s| s.trim())
135                .filter(|s| !s.is_empty())
136                .map(EventKind::from_str)
137                .collect();
138            kinds.map(Some).map_err(de::Error::custom)
139        }
140    }
141}
142
143/// Query parameters for the SSE events endpoint.
144///
145/// Both fields are optional. When set, only matching events are streamed.
146///
147/// # Examples
148///
149/// ```
150/// use ironflow_api::routes::events::EventsQuery;
151///
152/// let query = EventsQuery {
153///     run_id: None,
154///     types: None,
155/// };
156/// ```
157#[derive(Debug, Deserialize)]
158pub struct EventsQuery {
159    /// Only stream events related to this run.
160    pub run_id: Option<Uuid>,
161    /// Comma-separated list of event types to include (e.g. `?types=run_status_changed,step_completed`).
162    #[serde(default, deserialize_with = "deserialize_comma_event_kinds")]
163    pub types: Option<Vec<EventKind>>,
164}
165
166/// Extract the `run_id` from an event, if the variant carries one.
167fn event_run_id(event: &Event) -> Option<Uuid> {
168    match event {
169        Event::RunCreated { run_id, .. }
170        | Event::RunStatusChanged { run_id, .. }
171        | Event::RunFailed { run_id, .. }
172        | Event::StepCompleted { run_id, .. }
173        | Event::StepFailed { run_id, .. }
174        | Event::ApprovalRequested { run_id, .. }
175        | Event::ApprovalGranted { run_id, .. }
176        | Event::ApprovalRejected { run_id, .. } => Some(*run_id),
177        Event::UserSignedIn { .. } | Event::UserSignedUp { .. } | Event::UserSignedOut { .. } => {
178            None
179        }
180    }
181}
182
183/// `GET /api/v1/events` -- Server-Sent Events stream.
184///
185/// Streams domain events in real time. Supports optional filtering:
186/// - `?run_id=<uuid>` -- only events for that run
187/// - `?types=run_status_changed,step_completed` -- only those event types
188///
189/// Each SSE message has:
190/// - `event:` set to the event type (e.g. `run_status_changed`)
191/// - `data:` JSON-serialized event payload
192///
193/// A keep-alive comment is sent every 30 seconds.
194///
195/// # Errors
196///
197/// Returns 401 if the request is not authenticated.
198pub async fn events(
199    _auth: Authenticated,
200    State(state): State<AppState>,
201    Query(query): Query<EventsQuery>,
202) -> Sse<impl Stream<Item = Result<SseEvent, Infallible>>> {
203    let receiver = state.event_sender.subscribe();
204    let type_filter = query.types;
205
206    let stream = BroadcastStream::new(receiver).filter_map(move |result: Result<Event, _>| {
207        let type_filter = type_filter.clone();
208        let run_id_filter = query.run_id;
209        async move {
210            let event = result.ok()?;
211
212            if let Some(ref rid) = run_id_filter
213                && event_run_id(&event) != Some(*rid)
214            {
215                return None;
216            }
217
218            if let Some(ref kinds) = type_filter {
219                let event_type = event.event_type();
220                if !kinds.iter().any(|k| k.as_str() == event_type) {
221                    return None;
222                }
223            }
224
225            let data = serde_json::to_string(&event).ok()?;
226            let sse_event = SseEvent::default().event(event.event_type()).data(data);
227
228            Some(Ok::<_, Infallible>(sse_event))
229        }
230    });
231
232    Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(30)))
233}
234
235#[cfg(test)]
236mod tests {
237    use std::sync::Arc;
238    use std::time::Duration;
239
240    use axum::Router;
241    use axum::routing::get;
242    use chrono::Utc;
243    use ironflow_auth::jwt::AccessToken;
244    use ironflow_core::providers::claude::ClaudeCodeProvider;
245    use ironflow_engine::engine::Engine;
246    use ironflow_engine::notify::Event;
247    use ironflow_store::memory::InMemoryStore;
248    use ironflow_store::models::RunStatus;
249    use rust_decimal::Decimal;
250    use tokio::io::AsyncBufReadExt;
251    use tokio::io::BufReader;
252    use tokio::net::TcpListener;
253    use tokio::sync::broadcast;
254    use tokio::time::{sleep, timeout};
255    use uuid::Uuid;
256
257    use super::events;
258    use crate::state::AppState;
259
260    fn test_state() -> AppState {
261        let store = Arc::new(InMemoryStore::new());
262        let provider = Arc::new(ClaudeCodeProvider::new());
263        let engine = Arc::new(Engine::new(store.clone(), provider));
264        let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
265            secret: "test-secret".to_string(),
266            access_token_ttl_secs: 900,
267            refresh_token_ttl_secs: 604800,
268            cookie_domain: None,
269            cookie_secure: false,
270        });
271        let (event_sender, _) = broadcast::channel::<Event>(16);
272        AppState::new(
273            store,
274            engine,
275            jwt_config,
276            "test-worker-token".to_string(),
277            event_sender,
278        )
279    }
280
281    fn sample_run_event(run_id: Uuid) -> Event {
282        Event::RunStatusChanged {
283            run_id,
284            workflow_name: "deploy".to_string(),
285            from: RunStatus::Running,
286            to: RunStatus::Completed,
287            error: None,
288            cost_usd: Decimal::ZERO,
289            duration_ms: 1000,
290            at: Utc::now(),
291        }
292    }
293
294    fn sample_user_event() -> Event {
295        Event::UserSignedIn {
296            user_id: Uuid::now_v7(),
297            username: "alice".to_string(),
298            at: Utc::now(),
299        }
300    }
301
302    fn make_auth_token(state: &AppState) -> String {
303        let user_id = Uuid::now_v7();
304        let token = AccessToken::for_user(user_id, "testuser", false, &state.jwt_config).unwrap();
305        format!("Bearer {}", token.0)
306    }
307
308    /// Start a real TCP server and return (address, sender, auth header).
309    async fn start_sse_server(state: AppState) -> (String, broadcast::Sender<Event>, String) {
310        let sender = state.event_sender.clone();
311        let auth = make_auth_token(&state);
312        let app = Router::new()
313            .route("/events", get(events))
314            .with_state(state);
315
316        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
317        let addr = listener.local_addr().unwrap().to_string();
318        tokio::spawn(async move {
319            axum::serve(listener, app).await.unwrap();
320        });
321        (addr, sender, auth)
322    }
323
324    /// Connect to the SSE endpoint with auth and return a line reader.
325    async fn connect_sse(addr: &str, query: &str, auth: &str) -> BufReader<tokio::net::TcpStream> {
326        let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
327        let (reader, mut writer) = stream.into_split();
328
329        use tokio::io::AsyncWriteExt;
330        writer
331            .write_all(
332                format!(
333                    "GET /events{query} HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\nAuthorization: {auth}\r\n\r\n"
334                )
335                .as_bytes(),
336            )
337            .await
338            .unwrap();
339
340        BufReader::new(reader.reunite(writer).unwrap())
341    }
342
343    /// Read all available data from the SSE stream until `needle` is found
344    /// in the accumulated text, or timeout.
345    async fn read_until_contains(
346        reader: &mut BufReader<tokio::net::TcpStream>,
347        needle: &str,
348        dur: Duration,
349    ) -> String {
350        let mut accumulated = String::new();
351        let result = timeout(dur, async {
352            loop {
353                let mut line = String::new();
354                let n = reader.read_line(&mut line).await.unwrap();
355                if n == 0 {
356                    break;
357                }
358                accumulated.push_str(&line);
359                if accumulated.contains(needle) {
360                    break;
361                }
362            }
363        })
364        .await;
365        if result.is_err() {
366            panic!("timeout waiting for '{needle}' in SSE stream. Data so far:\n{accumulated}");
367        }
368        accumulated
369    }
370
371    #[tokio::test]
372    async fn sse_stream_receives_events() {
373        let state = test_state();
374        let (addr, sender, auth) = start_sse_server(state).await;
375        let mut reader = connect_sse(&addr, "", &auth).await;
376
377        sleep(Duration::from_millis(50)).await;
378
379        let run_id = Uuid::now_v7();
380        sender.send(sample_run_event(run_id)).unwrap();
381
382        let text =
383            read_until_contains(&mut reader, &run_id.to_string(), Duration::from_secs(5)).await;
384
385        assert!(text.contains("run_status_changed"));
386        assert!(text.contains(&run_id.to_string()));
387    }
388
389    #[tokio::test]
390    async fn sse_filters_by_run_id() {
391        let state = test_state();
392        let (addr, sender, auth) = start_sse_server(state).await;
393
394        let target_run = Uuid::now_v7();
395        let other_run = Uuid::now_v7();
396
397        let mut reader = connect_sse(&addr, &format!("?run_id={target_run}"), &auth).await;
398        sleep(Duration::from_millis(50)).await;
399
400        sender.send(sample_run_event(other_run)).unwrap();
401        sender.send(sample_run_event(target_run)).unwrap();
402
403        let text =
404            read_until_contains(&mut reader, &target_run.to_string(), Duration::from_secs(5)).await;
405
406        assert!(text.contains(&target_run.to_string()));
407        assert!(!text.contains(&other_run.to_string()));
408    }
409
410    #[tokio::test]
411    async fn sse_filters_by_event_type() {
412        let state = test_state();
413        let (addr, sender, auth) = start_sse_server(state).await;
414
415        let mut reader = connect_sse(&addr, "?types=user_signed_in", &auth).await;
416        sleep(Duration::from_millis(50)).await;
417
418        let run_id = Uuid::now_v7();
419        sender.send(sample_run_event(run_id)).unwrap();
420        sender.send(sample_user_event()).unwrap();
421
422        let text = read_until_contains(&mut reader, "user_signed_in", Duration::from_secs(5)).await;
423
424        assert!(text.contains("user_signed_in"));
425        assert!(!text.contains("run_status_changed"));
426    }
427
428    #[tokio::test]
429    async fn sse_returns_correct_content_type() {
430        let state = test_state();
431        let (addr, _sender, auth) = start_sse_server(state).await;
432        let mut reader = connect_sse(&addr, "", &auth).await;
433
434        let text =
435            read_until_contains(&mut reader, "text/event-stream", Duration::from_secs(5)).await;
436
437        assert!(text.contains("text/event-stream"));
438    }
439
440    #[tokio::test]
441    async fn sse_rejects_unauthenticated() {
442        let state = test_state();
443        let (addr, _sender, _auth) = start_sse_server(state).await;
444        // Connect without auth header
445        let stream = tokio::net::TcpStream::connect(&addr).await.unwrap();
446        let (reader, mut writer) = stream.into_split();
447
448        use tokio::io::AsyncWriteExt;
449        writer
450            .write_all(
451                format!(
452                    "GET /events HTTP/1.1\r\nHost: {addr}\r\nAccept: text/event-stream\r\n\r\n"
453                )
454                .as_bytes(),
455            )
456            .await
457            .unwrap();
458
459        let mut buf_reader = BufReader::new(reader.reunite(writer).unwrap());
460        let text = read_until_contains(&mut buf_reader, "401", Duration::from_secs(5)).await;
461
462        assert!(text.contains("401"));
463    }
464}