axum_github_webhook_extract/
lib.rs

1//! A library to secure [GitHub Webhooks][github-webhooks] and extract JSON
2//! event payloads in [Axum][axum].
3//!
4//! The library is an [Extractor][axum-extractor] paired with
5//! [State][axum-state] to provide the required [Secret
6//! Token][github-secret-token].
7//!
8//! Usage looks like:
9//! ```
10//! # use axum::response::IntoResponse;
11//! # use axum::routing::post;
12//! # use axum::Router;
13//! # use serde::Deserialize;
14//! # use std::sync::Arc;
15//! use axum_github_webhook_extract::{GithubToken, GithubEvent};
16//!
17//! #[derive(Debug, Deserialize)]
18//! struct Event {
19//!     action: String,
20//! }
21//!
22//! async fn echo(GithubEvent(e): GithubEvent<Event>) -> impl IntoResponse {
23//!     e.action
24//! }
25//!
26//! fn app() -> Router {
27//!     let token = String::from("d4705034dd0777ee9e1e3078a12a06985151b76f");
28//!     Router::new()
29//!         .route("/", post(echo))
30//!         .with_state(GithubToken(Arc::new(token)))
31//! }
32//! ```
33//!
34//! You will usually get the token from your environment or configuration.
35//! The event payload is under your control, just make sure to configure it to
36//! use [JSON][github-json].
37//!
38//! [github-webhooks]: https://docs.github.com/en/webhooks-and-events/webhooks/securing-your-webhooks
39//! [axum]: https://docs.rs/axum/latest/axum/
40//! [axum-extractor]: https://docs.rs/axum/latest/axum/#extractors
41//! [axum-state]: https://docs.rs/axum/latest/axum/#sharing-state-with-handlers
42//! [github-secret-token]: https://docs.github.com/en/webhooks-and-events/webhooks/securing-your-webhooks#setting-your-secret-token
43//! [github-json]: https://docs.github.com/en/webhooks-and-events/webhooks/creating-webhooks#content-type
44
45use axum::body::Bytes;
46use axum::extract::{FromRef, FromRequest, Request};
47use axum::http::StatusCode;
48use hmac_sha256::HMAC;
49use serde::de::DeserializeOwned;
50use std::fmt::Display;
51use std::sync::Arc;
52use subtle::ConstantTimeEq;
53
54/// State to provide the Github Token to verify Event signature.
55#[derive(Debug, Clone)]
56pub struct GithubToken(pub Arc<String>);
57
58/// Verify and extract Github Event Payload.
59#[derive(Debug, Clone, Copy, Default)]
60#[must_use]
61pub struct GithubEvent<T>(pub T);
62
63fn err(m: impl Display) -> (StatusCode, String) {
64    tracing::error!("{m}");
65    (StatusCode::BAD_REQUEST, m.to_string())
66}
67
68impl<T, S> FromRequest<S> for GithubEvent<T>
69where
70    GithubToken: FromRef<S>,
71    T: DeserializeOwned,
72    S: Send + Sync,
73{
74    type Rejection = (StatusCode, String);
75
76    fn from_request(
77        req: Request,
78        state: &S,
79    ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
80        async {
81            let token = GithubToken::from_ref(state);
82            let signature_sha256 = req
83                .headers()
84                .get("X-Hub-Signature-256")
85                .and_then(|v| v.to_str().ok())
86                .ok_or_else(|| err("signature missing"))?
87                .strip_prefix("sha256=")
88                .ok_or_else(|| err("signature prefix missing"))?;
89            let signature =
90                hex::decode(signature_sha256).map_err(|_| err("signature malformed"))?;
91            let body = Bytes::from_request(req, state)
92                .await
93                .map_err(|_| err("error reading body"))?;
94            let mac = HMAC::mac(&body, token.0.as_bytes());
95            if mac.ct_ne(&signature).into() {
96                return Err(err("signature mismatch"));
97            }
98            let deserializer = &mut serde_json::Deserializer::from_slice(&body);
99            let value = serde_path_to_error::deserialize(deserializer).map_err(err)?;
100            Ok(GithubEvent(value))
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use axum::body::Body;
108    use axum::extract::Request;
109    use axum::http::StatusCode;
110    use axum::response::IntoResponse;
111    use axum::routing::post;
112    use axum::Router;
113    use http_body_util::BodyExt;
114    use serde::Deserialize;
115    use std::sync::Arc;
116    use tower::ServiceExt;
117
118    use super::{GithubEvent, GithubToken};
119
120    #[derive(Debug, Deserialize)]
121    struct Event {
122        action: String,
123    }
124
125    async fn echo(GithubEvent(e): GithubEvent<Event>) -> impl IntoResponse {
126        e.action
127    }
128
129    fn app() -> Router {
130        Router::new()
131            .route("/", post(echo))
132            .with_state(GithubToken(Arc::new(String::from("42"))))
133    }
134
135    async fn body_string(body: Body) -> String {
136        String::from_utf8_lossy(&body.collect().await.unwrap().to_bytes()).into()
137    }
138
139    fn with_header(v: &'static str) -> Request {
140        Request::builder()
141            .method("POST")
142            .header("X-Hub-Signature-256", v)
143            .body(Body::empty())
144            .unwrap()
145    }
146
147    #[tokio::test]
148    async fn signature_missing() {
149        let req = Request::builder()
150            .method("POST")
151            .body(Body::empty())
152            .unwrap();
153        let res = app().oneshot(req).await.unwrap();
154        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
155        assert_eq!(body_string(res.into_body()).await, "signature missing");
156    }
157
158    #[tokio::test]
159    async fn signature_prefix_missing() {
160        let res = app().oneshot(with_header("x")).await.unwrap();
161        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
162        assert_eq!(
163            body_string(res.into_body()).await,
164            "signature prefix missing"
165        );
166    }
167
168    #[tokio::test]
169    async fn signature_malformed() {
170        let res = app().oneshot(with_header("sha256=x")).await.unwrap();
171        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
172        assert_eq!(body_string(res.into_body()).await, "signature malformed");
173    }
174
175    #[tokio::test]
176    async fn signature_mismatch() {
177        let res = app().oneshot(with_header("sha256=01")).await.unwrap();
178        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
179        assert_eq!(body_string(res.into_body()).await, "signature mismatch");
180    }
181
182    #[tokio::test]
183    async fn signature_valid() {
184        let req: Request = Request::builder()
185            .method("POST")
186            .header(
187                "X-Hub-Signature-256",
188                "sha256=8b99afd7996c3e3c291a0b54399bacb72016bdb088071de42d1d7156a6a4273d",
189            )
190            .body(r#"{"action":"hello world"}"#.into())
191            .unwrap();
192        let res = app().oneshot(req).await.unwrap();
193        assert_eq!(res.status(), StatusCode::OK);
194        assert_eq!(body_string(res.into_body()).await, "hello world");
195    }
196}