rocket_webhook/
guard.rs

1use std::{
2    marker::PhantomData,
3    time::{SystemTime, UNIX_EPOCH},
4};
5
6use rocket::{
7    Request, async_trait,
8    data::{FromData, Outcome, ToByteUnit},
9    http::{HeaderMap, Status},
10    outcome::try_outcome,
11    serde::{DeserializeOwned, json::serde_json},
12};
13
14use crate::{RocketWebhook, WebhookError, webhooks::Webhook};
15
16/**
17 Data guard to validate and deserialize the JSON body of webhook type `W` into the `T` type.
18 The `W` webhook configuration must be in Rocket state using [RocketWebhook].
19```
20use rocket::{post, serde::{Serialize, Deserialize}};
21use rocket_webhook::{WebhookPayload, webhooks::built_in::{GitHubWebhook}};
22
23/// Payload to deserialize
24#[derive(Debug, Serialize, Deserialize)]
25struct GithubPayload {
26    action: String,
27}
28
29// Use in a route handler as the data guard, passing in the payload and webhook type
30#[post("/api/webhooks/github", data = "<payload>")]
31async fn github_route(
32    payload: WebhookPayload<'_, GithubPayload, GitHubWebhook>,
33) -> &'static str {
34    payload.data; // access the validated webhook payload
35    payload.headers; // access the webhook headers
36
37    "OK"
38}
39```
40*/
41pub struct WebhookPayload<'r, T, W, M = W> {
42    /// The deserialized payload data
43    pub data: T,
44    /// The headers sent with the webhook request
45    pub headers: &'r HeaderMap<'r>,
46    _webhook: PhantomData<W>,
47    _marker: PhantomData<M>,
48}
49
50#[async_trait]
51impl<'r, T, W, M> FromData<'r> for WebhookPayload<'r, T, W, M>
52where
53    T: DeserializeOwned,
54    W: Webhook + Send + Sync + 'static,
55    M: Send + Sync + 'static,
56{
57    type Error = WebhookError;
58
59    async fn from_data(
60        req: &'r Request<'_>,
61        data: rocket::Data<'r>,
62    ) -> Outcome<'r, Self, Self::Error> {
63        let config: &RocketWebhook<W, M> = try_outcome!(get_webhook_from_state(req));
64        let body = data.open(config.max_body_size.bytes());
65        let time_bounds = get_timestamp_bounds(config.timestamp_tolerance);
66        let validated_body =
67            try_outcome!(config.webhook.validate_body(req, body, time_bounds).await);
68
69        match serde_json::from_slice(&validated_body) {
70            Ok(data) => Outcome::Success(Self {
71                data,
72                headers: req.headers(),
73                _webhook: PhantomData,
74                _marker: PhantomData,
75            }),
76            Err(e) => Outcome::Error((Status::BadRequest, WebhookError::Deserialize(e))),
77        }
78    }
79}
80
81/**
82Data guard to validate a webhook and get the raw body.
83The `W` webhook configuration must be in Rocket state using [RocketWebhook].
84```
85use rocket::{post, serde::{Serialize, Deserialize}};
86use rocket_webhook::{WebhookPayloadRaw, webhooks::built_in::{GitHubWebhook}};
87
88
89// Use in a route handler as the data guard, passing in the webhook type
90#[post("/api/webhooks/github", data = "<payload>")]
91async fn github_route(
92    payload: WebhookPayloadRaw<'_, GitHubWebhook>,
93) -> &'static str {
94    payload.data; // access the raw webhook payload (Vec<u8>)
95    payload.headers; // access the webhook headers
96
97    "OK"
98}
99```
100*/
101pub struct WebhookPayloadRaw<'r, W, M = W> {
102    /// The raw payload data
103    pub data: Vec<u8>,
104    /// The headers sent with the webhook request
105    pub headers: &'r HeaderMap<'r>,
106    _webhook: PhantomData<W>,
107    _marker: PhantomData<M>,
108}
109
110#[async_trait]
111impl<'r, W, M> FromData<'r> for WebhookPayloadRaw<'r, W, M>
112where
113    W: Webhook + Send + Sync + 'static,
114    M: Send + Sync + 'static,
115{
116    type Error = WebhookError;
117
118    async fn from_data(
119        req: &'r Request<'_>,
120        data: rocket::Data<'r>,
121    ) -> Outcome<'r, Self, Self::Error> {
122        let config: &RocketWebhook<W, M> = try_outcome!(get_webhook_from_state(req));
123        let body = data.open(config.max_body_size.bytes());
124        let time_bounds = get_timestamp_bounds(config.timestamp_tolerance);
125        let validated_body =
126            try_outcome!(config.webhook.validate_body(req, body, time_bounds).await);
127
128        Outcome::Success(Self {
129            data: validated_body,
130            headers: req.headers(),
131            _webhook: PhantomData,
132            _marker: PhantomData,
133        })
134    }
135}
136
137fn get_webhook_from_state<'r, W, M>(
138    req: &'r Request,
139) -> Outcome<'r, &'r RocketWebhook<W, M>, WebhookError>
140where
141    W: Webhook + Send + Sync + 'static,
142    M: Send + Sync + 'static,
143{
144    match req.rocket().state::<RocketWebhook<W, M>>() {
145        Some(config) => Outcome::Success(config),
146        None => {
147            return Outcome::Error((Status::InternalServerError, WebhookError::NotAttached));
148        }
149    }
150}
151
152/// Get the timestamp bounds based on the current unix epoch time in seconds
153fn get_timestamp_bounds((past_secs, future_secs): (u32, u32)) -> (u32, u32) {
154    let unix_time = SystemTime::now()
155        .duration_since(UNIX_EPOCH)
156        .unwrap()
157        .as_secs() as u32; // Safe to use u32 until 2106
158    let lower_bound = {
159        if past_secs > unix_time {
160            0
161        } else {
162            unix_time - past_secs
163        }
164    };
165    let upper_bound = unix_time + future_secs;
166
167    (lower_bound, upper_bound)
168}