use crate::adapters::database::DbPool;
use crate::adapters::database::push_token_repo::PushTokenRepository;
use crate::adapters::push::{PushError, PushProvider};
use crate::adapters::redis::NotificationRepository;
use crate::config::NotificationConfig;
use opentelemetry::{KeyValue, global, metrics::Counter};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Semaphore, mpsc};
use tracing::Instrument;
use uuid::Uuid;
#[derive(Clone, Debug)]
struct Metrics {
sent: Counter<u64>,
errors: Counter<u64>,
invalidated_tokens: Counter<u64>,
}
impl Metrics {
fn new() -> Self {
let meter = global::meter("obscura-server");
Self {
sent: meter
.u64_counter("obscura_push_notifications_sent_total")
.with_description("Total number of push notifications successfully sent")
.build(),
errors: meter
.u64_counter("obscura_push_notification_errors_total")
.with_description("Total number of push notification delivery errors")
.build(),
invalidated_tokens: meter
.u64_counter("obscura_push_invalid_tokens_total")
.with_description("Total number of push tokens removed due to being unregistered")
.build(),
}
}
}
#[derive(Debug)]
pub struct PushNotificationWorker {
pool: DbPool,
repo: Arc<NotificationRepository>,
provider: Arc<dyn PushProvider>,
token_repo: PushTokenRepository,
interval_secs: u64,
visibility_timeout_secs: u64,
invalid_token_cleanup_interval_secs: u64,
invalid_token_cleanup_batch_size: usize,
invalid_token_cleanup_channel_capacity: usize,
semaphore: Arc<Semaphore>,
metrics: Metrics,
}
impl PushNotificationWorker {
pub fn new(
pool: DbPool,
repo: Arc<NotificationRepository>,
provider: Arc<dyn PushProvider>,
token_repo: PushTokenRepository,
config: &NotificationConfig,
) -> Self {
Self {
pool,
repo,
provider,
token_repo,
interval_secs: config.worker_interval_secs,
visibility_timeout_secs: config.visibility_timeout_secs,
invalid_token_cleanup_interval_secs: config.invalid_token_cleanup_interval_secs,
invalid_token_cleanup_batch_size: config.invalid_token_cleanup_batch_size,
invalid_token_cleanup_channel_capacity: config.invalid_token_cleanup_channel_capacity,
semaphore: Arc::new(Semaphore::new(config.worker_concurrency)),
metrics: Metrics::new(),
}
}
pub async fn run(self, mut shutdown: tokio::sync::watch::Receiver<bool>) {
let mut interval = tokio::time::interval(Duration::from_secs(self.interval_secs));
let mut cleanup_interval = tokio::time::interval(Duration::from_secs(self.invalid_token_cleanup_interval_secs));
let (invalid_token_tx, mut invalid_token_rx) =
mpsc::channel::<String>(self.invalid_token_cleanup_channel_capacity);
let mut cleanup_batch = Vec::new();
tracing::info!("Push notification worker started");
while !*shutdown.borrow() {
tokio::select! {
_ = shutdown.changed() => break,
_ = interval.tick() => {
if let Err(e) = self.process_due_jobs(invalid_token_tx.clone())
.instrument(tracing::debug_span!("process_push_jobs"))
.await
{
tracing::error!(error = %e, "Failed to process due notification jobs");
}
}
_ = cleanup_interval.tick() => {
async {
Self::flush_invalid_tokens(&self.pool, &self.token_repo, &mut cleanup_batch).await;
}
.instrument(tracing::debug_span!("flush_invalid_tokens"))
.await;
}
res = invalid_token_rx.recv() => {
if let Some(token) = res {
cleanup_batch.push(token);
if cleanup_batch.len() >= self.invalid_token_cleanup_batch_size {
Self::flush_invalid_tokens(&self.pool, &self.token_repo, &mut cleanup_batch).await;
}
}
}
}
}
tracing::info!("Push notification worker shutting down...");
Self::flush_invalid_tokens(&self.pool, &self.token_repo, &mut cleanup_batch).await;
}
#[tracing::instrument(level = "debug", skip(pool, repo, batch))]
async fn flush_invalid_tokens(pool: &DbPool, repo: &PushTokenRepository, batch: &mut Vec<String>) {
if batch.is_empty() {
return;
}
let count = batch.len();
match pool.acquire().await {
Ok(mut conn) => {
if let Err(e) = repo.delete_tokens_batch(&mut conn, batch).await {
tracing::error!(error = %e, "Failed to delete invalid token batch");
} else {
tracing::info!(count, "Successfully deleted invalid tokens in batch");
batch.clear();
}
}
Err(e) => tracing::error!(error = %e, "Failed to acquire connection for cleanup"),
}
}
#[tracing::instrument(level = "debug", skip(self, invalid_token_tx), name = "process_due_jobs", err)]
pub async fn process_due_jobs(&self, invalid_token_tx: mpsc::Sender<String>) -> anyhow::Result<()> {
let available = self.semaphore.available_permits();
if available == 0 {
return Ok(());
}
let limit = available.cast_signed();
let device_ids = self.repo.lease_due_jobs(limit, self.visibility_timeout_secs).await?;
if device_ids.is_empty() {
tracing::debug!("No due push notification jobs found");
return Ok(());
}
tracing::info!(count = device_ids.len(), "Processing leased push notifications");
let device_token_pairs = {
let mut conn = self.pool.acquire().await?;
self.token_repo.find_tokens_for_devices(&mut conn, &device_ids).await?
};
let devices_with_tokens: std::collections::HashSet<Uuid> =
device_token_pairs.iter().map(|(id, _)| *id).collect();
for device_id in &device_ids {
if !devices_with_tokens.contains(device_id) {
tracing::info!(%device_id, "Device has no registered push token, removing job");
let _ = self.repo.delete_job(*device_id).await;
}
}
for (device_id, token) in device_token_pairs {
let provider = Arc::clone(&self.provider);
let repo = Arc::clone(&self.repo);
let metrics = self.metrics.clone();
let tx = invalid_token_tx.clone();
let permit = Arc::clone(&self.semaphore)
.acquire_owned()
.await
.map_err(|e| anyhow::anyhow!("Semaphore closed: {e}"))?;
tokio::spawn(
async move {
let _permit = permit;
match provider.send_push(&token).await {
Ok(()) => {
tracing::debug!("Push notification sent successfully");
metrics.sent.add(1, &[]);
let _ = repo.delete_job(device_id).await;
}
Err(PushError::Unregistered) => {
tracing::info!("Token unregistered, reporting to invalid token cleanup");
metrics.invalidated_tokens.add(1, &[]);
let _ = repo.delete_job(device_id).await;
let _ = tx.send(token).await;
}
Err(PushError::QuotaExceeded) => {
tracing::warn!("Push quota exceeded, allowing visibility timeout to trigger retry");
metrics.errors.add(1, &[KeyValue::new("reason", "quota_exceeded")]);
}
Err(PushError::Other(e)) => {
tracing::error!(error = %e, "Failed to send push notification, will retry");
metrics.errors.add(1, &[KeyValue::new("reason", "other")]);
}
}
}
.instrument(tracing::debug_span!("dispatch_push", %device_id)),
);
}
Ok(())
}
}