Skip to main content

lago_api/routes/
events.rs

1use std::convert::Infallible;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::{Path, Query, State};
6use axum::http::HeaderMap;
7use axum::http::{HeaderName, HeaderValue};
8use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
9use axum::response::{IntoResponse, Response};
10use futures::StreamExt;
11use serde::Deserialize;
12use tracing::debug;
13
14use lago_core::id::{BranchId, SeqNo, SessionId};
15
16use crate::error::ApiError;
17use crate::sse::format::{SseFormat, SseFrame};
18use crate::sse::{anthropic, lago, openai, vercel};
19use crate::state::AppState;
20
21// --- Query params
22
23#[derive(Deserialize, Default)]
24pub struct EventStreamQuery {
25    /// Output format: openai, anthropic, vercel, or lago (default).
26    #[serde(default = "default_format")]
27    pub format: String,
28    /// Only return events after this sequence number.
29    pub after_seq: Option<SeqNo>,
30    /// Branch name (default: "main").
31    #[serde(default = "default_branch")]
32    pub branch: String,
33}
34
35fn default_format() -> String {
36    "lago".to_string()
37}
38
39fn default_branch() -> String {
40    "main".to_string()
41}
42
43/// Resolve the SSE format adapter from the query parameter string.
44fn resolve_format(name: &str) -> Result<Arc<dyn SseFormat>, ApiError> {
45    match name {
46        "openai" => Ok(Arc::new(openai::OpenAiFormat)),
47        "anthropic" => Ok(Arc::new(anthropic::AnthropicFormat)),
48        "vercel" => Ok(Arc::new(vercel::VercelFormat)),
49        "lago" | "" => Ok(Arc::new(lago::LagoFormat)),
50        other => Err(ApiError::BadRequest(format!(
51            "unknown format: {other}. Supported: openai, anthropic, vercel, lago"
52        ))),
53    }
54}
55
56/// Parse the `Last-Event-ID` header to determine where to resume streaming.
57fn parse_last_event_id(headers: &HeaderMap) -> Option<SeqNo> {
58    headers
59        .get("Last-Event-ID")
60        .and_then(|v| v.to_str().ok())
61        .and_then(|s| s.parse::<SeqNo>().ok())
62}
63
64/// Convert an `SseFrame` into an axum `SseEvent`.
65fn frame_to_sse_event(frame: SseFrame) -> SseEvent {
66    let mut event = SseEvent::default().data(frame.data);
67    if let Some(name) = frame.event {
68        event = event.event(name);
69    }
70    if let Some(id) = frame.id {
71        event = event.id(id);
72    }
73    event
74}
75
76/// GET /v1/sessions/:id/events
77///
78/// Streams events for a session in the requested format using Server-Sent Events.
79/// Supports reconnection via the `Last-Event-ID` header and keep-alive pings
80/// every 15 seconds.
81pub async fn stream_events(
82    State(state): State<Arc<AppState>>,
83    Path(session_id): Path<String>,
84    Query(query): Query<EventStreamQuery>,
85    headers: HeaderMap,
86) -> Result<Response, ApiError> {
87    let session_id = SessionId::from_string(session_id);
88    let branch_id = BranchId::from_string(query.branch.clone());
89    let format = resolve_format(&query.format)?;
90
91    // Determine the starting sequence number. The `Last-Event-ID` header takes
92    // precedence, falling back to the `after_seq` query parameter, and finally
93    // defaulting to 0 (stream from the beginning).
94    let after_seq = parse_last_event_id(&headers)
95        .or(query.after_seq)
96        .unwrap_or(0);
97
98    debug!(
99        session = %session_id,
100        branch = %branch_id,
101        after_seq = after_seq,
102        format = format.name(),
103        "starting SSE event stream"
104    );
105
106    // Verify session exists
107    state
108        .journal
109        .get_session(&session_id)
110        .await?
111        .ok_or_else(|| ApiError::NotFound(format!("session not found: {session_id}")))?;
112
113    // Open a tailing event stream from the journal
114    let event_stream = state
115        .journal
116        .stream(session_id, branch_id, after_seq)
117        .await?;
118
119    // Map journal events through the format adapter, producing SSE frames.
120    // The format Arc is cloned for each item so the closure is 'static + Send.
121    let format_for_stream = Arc::clone(&format);
122    let sse_stream = event_stream.filter_map(move |result| {
123        let format = Arc::clone(&format_for_stream);
124        async move {
125            match result {
126                Ok(envelope) => {
127                    let frames = format.format(&envelope);
128                    if frames.is_empty() {
129                        None
130                    } else {
131                        let events: Vec<SseEvent> =
132                            frames.into_iter().map(frame_to_sse_event).collect();
133                        Some(events)
134                    }
135                }
136                Err(e) => {
137                    tracing::warn!(error = %e, "error reading event from journal stream");
138                    None
139                }
140            }
141        }
142    });
143
144    // Flatten: each envelope may produce multiple SSE events
145    let flat_stream = sse_stream
146        .flat_map(futures::stream::iter)
147        .map(Ok::<_, Infallible>);
148
149    let combined_stream: futures::stream::BoxStream<'static, Result<SseEvent, Infallible>> =
150        if let Some(done_frame) = format.done_frame() {
151            flat_stream
152                .chain(futures::stream::once(async move {
153                    Ok::<_, Infallible>(frame_to_sse_event(done_frame))
154                }))
155                .boxed()
156        } else {
157            flat_stream.boxed()
158        };
159
160    let sse = Sse::new(combined_stream).keep_alive(
161        KeepAlive::new()
162            .interval(Duration::from_secs(15))
163            .text("ping"),
164    );
165    let mut response = sse.into_response();
166    for (name, value) in format.extra_headers() {
167        if let (Ok(header_name), Ok(header_value)) = (
168            HeaderName::from_bytes(name.as_bytes()),
169            HeaderValue::from_str(&value),
170        ) {
171            response.headers_mut().insert(header_name, header_value);
172        }
173    }
174
175    Ok(response)
176}