1use std::{
2 marker::PhantomData,
3 time::{SystemTime, UNIX_EPOCH},
4};
5
6use rocket::{
7 Request, async_trait,
8 data::{FromData, Outcome, ToByteUnit},
9 http::{HeaderMap, Status},
10 outcome::try_outcome,
11 serde::{DeserializeOwned, json::serde_json},
12};
13
14use crate::{RocketWebhook, WebhookError, webhooks::Webhook};
15
16pub struct WebhookPayload<'r, T, W, M = W> {
42 pub data: T,
44 pub headers: &'r HeaderMap<'r>,
46 _webhook: PhantomData<W>,
47 _marker: PhantomData<M>,
48}
49
50#[async_trait]
51impl<'r, T, W, M> FromData<'r> for WebhookPayload<'r, T, W, M>
52where
53 T: DeserializeOwned,
54 W: Webhook + Send + Sync + 'static,
55 M: Send + Sync + 'static,
56{
57 type Error = WebhookError;
58
59 async fn from_data(
60 req: &'r Request<'_>,
61 data: rocket::Data<'r>,
62 ) -> Outcome<'r, Self, Self::Error> {
63 let config: &RocketWebhook<W, M> = try_outcome!(get_webhook_from_state(req));
64 let body = data.open(config.max_body_size.bytes());
65 let time_bounds = get_timestamp_bounds(config.timestamp_tolerance);
66 let validated_body =
67 try_outcome!(config.webhook.validate_body(req, body, time_bounds).await);
68
69 match serde_json::from_slice(&validated_body) {
70 Ok(data) => Outcome::Success(Self {
71 data,
72 headers: req.headers(),
73 _webhook: PhantomData,
74 _marker: PhantomData,
75 }),
76 Err(e) => Outcome::Error((Status::BadRequest, WebhookError::Deserialize(e))),
77 }
78 }
79}
80
81pub struct WebhookPayloadRaw<'r, W, M = W> {
102 pub data: Vec<u8>,
104 pub headers: &'r HeaderMap<'r>,
106 _webhook: PhantomData<W>,
107 _marker: PhantomData<M>,
108}
109
110#[async_trait]
111impl<'r, W, M> FromData<'r> for WebhookPayloadRaw<'r, W, M>
112where
113 W: Webhook + Send + Sync + 'static,
114 M: Send + Sync + 'static,
115{
116 type Error = WebhookError;
117
118 async fn from_data(
119 req: &'r Request<'_>,
120 data: rocket::Data<'r>,
121 ) -> Outcome<'r, Self, Self::Error> {
122 let config: &RocketWebhook<W, M> = try_outcome!(get_webhook_from_state(req));
123 let body = data.open(config.max_body_size.bytes());
124 let time_bounds = get_timestamp_bounds(config.timestamp_tolerance);
125 let validated_body =
126 try_outcome!(config.webhook.validate_body(req, body, time_bounds).await);
127
128 Outcome::Success(Self {
129 data: validated_body,
130 headers: req.headers(),
131 _webhook: PhantomData,
132 _marker: PhantomData,
133 })
134 }
135}
136
137fn get_webhook_from_state<'r, W, M>(
138 req: &'r Request,
139) -> Outcome<'r, &'r RocketWebhook<W, M>, WebhookError>
140where
141 W: Webhook + Send + Sync + 'static,
142 M: Send + Sync + 'static,
143{
144 match req.rocket().state::<RocketWebhook<W, M>>() {
145 Some(config) => Outcome::Success(config),
146 None => {
147 return Outcome::Error((Status::InternalServerError, WebhookError::NotAttached));
148 }
149 }
150}
151
152fn get_timestamp_bounds((past_secs, future_secs): (u32, u32)) -> (u32, u32) {
154 let unix_time = SystemTime::now()
155 .duration_since(UNIX_EPOCH)
156 .unwrap()
157 .as_secs() as u32; let lower_bound = {
159 if past_secs > unix_time {
160 0
161 } else {
162 unix_time - past_secs
163 }
164 };
165 let upper_bound = unix_time + future_secs;
166
167 (lower_bound, upper_bound)
168}