use std::sync::Arc;
use crate::email::{BoxedMailer, Email};
#[cfg(feature = "cache")]
use std::time::Duration;
#[derive(Debug, thiserror::Error)]
pub enum NotificationError {
#[error("mail channel failed: {0}")]
Mail(String),
#[error("database channel failed: {0}")]
Database(String),
#[error("broadcast channel failed: {0}")]
Broadcast(String),
}
pub trait Notifiable {
fn notification_id(&self) -> Option<i64> {
None
}
fn notification_locale(&self) -> Option<String> {
None
}
}
#[derive(Debug, Clone, Default)]
pub struct NotificationDispatch {
pub email: Option<Email>,
pub database: Option<serde_json::Value>,
pub log: Option<String>,
pub broadcast: Option<serde_json::Value>,
}
impl NotificationDispatch {
#[must_use]
pub fn none() -> Self {
Self::default()
}
#[must_use]
pub fn email_only(email: Email) -> Self {
Self { email: Some(email), ..Self::default() }
}
}
pub trait Notification<N: Notifiable> {
fn build(&self, recipient: &N) -> NotificationDispatch;
}
pub type BroadcastFn = Arc<
dyn Fn(serde_json::Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), String>> + Send>>
+ Send
+ Sync,
>;
#[derive(Default, Clone)]
pub struct NotificationContext {
mailer: Option<BoxedMailer>,
database_pool: Option<crate::sql::sqlx::PgPool>,
database_table: Option<String>,
broadcast: Option<BroadcastFn>,
}
impl NotificationContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_mailer(mut self, mailer: BoxedMailer) -> Self {
self.mailer = Some(mailer);
self
}
#[must_use]
pub fn with_database(
mut self,
pool: crate::sql::sqlx::PgPool,
table: impl Into<String>,
) -> Self {
self.database_pool = Some(pool);
self.database_table = Some(table.into());
self
}
#[must_use]
pub fn with_broadcast<F, Fut>(mut self, callback: F) -> Self
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<(), String>> + Send + 'static,
{
self.broadcast = Some(Arc::new(move |v| Box::pin(callback(v))));
self
}
}
#[derive(Debug, Clone, Default)]
pub struct NotificationResult {
pub mail: ChannelOutcome,
pub database: ChannelOutcome,
pub log: ChannelOutcome,
pub broadcast: ChannelOutcome,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum ChannelOutcome {
#[default]
Skipped,
Delivered,
Failed(String),
}
impl NotificationResult {
#[must_use]
pub fn delivered_count(&self) -> usize {
[&self.mail, &self.database, &self.log, &self.broadcast]
.iter()
.filter(|o| matches!(o, ChannelOutcome::Delivered))
.count()
}
#[must_use]
pub fn any_delivered(&self) -> bool {
self.delivered_count() > 0
}
#[must_use]
pub fn no_failures(&self) -> bool {
![&self.mail, &self.database, &self.log, &self.broadcast]
.iter()
.any(|o| matches!(o, ChannelOutcome::Failed(_)))
}
}
pub async fn notify<N: Notifiable, T: Notification<N>>(
recipient: &N,
notification: &T,
ctx: &NotificationContext,
) -> NotificationResult {
let dispatch = notification.build(recipient);
let mut result = NotificationResult::default();
if let Some(email) = &dispatch.email {
result.mail = match &ctx.mailer {
Some(m) => match m.send(email).await {
Ok(()) => ChannelOutcome::Delivered,
Err(e) => ChannelOutcome::Failed(e.to_string()),
},
None => ChannelOutcome::Failed("no mailer configured".into()),
};
}
if let Some(payload) = &dispatch.database {
result.database = match (&ctx.database_pool, &ctx.database_table, recipient.notification_id()) {
(Some(pool), Some(table), Some(id)) => {
let kind = payload
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("notification")
.to_owned();
match insert_database_notification(pool, table, id, &kind, payload).await {
Ok(()) => ChannelOutcome::Delivered,
Err(e) => ChannelOutcome::Failed(e),
}
}
(None, _, _) | (_, None, _) => {
ChannelOutcome::Failed("database channel not configured".into())
}
(_, _, None) => ChannelOutcome::Skipped, };
}
if let Some(line) = &dispatch.log {
tracing::info!(notification = %line, "notifications");
result.log = ChannelOutcome::Delivered;
}
if let Some(payload) = &dispatch.broadcast {
result.broadcast = match &ctx.broadcast {
Some(callback) => match callback(payload.clone()).await {
Ok(()) => ChannelOutcome::Delivered,
Err(e) => ChannelOutcome::Failed(e),
},
None => ChannelOutcome::Failed("broadcast channel not configured".into()),
};
}
result
}
pub async fn notify_many<N: Notifiable, T: Notification<N>>(
recipients: &[&N],
notification: &T,
ctx: &NotificationContext,
) -> Vec<NotificationResult> {
let mut out = Vec::with_capacity(recipients.len());
for r in recipients {
out.push(notify(*r, notification, ctx).await);
}
out
}
async fn insert_database_notification(
pool: &crate::sql::sqlx::PgPool,
table: &str,
notifiable_id: i64,
kind: &str,
payload: &serde_json::Value,
) -> Result<(), String> {
validate_table_name(table)?;
let sql = format!(
r#"INSERT INTO "{table}" ("notifiable_id", "type", "data") VALUES ($1, $2, $3)"#,
);
crate::sql::sqlx::query(&sql)
.bind(notifiable_id)
.bind(kind)
.bind(payload)
.execute(pool)
.await
.map_err(|e| e.to_string())
.map(|_| ())
}
fn validate_table_name(name: &str) -> Result<(), String> {
if name.is_empty() {
return Err("table name is empty".into());
}
let bad = ['"', '\0', '\n', '\r', '\\', ';', ' '];
if name.chars().any(|c| bad.contains(&c) || c.is_control()) {
return Err(format!("table name `{name}` contains forbidden characters"));
}
Ok(())
}
#[cfg(feature = "cache")]
pub async fn should_send_throttled(
cache: &dyn crate::cache::Cache,
key: &str,
ttl: Duration,
) -> bool {
if cache.exists(key).await.unwrap_or(false) {
return false;
}
let _ = cache.set(key, "1", Some(ttl)).await;
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::email::{Email, InMemoryMailer, Mailer};
use std::sync::Mutex;
struct TestUser { id: i64, email: String }
impl Notifiable for TestUser {
fn notification_id(&self) -> Option<i64> { Some(self.id) }
}
struct WelcomeEmail;
impl Notification<TestUser> for WelcomeEmail {
fn build(&self, user: &TestUser) -> NotificationDispatch {
NotificationDispatch {
email: Some(
Email::new()
.to(&user.email)
.from("noreply@app")
.subject("Welcome")
.body("Hi"),
),
log: Some(format!("welcomed user {}", user.id)),
..NotificationDispatch::default()
}
}
}
#[tokio::test]
async fn dispatch_sends_to_configured_channels() {
let mailer = Arc::new(InMemoryMailer::new());
let mailer_clone = mailer.clone();
let ctx = NotificationContext::new().with_mailer(mailer_clone as _);
let u = TestUser { id: 1, email: "a@x.com".into() };
let r = notify(&u, &WelcomeEmail, &ctx).await;
assert_eq!(r.mail, ChannelOutcome::Delivered);
assert_eq!(r.log, ChannelOutcome::Delivered);
assert_eq!(r.database, ChannelOutcome::Skipped); assert_eq!(r.broadcast, ChannelOutcome::Skipped);
assert_eq!(mailer.count(), 1);
}
#[tokio::test]
async fn missing_mailer_records_failure_without_aborting_other_channels() {
let ctx = NotificationContext::new(); let u = TestUser { id: 1, email: "a@x.com".into() };
let r = notify(&u, &WelcomeEmail, &ctx).await;
assert!(matches!(r.mail, ChannelOutcome::Failed(_)));
assert_eq!(r.log, ChannelOutcome::Delivered, "log channel should still fire");
assert!(!r.no_failures());
assert!(r.any_delivered());
}
#[tokio::test]
async fn empty_dispatch_skips_everything() {
struct NoOp;
impl Notification<TestUser> for NoOp {
fn build(&self, _user: &TestUser) -> NotificationDispatch {
NotificationDispatch::none()
}
}
let ctx = NotificationContext::new();
let u = TestUser { id: 1, email: "a@x.com".into() };
let r = notify(&u, &NoOp, &ctx).await;
assert_eq!(r.mail, ChannelOutcome::Skipped);
assert_eq!(r.log, ChannelOutcome::Skipped);
assert_eq!(r.database, ChannelOutcome::Skipped);
assert_eq!(r.broadcast, ChannelOutcome::Skipped);
assert!(r.no_failures());
assert_eq!(r.delivered_count(), 0);
}
#[tokio::test]
async fn broadcast_callback_fires() {
let captured: Arc<Mutex<Vec<serde_json::Value>>> = Arc::new(Mutex::new(Vec::new()));
let cap = captured.clone();
let ctx = NotificationContext::new().with_broadcast(move |payload| {
let cap = cap.clone();
async move {
cap.lock().unwrap().push(payload);
Ok(())
}
});
struct WithBroadcast;
impl Notification<TestUser> for WithBroadcast {
fn build(&self, _user: &TestUser) -> NotificationDispatch {
NotificationDispatch {
broadcast: Some(serde_json::json!({"event": "ping"})),
..NotificationDispatch::default()
}
}
}
let u = TestUser { id: 1, email: "a@x.com".into() };
let r = notify(&u, &WithBroadcast, &ctx).await;
assert_eq!(r.broadcast, ChannelOutcome::Delivered);
assert_eq!(captured.lock().unwrap().len(), 1);
}
#[tokio::test]
async fn notify_many_returns_one_result_per_recipient() {
let ctx = NotificationContext::new();
let users = vec![
TestUser { id: 1, email: "a@x".into() },
TestUser { id: 2, email: "b@x".into() },
TestUser { id: 3, email: "c@x".into() },
];
let refs: Vec<&TestUser> = users.iter().collect();
let results = notify_many(&refs, &WelcomeEmail, &ctx).await;
assert_eq!(results.len(), 3);
for r in &results {
assert_eq!(r.log, ChannelOutcome::Delivered);
}
}
#[tokio::test]
async fn delivered_count_matches_actual_deliveries() {
let mailer: Arc<dyn Mailer> = Arc::new(InMemoryMailer::new());
let ctx = NotificationContext::new().with_mailer(mailer);
let u = TestUser { id: 1, email: "a@x".into() };
let r = notify(&u, &WelcomeEmail, &ctx).await;
assert_eq!(r.delivered_count(), 2); }
#[cfg(feature = "cache")]
#[tokio::test]
async fn throttle_helper_allows_first_send_blocks_repeat() {
use crate::cache::InMemoryCache;
let cache = InMemoryCache::new();
assert!(should_send_throttled(&cache, "k", Duration::from_secs(60)).await);
assert!(!should_send_throttled(&cache, "k", Duration::from_secs(60)).await);
}
}