1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
7use http_body_util::{BodyExt, Full};
8use hyper::body::Bytes;
9use hyper_rustls::HttpsConnectorBuilder;
10use hyper_util::client::legacy::Client;
11use hyper_util::rt::TokioExecutor;
12use serde::{Deserialize, Serialize};
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15use crate::prelude::*;
16use cloudillo_types::meta_adapter::PushSubscriptionData;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct NotificationPayload {
21 pub title: String,
23 pub body: String,
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub path: Option<String>,
28 #[serde(skip_serializing_if = "Option::is_none")]
30 pub image: Option<String>,
31 #[serde(skip_serializing_if = "Option::is_none")]
33 pub tag: Option<String>,
34}
35
36#[derive(Debug)]
38pub enum PushResult {
39 Success,
41 SubscriptionGone,
43 TemporaryError(String),
45 PermanentError(String),
47}
48
49pub async fn send_notification(
60 app: &App,
61 tn_id: TnId,
62 subscription: &PushSubscriptionData,
63 payload: &NotificationPayload,
64) -> PushResult {
65 let vapid_keys = match app.auth_adapter.read_vapid_key(tn_id).await {
67 Ok(keys) => keys,
68 Err(e) => {
69 tracing::error!(tn_id = %tn_id.0, error = %e, "Failed to get VAPID keys");
70 return PushResult::PermanentError(format!("VAPID key error: {}", e));
71 }
72 };
73
74 let payload_json = match serde_json::to_string(payload) {
76 Ok(json) => json,
77 Err(e) => return PushResult::PermanentError(format!("Payload serialization error: {}", e)),
78 };
79
80 let encrypted =
82 match encrypt_payload(&payload_json, &subscription.keys.p256dh, &subscription.keys.auth) {
83 Ok(enc) => enc,
84 Err(e) => return PushResult::PermanentError(format!("Encryption error: {}", e)),
85 };
86
87 let id_tag = match app.auth_adapter.read_id_tag(tn_id).await {
89 Ok(tag) => tag,
90 Err(e) => {
91 tracing::error!(tn_id = %tn_id.0, error = %e, "Failed to get tenant id_tag");
92 return PushResult::PermanentError(format!("Tenant lookup error: {}", e));
93 }
94 };
95
96 let vapid_jwt = match create_vapid_jwt(&subscription.endpoint, &id_tag, &vapid_keys.private_key)
98 {
99 Ok(jwt) => jwt,
100 Err(e) => return PushResult::PermanentError(format!("VAPID JWT error: {}", e)),
101 };
102
103 send_push_request(
105 &subscription.endpoint,
106 &encrypted.body,
107 &encrypted.salt,
108 &encrypted.public_key,
109 &vapid_jwt,
110 &vapid_keys.public_key,
111 )
112 .await
113}
114
115struct EncryptedPayload {
117 body: Vec<u8>,
118 salt: Vec<u8>,
119 public_key: Vec<u8>,
120}
121
122fn encrypt_payload(
124 payload: &str,
125 p256dh_base64: &str,
126 auth_base64: &str,
127) -> Result<EncryptedPayload, String> {
128 let p256dh = URL_SAFE_NO_PAD
130 .decode(p256dh_base64)
131 .map_err(|e| format!("Invalid p256dh: {}", e))?;
132 let auth = URL_SAFE_NO_PAD
133 .decode(auth_base64)
134 .map_err(|e| format!("Invalid auth: {}", e))?;
135
136 let encrypted = ece::encrypt(&p256dh, &auth, payload.as_bytes())
139 .map_err(|e| format!("ECE encryption failed: {:?}", e))?;
140
141 let body = encrypted.to_vec();
144
145 let salt = body.get(0..16).ok_or("Encrypted data too short")?.to_vec();
147
148 let keyid_len = *body.get(20).ok_or("Missing keyid length")? as usize;
150 let public_key = body.get(21..21 + keyid_len).ok_or("Missing public key")?.to_vec();
151
152 Ok(EncryptedPayload { body, salt, public_key })
153}
154
155fn create_vapid_jwt(endpoint: &str, id_tag: &str, private_key_raw: &str) -> Result<String, String> {
160 use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
161 use p256::pkcs8::EncodePrivateKey;
162 use p256::pkcs8::LineEnding;
163
164 let private_key_bytes = URL_SAFE_NO_PAD
166 .decode(private_key_raw)
167 .map_err(|e| format!("Invalid base64url private key: {}", e))?;
168
169 let secret_key = p256::SecretKey::from_bytes(private_key_bytes.as_slice().into())
171 .map_err(|e| format!("Invalid P-256 private key: {:?}", e))?;
172
173 let pem = secret_key
175 .to_pkcs8_pem(LineEnding::LF)
176 .map_err(|e| format!("Failed to encode private key: {:?}", e))?;
177
178 let url = url::Url::parse(endpoint).map_err(|e| format!("Invalid endpoint URL: {}", e))?;
180 let audience = format!("{}://{}", url.scheme(), url.host_str().unwrap_or(""));
181
182 #[derive(Serialize)]
184 struct VapidClaims {
185 aud: String,
186 exp: u64,
187 sub: String,
188 }
189
190 let exp = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs()
191 + 12 * 3600; let claims = VapidClaims { aud: audience, exp, sub: format!("mailto:admin@{}", id_tag) };
194
195 let encoding_key = EncodingKey::from_ec_pem(pem.as_bytes())
197 .map_err(|e| format!("Invalid VAPID private key: {}", e))?;
198
199 let header = Header::new(Algorithm::ES256);
200 encode(&header, &claims, &encoding_key).map_err(|e| format!("JWT encoding failed: {}", e))
201}
202
203async fn send_push_request(
205 endpoint: &str,
206 body: &[u8],
207 _salt: &[u8],
208 _public_key: &[u8],
209 vapid_jwt: &str,
210 vapid_public_key: &str,
211) -> PushResult {
212 let connector = match HttpsConnectorBuilder::new()
214 .with_native_roots()
215 .map_err(|e| format!("TLS error: {}", e))
216 {
217 Ok(c) => c.https_only().enable_http2().build(),
218 Err(e) => return PushResult::PermanentError(e),
219 };
220
221 let client: Client<_, Full<Bytes>> =
222 Client::builder(TokioExecutor::new()).http2_only(true).build(connector);
223
224 let request = match hyper::Request::builder()
227 .method(hyper::Method::POST)
228 .uri(endpoint)
229 .header("Content-Type", "application/octet-stream")
230 .header("Content-Encoding", "aes128gcm")
231 .header("TTL", "86400") .header(
233 "Authorization",
234 format!("vapid t={},k={}", vapid_jwt, vapid_public_key),
235 )
236 .body(Full::new(Bytes::copy_from_slice(body)))
237 {
238 Ok(req) => req,
239 Err(e) => return PushResult::PermanentError(format!("Request build error: {}", e)),
240 };
241
242 match client.request(request).await {
244 Ok(response) => {
245 let status = response.status();
246 if status.is_success() {
247 PushResult::Success
248 } else if status == hyper::StatusCode::GONE || status == hyper::StatusCode::NOT_FOUND {
249 PushResult::SubscriptionGone
251 } else if status.is_client_error() {
252 let body_bytes = response.into_body().collect().await.ok().map(|b| b.to_bytes());
254 let body_str =
255 body_bytes.as_ref().and_then(|b| std::str::from_utf8(b).ok()).unwrap_or("");
256 PushResult::PermanentError(format!("HTTP {}: {}", status, body_str))
257 } else {
258 PushResult::TemporaryError(format!("HTTP {}", status))
260 }
261 }
262 Err(e) => PushResult::TemporaryError(format!("Network error: {}", e)),
263 }
264}
265
266pub async fn send_to_tenant(
270 app: &App,
271 tn_id: TnId,
272 payload: &NotificationPayload,
273) -> ClResult<usize> {
274 let subscriptions = app.meta_adapter.list_push_subscriptions(tn_id).await?;
275 let mut success_count = 0;
276
277 for subscription in subscriptions {
278 let result = send_notification(app, tn_id, &subscription.subscription, payload).await;
279
280 match result {
281 PushResult::Success => {
282 success_count += 1;
283 tracing::debug!(
284 tn_id = %tn_id.0,
285 subscription_id = %subscription.id,
286 "Push notification sent successfully"
287 );
288 }
289 PushResult::SubscriptionGone => {
290 tracing::info!(
292 tn_id = %tn_id.0,
293 subscription_id = %subscription.id,
294 "Deleting invalid push subscription"
295 );
296 let _ = app.meta_adapter.delete_push_subscription(tn_id, subscription.id).await;
297 }
298 PushResult::TemporaryError(e) => {
299 tracing::warn!(
300 tn_id = %tn_id.0,
301 subscription_id = %subscription.id,
302 error = %e,
303 "Temporary push notification error"
304 );
305 }
306 PushResult::PermanentError(e) => {
307 tracing::error!(
308 tn_id = %tn_id.0,
309 subscription_id = %subscription.id,
310 error = %e,
311 "Permanent push notification error"
312 );
313 }
314 }
315 }
316
317 Ok(success_count)
318}
319
320