use axum::{
body::Body,
extract::{RawQuery, State},
http::{header, Response, StatusCode},
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse,
},
};
use futures_util::stream::Stream;
use pubky_common::crypto::PublicKey;
use serde::Deserialize;
use std::{collections::HashMap, convert::Infallible, time::Instant};
use url::form_urlencoded;
use crate::{
client_server::{extractors::ListQueryParams, AppState},
metrics_server::routes::metrics::Metrics,
persistence::{
files::events::{EventCursor, EventEntity, EventsService, MAX_EVENT_STREAM_USERS},
sql::SqlDb,
},
shared::{webdav::WebDavPath, HttpError, HttpResult},
};
#[derive(Debug, thiserror::Error)]
pub enum EventStreamError {
#[error("User not found")]
UserNotFound,
#[error("{0}")]
InvalidParameter(String),
#[error("Database error: {0}")]
DatabaseError(#[from] sqlx::Error),
#[error("Invalid public key: {0}")]
InvalidPublicKey(String),
}
impl From<EventStreamError> for HttpError {
fn from(error: EventStreamError) -> Self {
match error {
EventStreamError::UserNotFound => HttpError::not_found(),
EventStreamError::DatabaseError(e) => HttpError::from(e),
_ => HttpError::bad_request(error.to_string()),
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(try_from = "RawEventStreamQueryParams")]
pub struct EventStreamQueryParams {
pub limit: Option<u16>,
pub reverse: bool,
pub live: bool,
pub user_cursors: Vec<(PublicKey, Option<String>)>,
pub path: Option<WebDavPath>,
}
#[derive(Debug, Deserialize)]
struct RawEventStreamQueryParams {
#[serde(default)]
user: Vec<String>,
limit: Option<u16>,
#[serde(default)]
reverse: bool,
#[serde(default)]
live: bool,
path: Option<String>,
}
fn parse_query_params(query: &str) -> Result<EventStreamQueryParams, EventStreamError> {
let mut users = Vec::new();
let mut limit = None;
let mut reverse = false;
let mut live = false;
let mut path = None;
for (key, value) in form_urlencoded::parse(query.as_bytes()) {
match key.as_ref() {
"user" => users.push(value.to_string()),
"limit" => {
limit = Some(value.parse::<u16>().map_err(|_| {
EventStreamError::InvalidParameter(format!("Invalid limit: {}", value))
})?);
}
"reverse" => {
reverse = value == "true" || value == "1";
}
"live" => {
live = value == "true" || value == "1";
}
"path" => {
if !value.is_empty() {
path = Some(value.to_string());
}
}
_ => {} }
}
let raw = RawEventStreamQueryParams {
user: users,
limit,
reverse,
live,
path,
};
raw.try_into()
}
impl TryFrom<RawEventStreamQueryParams> for EventStreamQueryParams {
type Error = EventStreamError;
fn try_from(raw: RawEventStreamQueryParams) -> Result<Self, Self::Error> {
if raw.live && raw.reverse {
return Err(EventStreamError::InvalidParameter(
"Cannot use live mode with reverse ordering".to_string(),
));
}
let mut user_cursors = Vec::new();
for value in raw.user {
if value.is_empty() {
continue;
}
let (pubkey_str, cursor_str) = if let Some((pubkey, cursor)) = value.split_once(':') {
(pubkey, Some(cursor))
} else {
(value.as_str(), None)
};
if PublicKey::is_pubky_prefixed(pubkey_str) {
return Err(EventStreamError::InvalidPublicKey(pubkey_str.to_string()));
}
let pubkey = PublicKey::try_from_z32(pubkey_str)
.map_err(|_| EventStreamError::InvalidPublicKey(pubkey_str.to_string()))?;
user_cursors.push((pubkey, cursor_str.map(|s| s.to_string())));
}
if user_cursors.is_empty() {
return Err(EventStreamError::InvalidParameter(
"user parameter is required".to_string(),
));
}
if user_cursors.len() > MAX_EVENT_STREAM_USERS {
return Err(EventStreamError::InvalidParameter(format!(
"Too many users. Maximum allowed: {}",
MAX_EVENT_STREAM_USERS
)));
}
let path = if let Some(p) = raw.path {
if p.is_empty() {
None
} else {
let normalized_path = if p.starts_with('/') {
p
} else {
format!("/{}", p)
};
Some(WebDavPath::new(&normalized_path).map_err(|_| {
EventStreamError::InvalidParameter(format!("Invalid path: {}", normalized_path))
})?)
}
} else {
None
};
Ok(EventStreamQueryParams {
limit: raw.limit,
reverse: raw.reverse,
live: raw.live,
user_cursors,
path,
})
}
}
fn formatted_event_path(entity: &EventEntity) -> String {
format!("pubky://{}{}", entity.user_pubkey.z32(), entity.path.path())
}
fn event_to_sse_data(entity: &EventEntity) -> String {
let path = formatted_event_path(entity);
let cursor_line = format!("cursor: {}", entity.cursor());
let mut lines = vec![path, cursor_line];
if let Some(hash) = entity.event_type.content_hash() {
let hash_base64 =
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, hash.as_bytes());
lines.push(format!("content_hash: {}", hash_base64));
}
lines.join("\n")
}
pub async fn feed(
State(state): State<AppState>,
params: ListQueryParams,
) -> HttpResult<impl IntoResponse> {
let cursor = match params.cursor {
Some(cursor) => cursor,
None => "0".to_string(),
};
let cursor = match state
.events_service
.parse_cursor(cursor.as_str(), &mut state.sql_db.pool().into())
.await
{
Ok(cursor) => cursor,
Err(_e) => return Err(HttpError::bad_request("Invalid cursor")),
};
let query_start = Instant::now();
let events = state
.events_service
.get_by_cursor(Some(cursor), params.limit, &mut state.sql_db.pool().into())
.await?;
state
.metrics
.record_events_db_query(query_start.elapsed().as_millis());
let mut result = events
.iter()
.map(|event| format!("{} {}", event.event_type, formatted_event_path(event)))
.collect::<Vec<String>>();
let next_cursor = events.last().map(|event| event.id.to_string());
if let Some(next_cursor) = next_cursor {
result.push(format!("cursor: {}", next_cursor));
}
Ok(Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/plain")
.body(Body::from(result.join("\n")))
.unwrap())
}
pub async fn feed_stream(
State(state): State<AppState>,
raw_query: RawQuery,
) -> HttpResult<Sse<impl Stream<Item = Result<Event, Infallible>>>> {
let params =
parse_query_params(raw_query.0.as_deref().unwrap_or("")).map_err(HttpError::from)?;
let mut user_cursor_map =
resolve_user_cursors(¶ms.user_cursors, &state.events_service, &state.sql_db)
.await
.map_err(HttpError::from)?;
let mut total_sent: usize = 0;
let stream = async_stream::stream! {
let _guard = ConnectionGuard::new(state.metrics.clone());
let mut rx = state.events_service.subscribe();
loop {
while rx.try_recv().is_ok() {}
let current_user_cursors: Vec<(i32, Option<EventCursor>)> =
user_cursor_map.iter().map(|(k, cursor)| (*k, *cursor)).collect();
let query_start = Instant::now();
let events = match state
.events_service
.get_by_user_cursors(
current_user_cursors,
params.reverse,
params.path.as_ref().map(|p| p.as_str()),
&mut state.sql_db.pool().into(),
)
.await
{
Ok(events) => events,
Err(e) => {
tracing::error!("Database error while fetching events: {}", e);
break;
}
};
state.metrics.record_event_stream_db_query(query_start.elapsed().as_millis());
let event_count = events.len();
for event in events {
user_cursor_map.insert(event.user_id, Some(event.cursor()));
yield Ok(Event::default()
.event(event.event_type.to_string())
.data(event_to_sse_data(&event)));
total_sent += 1;
if let Some(max) = params.limit {
if total_sent >= max as usize {
return;
}
}
}
if event_count == 0 {
if !params.live {
return;
}
break;
}
}
if params.live {
let user_ids: Vec<i32> = user_cursor_map.keys().copied().collect();
let half_capacity = state.events_service.channel_capacity() / 2;
loop {
match rx.recv().await {
Ok(event) => {
if rx.len() >= half_capacity {
state.metrics.record_broadcast_half_full();
}
if !should_include_live_event(&event, &user_ids, &user_cursor_map, params.path.as_ref()) {
continue;
}
user_cursor_map.insert(event.user_id, Some(event.cursor()));
yield Ok(Event::default()
.event(event.event_type.to_string())
.data(event_to_sse_data(&event)));
total_sent += 1;
if let Some(max) = params.limit {
if total_sent >= max as usize {
return;
}
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => {
state.metrics.record_broadcast_lagged();
tracing::warn!(
"Slow client detected: broadcast channel lagged by {} events. Closing connection.",
skipped
);
return;
}
Err(_) => break, }
}
}
};
Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
}
async fn resolve_user_cursors(
user_cursors: &[(PublicKey, Option<String>)],
events_service: &EventsService,
sql_db: &SqlDb,
) -> Result<HashMap<i32, Option<EventCursor>>, EventStreamError> {
use crate::persistence::sql::user::UserRepository;
let mut user_cursor_map: HashMap<i32, Option<EventCursor>> = HashMap::new();
for (user_pubkey, cursor_str_opt) in user_cursors {
let user_id = UserRepository::get_id(user_pubkey, &mut sql_db.pool().into())
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => EventStreamError::UserNotFound,
e => EventStreamError::DatabaseError(e),
})?;
let cursor = if let Some(cursor_str) = cursor_str_opt {
Some(
events_service
.parse_cursor(cursor_str, &mut sql_db.pool().into())
.await
.map_err(|_| {
EventStreamError::InvalidParameter(format!(
"Invalid cursor: {}",
cursor_str
))
})?,
)
} else {
None
};
user_cursor_map.insert(user_id, cursor);
}
Ok(user_cursor_map)
}
fn should_include_live_event(
event: &EventEntity,
user_ids: &[i32],
user_cursor_map: &HashMap<i32, Option<EventCursor>>,
path_filter: Option<&WebDavPath>,
) -> bool {
if !user_ids.contains(&event.user_id) {
return false;
}
if let Some(Some(cursor)) = user_cursor_map.get(&event.user_id) {
if event.cursor() <= *cursor {
return false;
}
}
if let Some(path) = path_filter {
let path_suffix = event.path.path().as_str();
if !path_suffix.starts_with(path.as_str()) {
return false;
}
}
true
}
struct ConnectionGuard {
metrics: Metrics,
start: Instant,
}
impl ConnectionGuard {
fn new(metrics: Metrics) -> Self {
metrics.increment_active_connections();
Self {
metrics,
start: Instant::now(),
}
}
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.metrics.decrement_active_connections();
self.metrics
.record_connection_closed(self.start.elapsed().as_millis());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn connection_guard_drops_on_early_return() {
let metrics = Metrics::new().expect("Failed to create metrics");
fn early_return_fn(metrics: Metrics) -> Result<(), &'static str> {
let _guard = ConnectionGuard::new(metrics.clone());
return Err("early exit");
#[allow(unreachable_code)]
{
Ok(())
}
}
let result = early_return_fn(metrics.clone());
assert!(result.is_err(), "Should have returned early");
let output = metrics.render().expect("Failed to render metrics");
assert!(
output.contains("event_stream_active_connections") && output.contains("} 0"),
"Should have 0 active connections after early return: {}",
output
);
assert!(
output.contains("event_stream_connection_duration_ms_count"),
"Should have recorded connection duration: {}",
output
);
}
#[tokio::test]
async fn connection_guard_concurrent() {
let metrics = Metrics::new().expect("Failed to create metrics");
let handles: Vec<_> = (0..5)
.map(|i| {
let metrics_clone = metrics.clone();
tokio::spawn(async move {
let _guard = ConnectionGuard::new(metrics_clone);
tokio::time::sleep(tokio::time::Duration::from_millis(10 * i)).await;
})
})
.collect();
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
let output = metrics.render().expect("Failed to render metrics");
assert!(
output.contains("event_stream_active_connections"),
"Should have active connections metric: {}",
output
);
for handle in handles {
handle.await.unwrap();
}
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let output = metrics.render().expect("Failed to render metrics");
assert!(
output.contains("event_stream_active_connections") && output.contains("} 0"),
"Should have 0 active connections after all concurrent guards dropped: {}",
output
);
assert!(
output.contains("event_stream_connection_duration_ms_count") && output.contains("} 5"),
"Should have recorded 5 connection durations: {}",
output
);
}
}