use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use async_trait::async_trait;
use irontide_core::Id20;
use parking_lot::Mutex;
use thiserror::Error;
use tokio::sync::{broadcast, oneshot};
use tracing::{debug, warn};
use crate::alert::{Alert, AlertKind};
use crate::settings::Settings;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NotificationRecord {
pub summary: String,
pub body: String,
}
#[derive(Debug, Error)]
pub enum NotificationError {
#[error("notify-rust failed: {0}")]
Backend(String),
#[error("spawn_blocking join error: {0}")]
JoinError(String),
#[error("injected test failure: {0}")]
Test(String),
}
#[async_trait]
pub trait NotificationSink: Send + Sync + 'static {
async fn show(&self, summary: &str, body: &str) -> Result<(), NotificationError>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct LibNotifySink;
impl LibNotifySink {
#[must_use]
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl NotificationSink for LibNotifySink {
async fn show(&self, summary: &str, body: &str) -> Result<(), NotificationError> {
let summary = summary.to_string();
let body = body.to_string();
tokio::task::spawn_blocking(move || {
notify_rust::Notification::new()
.summary(&summary)
.body(&body)
.appname("irontide")
.show()
.map(|_handle| ())
.map_err(|e| NotificationError::Backend(e.to_string()))
})
.await
.map_err(|e| NotificationError::JoinError(e.to_string()))?
}
}
#[derive(Debug, Clone, Default)]
pub struct InMemorySink {
pub records: Arc<Mutex<Vec<NotificationRecord>>>,
fail_with: Option<String>,
}
impl InMemorySink {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_failure(message: impl Into<String>) -> Self {
Self {
records: Arc::new(Mutex::new(Vec::new())),
fail_with: Some(message.into()),
}
}
#[must_use]
pub fn snapshot(&self) -> Vec<NotificationRecord> {
self.records.lock().clone()
}
}
#[async_trait]
impl NotificationSink for InMemorySink {
async fn show(&self, summary: &str, body: &str) -> Result<(), NotificationError> {
if let Some(ref msg) = self.fail_with {
return Err(NotificationError::Test(msg.clone()));
}
self.records.lock().push(NotificationRecord {
summary: summary.to_string(),
body: body.to_string(),
});
Ok(())
}
}
#[must_use]
pub fn sanitize_notification_text(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'<' => out.push_str("<"),
'>' => out.push_str(">"),
'&' => out.push_str("&"),
'"' => out.push_str("""),
'\'' => out.push_str("'"),
_ => out.push(c),
}
}
out
}
fn dispatch_one(
settings: &Settings,
kind: &AlertKind,
name_cache: &HashMap<Id20, String>,
) -> Option<(&'static str, String)> {
match kind {
AlertKind::TorrentFinished { info_hash } if settings.notify_on_complete => {
let raw = name_cache
.get(info_hash)
.cloned()
.unwrap_or_else(|| info_hash_short_hex(*info_hash));
let name = sanitize_notification_text(&raw);
Some(("IronTide", format!("{name} download complete")))
}
AlertKind::TorrentError { info_hash, message } if settings.notify_on_error => {
let raw = name_cache
.get(info_hash)
.cloned()
.unwrap_or_else(|| info_hash_short_hex(*info_hash));
let name = sanitize_notification_text(&raw);
let message = sanitize_notification_text(message);
Some(("IronTide", format!("{name}: {message}")))
}
_ => None,
}
}
fn info_hash_short_hex(hash: Id20) -> String {
let hex = hash.to_hex();
hex.chars().take(8).collect()
}
pub struct DispatcherOptions {
pub sink: Box<dyn NotificationSink>,
pub settings_rx: tokio::sync::watch::Receiver<Settings>,
pub alerts_rx: broadcast::Receiver<Alert>,
pub shutdown_rx: oneshot::Receiver<()>,
}
#[must_use]
pub fn spawn_notification_dispatcher(opts: DispatcherOptions) -> tokio::task::JoinHandle<()> {
let DispatcherOptions {
sink,
settings_rx,
mut alerts_rx,
mut shutdown_rx,
} = opts;
tokio::spawn(async move {
let mut name_cache: HashMap<Id20, String> = HashMap::new();
let dbus_failure_logged = AtomicBool::new(false);
loop {
tokio::select! {
_ = &mut shutdown_rx => {
debug!("notification dispatcher: shutdown signal received");
break;
}
event = alerts_rx.recv() => {
let alert = match event {
Ok(alert) => alert,
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!(lagged = n, "notification dispatcher: alert stream lagged");
continue;
}
Err(broadcast::error::RecvError::Closed) => {
debug!("notification dispatcher: alert stream closed");
break;
}
};
match &alert.kind {
AlertKind::TorrentAdded { info_hash, name } => {
name_cache.insert(*info_hash, name.clone());
continue;
}
AlertKind::TorrentRemoved { info_hash } => {
name_cache.remove(info_hash);
continue;
}
_ => {}
}
let settings = settings_rx.borrow().clone();
let Some((summary, body)) = dispatch_one(&settings, &alert.kind, &name_cache)
else {
continue;
};
if let Err(e) = sink.show(summary, &body).await
&& !dbus_failure_logged.swap(true, Ordering::Relaxed)
{
warn!(
error = %e,
"notification dispatcher: sink failed; degrading silently for the rest of the session"
);
}
}
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alert::Alert;
use std::time::Duration;
fn fake_hash(byte: u8) -> Id20 {
Id20([byte; 20])
}
#[test]
fn sanitizer_escapes_all_five_html_metacharacters() {
let input = r#"<b>"hi" & 'bye'</b>"#;
let out = sanitize_notification_text(input);
assert_eq!(out, "<b>"hi" & 'bye'</b>");
}
#[test]
fn sanitizer_does_not_double_escape_existing_entities() {
assert_eq!(sanitize_notification_text("AT&T"), "AT&T");
assert_eq!(
sanitize_notification_text("<AT&T>"),
"<AT&T>"
);
}
#[test]
fn sanitizer_passes_through_non_ascii_unchanged() {
let input = "Façade — 管理 — 🦀";
assert_eq!(sanitize_notification_text(input), input);
}
#[test]
fn dispatch_one_skips_when_notify_on_complete_false() {
let s = Settings {
notify_on_complete: false,
..Default::default()
};
let hash = fake_hash(0xAA);
let cache = HashMap::from([(hash, "test-torrent".to_string())]);
let result = dispatch_one(
&s,
&AlertKind::TorrentFinished { info_hash: hash },
&cache,
);
assert!(result.is_none(), "gate false must skip dispatch");
}
#[test]
fn dispatch_one_emits_when_notify_on_complete_true() {
let s = Settings {
notify_on_complete: true,
..Default::default()
};
let hash = fake_hash(0xBB);
let cache = HashMap::from([(hash, "my-movie".to_string())]);
let (summary, body) = dispatch_one(
&s,
&AlertKind::TorrentFinished { info_hash: hash },
&cache,
)
.expect("gate true + finished must dispatch");
assert_eq!(summary, "IronTide");
assert_eq!(body, "my-movie download complete");
}
#[test]
fn dispatch_one_emits_error_body_with_sanitised_message() {
let s = Settings {
notify_on_error: true,
..Default::default()
};
let hash = fake_hash(0xCC);
let cache = HashMap::from([(hash, "<evil>".to_string())]);
let (summary, body) = dispatch_one(
&s,
&AlertKind::TorrentError {
info_hash: hash,
message: "disk full <again>".to_string(),
},
&cache,
)
.expect("gate true + error must dispatch");
assert_eq!(summary, "IronTide");
assert_eq!(body, "<evil>: disk full <again>");
}
#[test]
fn dispatch_one_falls_back_to_hex_prefix_when_cache_miss() {
let s = Settings {
notify_on_complete: true,
..Default::default()
};
let hash = fake_hash(0xDE);
let cache = HashMap::new();
let (_, body) =
dispatch_one(&s, &AlertKind::TorrentFinished { info_hash: hash }, &cache).unwrap();
assert!(
body.starts_with("dededede"),
"cache miss must fall back to hex prefix, got: {body}"
);
}
#[tokio::test]
async fn in_memory_sink_records_and_can_inject_failure() {
let sink = InMemorySink::new();
sink.show("title", "body").await.unwrap();
let snap = sink.snapshot();
assert_eq!(snap.len(), 1);
assert_eq!(snap[0].summary, "title");
let failing = InMemorySink::with_failure("boom");
let err = failing.show("t", "b").await.unwrap_err();
assert!(matches!(err, NotificationError::Test(_)));
}
#[tokio::test]
async fn dispatcher_emits_completion_notification_with_cached_name() {
let (alert_tx, alert_rx) = broadcast::channel::<Alert>(16);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let settings = Settings {
notify_on_complete: true,
..Default::default()
};
let (settings_tx, settings_rx) = tokio::sync::watch::channel(settings);
let _ = &settings_tx;
let sink = InMemorySink::new();
let join = spawn_notification_dispatcher(DispatcherOptions {
sink: Box::new(sink.clone()),
settings_rx,
alerts_rx: alert_rx,
shutdown_rx,
});
let hash = fake_hash(0xEE);
alert_tx
.send(Alert::new(AlertKind::TorrentAdded {
info_hash: hash,
name: "demo-torrent".to_string(),
}))
.unwrap();
alert_tx
.send(Alert::new(AlertKind::TorrentFinished { info_hash: hash }))
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let _ = shutdown_tx.send(());
join.await.unwrap();
let records = sink.snapshot();
assert_eq!(records.len(), 1, "exactly one notification expected");
assert_eq!(records[0].summary, "IronTide");
assert_eq!(records[0].body, "demo-torrent download complete");
}
#[tokio::test]
async fn dispatcher_handles_sink_failure_without_crashing() {
let (alert_tx, alert_rx) = broadcast::channel::<Alert>(16);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let settings = Settings {
notify_on_complete: true,
..Default::default()
};
let (_settings_tx, settings_rx) = tokio::sync::watch::channel(settings);
let sink = InMemorySink::with_failure("no dbus session");
let join = spawn_notification_dispatcher(DispatcherOptions {
sink: Box::new(sink),
settings_rx,
alerts_rx: alert_rx,
shutdown_rx,
});
let hash = fake_hash(0xFF);
alert_tx
.send(Alert::new(AlertKind::TorrentAdded {
info_hash: hash,
name: "fail-test".to_string(),
}))
.unwrap();
alert_tx
.send(Alert::new(AlertKind::TorrentFinished { info_hash: hash }))
.unwrap();
alert_tx
.send(Alert::new(AlertKind::TorrentFinished { info_hash: hash }))
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let _ = shutdown_tx.send(());
join.await.expect("dispatcher must not panic on sink failure");
}
#[tokio::test]
async fn dispatcher_respects_live_settings_toggle() {
let (alert_tx, alert_rx) = broadcast::channel::<Alert>(16);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let settings = Settings {
notify_on_complete: false,
..Default::default()
};
let (settings_tx, settings_rx) = tokio::sync::watch::channel(settings);
let sink = InMemorySink::new();
let join = spawn_notification_dispatcher(DispatcherOptions {
sink: Box::new(sink.clone()),
settings_rx,
alerts_rx: alert_rx,
shutdown_rx,
});
let hash_a = fake_hash(0xA1);
alert_tx
.send(Alert::new(AlertKind::TorrentAdded {
info_hash: hash_a,
name: "first".to_string(),
}))
.unwrap();
alert_tx
.send(Alert::new(AlertKind::TorrentFinished { info_hash: hash_a }))
.unwrap();
tokio::time::sleep(Duration::from_millis(30)).await;
settings_tx.send_modify(|s| s.notify_on_complete = true);
let hash_b = fake_hash(0xB2);
alert_tx
.send(Alert::new(AlertKind::TorrentAdded {
info_hash: hash_b,
name: "second".to_string(),
}))
.unwrap();
alert_tx
.send(Alert::new(AlertKind::TorrentFinished { info_hash: hash_b }))
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let _ = shutdown_tx.send(());
join.await.unwrap();
let records = sink.snapshot();
assert_eq!(
records.len(),
1,
"only the second TorrentFinished must emit"
);
assert_eq!(records[0].body, "second download complete");
}
#[tokio::test]
async fn dispatcher_evicts_name_cache_on_torrent_removed() {
let (alert_tx, alert_rx) = broadcast::channel::<Alert>(16);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let settings = Settings {
notify_on_complete: true,
..Default::default()
};
let (_settings_tx, settings_rx) = tokio::sync::watch::channel(settings);
let sink = InMemorySink::new();
let join = spawn_notification_dispatcher(DispatcherOptions {
sink: Box::new(sink.clone()),
settings_rx,
alerts_rx: alert_rx,
shutdown_rx,
});
let hash = fake_hash(0xCA);
alert_tx
.send(Alert::new(AlertKind::TorrentAdded {
info_hash: hash,
name: "cache-test".to_string(),
}))
.unwrap();
alert_tx
.send(Alert::new(AlertKind::TorrentRemoved { info_hash: hash }))
.unwrap();
alert_tx
.send(Alert::new(AlertKind::TorrentFinished { info_hash: hash }))
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let _ = shutdown_tx.send(());
join.await.unwrap();
let records = sink.snapshot();
assert_eq!(records.len(), 1);
assert!(
!records[0].body.contains("cache-test"),
"after eviction the cached name must NOT appear; got {}",
records[0].body
);
assert!(
records[0].body.starts_with("cacacaca"),
"expected hex-prefix fallback; got {}",
records[0].body
);
}
#[tokio::test]
async fn dispatcher_exits_cleanly_when_alert_stream_closes() {
let (alert_tx, alert_rx) = broadcast::channel::<Alert>(4);
let (_shutdown_tx, shutdown_rx) = oneshot::channel();
let (_settings_tx, settings_rx) = tokio::sync::watch::channel(Settings::default());
let join = spawn_notification_dispatcher(DispatcherOptions {
sink: Box::new(InMemorySink::new()),
settings_rx,
alerts_rx: alert_rx,
shutdown_rx,
});
drop(alert_tx);
tokio::time::timeout(Duration::from_secs(1), join)
.await
.expect("dispatcher must exit after broadcast Sender drops")
.unwrap();
}
}