lago_api/routes/
events.rs1use std::convert::Infallible;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::{Path, Query, State};
6use axum::http::HeaderMap;
7use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
8use futures::StreamExt;
9use serde::Deserialize;
10use tracing::debug;
11
12use lago_core::id::{BranchId, SeqNo, SessionId};
13
14use crate::error::ApiError;
15use crate::sse::format::{SseFormat, SseFrame};
16use crate::sse::{anthropic, lago, openai, vercel};
17use crate::state::AppState;
18
19#[derive(Deserialize, Default)]
22pub struct EventStreamQuery {
23 #[serde(default = "default_format")]
25 pub format: String,
26 pub after_seq: Option<SeqNo>,
28 #[serde(default = "default_branch")]
30 pub branch: String,
31}
32
33fn default_format() -> String {
34 "lago".to_string()
35}
36
37fn default_branch() -> String {
38 "main".to_string()
39}
40
41fn resolve_format(name: &str) -> Result<Arc<dyn SseFormat>, ApiError> {
43 match name {
44 "openai" => Ok(Arc::new(openai::OpenAiFormat)),
45 "anthropic" => Ok(Arc::new(anthropic::AnthropicFormat)),
46 "vercel" => Ok(Arc::new(vercel::VercelFormat)),
47 "lago" | "" => Ok(Arc::new(lago::LagoFormat)),
48 other => Err(ApiError::BadRequest(format!(
49 "unknown format: {other}. Supported: openai, anthropic, vercel, lago"
50 ))),
51 }
52}
53
54fn parse_last_event_id(headers: &HeaderMap) -> Option<SeqNo> {
56 headers
57 .get("Last-Event-ID")
58 .and_then(|v| v.to_str().ok())
59 .and_then(|s| s.parse::<SeqNo>().ok())
60}
61
62fn frame_to_sse_event(frame: SseFrame) -> SseEvent {
64 let mut event = SseEvent::default().data(frame.data);
65 if let Some(name) = frame.event {
66 event = event.event(name);
67 }
68 if let Some(id) = frame.id {
69 event = event.id(id);
70 }
71 event
72}
73
74pub async fn stream_events(
80 State(state): State<Arc<AppState>>,
81 Path(session_id): Path<String>,
82 Query(query): Query<EventStreamQuery>,
83 headers: HeaderMap,
84) -> Result<Sse<impl futures::Stream<Item = Result<SseEvent, Infallible>>>, ApiError> {
85 let session_id = SessionId::from_string(session_id);
86 let branch_id = BranchId::from_string(query.branch.clone());
87 let format = resolve_format(&query.format)?;
88
89 let after_seq = parse_last_event_id(&headers)
93 .or(query.after_seq)
94 .unwrap_or(0);
95
96 debug!(
97 session = %session_id,
98 branch = %branch_id,
99 after_seq = after_seq,
100 format = format.name(),
101 "starting SSE event stream"
102 );
103
104 state
106 .journal
107 .get_session(&session_id)
108 .await?
109 .ok_or_else(|| ApiError::NotFound(format!("session not found: {session_id}")))?;
110
111 let event_stream = state
113 .journal
114 .stream(session_id, branch_id, after_seq)
115 .await?;
116
117 let sse_stream = event_stream.filter_map(move |result| {
120 let format = Arc::clone(&format);
121 async move {
122 match result {
123 Ok(envelope) => {
124 let frames = format.format(&envelope);
125 if frames.is_empty() {
126 None
127 } else {
128 let events: Vec<SseEvent> =
129 frames.into_iter().map(frame_to_sse_event).collect();
130 Some(events)
131 }
132 }
133 Err(e) => {
134 tracing::warn!(error = %e, "error reading event from journal stream");
135 None
136 }
137 }
138 }
139 });
140
141 let flat_stream = sse_stream
143 .flat_map(futures::stream::iter)
144 .map(Ok::<_, Infallible>);
145
146 Ok(Sse::new(flat_stream).keep_alive(
147 KeepAlive::new()
148 .interval(Duration::from_secs(15))
149 .text("ping"),
150 ))
151}