axum_standardwebhooks/
lib.rs

1//! Integration of the [standardwebhooks](https://crates.io/crates/standardwebhooks) crate with the
2//! [Axum](https://github.com/tokio-rs/axum) web framework.
3//!
4//! This crate provides an extractor for Axum that verifies webhook requests according to the
5//! [Standard Webhooks specification](https://github.com/standard-webhooks/standard-webhooks).
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use axum::{Router, routing::post, Json};
11//! use axum_standardwebhooks::{StandardWebhook, SharedWebhook, Webhook};
12//! use serde_json::Value;
13//! use std::sync::Arc;
14//! use axum::extract::FromRef;
15//!
16//! async fn webhook_handler(StandardWebhook(Json(payload)): StandardWebhook<Json<Value>>) -> String {
17//!     // The webhook signature has been verified, and we can safely use the payload
18//!     format!("Received webhook: {}", payload)
19//! }
20//!
21//! #[derive(Clone)]
22//! struct AppState {
23//!     webhook: SharedWebhook,
24//! }
25//!
26//! impl FromRef<AppState> for SharedWebhook {
27//!     fn from_ref(state: &AppState) -> Self {
28//!         state.webhook.clone()
29//!     }
30//! }
31//!
32//! #[tokio::main]
33//! async fn main() {
34//!     let app = Router::new()
35//!         .route("/webhooks", post(webhook_handler))
36//!         .with_state(AppState {
37//!             webhook: SharedWebhook::new(Webhook::new("whsec_C2FVsBQIhrscChlQIMV+b5sSYspob7oD").unwrap()),
38//!         });
39//!
40//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
41//!     axum::serve(listener, app).await.unwrap();
42//! }
43//! ```
44
45#![forbid(unsafe_code)]
46
47use axum::body::Body;
48use axum::extract::FromRef;
49use axum::extract::rejection::FailedToBufferBody;
50use axum::http::StatusCode;
51use axum::{
52    extract::FromRequest,
53    http::{Request, Response},
54    response::IntoResponse,
55};
56use bytes::Bytes;
57pub use standardwebhooks::Webhook;
58pub use standardwebhooks::WebhookError;
59use std::ops::Deref;
60use std::sync::Arc;
61
62/// A thread-safe wrapper around [`Webhook`] to make it shareable between Axum handlers.
63///
64/// This type provides a convenient way to share the webhook verifier across multiple
65/// request handlers without needing to clone the underlying `Webhook` for each request.
66#[derive(Clone)]
67pub struct SharedWebhook(Arc<Webhook>);
68
69impl Deref for SharedWebhook {
70    type Target = Webhook;
71
72    fn deref(&self) -> &Self::Target {
73        &self.0
74    }
75}
76
77impl SharedWebhook {
78    /// Creates a new `SharedWebhook` from a `Webhook`
79    /// for use during verification.
80    ///
81    /// # Arguments
82    ///
83    /// * `webhook` - The `Webhook` to wrap
84    ///
85    /// # Returns
86    ///
87    /// A new `SharedWebhook` wrapping the provided `Webhook`
88    ///
89    /// # Example
90    ///
91    /// ```rust
92    /// use axum_standardwebhooks::{SharedWebhook, Webhook};
93    ///
94    /// let shared_webhook = SharedWebhook::new(Webhook::new("whsec_C2FVsBQIhrscChlQIMV+b5sSYspob7oD").unwrap());
95    /// ```
96    pub fn new(webhook: Webhook) -> Self {
97        Self(Arc::new(webhook))
98    }
99}
100
101/// Represents the ways in which webhook verification and extraction can fail.
102/// Represents the ways in which webhook verification and extraction can fail.
103///
104/// This enum combines errors from body buffering, webhook verification, and
105/// the extraction of the inner type.
106#[derive(Debug)]
107pub enum StandardWebhookRejection<E> {
108    /// The request body could not be buffered.
109    FailedToBufferBody(FailedToBufferBody),
110    /// The webhook signature could not be verified.
111    FailedToVerifyWebhook(WebhookError),
112    /// The request body could not be extracted into the desired type.
113    FailedToExtractBody(E),
114}
115
116/// An extractor that verifies a webhook request and extracts the inner payload.
117///
118/// `StandardWebhook<T>` wraps another extractor `T` and ensures that the webhook
119/// signature is valid before proceeding with the extraction of `T`. This provides
120/// a way to safely handle webhook payloads in Axum handlers.
121///
122/// The inner extractor `T` can be any type that implements [`FromRequest`],
123/// such as [`Json`], [`Form`], or [`Query`].
124#[derive(Debug, Clone, Copy, Default)]
125#[must_use]
126pub struct StandardWebhook<T>(pub T);
127
128impl<E> IntoResponse for StandardWebhookRejection<E>
129where
130    E: IntoResponse,
131{
132    fn into_response(self) -> Response<Body> {
133        match self {
134            Self::FailedToBufferBody(e) => e.into_response(),
135            Self::FailedToVerifyWebhook(e) => {
136                (StatusCode::BAD_REQUEST, e.to_string()).into_response()
137            }
138            Self::FailedToExtractBody(e) => e.into_response(),
139        }
140    }
141}
142
143impl<S, T> FromRequest<S> for StandardWebhook<T>
144where
145    T: FromRequest<S>,
146    S: Send + Sync,
147    SharedWebhook: FromRef<S>,
148{
149    type Rejection = StandardWebhookRejection<T::Rejection>;
150
151    /// Extracts a `StandardWebhook<T>` from the request.
152    ///
153    /// This method:
154    /// 1. Buffers the request body
155    /// 2. Verifies the webhook signature
156    /// 3. Extracts the inner type `T` from the request
157    ///
158    /// # Errors
159    ///
160    /// Returns an error if:
161    /// - The request body could not be buffered
162    /// - The webhook signature is invalid
163    /// - The inner extractor `T` fails
164    async fn from_request(mut req: Request<Body>, state: &S) -> Result<Self, Self::Rejection> {
165        // we want to avoid copying the entire request object,
166        // so we take the original request's body,
167        // create a fake request with the body, perform the buffering,
168        // and then replace the original request's body with the buffered one
169        let body = std::mem::replace(req.body_mut(), Body::empty());
170
171        let fake_req = Request::new(body);
172        let bytes = Bytes::from_request(fake_req, state).await.unwrap();
173
174        let verifier = SharedWebhook::from_ref(state);
175        verifier
176            .verify(&bytes, req.headers())
177            .map_err(StandardWebhookRejection::FailedToVerifyWebhook)?;
178
179        let body = bytes.into();
180        *req.body_mut() = body;
181
182        Ok(StandardWebhook(
183            T::from_request(req, state)
184                .await
185                .map_err(StandardWebhookRejection::FailedToExtractBody)?,
186        ))
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use axum::http::header;
194    use axum::{Json, Router, routing::post};
195    use http_body_util::BodyExt;
196    use serde_json::Value;
197    use standardwebhooks::{HEADER_WEBHOOK_ID, HEADER_WEBHOOK_SIGNATURE, HEADER_WEBHOOK_TIMESTAMP};
198    use std::sync::Arc;
199    use time::OffsetDateTime;
200    use tower::ServiceExt;
201
202    const SECRET: &str = "whsec_C2FVsBQIhrscChlQIMV+b5sSYspob7oD";
203    const MSG_ID: &str = "msg_27UH4WbU6Z5A5EzD8u03UvzRbpk";
204    const PAYLOAD: &[u8] = br#"{"email":"test@example.com","username":"test_user"}"#;
205
206    async fn echo(StandardWebhook(body): StandardWebhook<Json<Value>>) -> impl IntoResponse {
207        body["username"].as_str().unwrap().to_string()
208    }
209
210    async fn body_string(body: Body) -> String {
211        String::from_utf8_lossy(&body.collect().await.unwrap().to_bytes()).into()
212    }
213
214    fn with_headers(msg_id: &str, signature: &str, body: &'static [u8]) -> Request<Body> {
215        Request::builder()
216            .method("POST")
217            .header(HEADER_WEBHOOK_ID, msg_id)
218            .header(HEADER_WEBHOOK_SIGNATURE, signature)
219            .header(
220                HEADER_WEBHOOK_TIMESTAMP,
221                OffsetDateTime::now_utc().unix_timestamp().to_string(),
222            )
223            .header(header::CONTENT_TYPE, "application/json")
224            .body(body.into())
225            .unwrap()
226    }
227
228    fn app() -> Router {
229        Router::new()
230            .route("/", post(echo))
231            .with_state(SharedWebhook(Arc::new(Webhook::new(SECRET).unwrap())))
232    }
233
234    #[tokio::test]
235    async fn header_missing() {
236        let req = Request::builder()
237            .method("POST")
238            .body(Body::empty())
239            .unwrap();
240        let res = app().oneshot(req).await.unwrap();
241        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
242        assert_eq!(body_string(res.into_body()).await, "missing header id");
243    }
244
245    #[tokio::test]
246    async fn valid_signature() {
247        let wh = Webhook::new(SECRET).unwrap();
248        let signature = wh
249            .sign(MSG_ID, OffsetDateTime::now_utc().unix_timestamp(), PAYLOAD)
250            .unwrap();
251
252        let req = with_headers(MSG_ID, &signature, PAYLOAD);
253        let res = app().oneshot(req).await.unwrap();
254        assert_eq!(res.status(), StatusCode::OK);
255        assert_eq!(body_string(res.into_body()).await, "test_user");
256    }
257
258    #[tokio::test]
259    async fn invalid_signature() {
260        let wh = Webhook::new(SECRET).unwrap();
261        let mut signature = wh
262            .sign(MSG_ID, OffsetDateTime::now_utc().unix_timestamp(), PAYLOAD)
263            .unwrap();
264        signature.pop().unwrap();
265
266        let req = with_headers(MSG_ID, &signature, PAYLOAD);
267        let res = app().oneshot(req).await.unwrap();
268        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
269        assert_eq!(body_string(res.into_body()).await, "signature invalid");
270    }
271}