1use std::fmt::{Display, Formatter};
2
3use ct_codecs::{Base64UrlSafeNoPadding, Decoder};
4use http::uri::Uri;
5
6use crate::{
7 error::WebPushError,
8 http_ece::{ContentEncoding, HttpEce},
9 vapid::VapidSignature,
10};
11
12#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq, Ord, PartialOrd, Default, Hash)]
14pub struct SubscriptionKeys {
15 pub p256dh: String,
17 pub auth: String,
19}
20
21#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq, Ord, PartialOrd, Default, Hash)]
26pub struct SubscriptionInfo {
27 pub endpoint: String,
29 pub keys: SubscriptionKeys,
31}
32
33impl SubscriptionInfo {
34 pub fn new<S>(endpoint: S, p256dh: S, auth: S) -> SubscriptionInfo
37 where
38 S: Into<String>,
39 {
40 SubscriptionInfo {
41 endpoint: endpoint.into(),
42 keys: SubscriptionKeys {
43 p256dh: p256dh.into(),
44 auth: auth.into(),
45 },
46 }
47 }
48}
49
50#[derive(Debug, PartialEq)]
52pub struct WebPushPayload {
53 pub content: Vec<u8>,
55 pub crypto_headers: Vec<(&'static str, String)>,
57 pub content_encoding: ContentEncoding,
59}
60
61#[derive(Debug, Deserialize, Serialize, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Default, Hash)]
62#[serde(rename_all = "kebab-case")]
63pub enum Urgency {
64 VeryLow,
65 Low,
66 #[default]
67 Normal,
68 High,
69}
70
71impl Display for Urgency {
72 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73 let str = match self {
74 Urgency::VeryLow => "very-low",
75 Urgency::Low => "low",
76 Urgency::Normal => "normal",
77 Urgency::High => "high",
78 };
79
80 f.write_str(str)
81 }
82}
83
84#[derive(Debug)]
86pub struct WebPushMessage {
87 pub endpoint: Uri,
89 pub ttl: u32,
92 pub urgency: Option<Urgency>,
94 pub topic: Option<String>,
96 pub payload: Option<WebPushPayload>,
98}
99
100struct WebPushPayloadBuilder<'a> {
101 pub content: &'a [u8],
102 pub encoding: ContentEncoding,
103}
104
105pub struct WebPushMessageBuilder<'a> {
107 subscription_info: &'a SubscriptionInfo,
108 payload: Option<WebPushPayloadBuilder<'a>>,
109 ttl: u32,
110 urgency: Option<Urgency>,
111 topic: Option<String>,
112 vapid_signature: Option<VapidSignature>,
113}
114
115impl<'a> WebPushMessageBuilder<'a> {
116 pub fn new(subscription_info: &'a SubscriptionInfo) -> WebPushMessageBuilder<'a> {
121 WebPushMessageBuilder {
122 subscription_info,
123 ttl: 2_419_200,
124 urgency: None,
125 topic: None,
126 payload: None,
127 vapid_signature: None,
128 }
129 }
130
131 pub fn set_ttl(&mut self, ttl: u32) {
135 self.ttl = ttl;
136 }
137
138 pub fn set_urgency(&mut self, urgency: Urgency) {
144 self.urgency = Some(urgency);
145 }
146
147 pub fn set_topic(&mut self, topic: String) {
156 self.topic = Some(topic);
157 }
158
159 pub fn set_vapid_signature(&mut self, vapid_signature: VapidSignature) {
162 self.vapid_signature = Some(vapid_signature);
163 }
164
165 pub fn set_payload(&mut self, encoding: ContentEncoding, content: &'a [u8]) {
170 self.payload = Some(WebPushPayloadBuilder { content, encoding });
171 }
172
173 pub fn build(self) -> Result<WebPushMessage, WebPushError> {
175 let endpoint: Uri = self.subscription_info.endpoint.parse()?;
176 let topic: Option<String> = self
177 .topic
178 .map(|topic| {
179 if topic.len() > 32 {
180 Err(WebPushError::InvalidTopic)
181 } else if topic.chars().all(is_base64url_char) {
182 Ok(topic)
183 } else {
184 Err(WebPushError::InvalidTopic)
185 }
186 })
187 .transpose()?;
188
189 if let Some(payload) = self.payload {
190 let p256dh = Base64UrlSafeNoPadding::decode_to_vec(&self.subscription_info.keys.p256dh, None)
191 .map_err(|_| WebPushError::InvalidCryptoKeys)?;
192 let auth = Base64UrlSafeNoPadding::decode_to_vec(&self.subscription_info.keys.auth, None)
193 .map_err(|_| WebPushError::InvalidCryptoKeys)?;
194
195 let http_ece = HttpEce::new(payload.encoding, &p256dh, &auth, self.vapid_signature);
196
197 Ok(WebPushMessage {
198 endpoint,
199 ttl: self.ttl,
200 urgency: self.urgency,
201 topic,
202 payload: Some(http_ece.encrypt(payload.content)?),
203 })
204 } else {
205 Ok(WebPushMessage {
206 endpoint,
207 ttl: self.ttl,
208 urgency: self.urgency,
209 topic,
210 payload: None,
211 })
212 }
213 }
214}
215
216fn is_base64url_char(c: char) -> bool {
217 c.is_ascii_uppercase() || c.is_ascii_lowercase() || c.is_ascii_digit() || (c == '-' || c == '_')
218}