use axum::{
extract::State,
response::sse::{Event, KeepAlive, Sse},
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::{
convert::Infallible,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
task::{Context, Poll},
time::Duration,
};
use tokio::sync::broadcast::{self, Receiver, Sender};
use tokio::time::{interval, Interval};
use tracing::{debug, info, instrument, warn};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "type", content = "data")]
pub enum FeedEvent {
#[serde(rename = "capture_received")]
CaptureReceived(CaptureReceivedData),
#[serde(rename = "processing_complete")]
ProcessingComplete(ProcessingCompleteData),
#[serde(rename = "error")]
Error(ErrorData),
#[serde(rename = "heartbeat")]
Heartbeat(HeartbeatData),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct CaptureReceivedData {
pub capture_id: String,
pub url: String,
pub timestamp: u64,
pub capture_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ProcessingCompleteData {
pub capture_id: String,
pub duration_ms: u64,
pub size_bytes: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub summary: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ErrorData {
#[serde(skip_serializing_if = "Option::is_none")]
pub capture_id: Option<String>,
pub code: String,
pub message: String,
pub recoverable: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct HeartbeatData {
pub timestamp: u64,
pub connected_clients: u64,
pub uptime_seconds: u64,
}
impl FeedEvent {
pub fn event_type(&self) -> &'static str {
match self {
FeedEvent::CaptureReceived(_) => "capture_received",
FeedEvent::ProcessingComplete(_) => "processing_complete",
FeedEvent::Error(_) => "error",
FeedEvent::Heartbeat(_) => "heartbeat",
}
}
pub fn to_sse_event(&self) -> Result<Event, serde_json::Error> {
let data = serde_json::to_string(self)?;
Ok(Event::default().event(self.event_type()).data(data))
}
}
pub struct FeedState {
sender: Sender<FeedEvent>,
connected_clients: AtomicU64,
start_time: std::time::Instant,
capacity: usize,
}
impl FeedState {
pub fn new(capacity: usize) -> Self {
let (sender, _) = broadcast::channel(capacity);
Self {
sender,
connected_clients: AtomicU64::new(0),
start_time: std::time::Instant::now(),
capacity,
}
}
pub fn subscribe(&self) -> Receiver<FeedEvent> {
self.sender.subscribe()
}
#[instrument(skip(self, event), fields(event_type = event.event_type()))]
pub fn publish(&self, event: FeedEvent) -> usize {
match self.sender.send(event) {
Ok(count) => {
debug!("Published event to {} clients", count);
count
}
Err(_) => {
debug!("No clients connected, event dropped");
0
}
}
}
pub fn publish_capture_received(
&self,
capture_id: impl Into<String>,
url: impl Into<String>,
capture_type: impl Into<String>,
) -> usize {
self.publish(FeedEvent::CaptureReceived(CaptureReceivedData {
capture_id: capture_id.into(),
url: url.into(),
timestamp: current_timestamp_ms(),
capture_type: capture_type.into(),
}))
}
pub fn publish_processing_complete(
&self,
capture_id: impl Into<String>,
duration_ms: u64,
size_bytes: u64,
summary: Option<String>,
) -> usize {
self.publish(FeedEvent::ProcessingComplete(ProcessingCompleteData {
capture_id: capture_id.into(),
duration_ms,
size_bytes,
summary,
}))
}
pub fn publish_error(
&self,
capture_id: Option<String>,
code: impl Into<String>,
message: impl Into<String>,
recoverable: bool,
) -> usize {
self.publish(FeedEvent::Error(ErrorData {
capture_id,
code: code.into(),
message: message.into(),
recoverable,
}))
}
pub fn connected_clients(&self) -> u64 {
self.connected_clients.load(Ordering::Relaxed)
}
pub fn uptime_seconds(&self) -> u64 {
self.start_time.elapsed().as_secs()
}
pub fn capacity(&self) -> usize {
self.capacity
}
fn client_connected(&self) -> u64 {
let count = self.connected_clients.fetch_add(1, Ordering::Relaxed) + 1;
info!("Client connected, total: {}", count);
count
}
fn client_disconnected(&self) -> u64 {
let count = self.connected_clients.fetch_sub(1, Ordering::Relaxed) - 1;
info!("Client disconnected, total: {}", count);
count
}
}
impl Default for FeedState {
fn default() -> Self {
Self::new(1024)
}
}
struct ClientGuard {
state: Arc<FeedState>,
}
impl ClientGuard {
fn new(state: Arc<FeedState>) -> Self {
state.client_connected();
Self { state }
}
}
impl Drop for ClientGuard {
fn drop(&mut self) {
self.state.client_disconnected();
}
}
pub struct FeedStream {
receiver: Receiver<FeedEvent>,
heartbeat_interval: Interval,
state: Arc<FeedState>,
_guard: ClientGuard,
#[allow(dead_code)]
stream_id: u64,
}
impl FeedStream {
pub fn new(state: Arc<FeedState>, heartbeat_interval_secs: u64) -> Self {
static STREAM_COUNTER: AtomicU64 = AtomicU64::new(0);
let stream_id = STREAM_COUNTER.fetch_add(1, Ordering::Relaxed);
let receiver = state.subscribe();
let heartbeat_interval = interval(Duration::from_secs(heartbeat_interval_secs));
let guard = ClientGuard::new(Arc::clone(&state));
debug!("Created FeedStream {}", stream_id);
Self {
receiver,
heartbeat_interval,
state,
_guard: guard,
stream_id,
}
}
fn generate_heartbeat(&self) -> FeedEvent {
FeedEvent::Heartbeat(HeartbeatData {
timestamp: current_timestamp_ms(),
connected_clients: self.state.connected_clients(),
uptime_seconds: self.state.uptime_seconds(),
})
}
}
impl Stream for FeedStream {
type Item = Result<Event, Infallible>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.heartbeat_interval.poll_tick(cx).is_ready() {
let heartbeat = self.generate_heartbeat();
match heartbeat.to_sse_event() {
Ok(event) => return Poll::Ready(Some(Ok(event))),
Err(e) => {
warn!("Failed to serialize heartbeat: {}", e);
}
}
}
match self.receiver.try_recv() {
Ok(feed_event) => match feed_event.to_sse_event() {
Ok(event) => Poll::Ready(Some(Ok(event))),
Err(e) => {
warn!("Failed to serialize event: {}", e);
cx.waker().wake_by_ref();
Poll::Pending
}
},
Err(broadcast::error::TryRecvError::Empty) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(broadcast::error::TryRecvError::Lagged(count)) => {
warn!("Client lagged behind by {} events", count);
cx.waker().wake_by_ref();
Poll::Pending
}
Err(broadcast::error::TryRecvError::Closed) => {
debug!("Broadcast channel closed, ending stream");
Poll::Ready(None)
}
}
}
}
#[instrument(skip(state))]
pub async fn feed_handler(
State(state): State<Arc<FeedState>>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
info!("New SSE client connected to /feed");
let stream = FeedStream::new(state, 30);
Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
}
#[instrument(skip(state))]
pub async fn feed_handler_with_interval(
State(state): State<Arc<FeedState>>,
heartbeat_secs: u64,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
info!(
"New SSE client connected to /feed (heartbeat: {}s)",
heartbeat_secs
);
let stream = FeedStream::new(state, heartbeat_secs);
Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(heartbeat_secs / 2))
.text("keep-alive"),
)
}
fn current_timestamp_ms() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
pub fn build_feed_router(state: Arc<FeedState>) -> axum::Router {
use axum::routing::get;
axum::Router::new()
.route("/feed", get(feed_handler))
.with_state(state)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::sleep;
#[test]
fn test_feed_event_serialization() {
let event = FeedEvent::CaptureReceived(CaptureReceivedData {
capture_id: "test-123".to_string(),
url: "https://example.com".to_string(),
timestamp: 1704067200000,
capture_type: "screenshot".to_string(),
});
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("capture_received"));
assert!(json.contains("test-123"));
assert!(json.contains("https://example.com"));
}
#[test]
fn test_feed_event_deserialization() {
let json = r#"{"type":"capture_received","data":{"capture_id":"abc","url":"https://test.com","timestamp":1000,"capture_type":"pdf"}}"#;
let event: FeedEvent = serde_json::from_str(json).unwrap();
match event {
FeedEvent::CaptureReceived(data) => {
assert_eq!(data.capture_id, "abc");
assert_eq!(data.url, "https://test.com");
assert_eq!(data.capture_type, "pdf");
}
_ => panic!("Expected CaptureReceived"),
}
}
#[test]
fn test_feed_event_type() {
assert_eq!(
FeedEvent::CaptureReceived(CaptureReceivedData {
capture_id: String::new(),
url: String::new(),
timestamp: 0,
capture_type: String::new(),
})
.event_type(),
"capture_received"
);
assert_eq!(
FeedEvent::ProcessingComplete(ProcessingCompleteData {
capture_id: String::new(),
duration_ms: 0,
size_bytes: 0,
summary: None,
})
.event_type(),
"processing_complete"
);
assert_eq!(
FeedEvent::Error(ErrorData {
capture_id: None,
code: String::new(),
message: String::new(),
recoverable: false,
})
.event_type(),
"error"
);
assert_eq!(
FeedEvent::Heartbeat(HeartbeatData {
timestamp: 0,
connected_clients: 0,
uptime_seconds: 0,
})
.event_type(),
"heartbeat"
);
}
#[test]
fn test_feed_state_new() {
let state = FeedState::new(100);
assert_eq!(state.capacity(), 100);
assert_eq!(state.connected_clients(), 0);
}
#[tokio::test]
async fn test_feed_state_publish_no_subscribers() {
let state = FeedState::new(10);
let count = state.publish_capture_received("test", "https://test.com", "screenshot");
assert_eq!(count, 0); }
#[tokio::test]
async fn test_feed_state_publish_with_subscriber() {
let state = Arc::new(FeedState::new(10));
let mut receiver = state.subscribe();
let count = state.publish_capture_received("test", "https://test.com", "screenshot");
assert_eq!(count, 1);
let event = receiver.recv().await.unwrap();
match event {
FeedEvent::CaptureReceived(data) => {
assert_eq!(data.capture_id, "test");
assert_eq!(data.url, "https://test.com");
}
_ => panic!("Expected CaptureReceived"),
}
}
#[tokio::test]
async fn test_feed_state_client_tracking() {
let state = Arc::new(FeedState::new(10));
assert_eq!(state.connected_clients(), 0);
{
let _guard = ClientGuard::new(Arc::clone(&state));
assert_eq!(state.connected_clients(), 1);
{
let _guard2 = ClientGuard::new(Arc::clone(&state));
assert_eq!(state.connected_clients(), 2);
}
assert_eq!(state.connected_clients(), 1);
}
assert_eq!(state.connected_clients(), 0);
}
#[tokio::test]
async fn test_feed_state_uptime() {
let state = FeedState::new(10);
let uptime1 = state.uptime_seconds();
sleep(Duration::from_millis(100)).await;
let uptime2 = state.uptime_seconds();
assert!(uptime2 >= uptime1);
}
#[test]
fn test_error_event() {
let state = FeedState::new(10);
let _receiver = state.subscribe();
let count = state.publish_error(
Some("capture-123".to_string()),
"E_TIMEOUT",
"Operation timed out",
true,
);
assert_eq!(count, 1);
}
#[test]
fn test_processing_complete_event() {
let state = FeedState::new(10);
let _receiver = state.subscribe();
let count = state.publish_processing_complete(
"capture-456",
150,
1024,
Some("Page title extracted".to_string()),
);
assert_eq!(count, 1);
}
#[test]
fn test_to_sse_event() {
let event = FeedEvent::Heartbeat(HeartbeatData {
timestamp: 1704067200000,
connected_clients: 5,
uptime_seconds: 3600,
});
let sse_event = event.to_sse_event().unwrap();
assert!(format!("{:?}", sse_event).contains("heartbeat"));
}
#[tokio::test]
async fn test_feed_stream_creation() {
let state = Arc::new(FeedState::new(10));
let _stream = FeedStream::new(Arc::clone(&state), 30);
assert_eq!(state.connected_clients(), 1);
}
#[tokio::test]
async fn test_multiple_subscribers() {
let state = Arc::new(FeedState::new(10));
let mut rx1 = state.subscribe();
let mut rx2 = state.subscribe();
let mut rx3 = state.subscribe();
let count = state.publish_capture_received("multi-test", "https://example.com", "html");
assert_eq!(count, 3);
assert!(rx1.recv().await.is_ok());
assert!(rx2.recv().await.is_ok());
assert!(rx3.recv().await.is_ok());
}
#[test]
fn test_current_timestamp_ms() {
let ts = current_timestamp_ms();
assert!(ts > 1704067200000);
}
}