1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
//! This crate implements "Generic Event Delivery Using Http Push" (web-push)
//! according to [RFC8030](https://www.rfc-editor.org/rfc/rfc8030).
//!
//! # Example
//!
//! This example shows how to use the [`WebPushBuilder`] to create a HTTP push
//! request to a single hard-coded client.
//!
//! For most projects you will need to implement some form of state management
//! to send messages to all of your clients. You are expected to create one
//! [`WebPushBuilder`] for each client you want to send messages to, but can
//! reuse the same builder for multiple push requests to the same client.
//!
//! Please see the
//! [`/example`](https://github.com/leotaku/web-push-native/tree/master/example)
//! directory on GitHub for a more fully-featured example which presents how to
//! setup an [`axum`] web server in combination with this library to expose a
//! simple HTTP API for sending web-push notifications.
//!
//! ```
//! use base64ct::{Base64UrlUnpadded, Encoding};
//! use web_push_native::{
//!     jwt_simple::algorithms::ES256KeyPair, p256::PublicKey, Auth, Error, WebPushBuilder,
//! };
//!
//! // Placeholders for variables provided by individual clients. In most cases,
//! // these will be retrieved in-browser by calling `pushManager.subscribe` on
//! // a service worker registration object.
//! const ENDPOINT: &str = "";
//! const P256DH: &str = "";
//! const AUTH: &str = "";
//!
//! // Placeholder for your private VAPID key. Keep this private and out of your
//! // source tree in real projects!
//! const VAPID: &str = "";
//!
//! async fn push(content: Vec<u8>) -> Result<http::Request<Vec<u8>>, Box<dyn std::error::Error>> {
//!     let key_pair = ES256KeyPair::from_bytes(&Base64UrlUnpadded::decode_vec(VAPID)?)?;
//!     let builder = WebPushBuilder::new(
//!         ENDPOINT.parse()?,
//!         PublicKey::from_sec1_bytes(&Base64UrlUnpadded::decode_vec(P256DH)?)?,
//!         Auth::clone_from_slice(&Base64UrlUnpadded::decode_vec(AUTH)?),
//!     )
//!     .with_vapid(&key_pair, "mailto:john.doe@example.com");
//!
//!     Ok(builder.build(content)?)
//! }
//! ```
//!
//! [`axum`]: https://docs.rs/axum

#[cfg(feature = "serialization")]
mod serde_;
#[cfg(test)]
mod tests;
#[cfg(feature = "vapid")]
mod vapid;

#[cfg(feature = "vapid")]
pub use jwt_simple;
pub use p256;

use aes_gcm::aead::{
    generic_array::{typenum::U16, GenericArray},
    rand_core::RngCore,
    OsRng,
};
use hkdf::Hkdf;
use http::{self, header, Request, Uri};
use p256::elliptic_curve::sec1::ToEncodedPoint;
use sha2::Sha256;
use std::time::Duration;

/// Error type for HTTP push failure modes
#[derive(Debug)]
pub enum Error {
    /// Internal ECE error
    ECE(ece_native::Error),
    /// Internal error coming from an http auth provider
    Extension(Box<dyn std::error::Error + Send + Sync + 'static>),
}

impl std::error::Error for Error {}

impl std::fmt::Display for Error {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Error::ECE(ece) => write!(f, "ece: {}", ece),
            Error::Extension(ext) => write!(f, "extension: {}", ext),
        }
    }
}

/// HTTP push authentication secret
pub type Auth = GenericArray<u8, U16>;

/// Reusable builder for HTTP push requests
#[derive(Clone, Debug)]
pub struct WebPushBuilder<A = ()> {
    endpoint: Uri,
    valid_duration: Duration,
    ua_public: p256::PublicKey,
    ua_auth: Auth,
    #[cfg_attr(not(feature = "vapid"), allow(dead_code))]
    http_auth: A,
}

impl WebPushBuilder {
    /// Creates a new [`WebPushBuilder`] factory for HTTP push requests.
    ///
    /// Requests generated using this factory will have a valid  duration of 12
    /// hours and no VAPID signature.
    ///
    /// Most providers accepting HTTP push requests will require a valid VAPID
    /// signature, so you will most likely want to add one using
    /// [`WebPushBuilder::with_vapid`].
    pub fn new(endpoint: Uri, ua_public: p256::PublicKey, ua_auth: Auth) -> Self {
        Self {
            endpoint,
            ua_public,
            ua_auth,
            valid_duration: Duration::from_secs(12 * 60 * 60),
            http_auth: (),
        }
    }

    /// Sets the valid duration for generated HTTP push requests.
    pub fn with_valid_duration(self, valid_duration: Duration) -> Self {
        let mut this = self;
        this.valid_duration = valid_duration;
        this
    }

    /// Sets the VAPID signature header for generated HTTP push requests.
    #[cfg(feature = "vapid")]
    pub fn with_vapid<'a>(
        self,
        vapid_kp: &'a jwt_simple::algorithms::ES256KeyPair,
        contact: &'a str,
    ) -> WebPushBuilder<vapid::VapidAuthorization<'a>> {
        WebPushBuilder {
            endpoint: self.endpoint,
            valid_duration: self.valid_duration,
            ua_public: self.ua_public,
            ua_auth: self.ua_auth,
            http_auth: vapid::VapidAuthorization::new(vapid_kp, contact),
        }
    }
}

