use std::{
marker::PhantomData,
time::{SystemTime, UNIX_EPOCH},
};
use rocket::{
Request, async_trait,
data::{FromData, Outcome, ToByteUnit},
http::{HeaderMap, Status},
outcome::try_outcome,
serde::{DeserializeOwned, json::serde_json},
};
use crate::{RocketWebhook, WebhookError, webhooks::Webhook};
pub struct WebhookPayload<'r, T, W, M = W> {
pub data: T,
pub headers: &'r HeaderMap<'r>,
_webhook: PhantomData<W>,
_marker: PhantomData<M>,
}
#[async_trait]
impl<'r, T, W, M> FromData<'r> for WebhookPayload<'r, T, W, M>
where
T: DeserializeOwned,
W: Webhook + Send + Sync + 'static,
M: Send + Sync + 'static,
{
type Error = WebhookError;
async fn from_data(
req: &'r Request<'_>,
data: rocket::Data<'r>,
) -> Outcome<'r, Self, Self::Error> {
let config: &RocketWebhook<W, M> = try_outcome!(get_webhook_from_state(req));
let body = data.open(config.max_body_size.bytes());
let time_bounds = get_timestamp_bounds(config.timestamp_tolerance);
let validated_body =
try_outcome!(config.webhook.validate_body(req, body, time_bounds).await);
match serde_json::from_slice(&validated_body) {
Ok(data) => Outcome::Success(Self {
data,
headers: req.headers(),
_webhook: PhantomData,
_marker: PhantomData,
}),
Err(e) => Outcome::Error((Status::BadRequest, WebhookError::Deserialize(e))),
}
}
}
pub struct WebhookPayloadRaw<'r, W, M = W> {
pub data: Vec<u8>,
pub headers: &'r HeaderMap<'r>,
_webhook: PhantomData<W>,
_marker: PhantomData<M>,
}
#[async_trait]
impl<'r, W, M> FromData<'r> for WebhookPayloadRaw<'r, W, M>
where
W: Webhook + Send + Sync + 'static,
M: Send + Sync + 'static,
{
type Error = WebhookError;
async fn from_data(
req: &'r Request<'_>,
data: rocket::Data<'r>,
) -> Outcome<'r, Self, Self::Error> {
let config: &RocketWebhook<W, M> = try_outcome!(get_webhook_from_state(req));
let body = data.open(config.max_body_size.bytes());
let time_bounds = get_timestamp_bounds(config.timestamp_tolerance);
let validated_body =
try_outcome!(config.webhook.validate_body(req, body, time_bounds).await);
Outcome::Success(Self {
data: validated_body,
headers: req.headers(),
_webhook: PhantomData,
_marker: PhantomData,
})
}
}
fn get_webhook_from_state<'r, W, M>(
req: &'r Request,
) -> Outcome<'r, &'r RocketWebhook<W, M>, WebhookError>
where
W: Webhook + Send + Sync + 'static,
M: Send + Sync + 'static,
{
match req.rocket().state::<RocketWebhook<W, M>>() {
Some(config) => Outcome::Success(config),
None => {
return Outcome::Error((Status::InternalServerError, WebhookError::NotAttached));
}
}
}
fn get_timestamp_bounds((past_secs, future_secs): (u32, u32)) -> (u32, u32) {
let unix_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as u32; let lower_bound = {
if past_secs > unix_time {
0
} else {
unix_time - past_secs
}
};
let upper_bound = unix_time + future_secs;
(lower_bound, upper_bound)
}