use std::sync::Arc;
use tokio::sync::broadcast;
use crate::engine::{ProgressSink, StepProgress};
pub const SSE_EVENT_NAME: &str = "step_progress";
pub struct ProgressHub {
tx: broadcast::Sender<StepProgress>,
}
impl ProgressHub {
#[must_use]
pub fn new(capacity: usize) -> Self {
let (tx, _rx) = broadcast::channel(capacity.max(1));
Self { tx }
}
#[must_use]
pub fn sink(&self) -> ProgressSink {
let tx = self.tx.clone();
Some(Arc::new(move |progress: StepProgress| {
let _ = tx.send(progress);
}))
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<StepProgress> {
self.tx.subscribe()
}
#[must_use]
pub fn subscriber_count(&self) -> usize {
self.tx.receiver_count()
}
}
impl std::fmt::Debug for ProgressHub {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProgressHub")
.field("subscribers", &self.tx.receiver_count())
.finish()
}
}
#[must_use]
pub fn progress_to_sse(progress: &StepProgress) -> String {
let data = serde_json::to_string(progress).unwrap_or_else(|_| "{}".to_owned());
sse_frame(Some(SSE_EVENT_NAME), Some(&progress.step_id), &data)
}
#[must_use]
pub fn sse_frame(event: Option<&str>, id: Option<&str>, data: &str) -> String {
use std::fmt::Write;
let mut out = String::with_capacity(data.len() + 32);
if let Some(event) = event {
let _ = writeln!(out, "event: {event}");
}
if let Some(id) = id {
let _ = writeln!(out, "id: {id}");
}
for line in data.split('\n') {
let _ = writeln!(out, "data: {line}");
}
out.push('\n');
out
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn progress(name: &str, data: serde_json::Value) -> StepProgress {
StepProgress {
step_name: name.to_owned(),
step_id: format!("{name}-id"),
data,
}
}
#[test]
fn sse_frame_basic() {
let frame = sse_frame(Some("step_progress"), Some("s1"), "{\"x\":1}");
assert_eq!(frame, "event: step_progress\nid: s1\ndata: {\"x\":1}\n\n");
}
#[test]
fn sse_frame_multiline_data() {
let frame = sse_frame(None, None, "a\nb\nc");
assert_eq!(frame, "data: a\ndata: b\ndata: c\n\n");
}
#[test]
fn sse_frame_omits_absent_fields() {
let frame = sse_frame(None, Some("9"), "hi");
assert_eq!(frame, "id: 9\ndata: hi\n\n");
}
#[test]
fn progress_to_sse_shape() {
let p = progress("upload", json!({"percent": 50}));
let frame = progress_to_sse(&p);
assert!(frame.starts_with("event: step_progress\n"));
assert!(frame.contains("id: upload-id\n"));
assert!(frame.contains("\"percent\":50"));
assert!(frame.ends_with("\n\n"));
}
#[tokio::test]
async fn hub_delivers_to_subscriber() {
let hub = ProgressHub::new(16);
let mut rx = hub.subscribe();
let sink = hub.sink().unwrap();
sink(progress("a", json!({"percent": 10})));
sink(progress("b", json!({"percent": 20})));
let first = rx.recv().await.unwrap();
let second = rx.recv().await.unwrap();
assert_eq!(first.step_name, "a");
assert_eq!(second.data["percent"], 20);
}
#[tokio::test]
async fn hub_fans_out_to_multiple_subscribers() {
let hub = ProgressHub::new(16);
let mut rx1 = hub.subscribe();
let mut rx2 = hub.subscribe();
assert_eq!(hub.subscriber_count(), 2);
hub.sink().unwrap()(progress("x", json!(null)));
assert_eq!(rx1.recv().await.unwrap().step_name, "x");
assert_eq!(rx2.recv().await.unwrap().step_name, "x");
}
#[tokio::test]
async fn hub_send_without_subscribers_is_ok() {
let hub = ProgressHub::new(4);
hub.sink().unwrap()(progress("orphan", json!(null)));
assert_eq!(hub.subscriber_count(), 0);
}
#[test]
fn hub_capacity_clamped() {
let hub = ProgressHub::new(0);
let _rx = hub.subscribe();
assert_eq!(hub.subscriber_count(), 1);
}
}