#[doc(hidden)]
pub trait AddHeaders: Sized {
    type Error: Into<Box<dyn std::error::Error + Sync + Send + 'static>>;

    fn add_headers(
        this: &WebPushBuilder<Self>,
        builder: http::request::Builder,
    ) -> Result<http::request::Builder, Self::Error>;
}

impl AddHeaders for () {
    type Error = std::convert::Infallible;

    fn add_headers(
        _this: &WebPushBuilder<Self>,
        builder: http::request::Builder,
    ) -> Result<http::request::Builder, Self::Error> {
        Ok(builder)
    }
}

impl<A: AddHeaders> WebPushBuilder<A> {
    /// Generates a new HTTP push request according to the
    /// specifications of the builder.
    pub fn build<T: Into<Vec<u8>>>(&self, body: T) -> Result<Request<Vec<u8>>, Error> {
        let body = body.into();

        let payload = encrypt(body, &self.ua_public, &self.ua_auth)?;
        let builder = Request::builder()
            .uri(self.endpoint.clone())
            .method(http::method::Method::POST)
            .header("TTL", self.valid_duration.as_secs())
            .header(header::CONTENT_ENCODING, "aes128gcm")
            .header(header::CONTENT_TYPE, "application/octet-stream")
            .header(header::CONTENT_LENGTH, payload.len());

        let builder =
            AddHeaders::add_headers(self, builder).map_err(|it| Error::Extension(it.into()))?;

        Ok(builder
            .body(payload)
            .expect("builder arguments are always well-defined"))
    }
}

/// Lower-level encryption used for HTTP push request content
pub fn encrypt(
    message: Vec<u8>,
    ua_public: &p256::PublicKey,
    ua_auth: &Auth,
) -> Result<Vec<u8>, Error> {
    let mut salt = [0u8; 16];
    OsRng.fill_bytes(&mut salt);
    let as_secret = p256::SecretKey::random(&mut OsRng);
    encrypt_predictably(salt, message, &as_secret, ua_public, ua_auth).map_err(Error::ECE)
}

fn encrypt_predictably(
    salt: [u8; 16],
    message: Vec<u8>,
    as_secret: &p256::SecretKey,
    ua_public: &p256::PublicKey,
    ua_auth: &Auth,
) -> Result<Vec<u8>, ece_native::Error> {
    let as_public = as_secret.public_key();
    let shared = p256::ecdh::diffie_hellman(as_secret.to_nonzero_scalar(), ua_public.as_affine());

    let ikm = compute_ikm(ua_auth, &shared, ua_public, &as_public);
    let keyid = as_public.as_affine().to_encoded_point(false);
    let encrypted_record_length = (message.len() + 17)
        .try_into()
        .map_err(|_| ece_native::Error::RecordLengthInvalid)?;

    ece_native::encrypt(
        ikm,
        salt,
        keyid,
        Some(message).into_iter(),
        encrypted_record_length,
    )
}

/// Lower-level decryption used for HTTP push request content
pub fn decrypt(
    encrypted_message: Vec<u8>,
    as_secret: &p256::SecretKey,
    ua_auth: &Auth,
) -> Result<Vec<u8>, Error> {
    let keyid = view_keyid(&encrypted_message).map_err(Error::ECE)?;
    let ua_public = p256::PublicKey::from_sec1_bytes(keyid)
        .map_err(|_| ece_native::Error::Aes128Gcm)
        .map_err(Error::ECE)?;
    let shared = p256::ecdh::diffie_hellman(as_secret.to_nonzero_scalar(), ua_public.as_affine());

    let ikm = compute_ikm(ua_auth, &shared, &as_secret.public_key(), &ua_public);

    ece_native::decrypt(ikm, encrypted_message).map_err(Error::ECE)
}

fn compute_ikm(
    auth: &Auth,
    shared: &p256::ecdh::SharedSecret,
    ua_public: &p256::PublicKey,
    as_public: &p256::PublicKey,
) -> [u8; 32] {
    let mut info = Vec::new();
    info.extend_from_slice(&b"WebPush: info"[..]);
    info.push(0u8);
    info.extend_from_slice(ua_public.as_affine().to_encoded_point(false).as_bytes());
    info.extend_from_slice(as_public.as_affine().to_encoded_point(false).as_bytes());

    let mut okm = [0u8; 32];
    let hk = Hkdf::<Sha256>::new(Some(auth), shared.raw_secret_bytes().as_ref());
    hk.expand(&info, &mut okm)
        .expect("okm length is always 32 bytes, cannot be too large");

    okm
}

fn view_keyid(encrypted_message: &[u8]) -> Result<&[u8], ece_native::Error> {
    if encrypted_message.len() < 21 {
        return Err(ece_native::Error::HeaderLengthInvalid);
    }

    let idlen: usize = encrypted_message[20].into();
    if encrypted_message[21..].len() < idlen {
        return Err(ece_native::Error::KeyIdLengthInvalid);
    }

    Ok(&encrypted_message[21..21 + idlen])
}