use std::cell::RefCell;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ProgressEvent {
pub progress: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
impl ProgressEvent {
pub fn new(progress: f64) -> Self {
Self { progress, total: None, message: None }
}
pub fn with_total(mut self, total: f64) -> Self {
self.total = Some(total);
self
}
pub fn with_message(mut self, message: impl Into<String>) -> Self {
self.message = Some(message.into());
self
}
}
pub trait ProgressSink: Send + Sync {
fn emit(&self, event: ProgressEvent);
}
#[derive(Debug, Default)]
pub struct NoopProgressSink;
impl ProgressSink for NoopProgressSink {
fn emit(&self, _event: ProgressEvent) {}
}
#[derive(Debug, Clone)]
pub struct ChannelProgressSink {
sender: mpsc::UnboundedSender<ProgressEvent>,
}
impl ChannelProgressSink {
pub fn channel() -> (Self, mpsc::UnboundedReceiver<ProgressEvent>) {
let (tx, rx) = mpsc::unbounded_channel();
(Self { sender: tx }, rx)
}
}
impl ProgressSink for ChannelProgressSink {
fn emit(&self, event: ProgressEvent) {
let _ = self.sender.send(event);
}
}
thread_local! {
static ACTIVE_SINK: RefCell<Option<Arc<dyn ProgressSink>>> = const { RefCell::new(None) };
}
pub fn set_sink(sink: Arc<dyn ProgressSink>) {
ACTIVE_SINK.with(|s| *s.borrow_mut() = Some(sink));
}
pub fn clear_sink() {
ACTIVE_SINK.with(|s| *s.borrow_mut() = None);
}
pub fn emit(event: ProgressEvent) {
ACTIVE_SINK.with(|s| {
if let Some(sink) = s.borrow().as_ref() {
sink.emit(event);
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn emit_without_sink_is_noop() {
emit(ProgressEvent::new(1.0));
}
#[test]
fn channel_sink_collects_events() {
let (sink, mut rx) = ChannelProgressSink::channel();
sink.emit(ProgressEvent::new(0.0).with_total(2.0).with_message("a"));
sink.emit(ProgressEvent::new(1.0).with_total(2.0).with_message("b"));
sink.emit(ProgressEvent::new(2.0).with_total(2.0).with_message("c"));
drop(sink);
let mut got = Vec::new();
while let Ok(ev) = rx.try_recv() {
got.push(ev);
}
assert_eq!(got.len(), 3);
assert_eq!(got[0].progress, 0.0);
assert_eq!(got[2].message.as_deref(), Some("c"));
}
#[test]
fn thread_local_sink_routes_emit() {
let (sink, mut rx) = ChannelProgressSink::channel();
set_sink(Arc::new(sink));
emit(ProgressEvent::new(7.0));
clear_sink();
emit(ProgressEvent::new(99.0));
let first = rx.try_recv().expect("first event");
assert_eq!(first.progress, 7.0);
assert!(rx.try_recv().is_err(), "no second event after clear_sink");
}
}