use crate::types::PushSubscription;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine as _;
use jsonwebtoken::{Algorithm, EncodingKey, Header};
use p256::ecdsa::SigningKey;
use p256::pkcs8::{DecodePrivateKey, EncodePrivateKey};
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum WebPushError {
#[error("push endpoint returned 410 Gone")]
Gone,
#[error("HTTP transport error: {0}")]
Http(#[from] reqwest::Error),
#[error("VAPID JWT signing error: {0}")]
JwtSigning(#[from] jsonwebtoken::errors::Error),
#[error("VAPID key error: {0}")]
KeyError(String),
#[error("push endpoint returned unexpected status {0}")]
UnexpectedStatus(u16),
}
#[derive(Debug, Serialize, Deserialize)]
struct VapidClaims {
aud: String,
sub: String,
exp: u64,
}
#[derive(Clone)]
pub struct WebPushClient {
http: reqwest::Client,
vapid_key: std::sync::Arc<SigningKey>,
vapid_pubkey_base64url: String,
admin_sub: String,
}
impl std::fmt::Debug for WebPushClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebPushClient")
.field("vapid_pubkey_base64url", &self.vapid_pubkey_base64url)
.field("admin_sub", &self.admin_sub)
.finish_non_exhaustive()
}
}
impl WebPushClient {
pub fn new(vapid_pem: Option<&[u8]>, admin_email: &str) -> Result<Self, WebPushError> {
let signing_key = match vapid_pem {
Some(pem_bytes) => {
let pem_str = std::str::from_utf8(pem_bytes)
.map_err(|e| WebPushError::KeyError(format!("PEM is not valid UTF-8: {e}")))?;
SigningKey::from_pkcs8_pem(pem_str)
.map_err(|e| WebPushError::KeyError(format!("Failed to load VAPID key: {e}")))?
}
None => {
let mut rng_buf = [0u8; 32];
getrandom::fill(&mut rng_buf)
.map_err(|e| WebPushError::KeyError(format!("RNG failure: {e}")))?;
SigningKey::from_slice(&rng_buf)
.map_err(|e| WebPushError::KeyError(format!("Key generation failed: {e}")))?
}
};
let pubkey_bytes = p256::ecdsa::VerifyingKey::from(&signing_key)
.to_encoded_point(false)
.as_bytes()
.to_vec();
let vapid_pubkey_base64url = URL_SAFE_NO_PAD.encode(&pubkey_bytes);
let admin_sub = if admin_email.contains('@') {
format!("mailto:{admin_email}")
} else {
admin_email.to_string()
};
Ok(Self {
http: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.map_err(WebPushError::Http)?,
vapid_key: std::sync::Arc::new(signing_key),
vapid_pubkey_base64url,
admin_sub,
})
}
pub fn new_with_persistence(
key_path: Option<&std::path::Path>,
admin_email: &str,
) -> Result<Self, WebPushError> {
match key_path {
None => Self::new(None, admin_email),
Some(path) if path.exists() => {
let pem = std::fs::read(path).map_err(|e| {
WebPushError::KeyError(format!("Cannot read VAPID key file: {e}"))
})?;
Self::new(Some(&pem), admin_email)
}
Some(path) => {
let client = Self::new(None, admin_email)?;
let pem = client
.vapid_key
.to_pkcs8_pem(Default::default())
.map_err(|e| {
WebPushError::KeyError(format!("PEM serialization failed: {e}"))
})?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
WebPushError::KeyError(format!("Cannot create VAPID key dir: {e}"))
})?;
}
std::fs::write(path, pem.as_bytes())
.map_err(|e| WebPushError::KeyError(format!("Cannot write VAPID key: {e}")))?;
Ok(client)
}
}
}
pub fn vapid_pubkey_base64url(&self) -> &str {
&self.vapid_pubkey_base64url
}
pub(crate) fn build_vapid_jwt(&self, endpoint_origin: &str) -> Result<String, WebPushError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| WebPushError::KeyError(format!("System clock error: {e}")))?;
let exp = now.as_secs() + 86_400;
let claims = VapidClaims {
aud: endpoint_origin.to_string(),
sub: self.admin_sub.clone(),
exp,
};
let header = Header::new(Algorithm::ES256);
let der = self
.vapid_key
.to_pkcs8_der()
.map_err(|e| WebPushError::KeyError(format!("Key DER export failed: {e}")))?;
let encoding_key = EncodingKey::from_ec_der(der.as_bytes());
let token = jsonwebtoken::encode(&header, &claims, &encoding_key)?;
Ok(token)
}
pub async fn send(
&self,
subscription: &PushSubscription,
payload: &[u8],
) -> Result<(), WebPushError> {
let endpoint_origin = extract_origin(&subscription.url).ok_or_else(|| {
WebPushError::KeyError(format!(
"Cannot determine origin from URL: {}",
subscription.url
))
})?;
let jwt = self.build_vapid_jwt(&endpoint_origin)?;
let authorization = format!("vapid t={},k={}", jwt, self.vapid_pubkey_base64url);
const TTL_SECONDS: u32 = 86_400;
let body = if subscription.keys.is_none() || payload.is_empty() {
bytes::Bytes::new()
} else {
bytes::Bytes::new()
};
let response = self
.http
.post(&subscription.url)
.header("Authorization", authorization)
.header("TTL", TTL_SECONDS.to_string())
.header("Content-Type", "application/octet-stream")
.body(body)
.send()
.await?;
let status = response.status().as_u16();
match status {
200..=299 => Ok(()),
410 => Err(WebPushError::Gone),
other => Err(WebPushError::UnexpectedStatus(other)),
}
}
}
fn extract_origin(url: &str) -> Option<String> {
let after_scheme = url.split_once("://")?.1;
let host_and_rest = after_scheme.split('/').next()?;
let scheme = url.split("://").next()?;
Some(format!("{scheme}://{host_and_rest}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_origin_https() {
let url = "https://push.example.com/v1/subscriptions/abc123";
assert_eq!(
extract_origin(url),
Some("https://push.example.com".to_string())
);
}
#[test]
fn test_extract_origin_with_port() {
let url = "https://push.example.com:8443/endpoint";
assert_eq!(
extract_origin(url),
Some("https://push.example.com:8443".to_string())
);
}
#[test]
fn test_vapid_client_ephemeral_key() {
let client = WebPushClient::new(None, "admin@example.com").unwrap();
assert!(!client.vapid_pubkey_base64url().is_empty());
assert!(client.admin_sub.starts_with("mailto:"));
}
#[test]
fn test_build_vapid_jwt() {
let client = WebPushClient::new(None, "admin@example.com").unwrap();
let jwt = client.build_vapid_jwt("https://push.example.com").unwrap();
let parts: Vec<&str> = jwt.split('.').collect();
assert_eq!(parts.len(), 3, "JWT must have header.payload.signature");
}
}