1use axum::{
30 extract::State,
31 response::sse::{Event, KeepAlive, Sse},
32};
33use futures::stream::Stream;
34use serde::{Deserialize, Serialize};
35use std::{
36 convert::Infallible,
37 pin::Pin,
38 sync::{
39 atomic::{AtomicU64, Ordering},
40 Arc,
41 },
42 task::{Context, Poll},
43 time::Duration,
44};
45use tokio::sync::broadcast::{self, Receiver, Sender};
46use tokio::time::{interval, Interval};
47use tracing::{debug, info, instrument, warn};
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
55#[serde(tag = "type", content = "data")]
56pub enum FeedEvent {
57 #[serde(rename = "capture_received")]
59 CaptureReceived(CaptureReceivedData),
60
61 #[serde(rename = "processing_complete")]
63 ProcessingComplete(ProcessingCompleteData),
64
65 #[serde(rename = "error")]
67 Error(ErrorData),
68
69 #[serde(rename = "heartbeat")]
71 Heartbeat(HeartbeatData),
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
76pub struct CaptureReceivedData {
77 pub capture_id: String,
79 pub url: String,
81 pub timestamp: u64,
83 pub capture_type: String,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
89pub struct ProcessingCompleteData {
90 pub capture_id: String,
92 pub duration_ms: u64,
94 pub size_bytes: u64,
96 #[serde(skip_serializing_if = "Option::is_none")]
98 pub summary: Option<String>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
103pub struct ErrorData {
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub capture_id: Option<String>,
107 pub code: String,
109 pub message: String,
111 pub recoverable: bool,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
117pub struct HeartbeatData {
118 pub timestamp: u64,
120 pub connected_clients: u64,
122 pub uptime_seconds: u64,
124}
125
126impl FeedEvent {
127 pub fn event_type(&self) -> &'static str {
129 match self {
130 FeedEvent::CaptureReceived(_) => "capture_received",
131 FeedEvent::ProcessingComplete(_) => "processing_complete",
132 FeedEvent::Error(_) => "error",
133 FeedEvent::Heartbeat(_) => "heartbeat",
134 }
135 }
136
137 pub fn to_sse_event(&self) -> Result<Event, serde_json::Error> {
139 let data = serde_json::to_string(self)?;
140 Ok(Event::default().event(self.event_type()).data(data))
141 }
142}
143
144pub struct FeedState {
153 sender: Sender<FeedEvent>,
155 connected_clients: AtomicU64,
157 start_time: std::time::Instant,
159 capacity: usize,
161}
162
163impl FeedState {
164 pub fn new(capacity: usize) -> Self {
171 let (sender, _) = broadcast::channel(capacity);
172 Self {
173 sender,
174 connected_clients: AtomicU64::new(0),
175 start_time: std::time::Instant::now(),
176 capacity,
177 }
178 }
179
180 pub fn subscribe(&self) -> Receiver<FeedEvent> {
184 self.sender.subscribe()
185 }
186
187 #[instrument(skip(self, event), fields(event_type = event.event_type()))]
192 pub fn publish(&self, event: FeedEvent) -> usize {
193 match self.sender.send(event) {
194 Ok(count) => {
195 debug!("Published event to {} clients", count);
196 count
197 }
198 Err(_) => {
199 debug!("No clients connected, event dropped");
200 0
201 }
202 }
203 }
204
205 pub fn publish_capture_received(
207 &self,
208 capture_id: impl Into<String>,
209 url: impl Into<String>,
210 capture_type: impl Into<String>,
211 ) -> usize {
212 self.publish(FeedEvent::CaptureReceived(CaptureReceivedData {
213 capture_id: capture_id.into(),
214 url: url.into(),
215 timestamp: current_timestamp_ms(),
216 capture_type: capture_type.into(),
217 }))
218 }
219
220 pub fn publish_processing_complete(
222 &self,
223 capture_id: impl Into<String>,
224 duration_ms: u64,
225 size_bytes: u64,
226 summary: Option<String>,
227 ) -> usize {
228 self.publish(FeedEvent::ProcessingComplete(ProcessingCompleteData {
229 capture_id: capture_id.into(),
230 duration_ms,
231 size_bytes,
232 summary,
233 }))
234 }
235
236 pub fn publish_error(
238 &self,
239 capture_id: Option<String>,
240 code: impl Into<String>,
241 message: impl Into<String>,
242 recoverable: bool,
243 ) -> usize {
244 self.publish(FeedEvent::Error(ErrorData {
245 capture_id,
246 code: code.into(),
247 message: message.into(),
248 recoverable,
249 }))
250 }
251
252 pub fn connected_clients(&self) -> u64 {
254 self.connected_clients.load(Ordering::Relaxed)
255 }
256
257 pub fn uptime_seconds(&self) -> u64 {
259 self.start_time.elapsed().as_secs()
260 }
261
262 pub fn capacity(&self) -> usize {
264 self.capacity
265 }
266
267 fn client_connected(&self) -> u64 {
269 let count = self.connected_clients.fetch_add(1, Ordering::Relaxed) + 1;
270 info!("Client connected, total: {}", count);
271 count
272 }
273
274 fn client_disconnected(&self) -> u64 {
276 let count = self.connected_clients.fetch_sub(1, Ordering::Relaxed) - 1;
277 info!("Client disconnected, total: {}", count);
278 count
279 }
280}
281
282impl Default for FeedState {
283 fn default() -> Self {
284 Self::new(1024)
285 }
286}
287
288struct ClientGuard {
296 state: Arc<FeedState>,
297}
298
299impl ClientGuard {
300 fn new(state: Arc<FeedState>) -> Self {
301 state.client_connected();
302 Self { state }
303 }
304}
305
306impl Drop for ClientGuard {
307 fn drop(&mut self) {
308 self.state.client_disconnected();
309 }
310}
311
312pub struct FeedStream {
321 receiver: Receiver<FeedEvent>,
323 heartbeat_interval: Interval,
325 state: Arc<FeedState>,
327 _guard: ClientGuard,
329 #[allow(dead_code)]
331 stream_id: u64,
332}
333
334impl FeedStream {
335 pub fn new(state: Arc<FeedState>, heartbeat_interval_secs: u64) -> Self {
342 static STREAM_COUNTER: AtomicU64 = AtomicU64::new(0);
343 let stream_id = STREAM_COUNTER.fetch_add(1, Ordering::Relaxed);
344
345 let receiver = state.subscribe();
346 let heartbeat_interval = interval(Duration::from_secs(heartbeat_interval_secs));
347 let guard = ClientGuard::new(Arc::clone(&state));
348
349 debug!("Created FeedStream {}", stream_id);
350
351 Self {
352 receiver,
353 heartbeat_interval,
354 state,
355 _guard: guard,
356 stream_id,
357 }
358 }
359
360 fn generate_heartbeat(&self) -> FeedEvent {
362 FeedEvent::Heartbeat(HeartbeatData {
363 timestamp: current_timestamp_ms(),
364 connected_clients: self.state.connected_clients(),
365 uptime_seconds: self.state.uptime_seconds(),
366 })
367 }
368}
369
370impl Stream for FeedStream {
371 type Item = Result<Event, Infallible>;
372
373 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
374 if self.heartbeat_interval.poll_tick(cx).is_ready() {
376 let heartbeat = self.generate_heartbeat();
377 match heartbeat.to_sse_event() {
378 Ok(event) => return Poll::Ready(Some(Ok(event))),
379 Err(e) => {
380 warn!("Failed to serialize heartbeat: {}", e);
381 }
383 }
384 }
385
386 match self.receiver.try_recv() {
388 Ok(feed_event) => match feed_event.to_sse_event() {
389 Ok(event) => Poll::Ready(Some(Ok(event))),
390 Err(e) => {
391 warn!("Failed to serialize event: {}", e);
392 cx.waker().wake_by_ref();
394 Poll::Pending
395 }
396 },
397 Err(broadcast::error::TryRecvError::Empty) => {
398 cx.waker().wake_by_ref();
400 Poll::Pending
401 }
402 Err(broadcast::error::TryRecvError::Lagged(count)) => {
403 warn!("Client lagged behind by {} events", count);
405 cx.waker().wake_by_ref();
406 Poll::Pending
407 }
408 Err(broadcast::error::TryRecvError::Closed) => {
409 debug!("Broadcast channel closed, ending stream");
411 Poll::Ready(None)
412 }
413 }
414 }
415}
416
417#[instrument(skip(state))]
438pub async fn feed_handler(
439 State(state): State<Arc<FeedState>>,
440) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
441 info!("New SSE client connected to /feed");
442
443 let stream = FeedStream::new(state, 30); Sse::new(stream).keep_alive(
446 KeepAlive::new()
447 .interval(Duration::from_secs(15))
448 .text("keep-alive"),
449 )
450}
451
452#[instrument(skip(state))]
454pub async fn feed_handler_with_interval(
455 State(state): State<Arc<FeedState>>,
456 heartbeat_secs: u64,
457) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
458 info!(
459 "New SSE client connected to /feed (heartbeat: {}s)",
460 heartbeat_secs
461 );
462
463 let stream = FeedStream::new(state, heartbeat_secs);
464
465 Sse::new(stream).keep_alive(
466 KeepAlive::new()
467 .interval(Duration::from_secs(heartbeat_secs / 2))
468 .text("keep-alive"),
469 )
470}
471
472fn current_timestamp_ms() -> u64 {
478 use std::time::{SystemTime, UNIX_EPOCH};
479 SystemTime::now()
480 .duration_since(UNIX_EPOCH)
481 .unwrap_or_default()
482 .as_millis() as u64
483}
484
485pub fn build_feed_router(state: Arc<FeedState>) -> axum::Router {
501 use axum::routing::get;
502
503 axum::Router::new()
504 .route("/feed", get(feed_handler))
505 .with_state(state)
506}
507
508#[cfg(test)]
513mod tests {
514 use super::*;
515 use tokio::time::sleep;
516
517 #[test]
518 fn test_feed_event_serialization() {
519 let event = FeedEvent::CaptureReceived(CaptureReceivedData {
520 capture_id: "test-123".to_string(),
521 url: "https://example.com".to_string(),
522 timestamp: 1704067200000,
523 capture_type: "screenshot".to_string(),
524 });
525
526 let json = serde_json::to_string(&event).unwrap();
527 assert!(json.contains("capture_received"));
528 assert!(json.contains("test-123"));
529 assert!(json.contains("https://example.com"));
530 }
531
532 #[test]
533 fn test_feed_event_deserialization() {
534 let json = r#"{"type":"capture_received","data":{"capture_id":"abc","url":"https://test.com","timestamp":1000,"capture_type":"pdf"}}"#;
535 let event: FeedEvent = serde_json::from_str(json).unwrap();
536
537 match event {
538 FeedEvent::CaptureReceived(data) => {
539 assert_eq!(data.capture_id, "abc");
540 assert_eq!(data.url, "https://test.com");
541 assert_eq!(data.capture_type, "pdf");
542 }
543 _ => panic!("Expected CaptureReceived"),
544 }
545 }
546
547 #[test]
548 fn test_feed_event_type() {
549 assert_eq!(
550 FeedEvent::CaptureReceived(CaptureReceivedData {
551 capture_id: String::new(),
552 url: String::new(),
553 timestamp: 0,
554 capture_type: String::new(),
555 })
556 .event_type(),
557 "capture_received"
558 );
559 assert_eq!(
560 FeedEvent::ProcessingComplete(ProcessingCompleteData {
561 capture_id: String::new(),
562 duration_ms: 0,
563 size_bytes: 0,
564 summary: None,
565 })
566 .event_type(),
567 "processing_complete"
568 );
569 assert_eq!(
570 FeedEvent::Error(ErrorData {
571 capture_id: None,
572 code: String::new(),
573 message: String::new(),
574 recoverable: false,
575 })
576 .event_type(),
577 "error"
578 );
579 assert_eq!(
580 FeedEvent::Heartbeat(HeartbeatData {
581 timestamp: 0,
582 connected_clients: 0,
583 uptime_seconds: 0,
584 })
585 .event_type(),
586 "heartbeat"
587 );
588 }
589
590 #[test]
591 fn test_feed_state_new() {
592 let state = FeedState::new(100);
593 assert_eq!(state.capacity(), 100);
594 assert_eq!(state.connected_clients(), 0);
595 }
596
597 #[tokio::test]
598 async fn test_feed_state_publish_no_subscribers() {
599 let state = FeedState::new(10);
600 let count = state.publish_capture_received("test", "https://test.com", "screenshot");
601 assert_eq!(count, 0); }
603
604 #[tokio::test]
605 async fn test_feed_state_publish_with_subscriber() {
606 let state = Arc::new(FeedState::new(10));
607 let mut receiver = state.subscribe();
608
609 let count = state.publish_capture_received("test", "https://test.com", "screenshot");
610 assert_eq!(count, 1);
611
612 let event = receiver.recv().await.unwrap();
613 match event {
614 FeedEvent::CaptureReceived(data) => {
615 assert_eq!(data.capture_id, "test");
616 assert_eq!(data.url, "https://test.com");
617 }
618 _ => panic!("Expected CaptureReceived"),
619 }
620 }
621
622 #[tokio::test]
623 async fn test_feed_state_client_tracking() {
624 let state = Arc::new(FeedState::new(10));
625 assert_eq!(state.connected_clients(), 0);
626
627 {
628 let _guard = ClientGuard::new(Arc::clone(&state));
629 assert_eq!(state.connected_clients(), 1);
630
631 {
632 let _guard2 = ClientGuard::new(Arc::clone(&state));
633 assert_eq!(state.connected_clients(), 2);
634 }
635
636 assert_eq!(state.connected_clients(), 1);
637 }
638
639 assert_eq!(state.connected_clients(), 0);
640 }
641
642 #[tokio::test]
643 async fn test_feed_state_uptime() {
644 let state = FeedState::new(10);
645 let uptime1 = state.uptime_seconds();
646
647 sleep(Duration::from_millis(100)).await;
648
649 let uptime2 = state.uptime_seconds();
650 assert!(uptime2 >= uptime1);
652 }
653
654 #[test]
655 fn test_error_event() {
656 let state = FeedState::new(10);
657 let _receiver = state.subscribe();
658
659 let count = state.publish_error(
660 Some("capture-123".to_string()),
661 "E_TIMEOUT",
662 "Operation timed out",
663 true,
664 );
665 assert_eq!(count, 1);
666 }
667
668 #[test]
669 fn test_processing_complete_event() {
670 let state = FeedState::new(10);
671 let _receiver = state.subscribe();
672
673 let count = state.publish_processing_complete(
674 "capture-456",
675 150,
676 1024,
677 Some("Page title extracted".to_string()),
678 );
679 assert_eq!(count, 1);
680 }
681
682 #[test]
683 fn test_to_sse_event() {
684 let event = FeedEvent::Heartbeat(HeartbeatData {
685 timestamp: 1704067200000,
686 connected_clients: 5,
687 uptime_seconds: 3600,
688 });
689
690 let sse_event = event.to_sse_event().unwrap();
691 assert!(format!("{:?}", sse_event).contains("heartbeat"));
693 }
694
695 #[tokio::test]
696 async fn test_feed_stream_creation() {
697 let state = Arc::new(FeedState::new(10));
698
699 let _stream = FeedStream::new(Arc::clone(&state), 30);
701 assert_eq!(state.connected_clients(), 1);
702 }
703
704 #[tokio::test]
705 async fn test_multiple_subscribers() {
706 let state = Arc::new(FeedState::new(10));
707
708 let mut rx1 = state.subscribe();
709 let mut rx2 = state.subscribe();
710 let mut rx3 = state.subscribe();
711
712 let count = state.publish_capture_received("multi-test", "https://example.com", "html");
713 assert_eq!(count, 3);
714
715 assert!(rx1.recv().await.is_ok());
717 assert!(rx2.recv().await.is_ok());
718 assert!(rx3.recv().await.is_ok());
719 }
720
721 #[test]
722 fn test_current_timestamp_ms() {
723 let ts = current_timestamp_ms();
724 assert!(ts > 1704067200000);
726 }
727}