use crate::types::{JmapSetError, Principal, PushKeys, PushSubscription};
use crate::web_push::{WebPushClient, WebPushError};
use base64::Engine as _;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
pub struct PushState {
pub registry: Arc<DashMap<String, PushSubscription>>,
pub client: Arc<WebPushClient>,
}
static PUSH_STATE: OnceLock<Arc<PushState>> = OnceLock::new();
pub fn init_push_state(state: Arc<PushState>) {
let _ = PUSH_STATE.set(state);
}
pub fn push_state() -> Option<&'static Arc<PushState>> {
PUSH_STATE.get()
}
pub type PushRegistry = Arc<DashMap<String, PushSubscription>>;
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PushSubscriptionGetRequest {
#[serde(default)]
pub ids: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PushSubscriptionGetResponse {
pub list: Vec<PushSubscriptionView>,
pub not_found: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PushSubscriptionView {
pub id: String,
pub device_client_id: String,
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub keys: Option<PushKeys>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires: Option<chrono::DateTime<chrono::Utc>>,
pub types: Vec<String>,
}
impl From<&PushSubscription> for PushSubscriptionView {
fn from(s: &PushSubscription) -> Self {
Self {
id: s.id.clone(),
device_client_id: s.device_client_id.clone(),
url: s.url.clone(),
keys: s.keys.clone(),
expires: s.expires,
types: s.types.clone(),
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PushSubscriptionSetRequest {
#[serde(default)]
pub create: Option<HashMap<String, PushSubscriptionCreate>>,
#[serde(default)]
pub update: Option<HashMap<String, PushSubscriptionUpdate>>,
#[serde(default)]
pub destroy: Option<Vec<String>>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PushSubscriptionCreate {
pub device_client_id: String,
pub url: String,
#[serde(default)]
pub keys: Option<PushKeys>,
#[serde(default)]
pub expires: Option<chrono::DateTime<chrono::Utc>>,
#[serde(default)]
pub types: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PushSubscriptionUpdate {
#[serde(default)]
pub verification_code: Option<String>,
#[serde(default)]
pub types: Option<Vec<String>>,
#[serde(default)]
pub expires: Option<chrono::DateTime<chrono::Utc>>,
}
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct PushSubscriptionSetResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub created: Option<HashMap<String, PushSubscriptionCreated>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub updated: Option<HashMap<String, Option<serde_json::Value>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub destroyed: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub not_created: Option<HashMap<String, JmapSetError>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub not_updated: Option<HashMap<String, JmapSetError>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub not_destroyed: Option<HashMap<String, JmapSetError>>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PushSubscriptionCreated {
pub id: String,
pub verification_code: String,
}
pub async fn push_subscription_get(
request: PushSubscriptionGetRequest,
principal: &Principal,
) -> anyhow::Result<PushSubscriptionGetResponse> {
let state = match push_state() {
Some(s) => s,
None => {
return Ok(PushSubscriptionGetResponse {
list: vec![],
not_found: vec![],
});
}
};
let mut list = Vec::new();
let mut not_found = Vec::new();
match request.ids {
None => {
for entry in state.registry.iter() {
if entry.value().principal_id == principal.account_id {
list.push(PushSubscriptionView::from(entry.value()));
}
}
}
Some(ids) => {
for id in ids {
match state.registry.get(&id) {
Some(entry) if entry.value().principal_id == principal.account_id => {
list.push(PushSubscriptionView::from(entry.value()));
}
Some(_) => {
not_found.push(id);
}
None => {
not_found.push(id);
}
}
}
}
}
Ok(PushSubscriptionGetResponse { list, not_found })
}
pub async fn push_subscription_set(
request: PushSubscriptionSetRequest,
principal: &Principal,
) -> anyhow::Result<PushSubscriptionSetResponse> {
let state = match push_state() {
Some(s) => s,
None => {
return Err(anyhow::anyhow!(
"Push subsystem not initialised; call init_push_state() at server startup"
));
}
};
let mut response = PushSubscriptionSetResponse::default();
if let Some(creates) = request.create {
let mut created = HashMap::new();
let mut not_created = HashMap::new();
for (client_id, create) in creates {
match create_subscription(state, create, principal).await {
Ok(result) => {
created.insert(client_id, result);
}
Err(e) => {
not_created.insert(
client_id,
JmapSetError {
error_type: "serverFail".to_string(),
description: Some(e.to_string()),
},
);
}
}
}
if !created.is_empty() {
response.created = Some(created);
}
if !not_created.is_empty() {
response.not_created = Some(not_created);
}
}
if let Some(updates) = request.update {
let mut updated = HashMap::new();
let mut not_updated = HashMap::new();
for (id, patch) in updates {
match update_subscription(state, &id, patch, principal) {
Ok(()) => {
updated.insert(id, None);
}
Err(e) => {
not_updated.insert(
id,
JmapSetError {
error_type: "serverFail".to_string(),
description: Some(e.to_string()),
},
);
}
}
}
if !updated.is_empty() {
response.updated = Some(updated);
}
if !not_updated.is_empty() {
response.not_updated = Some(not_updated);
}
}
if let Some(destroy_ids) = request.destroy {
let mut destroyed = Vec::new();
let mut not_destroyed = HashMap::new();
for id in destroy_ids {
match destroy_subscription(state, &id, principal) {
Ok(()) => {
destroyed.push(id);
}
Err(e) => {
not_destroyed.insert(
id,
JmapSetError {
error_type: "serverFail".to_string(),
description: Some(e.to_string()),
},
);
}
}
}
if !destroyed.is_empty() {
response.destroyed = Some(destroyed);
}
if !not_destroyed.is_empty() {
response.not_destroyed = Some(not_destroyed);
}
}
Ok(response)
}
pub async fn push_subscription_set_with_state(
request: PushSubscriptionSetRequest,
principal: &Principal,
state: &Arc<PushState>,
) -> anyhow::Result<PushSubscriptionSetResponse> {
let mut response = PushSubscriptionSetResponse::default();
if let Some(creates) = request.create {
let mut created = HashMap::new();
let mut not_created = HashMap::new();
for (client_id, create) in creates {
match create_subscription(state, create, principal).await {
Ok(result) => {
created.insert(client_id, result);
}
Err(e) => {
not_created.insert(
client_id,
JmapSetError {
error_type: "serverFail".to_string(),
description: Some(e.to_string()),
},
);
}
}
}
if !created.is_empty() {
response.created = Some(created);
}
if !not_created.is_empty() {
response.not_created = Some(not_created);
}
}
if let Some(updates) = request.update {
let mut updated = HashMap::new();
let mut not_updated = HashMap::new();
for (id, patch) in updates {
match update_subscription(state, &id, patch, principal) {
Ok(()) => {
updated.insert(id, None);
}
Err(e) => {
not_updated.insert(
id,
JmapSetError {
error_type: "serverFail".to_string(),
description: Some(e.to_string()),
},
);
}
}
}
if !updated.is_empty() {
response.updated = Some(updated);
}
if !not_updated.is_empty() {
response.not_updated = Some(not_updated);
}
}
if let Some(destroy_ids) = request.destroy {
let mut destroyed = Vec::new();
let mut not_destroyed = HashMap::new();
for id in destroy_ids {
match destroy_subscription(state, &id, principal) {
Ok(()) => {
destroyed.push(id);
}
Err(e) => {
not_destroyed.insert(
id,
JmapSetError {
error_type: "serverFail".to_string(),
description: Some(e.to_string()),
},
);
}
}
}
if !destroyed.is_empty() {
response.destroyed = Some(destroyed);
}
if !not_destroyed.is_empty() {
response.not_destroyed = Some(not_destroyed);
}
}
Ok(response)
}
fn generate_verification_code() -> Result<String, anyhow::Error> {
let mut buf = [0u8; 32];
getrandom::fill(&mut buf)
.map_err(|e| anyhow::anyhow!("RNG failure during verification code generation: {e}"))?;
Ok(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(buf))
}
fn validate_push_url(url: &str) -> Result<(), anyhow::Error> {
if url.starts_with("https://") {
return Ok(());
}
#[cfg(feature = "test-push-http")]
if url.starts_with("http://") {
return Ok(());
}
Err(anyhow::anyhow!(
"Push subscription URL must use HTTPS, got: {url}"
))
}
async fn create_subscription(
state: &PushState,
create: PushSubscriptionCreate,
principal: &Principal,
) -> anyhow::Result<PushSubscriptionCreated> {
validate_push_url(&create.url)?;
let id = uuid::Uuid::new_v4().to_string();
let verification_code = generate_verification_code()?;
let sub = PushSubscription {
id: id.clone(),
device_client_id: create.device_client_id,
url: create.url,
keys: create.keys,
verification_code: Some(verification_code.clone()),
expires: create.expires,
types: create.types,
verified: false,
principal_id: principal.account_id.clone(),
};
match state.client.send(&sub, b"").await {
Ok(()) => {}
Err(WebPushError::Gone) => {
return Err(anyhow::anyhow!(
"Push endpoint returned 410 Gone during verification"
));
}
Err(e) => {
return Err(anyhow::anyhow!("Failed to send verification push: {e}"));
}
}
state.registry.insert(id.clone(), sub);
Ok(PushSubscriptionCreated {
id,
verification_code,
})
}
fn update_subscription(
state: &PushState,
id: &str,
patch: PushSubscriptionUpdate,
principal: &Principal,
) -> anyhow::Result<()> {
let mut entry = state
.registry
.get_mut(id)
.ok_or_else(|| anyhow::anyhow!("Subscription not found: {id}"))?;
if entry.value().principal_id != principal.account_id {
return Err(anyhow::anyhow!(
"Subscription {id} not owned by this principal"
));
}
if let Some(code) = patch.verification_code {
if entry.value().verification_code.as_deref() == Some(code.as_str()) {
entry.value_mut().verified = true;
} else {
return Err(anyhow::anyhow!(
"Verification code mismatch for subscription {id}"
));
}
}
if let Some(types) = patch.types {
entry.value_mut().types = types;
}
if let Some(expires) = patch.expires {
entry.value_mut().expires = Some(expires);
}
Ok(())
}
fn destroy_subscription(state: &PushState, id: &str, principal: &Principal) -> anyhow::Result<()> {
let owned = {
match state.registry.get(id) {
None => return Err(anyhow::anyhow!("Subscription not found: {id}")),
Some(entry) => entry.value().principal_id == principal.account_id,
}
};
if !owned {
return Err(anyhow::anyhow!(
"Subscription {id} not owned by this principal"
));
}
state.registry.remove(id);
Ok(())
}