axum_github_webhook_extract/
lib.rs1use axum::body::Bytes;
46use axum::extract::{FromRef, FromRequest, Request};
47use axum::http::StatusCode;
48use hmac_sha256::HMAC;
49use serde::de::DeserializeOwned;
50use std::fmt::Display;
51use std::sync::Arc;
52use subtle::ConstantTimeEq;
53
54#[derive(Debug, Clone)]
56pub struct GithubToken(pub Arc<String>);
57
58#[derive(Debug, Clone, Copy, Default)]
60#[must_use]
61pub struct GithubEvent<T>(pub T);
62
63fn err(m: impl Display) -> (StatusCode, String) {
64 tracing::error!("{m}");
65 (StatusCode::BAD_REQUEST, m.to_string())
66}
67
68impl<T, S> FromRequest<S> for GithubEvent<T>
69where
70 GithubToken: FromRef<S>,
71 T: DeserializeOwned,
72 S: Send + Sync,
73{
74 type Rejection = (StatusCode, String);
75
76 fn from_request(
77 req: Request,
78 state: &S,
79 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
80 async {
81 let token = GithubToken::from_ref(state);
82 let signature_sha256 = req
83 .headers()
84 .get("X-Hub-Signature-256")
85 .and_then(|v| v.to_str().ok())
86 .ok_or_else(|| err("signature missing"))?
87 .strip_prefix("sha256=")
88 .ok_or_else(|| err("signature prefix missing"))?;
89 let signature =
90 hex::decode(signature_sha256).map_err(|_| err("signature malformed"))?;
91 let body = Bytes::from_request(req, state)
92 .await
93 .map_err(|_| err("error reading body"))?;
94 let mac = HMAC::mac(&body, token.0.as_bytes());
95 if mac.ct_ne(&signature).into() {
96 return Err(err("signature mismatch"));
97 }
98 let deserializer = &mut serde_json::Deserializer::from_slice(&body);
99 let value = serde_path_to_error::deserialize(deserializer).map_err(err)?;
100 Ok(GithubEvent(value))
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use axum::body::Body;
108 use axum::extract::Request;
109 use axum::http::StatusCode;
110 use axum::response::IntoResponse;
111 use axum::routing::post;
112 use axum::Router;
113 use http_body_util::BodyExt;
114 use serde::Deserialize;
115 use std::sync::Arc;
116 use tower::ServiceExt;
117
118 use super::{GithubEvent, GithubToken};
119
120 #[derive(Debug, Deserialize)]
121 struct Event {
122 action: String,
123 }
124
125 async fn echo(GithubEvent(e): GithubEvent<Event>) -> impl IntoResponse {
126 e.action
127 }
128
129 fn app() -> Router {
130 Router::new()
131 .route("/", post(echo))
132 .with_state(GithubToken(Arc::new(String::from("42"))))
133 }
134
135 async fn body_string(body: Body) -> String {
136 String::from_utf8_lossy(&body.collect().await.unwrap().to_bytes()).into()
137 }
138
139 fn with_header(v: &'static str) -> Request {
140 Request::builder()
141 .method("POST")
142 .header("X-Hub-Signature-256", v)
143 .body(Body::empty())
144 .unwrap()
145 }
146
147 #[tokio::test]
148 async fn signature_missing() {
149 let req = Request::builder()
150 .method("POST")
151 .body(Body::empty())
152 .unwrap();
153 let res = app().oneshot(req).await.unwrap();
154 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
155 assert_eq!(body_string(res.into_body()).await, "signature missing");
156 }
157
158 #[tokio::test]
159 async fn signature_prefix_missing() {
160 let res = app().oneshot(with_header("x")).await.unwrap();
161 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
162 assert_eq!(
163 body_string(res.into_body()).await,
164 "signature prefix missing"
165 );
166 }
167
168 #[tokio::test]
169 async fn signature_malformed() {
170 let res = app().oneshot(with_header("sha256=x")).await.unwrap();
171 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
172 assert_eq!(body_string(res.into_body()).await, "signature malformed");
173 }
174
175 #[tokio::test]
176 async fn signature_mismatch() {
177 let res = app().oneshot(with_header("sha256=01")).await.unwrap();
178 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
179 assert_eq!(body_string(res.into_body()).await, "signature mismatch");
180 }
181
182 #[tokio::test]
183 async fn signature_valid() {
184 let req: Request = Request::builder()
185 .method("POST")
186 .header(
187 "X-Hub-Signature-256",
188 "sha256=8b99afd7996c3e3c291a0b54399bacb72016bdb088071de42d1d7156a6a4273d",
189 )
190 .body(r#"{"action":"hello world"}"#.into())
191 .unwrap();
192 let res = app().oneshot(req).await.unwrap();
193 assert_eq!(res.status(), StatusCode::OK);
194 assert_eq!(body_string(res.into_body()).await, "hello world");
195 }
196}