axum_standardwebhooks/
lib.rs1#![forbid(unsafe_code)]
46
47use axum::body::Body;
48use axum::extract::FromRef;
49use axum::extract::rejection::FailedToBufferBody;
50use axum::http::StatusCode;
51use axum::{
52 extract::FromRequest,
53 http::{Request, Response},
54 response::IntoResponse,
55};
56use bytes::Bytes;
57pub use standardwebhooks::Webhook;
58pub use standardwebhooks::WebhookError;
59use std::ops::Deref;
60use std::sync::Arc;
61
62#[derive(Clone)]
67pub struct SharedWebhook(Arc<Webhook>);
68
69impl Deref for SharedWebhook {
70 type Target = Webhook;
71
72 fn deref(&self) -> &Self::Target {
73 &self.0
74 }
75}
76
77impl SharedWebhook {
78 pub fn new(webhook: Webhook) -> Self {
97 Self(Arc::new(webhook))
98 }
99}
100
101#[derive(Debug)]
107pub enum StandardWebhookRejection<E> {
108 FailedToBufferBody(FailedToBufferBody),
110 FailedToVerifyWebhook(WebhookError),
112 FailedToExtractBody(E),
114}
115
116#[derive(Debug, Clone, Copy, Default)]
125#[must_use]
126pub struct StandardWebhook<T>(pub T);
127
128impl<E> IntoResponse for StandardWebhookRejection<E>
129where
130 E: IntoResponse,
131{
132 fn into_response(self) -> Response<Body> {
133 match self {
134 Self::FailedToBufferBody(e) => e.into_response(),
135 Self::FailedToVerifyWebhook(e) => {
136 (StatusCode::BAD_REQUEST, e.to_string()).into_response()
137 }
138 Self::FailedToExtractBody(e) => e.into_response(),
139 }
140 }
141}
142
143impl<S, T> FromRequest<S> for StandardWebhook<T>
144where
145 T: FromRequest<S>,
146 S: Send + Sync,
147 SharedWebhook: FromRef<S>,
148{
149 type Rejection = StandardWebhookRejection<T::Rejection>;
150
151 async fn from_request(mut req: Request<Body>, state: &S) -> Result<Self, Self::Rejection> {
165 let body = std::mem::replace(req.body_mut(), Body::empty());
170
171 let fake_req = Request::new(body);
172 let bytes = Bytes::from_request(fake_req, state).await.unwrap();
173
174 let verifier = SharedWebhook::from_ref(state);
175 verifier
176 .verify(&bytes, req.headers())
177 .map_err(StandardWebhookRejection::FailedToVerifyWebhook)?;
178
179 let body = bytes.into();
180 *req.body_mut() = body;
181
182 Ok(StandardWebhook(
183 T::from_request(req, state)
184 .await
185 .map_err(StandardWebhookRejection::FailedToExtractBody)?,
186 ))
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use axum::http::header;
194 use axum::{Json, Router, routing::post};
195 use http_body_util::BodyExt;
196 use serde_json::Value;
197 use standardwebhooks::{HEADER_WEBHOOK_ID, HEADER_WEBHOOK_SIGNATURE, HEADER_WEBHOOK_TIMESTAMP};
198 use std::sync::Arc;
199 use time::OffsetDateTime;
200 use tower::ServiceExt;
201
202 const SECRET: &str = "whsec_C2FVsBQIhrscChlQIMV+b5sSYspob7oD";
203 const MSG_ID: &str = "msg_27UH4WbU6Z5A5EzD8u03UvzRbpk";
204 const PAYLOAD: &[u8] = br#"{"email":"test@example.com","username":"test_user"}"#;
205
206 async fn echo(StandardWebhook(body): StandardWebhook<Json<Value>>) -> impl IntoResponse {
207 body["username"].as_str().unwrap().to_string()
208 }
209
210 async fn body_string(body: Body) -> String {
211 String::from_utf8_lossy(&body.collect().await.unwrap().to_bytes()).into()
212 }
213
214 fn with_headers(msg_id: &str, signature: &str, body: &'static [u8]) -> Request<Body> {
215 Request::builder()
216 .method("POST")
217 .header(HEADER_WEBHOOK_ID, msg_id)
218 .header(HEADER_WEBHOOK_SIGNATURE, signature)
219 .header(
220 HEADER_WEBHOOK_TIMESTAMP,
221 OffsetDateTime::now_utc().unix_timestamp().to_string(),
222 )
223 .header(header::CONTENT_TYPE, "application/json")
224 .body(body.into())
225 .unwrap()
226 }
227
228 fn app() -> Router {
229 Router::new()
230 .route("/", post(echo))
231 .with_state(SharedWebhook(Arc::new(Webhook::new(SECRET).unwrap())))
232 }
233
234 #[tokio::test]
235 async fn header_missing() {
236 let req = Request::builder()
237 .method("POST")
238 .body(Body::empty())
239 .unwrap();
240 let res = app().oneshot(req).await.unwrap();
241 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
242 assert_eq!(body_string(res.into_body()).await, "missing header id");
243 }
244
245 #[tokio::test]
246 async fn valid_signature() {
247 let wh = Webhook::new(SECRET).unwrap();
248 let signature = wh
249 .sign(MSG_ID, OffsetDateTime::now_utc().unix_timestamp(), PAYLOAD)
250 .unwrap();
251
252 let req = with_headers(MSG_ID, &signature, PAYLOAD);
253 let res = app().oneshot(req).await.unwrap();
254 assert_eq!(res.status(), StatusCode::OK);
255 assert_eq!(body_string(res.into_body()).await, "test_user");
256 }
257
258 #[tokio::test]
259 async fn invalid_signature() {
260 let wh = Webhook::new(SECRET).unwrap();
261 let mut signature = wh
262 .sign(MSG_ID, OffsetDateTime::now_utc().unix_timestamp(), PAYLOAD)
263 .unwrap();
264 signature.pop().unwrap();
265
266 let req = with_headers(MSG_ID, &signature, PAYLOAD);
267 let res = app().oneshot(req).await.unwrap();
268 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
269 assert_eq!(body_string(res.into_body()).await, "signature invalid");
270 }
271}