Skip to main content

lago_api/routes/
events.rs

1use std::convert::Infallible;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::{Path, Query, State};
6use axum::http::HeaderMap;
7use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
8use futures::StreamExt;
9use serde::Deserialize;
10use tracing::debug;
11
12use lago_core::id::{BranchId, SeqNo, SessionId};
13
14use crate::error::ApiError;
15use crate::sse::format::{SseFormat, SseFrame};
16use crate::sse::{anthropic, lago, openai, vercel};
17use crate::state::AppState;
18
19// --- Query params
20
21#[derive(Deserialize, Default)]
22pub struct EventStreamQuery {
23    /// Output format: openai, anthropic, vercel, or lago (default).
24    #[serde(default = "default_format")]
25    pub format: String,
26    /// Only return events after this sequence number.
27    pub after_seq: Option<SeqNo>,
28    /// Branch name (default: "main").
29    #[serde(default = "default_branch")]
30    pub branch: String,
31}
32
33fn default_format() -> String {
34    "lago".to_string()
35}
36
37fn default_branch() -> String {
38    "main".to_string()
39}
40
41/// Resolve the SSE format adapter from the query parameter string.
42fn resolve_format(name: &str) -> Result<Arc<dyn SseFormat>, ApiError> {
43    match name {
44        "openai" => Ok(Arc::new(openai::OpenAiFormat)),
45        "anthropic" => Ok(Arc::new(anthropic::AnthropicFormat)),
46        "vercel" => Ok(Arc::new(vercel::VercelFormat)),
47        "lago" | "" => Ok(Arc::new(lago::LagoFormat)),
48        other => Err(ApiError::BadRequest(format!(
49            "unknown format: {other}. Supported: openai, anthropic, vercel, lago"
50        ))),
51    }
52}
53
54/// Parse the `Last-Event-ID` header to determine where to resume streaming.
55fn parse_last_event_id(headers: &HeaderMap) -> Option<SeqNo> {
56    headers
57        .get("Last-Event-ID")
58        .and_then(|v| v.to_str().ok())
59        .and_then(|s| s.parse::<SeqNo>().ok())
60}
61
62/// Convert an `SseFrame` into an axum `SseEvent`.
63fn frame_to_sse_event(frame: SseFrame) -> SseEvent {
64    let mut event = SseEvent::default().data(frame.data);
65    if let Some(name) = frame.event {
66        event = event.event(name);
67    }
68    if let Some(id) = frame.id {
69        event = event.id(id);
70    }
71    event
72}
73
74/// GET /v1/sessions/:id/events
75///
76/// Streams events for a session in the requested format using Server-Sent Events.
77/// Supports reconnection via the `Last-Event-ID` header and keep-alive pings
78/// every 15 seconds.
79pub async fn stream_events(
80    State(state): State<Arc<AppState>>,
81    Path(session_id): Path<String>,
82    Query(query): Query<EventStreamQuery>,
83    headers: HeaderMap,
84) -> Result<Sse<impl futures::Stream<Item = Result<SseEvent, Infallible>>>, ApiError> {
85    let session_id = SessionId::from_string(session_id);
86    let branch_id = BranchId::from_string(query.branch.clone());
87    let format = resolve_format(&query.format)?;
88
89    // Determine the starting sequence number. The `Last-Event-ID` header takes
90    // precedence, falling back to the `after_seq` query parameter, and finally
91    // defaulting to 0 (stream from the beginning).
92    let after_seq = parse_last_event_id(&headers)
93        .or(query.after_seq)
94        .unwrap_or(0);
95
96    debug!(
97        session = %session_id,
98        branch = %branch_id,
99        after_seq = after_seq,
100        format = format.name(),
101        "starting SSE event stream"
102    );
103
104    // Verify session exists
105    state
106        .journal
107        .get_session(&session_id)
108        .await?
109        .ok_or_else(|| ApiError::NotFound(format!("session not found: {session_id}")))?;
110
111    // Open a tailing event stream from the journal
112    let event_stream = state
113        .journal
114        .stream(session_id, branch_id, after_seq)
115        .await?;
116
117    // Map journal events through the format adapter, producing SSE frames.
118    // The format Arc is cloned for each item so the closure is 'static + Send.
119    let sse_stream = event_stream.filter_map(move |result| {
120        let format = Arc::clone(&format);
121        async move {
122            match result {
123                Ok(envelope) => {
124                    let frames = format.format(&envelope);
125                    if frames.is_empty() {
126                        None
127                    } else {
128                        let events: Vec<SseEvent> =
129                            frames.into_iter().map(frame_to_sse_event).collect();
130                        Some(events)
131                    }
132                }
133                Err(e) => {
134                    tracing::warn!(error = %e, "error reading event from journal stream");
135                    None
136                }
137            }
138        }
139    });
140
141    // Flatten: each envelope may produce multiple SSE events
142    let flat_stream = sse_stream
143        .flat_map(futures::stream::iter)
144        .map(Ok::<_, Infallible>);
145
146    Ok(Sse::new(flat_stream).keep_alive(
147        KeepAlive::new()
148            .interval(Duration::from_secs(15))
149            .text("ping"),
150    ))
151}