reasonkit_web/handlers/
feed.rs

1//! Server-Sent Events (SSE) feed handler for real-time event streaming
2//!
3//! This module provides a `/feed` endpoint that streams events to connected clients
4//! using the SSE protocol. Events include capture notifications, processing status,
5//! errors, and heartbeat keep-alives.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Capture Event ──▶ Broadcast Channel ──▶ SSE Stream ──▶ Connected Clients
11//!                          │
12//!                          ▼
13//!                   Multiple Subscribers
14//! ```
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use reasonkit_web::handlers::feed::{FeedState, feed_handler};
20//! use axum::{Router, routing::get};
21//! use std::sync::Arc;
22//!
23//! let state = Arc::new(FeedState::new(1024));
24//! let app = Router::new()
25//!     .route("/feed", get(feed_handler))
26//!     .with_state(state);
27//! ```
28
29use axum::{
30    extract::State,
31    response::sse::{Event, KeepAlive, Sse},
32};
33use futures::stream::Stream;
34use serde::{Deserialize, Serialize};
35use std::{
36    convert::Infallible,
37    pin::Pin,
38    sync::{
39        atomic::{AtomicU64, Ordering},
40        Arc,
41    },
42    task::{Context, Poll},
43    time::Duration,
44};
45use tokio::sync::broadcast::{self, Receiver, Sender};
46use tokio::time::{interval, Interval};
47use tracing::{debug, info, instrument, warn};
48
49// ============================================================================
50// Feed Event Types
51// ============================================================================
52
53/// Types of events that can be sent through the SSE feed
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
55#[serde(tag = "type", content = "data")]
56pub enum FeedEvent {
57    /// A new capture has been received and is being processed
58    #[serde(rename = "capture_received")]
59    CaptureReceived(CaptureReceivedData),
60
61    /// Capture processing has completed successfully
62    #[serde(rename = "processing_complete")]
63    ProcessingComplete(ProcessingCompleteData),
64
65    /// An error occurred during processing
66    #[serde(rename = "error")]
67    Error(ErrorData),
68
69    /// Keep-alive heartbeat to maintain connection
70    #[serde(rename = "heartbeat")]
71    Heartbeat(HeartbeatData),
72}
73
74/// Data for capture_received event
75#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
76pub struct CaptureReceivedData {
77    /// Unique capture ID
78    pub capture_id: String,
79    /// URL of the captured page
80    pub url: String,
81    /// Timestamp when capture was received (Unix ms)
82    pub timestamp: u64,
83    /// Capture type (screenshot, pdf, html, etc.)
84    pub capture_type: String,
85}
86
87/// Data for processing_complete event
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
89pub struct ProcessingCompleteData {
90    /// Capture ID that was processed
91    pub capture_id: String,
92    /// Processing duration in milliseconds
93    pub duration_ms: u64,
94    /// Size of processed content in bytes
95    pub size_bytes: u64,
96    /// Summary of extracted content (if any)
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub summary: Option<String>,
99}
100
101/// Data for error event
102#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
103pub struct ErrorData {
104    /// Capture ID associated with the error (if any)
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub capture_id: Option<String>,
107    /// Error code
108    pub code: String,
109    /// Human-readable error message
110    pub message: String,
111    /// Whether the error is recoverable
112    pub recoverable: bool,
113}
114
115/// Data for heartbeat event
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
117pub struct HeartbeatData {
118    /// Server timestamp (Unix ms)
119    pub timestamp: u64,
120    /// Number of currently connected clients
121    pub connected_clients: u64,
122    /// Server uptime in seconds
123    pub uptime_seconds: u64,
124}
125
126impl FeedEvent {
127    /// Get the event type name for SSE
128    pub fn event_type(&self) -> &'static str {
129        match self {
130            FeedEvent::CaptureReceived(_) => "capture_received",
131            FeedEvent::ProcessingComplete(_) => "processing_complete",
132            FeedEvent::Error(_) => "error",
133            FeedEvent::Heartbeat(_) => "heartbeat",
134        }
135    }
136
137    /// Convert to SSE Event
138    pub fn to_sse_event(&self) -> Result<Event, serde_json::Error> {
139        let data = serde_json::to_string(self)?;
140        Ok(Event::default().event(self.event_type()).data(data))
141    }
142}
143
144// ============================================================================
145// Feed State (Shared Application State)
146// ============================================================================
147
148/// Shared state for the feed system
149///
150/// This struct manages the broadcast channel and client tracking.
151/// It should be wrapped in `Arc` and shared across handlers.
152pub struct FeedState {
153    /// Broadcast sender for publishing events
154    sender: Sender<FeedEvent>,
155    /// Counter for connected clients
156    connected_clients: AtomicU64,
157    /// Server start time (for uptime calculation)
158    start_time: std::time::Instant,
159    /// Maximum channel capacity
160    capacity: usize,
161}
162
163impl FeedState {
164    /// Create a new FeedState with the specified channel capacity
165    ///
166    /// # Arguments
167    ///
168    /// * `capacity` - Maximum number of events to buffer in the channel.
169    ///   Clients that fall behind will miss events.
170    pub fn new(capacity: usize) -> Self {
171        let (sender, _) = broadcast::channel(capacity);
172        Self {
173            sender,
174            connected_clients: AtomicU64::new(0),
175            start_time: std::time::Instant::now(),
176            capacity,
177        }
178    }
179
180    /// Subscribe to the event feed
181    ///
182    /// Returns a receiver that will receive all events published after subscription.
183    pub fn subscribe(&self) -> Receiver<FeedEvent> {
184        self.sender.subscribe()
185    }
186
187    /// Publish an event to all connected clients
188    ///
189    /// Returns the number of clients that received the event.
190    /// Returns 0 if no clients are connected.
191    #[instrument(skip(self, event), fields(event_type = event.event_type()))]
192    pub fn publish(&self, event: FeedEvent) -> usize {
193        match self.sender.send(event) {
194            Ok(count) => {
195                debug!("Published event to {} clients", count);
196                count
197            }
198            Err(_) => {
199                debug!("No clients connected, event dropped");
200                0
201            }
202        }
203    }
204
205    /// Publish a capture_received event
206    pub fn publish_capture_received(
207        &self,
208        capture_id: impl Into<String>,
209        url: impl Into<String>,
210        capture_type: impl Into<String>,
211    ) -> usize {
212        self.publish(FeedEvent::CaptureReceived(CaptureReceivedData {
213            capture_id: capture_id.into(),
214            url: url.into(),
215            timestamp: current_timestamp_ms(),
216            capture_type: capture_type.into(),
217        }))
218    }
219
220    /// Publish a processing_complete event
221    pub fn publish_processing_complete(
222        &self,
223        capture_id: impl Into<String>,
224        duration_ms: u64,
225        size_bytes: u64,
226        summary: Option<String>,
227    ) -> usize {
228        self.publish(FeedEvent::ProcessingComplete(ProcessingCompleteData {
229            capture_id: capture_id.into(),
230            duration_ms,
231            size_bytes,
232            summary,
233        }))
234    }
235
236    /// Publish an error event
237    pub fn publish_error(
238        &self,
239        capture_id: Option<String>,
240        code: impl Into<String>,
241        message: impl Into<String>,
242        recoverable: bool,
243    ) -> usize {
244        self.publish(FeedEvent::Error(ErrorData {
245            capture_id,
246            code: code.into(),
247            message: message.into(),
248            recoverable,
249        }))
250    }
251
252    /// Get the number of connected clients
253    pub fn connected_clients(&self) -> u64 {
254        self.connected_clients.load(Ordering::Relaxed)
255    }
256
257    /// Get the server uptime in seconds
258    pub fn uptime_seconds(&self) -> u64 {
259        self.start_time.elapsed().as_secs()
260    }
261
262    /// Get the channel capacity
263    pub fn capacity(&self) -> usize {
264        self.capacity
265    }
266
267    /// Increment connected client count
268    fn client_connected(&self) -> u64 {
269        let count = self.connected_clients.fetch_add(1, Ordering::Relaxed) + 1;
270        info!("Client connected, total: {}", count);
271        count
272    }
273
274    /// Decrement connected client count
275    fn client_disconnected(&self) -> u64 {
276        let count = self.connected_clients.fetch_sub(1, Ordering::Relaxed) - 1;
277        info!("Client disconnected, total: {}", count);
278        count
279    }
280}
281
282impl Default for FeedState {
283    fn default() -> Self {
284        Self::new(1024)
285    }
286}
287
288// ============================================================================
289// Client Connection Tracking
290// ============================================================================
291
292/// Guard that tracks client connection lifetime
293///
294/// Automatically decrements the connected client count when dropped.
295struct ClientGuard {
296    state: Arc<FeedState>,
297}
298
299impl ClientGuard {
300    fn new(state: Arc<FeedState>) -> Self {
301        state.client_connected();
302        Self { state }
303    }
304}
305
306impl Drop for ClientGuard {
307    fn drop(&mut self) {
308        self.state.client_disconnected();
309    }
310}
311
312// ============================================================================
313// SSE Stream Implementation
314// ============================================================================
315
316/// SSE stream that combines event receiver with heartbeat
317///
318/// This stream yields events from the broadcast channel and also
319/// generates heartbeat events at regular intervals.
320pub struct FeedStream {
321    /// Event receiver
322    receiver: Receiver<FeedEvent>,
323    /// Heartbeat interval timer
324    heartbeat_interval: Interval,
325    /// Reference to state for heartbeat data
326    state: Arc<FeedState>,
327    /// Client guard for connection tracking
328    _guard: ClientGuard,
329    /// Stream ID for debugging
330    #[allow(dead_code)]
331    stream_id: u64,
332}
333
334impl FeedStream {
335    /// Create a new feed stream
336    ///
337    /// # Arguments
338    ///
339    /// * `state` - Shared feed state
340    /// * `heartbeat_interval_secs` - Interval between heartbeats in seconds
341    pub fn new(state: Arc<FeedState>, heartbeat_interval_secs: u64) -> Self {
342        static STREAM_COUNTER: AtomicU64 = AtomicU64::new(0);
343        let stream_id = STREAM_COUNTER.fetch_add(1, Ordering::Relaxed);
344
345        let receiver = state.subscribe();
346        let heartbeat_interval = interval(Duration::from_secs(heartbeat_interval_secs));
347        let guard = ClientGuard::new(Arc::clone(&state));
348
349        debug!("Created FeedStream {}", stream_id);
350
351        Self {
352            receiver,
353            heartbeat_interval,
354            state,
355            _guard: guard,
356            stream_id,
357        }
358    }
359
360    /// Generate a heartbeat event
361    fn generate_heartbeat(&self) -> FeedEvent {
362        FeedEvent::Heartbeat(HeartbeatData {
363            timestamp: current_timestamp_ms(),
364            connected_clients: self.state.connected_clients(),
365            uptime_seconds: self.state.uptime_seconds(),
366        })
367    }
368}
369
370impl Stream for FeedStream {
371    type Item = Result<Event, Infallible>;
372
373    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
374        // First, check for heartbeat
375        if self.heartbeat_interval.poll_tick(cx).is_ready() {
376            let heartbeat = self.generate_heartbeat();
377            match heartbeat.to_sse_event() {
378                Ok(event) => return Poll::Ready(Some(Ok(event))),
379                Err(e) => {
380                    warn!("Failed to serialize heartbeat: {}", e);
381                    // Continue to check for other events
382                }
383            }
384        }
385
386        // Then, check for broadcast events
387        match self.receiver.try_recv() {
388            Ok(feed_event) => match feed_event.to_sse_event() {
389                Ok(event) => Poll::Ready(Some(Ok(event))),
390                Err(e) => {
391                    warn!("Failed to serialize event: {}", e);
392                    // Wake up to try again
393                    cx.waker().wake_by_ref();
394                    Poll::Pending
395                }
396            },
397            Err(broadcast::error::TryRecvError::Empty) => {
398                // No events available, register waker and wait
399                cx.waker().wake_by_ref();
400                Poll::Pending
401            }
402            Err(broadcast::error::TryRecvError::Lagged(count)) => {
403                // Client fell behind, log and continue
404                warn!("Client lagged behind by {} events", count);
405                cx.waker().wake_by_ref();
406                Poll::Pending
407            }
408            Err(broadcast::error::TryRecvError::Closed) => {
409                // Channel closed, end stream
410                debug!("Broadcast channel closed, ending stream");
411                Poll::Ready(None)
412            }
413        }
414    }
415}
416
417// ============================================================================
418// Axum Handler
419// ============================================================================
420
421/// SSE feed handler for the `/feed` endpoint
422///
423/// This handler creates an SSE stream that:
424/// - Sends all published events (captures, processing status, errors)
425/// - Sends heartbeat events every 30 seconds
426/// - Handles client disconnection gracefully
427///
428/// # Example Response
429///
430/// ```text
431/// event: capture_received
432/// data: {"type":"capture_received","data":{"capture_id":"abc123","url":"https://example.com","timestamp":1704067200000,"capture_type":"screenshot"}}
433///
434/// event: heartbeat
435/// data: {"type":"heartbeat","data":{"timestamp":1704067230000,"connected_clients":5,"uptime_seconds":3600}}
436/// ```
437#[instrument(skip(state))]
438pub async fn feed_handler(
439    State(state): State<Arc<FeedState>>,
440) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
441    info!("New SSE client connected to /feed");
442
443    let stream = FeedStream::new(state, 30); // 30 second heartbeat
444
445    Sse::new(stream).keep_alive(
446        KeepAlive::new()
447            .interval(Duration::from_secs(15))
448            .text("keep-alive"),
449    )
450}
451
452/// Alternative handler with configurable heartbeat interval
453#[instrument(skip(state))]
454pub async fn feed_handler_with_interval(
455    State(state): State<Arc<FeedState>>,
456    heartbeat_secs: u64,
457) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
458    info!(
459        "New SSE client connected to /feed (heartbeat: {}s)",
460        heartbeat_secs
461    );
462
463    let stream = FeedStream::new(state, heartbeat_secs);
464
465    Sse::new(stream).keep_alive(
466        KeepAlive::new()
467            .interval(Duration::from_secs(heartbeat_secs / 2))
468            .text("keep-alive"),
469    )
470}
471
472// ============================================================================
473// Utility Functions
474// ============================================================================
475
476/// Get current timestamp in milliseconds
477fn current_timestamp_ms() -> u64 {
478    use std::time::{SystemTime, UNIX_EPOCH};
479    SystemTime::now()
480        .duration_since(UNIX_EPOCH)
481        .unwrap_or_default()
482        .as_millis() as u64
483}
484
485// ============================================================================
486// Router Builder
487// ============================================================================
488
489/// Build a router with the feed endpoint
490///
491/// # Example
492///
493/// ```rust,no_run
494/// use reasonkit_web::handlers::feed::{FeedState, build_feed_router};
495/// use std::sync::Arc;
496///
497/// let state = Arc::new(FeedState::new(1024));
498/// let router = build_feed_router(state);
499/// ```
500pub fn build_feed_router(state: Arc<FeedState>) -> axum::Router {
501    use axum::routing::get;
502
503    axum::Router::new()
504        .route("/feed", get(feed_handler))
505        .with_state(state)
506}
507
508// ============================================================================
509// Tests
510// ============================================================================
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use tokio::time::sleep;
516
517    #[test]
518    fn test_feed_event_serialization() {
519        let event = FeedEvent::CaptureReceived(CaptureReceivedData {
520            capture_id: "test-123".to_string(),
521            url: "https://example.com".to_string(),
522            timestamp: 1704067200000,
523            capture_type: "screenshot".to_string(),
524        });
525
526        let json = serde_json::to_string(&event).unwrap();
527        assert!(json.contains("capture_received"));
528        assert!(json.contains("test-123"));
529        assert!(json.contains("https://example.com"));
530    }
531
532    #[test]
533    fn test_feed_event_deserialization() {
534        let json = r#"{"type":"capture_received","data":{"capture_id":"abc","url":"https://test.com","timestamp":1000,"capture_type":"pdf"}}"#;
535        let event: FeedEvent = serde_json::from_str(json).unwrap();
536
537        match event {
538            FeedEvent::CaptureReceived(data) => {
539                assert_eq!(data.capture_id, "abc");
540                assert_eq!(data.url, "https://test.com");
541                assert_eq!(data.capture_type, "pdf");
542            }
543            _ => panic!("Expected CaptureReceived"),
544        }
545    }
546
547    #[test]
548    fn test_feed_event_type() {
549        assert_eq!(
550            FeedEvent::CaptureReceived(CaptureReceivedData {
551                capture_id: String::new(),
552                url: String::new(),
553                timestamp: 0,
554                capture_type: String::new(),
555            })
556            .event_type(),
557            "capture_received"
558        );
559        assert_eq!(
560            FeedEvent::ProcessingComplete(ProcessingCompleteData {
561                capture_id: String::new(),
562                duration_ms: 0,
563                size_bytes: 0,
564                summary: None,
565            })
566            .event_type(),
567            "processing_complete"
568        );
569        assert_eq!(
570            FeedEvent::Error(ErrorData {
571                capture_id: None,
572                code: String::new(),
573                message: String::new(),
574                recoverable: false,
575            })
576            .event_type(),
577            "error"
578        );
579        assert_eq!(
580            FeedEvent::Heartbeat(HeartbeatData {
581                timestamp: 0,
582                connected_clients: 0,
583                uptime_seconds: 0,
584            })
585            .event_type(),
586            "heartbeat"
587        );
588    }
589
590    #[test]
591    fn test_feed_state_new() {
592        let state = FeedState::new(100);
593        assert_eq!(state.capacity(), 100);
594        assert_eq!(state.connected_clients(), 0);
595    }
596
597    #[tokio::test]
598    async fn test_feed_state_publish_no_subscribers() {
599        let state = FeedState::new(10);
600        let count = state.publish_capture_received("test", "https://test.com", "screenshot");
601        assert_eq!(count, 0); // No subscribers
602    }
603
604    #[tokio::test]
605    async fn test_feed_state_publish_with_subscriber() {
606        let state = Arc::new(FeedState::new(10));
607        let mut receiver = state.subscribe();
608
609        let count = state.publish_capture_received("test", "https://test.com", "screenshot");
610        assert_eq!(count, 1);
611
612        let event = receiver.recv().await.unwrap();
613        match event {
614            FeedEvent::CaptureReceived(data) => {
615                assert_eq!(data.capture_id, "test");
616                assert_eq!(data.url, "https://test.com");
617            }
618            _ => panic!("Expected CaptureReceived"),
619        }
620    }
621
622    #[tokio::test]
623    async fn test_feed_state_client_tracking() {
624        let state = Arc::new(FeedState::new(10));
625        assert_eq!(state.connected_clients(), 0);
626
627        {
628            let _guard = ClientGuard::new(Arc::clone(&state));
629            assert_eq!(state.connected_clients(), 1);
630
631            {
632                let _guard2 = ClientGuard::new(Arc::clone(&state));
633                assert_eq!(state.connected_clients(), 2);
634            }
635
636            assert_eq!(state.connected_clients(), 1);
637        }
638
639        assert_eq!(state.connected_clients(), 0);
640    }
641
642    #[tokio::test]
643    async fn test_feed_state_uptime() {
644        let state = FeedState::new(10);
645        let uptime1 = state.uptime_seconds();
646
647        sleep(Duration::from_millis(100)).await;
648
649        let uptime2 = state.uptime_seconds();
650        // Uptime should be the same or slightly higher (within 1 second)
651        assert!(uptime2 >= uptime1);
652    }
653
654    #[test]
655    fn test_error_event() {
656        let state = FeedState::new(10);
657        let _receiver = state.subscribe();
658
659        let count = state.publish_error(
660            Some("capture-123".to_string()),
661            "E_TIMEOUT",
662            "Operation timed out",
663            true,
664        );
665        assert_eq!(count, 1);
666    }
667
668    #[test]
669    fn test_processing_complete_event() {
670        let state = FeedState::new(10);
671        let _receiver = state.subscribe();
672
673        let count = state.publish_processing_complete(
674            "capture-456",
675            150,
676            1024,
677            Some("Page title extracted".to_string()),
678        );
679        assert_eq!(count, 1);
680    }
681
682    #[test]
683    fn test_to_sse_event() {
684        let event = FeedEvent::Heartbeat(HeartbeatData {
685            timestamp: 1704067200000,
686            connected_clients: 5,
687            uptime_seconds: 3600,
688        });
689
690        let sse_event = event.to_sse_event().unwrap();
691        // SSE Event is opaque, but we can verify it was created
692        assert!(format!("{:?}", sse_event).contains("heartbeat"));
693    }
694
695    #[tokio::test]
696    async fn test_feed_stream_creation() {
697        let state = Arc::new(FeedState::new(10));
698
699        // Create stream - should increment client count
700        let _stream = FeedStream::new(Arc::clone(&state), 30);
701        assert_eq!(state.connected_clients(), 1);
702    }
703
704    #[tokio::test]
705    async fn test_multiple_subscribers() {
706        let state = Arc::new(FeedState::new(10));
707
708        let mut rx1 = state.subscribe();
709        let mut rx2 = state.subscribe();
710        let mut rx3 = state.subscribe();
711
712        let count = state.publish_capture_received("multi-test", "https://example.com", "html");
713        assert_eq!(count, 3);
714
715        // All receivers should get the event
716        assert!(rx1.recv().await.is_ok());
717        assert!(rx2.recv().await.is_ok());
718        assert!(rx3.recv().await.is_ok());
719    }
720
721    #[test]
722    fn test_current_timestamp_ms() {
723        let ts = current_timestamp_ms();
724        // Should be a reasonable timestamp (after year 2024)
725        assert!(ts > 1704067200000);
726    }
727}