Skip to main content

lago_api/routes/
events.rs

1use std::convert::Infallible;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::Json;
6use axum::extract::{Path, Query, State};
7use axum::http::HeaderMap;
8use axum::http::{HeaderName, HeaderValue};
9use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
10use axum::response::{IntoResponse, Response};
11use futures::StreamExt;
12use serde::{Deserialize, Serialize};
13use tracing::{debug, instrument};
14
15use lago_core::EventQuery;
16use lago_core::event::EventEnvelope;
17use lago_core::id::{BranchId, SeqNo, SessionId};
18
19use crate::error::ApiError;
20use crate::sse::format::{SseFormat, SseFrame};
21use crate::sse::{anthropic, lago, openai, vercel};
22use crate::state::AppState;
23
24// --- Query params
25
26#[derive(Deserialize, Default)]
27pub struct EventStreamQuery {
28    /// Output format: openai, anthropic, vercel, or lago (default).
29    #[serde(default = "default_format")]
30    pub format: String,
31    /// Only return events after this sequence number.
32    pub after_seq: Option<SeqNo>,
33    /// Branch name (default: "main").
34    #[serde(default = "default_branch")]
35    pub branch: String,
36}
37
38fn default_format() -> String {
39    "lago".to_string()
40}
41
42fn default_branch() -> String {
43    "main".to_string()
44}
45
46/// Resolve the SSE format adapter from the query parameter string.
47fn resolve_format(name: &str) -> Result<Arc<dyn SseFormat>, ApiError> {
48    match name {
49        "openai" => Ok(Arc::new(openai::OpenAiFormat)),
50        "anthropic" => Ok(Arc::new(anthropic::AnthropicFormat)),
51        "vercel" => Ok(Arc::new(vercel::VercelFormat)),
52        "lago" | "" => Ok(Arc::new(lago::LagoFormat)),
53        other => Err(ApiError::BadRequest(format!(
54            "unknown format: {other}. Supported: openai, anthropic, vercel, lago"
55        ))),
56    }
57}
58
59/// Parse the `Last-Event-ID` header to determine where to resume streaming.
60fn parse_last_event_id(headers: &HeaderMap) -> Option<SeqNo> {
61    headers
62        .get("Last-Event-ID")
63        .and_then(|v| v.to_str().ok())
64        .and_then(|s| s.parse::<SeqNo>().ok())
65}
66
67/// Convert an `SseFrame` into an axum `SseEvent`.
68fn frame_to_sse_event(frame: SseFrame) -> SseEvent {
69    let mut event = SseEvent::default().data(frame.data);
70    if let Some(name) = frame.event {
71        event = event.event(name);
72    }
73    if let Some(id) = frame.id {
74        event = event.id(id);
75    }
76    event
77}
78
79// ─── Request / response types for write endpoints ─────────────────────────
80
81#[derive(Deserialize)]
82pub struct AppendEventRequest {
83    pub event: EventEnvelope,
84}
85
86#[derive(Serialize)]
87pub struct AppendEventResponse {
88    pub seq: SeqNo,
89}
90
91#[derive(Deserialize, Default)]
92pub struct ReadEventsQuery {
93    #[serde(default = "default_branch")]
94    pub branch: String,
95    #[serde(default)]
96    pub after_seq: SeqNo,
97    pub limit: Option<usize>,
98}
99
100#[derive(Deserialize, Default)]
101pub struct HeadQuery {
102    #[serde(default = "default_branch")]
103    pub branch: String,
104}
105
106#[derive(Serialize)]
107pub struct HeadSeqResponse {
108    pub seq: SeqNo,
109}
110
111// ─── POST /v1/sessions/:id/events ─────────────────────────────────────────
112
113/// POST /v1/sessions/:id/events
114///
115/// Append a single event to the journal. The `seq` field in the request body
116/// is ignored — the journal assigns a monotonically increasing sequence number.
117/// Returns `{ seq }` with the assigned sequence.
118pub async fn append_event(
119    State(state): State<Arc<AppState>>,
120    Path(session_id): Path<String>,
121    Json(body): Json<AppendEventRequest>,
122) -> Result<Json<AppendEventResponse>, ApiError> {
123    let mut event = body.event;
124    // Ensure session_id on the envelope matches the path parameter.
125    event.session_id = SessionId::from_string(session_id);
126    let seq = state.journal.append(event).await?;
127    Ok(Json(AppendEventResponse { seq }))
128}
129
130// ─── GET /v1/sessions/:id/events/read ─────────────────────────────────────
131
132/// GET /v1/sessions/:id/events/read?branch=main&after_seq=0&limit=100
133///
134/// Batch-read events from the journal. Unlike the SSE stream endpoint this
135/// returns immediately with the current events — it does not tail.
136pub async fn read_events(
137    State(state): State<Arc<AppState>>,
138    Path(session_id): Path<String>,
139    Query(query): Query<ReadEventsQuery>,
140) -> Result<Json<Vec<EventEnvelope>>, ApiError> {
141    let session_id = SessionId::from_string(session_id);
142    let branch_id = BranchId::from_string(query.branch);
143
144    let mut q = EventQuery::new()
145        .session(session_id)
146        .branch(branch_id)
147        .after(query.after_seq.saturating_sub(1));
148    if let Some(limit) = query.limit {
149        q = q.limit(limit);
150    }
151
152    let events = state.journal.read(q).await?;
153    Ok(Json(events))
154}
155
156// ─── GET /v1/sessions/:id/events/head ─────────────────────────────────────
157
158/// GET /v1/sessions/:id/events/head?branch=main
159///
160/// Returns the current head sequence number for a session+branch.
161/// Returns `{ seq: 0 }` if the session has no events yet.
162pub async fn head_seq(
163    State(state): State<Arc<AppState>>,
164    Path(session_id): Path<String>,
165    Query(query): Query<HeadQuery>,
166) -> Result<Json<HeadSeqResponse>, ApiError> {
167    let session_id = SessionId::from_string(session_id);
168    let branch_id = BranchId::from_string(query.branch);
169    let seq = state.journal.head_seq(&session_id, &branch_id).await?;
170    Ok(Json(HeadSeqResponse { seq }))
171}
172
173// ─── SSE stream ───────────────────────────────────────────────────────────
174
175/// GET /v1/sessions/:id/events
176///
177/// Streams events for a session in the requested format using Server-Sent Events.
178/// Supports reconnection via the `Last-Event-ID` header and keep-alive pings
179/// every 15 seconds.
180#[instrument(skip(state, query, headers), fields(lago.stream_id = %session_id))]
181pub async fn stream_events(
182    State(state): State<Arc<AppState>>,
183    Path(session_id): Path<String>,
184    Query(query): Query<EventStreamQuery>,
185    headers: HeaderMap,
186) -> Result<Response, ApiError> {
187    let session_id = SessionId::from_string(session_id);
188    let branch_id = BranchId::from_string(query.branch.clone());
189    let format = resolve_format(&query.format)?;
190
191    // Determine the starting sequence number. The `Last-Event-ID` header takes
192    // precedence, falling back to the `after_seq` query parameter, and finally
193    // defaulting to 0 (stream from the beginning).
194    let after_seq = parse_last_event_id(&headers)
195        .or(query.after_seq)
196        .unwrap_or(0);
197
198    debug!(
199        session = %session_id,
200        branch = %branch_id,
201        after_seq = after_seq,
202        format = format.name(),
203        "starting SSE event stream"
204    );
205
206    // Verify session exists
207    state
208        .journal
209        .get_session(&session_id)
210        .await?
211        .ok_or_else(|| ApiError::NotFound(format!("session not found: {session_id}")))?;
212
213    // Open a tailing event stream from the journal
214    let event_stream = state
215        .journal
216        .stream(session_id, branch_id, after_seq)
217        .await?;
218
219    // Map journal events through the format adapter, producing SSE frames.
220    // The format Arc is cloned for each item so the closure is 'static + Send.
221    let format_for_stream = Arc::clone(&format);
222    let sse_stream = event_stream.filter_map(move |result| {
223        let format = Arc::clone(&format_for_stream);
224        async move {
225            match result {
226                Ok(envelope) => {
227                    let frames = format.format(&envelope);
228                    if frames.is_empty() {
229                        None
230                    } else {
231                        let events: Vec<SseEvent> =
232                            frames.into_iter().map(frame_to_sse_event).collect();
233                        Some(events)
234                    }
235                }
236                Err(e) => {
237                    tracing::warn!(error = %e, "error reading event from journal stream");
238                    None
239                }
240            }
241        }
242    });
243
244    // Flatten: each envelope may produce multiple SSE events
245    let flat_stream = sse_stream
246        .flat_map(futures::stream::iter)
247        .map(Ok::<_, Infallible>);
248
249    let combined_stream: futures::stream::BoxStream<'static, Result<SseEvent, Infallible>> =
250        if let Some(done_frame) = format.done_frame() {
251            flat_stream
252                .chain(futures::stream::once(async move {
253                    Ok::<_, Infallible>(frame_to_sse_event(done_frame))
254                }))
255                .boxed()
256        } else {
257            flat_stream.boxed()
258        };
259
260    let sse = Sse::new(combined_stream).keep_alive(
261        KeepAlive::new()
262            .interval(Duration::from_secs(15))
263            .text("ping"),
264    );
265    let mut response = sse.into_response();
266    for (name, value) in format.extra_headers() {
267        if let (Ok(header_name), Ok(header_value)) = (
268            HeaderName::from_bytes(name.as_bytes()),
269            HeaderValue::from_str(&value),
270        ) {
271            response.headers_mut().insert(header_name, header_value);
272        }
273    }
274
275    Ok(response)
276}