1use std::collections::BTreeMap;
4use std::fmt;
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
7use base64::Engine;
8use hmac::{Hmac, KeyInit, Mac};
9use secrecy::{ExposeSecret, SecretString};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use sha2::Sha256;
13use subtle::ConstantTimeEq;
14
15use crate::error::{Error, Result, WebhookVerificationError};
16
17type HmacSha256 = Hmac<Sha256>;
18
19pub trait HeaderLookup {
21 fn get_header(&self, name: &str) -> Option<String>;
23}
24
25impl HeaderLookup for http::HeaderMap {
26 fn get_header(&self, name: &str) -> Option<String> {
27 self.get(name)
28 .and_then(|value| value.to_str().ok())
29 .map(str::to_owned)
30 }
31}
32
33impl HeaderLookup for BTreeMap<String, String> {
34 fn get_header(&self, name: &str) -> Option<String> {
35 self.get(name).cloned()
36 }
37}
38
39impl<const N: usize> HeaderLookup for [(&str, &str); N] {
40 fn get_header(&self, name: &str) -> Option<String> {
41 self.iter()
42 .find_map(|(key, value)| (*key == name).then(|| (*value).to_owned()))
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct WebhookEvent {
49 pub id: String,
51 pub object: Option<String>,
53 pub created_at: i64,
55 #[serde(rename = "type")]
57 pub event_type: String,
58 pub data: Value,
60 #[serde(flatten)]
62 pub extra: BTreeMap<String, Value>,
63}
64
65#[derive(Clone)]
67pub struct WebhookVerifier {
68 secret: Option<SecretString>,
69}
70
71impl fmt::Debug for WebhookVerifier {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 f.debug_struct("WebhookVerifier")
74 .field("secret", &self.secret.as_ref().map(|_| "<redacted>"))
75 .finish()
76 }
77}
78
79impl WebhookVerifier {
80 pub fn new(secret: Option<SecretString>) -> Self {
82 Self { secret }
83 }
84
85 pub fn verify_signature<H>(
91 &self,
92 payload: &str,
93 headers: &H,
94 secret: Option<&str>,
95 tolerance: Duration,
96 ) -> Result<()>
97 where
98 H: HeaderLookup,
99 {
100 let secret = secret
101 .map(str::to_owned)
102 .or_else(|| {
103 self.secret
104 .as_ref()
105 .map(|value| value.expose_secret().to_owned())
106 })
107 .ok_or_else(|| {
108 Error::WebhookVerification(WebhookVerificationError::new("Webhook secret 未配置"))
109 })?;
110
111 let signature_header = required_header(headers, "webhook-signature")?;
112 let timestamp = required_header(headers, "webhook-timestamp")?;
113 let webhook_id = required_header(headers, "webhook-id")?;
114
115 let timestamp = timestamp.parse::<u64>().map_err(|_| {
116 Error::WebhookVerification(WebhookVerificationError::new(
117 "Invalid webhook timestamp format",
118 ))
119 })?;
120 validate_timestamp(timestamp, tolerance)?;
121
122 let signed_payload = format!("{webhook_id}.{timestamp}.{payload}");
123 let expected = compute_signature(&secret, signed_payload.as_bytes())?;
124
125 let valid = signature_header.split(' ').any(|part| {
126 let signature = part.strip_prefix("v1,").unwrap_or(part);
127 base64::engine::general_purpose::STANDARD
128 .decode(signature)
129 .ok()
130 .is_some_and(|candidate| candidate.ct_eq(&expected).into())
131 });
132
133 if !valid {
134 return Err(Error::WebhookVerification(WebhookVerificationError::new(
135 "The given webhook signature does not match the expected signature",
136 )));
137 }
138
139 Ok(())
140 }
141
142 pub fn unwrap<H, T>(
148 &self,
149 payload: &str,
150 headers: &H,
151 secret: Option<&str>,
152 tolerance: Duration,
153 ) -> Result<T>
154 where
155 H: HeaderLookup,
156 T: serde::de::DeserializeOwned,
157 {
158 self.verify_signature(payload, headers, secret, tolerance)?;
159 serde_json::from_str(payload).map_err(|error| {
160 Error::Serialization(crate::SerializationError::new(format!(
161 "Webhook 负载解析失败: {error}"
162 )))
163 })
164 }
165}
166
167fn required_header<H>(headers: &H, name: &str) -> Result<String>
168where
169 H: HeaderLookup,
170{
171 headers.get_header(name).ok_or_else(|| {
172 Error::WebhookVerification(WebhookVerificationError::new(format!(
173 "Missing required header: {name}"
174 )))
175 })
176}
177
178fn validate_timestamp(timestamp: u64, tolerance: Duration) -> Result<()> {
179 let now = SystemTime::now()
180 .duration_since(UNIX_EPOCH)
181 .map_err(|error| {
182 Error::WebhookVerification(WebhookVerificationError::new(error.to_string()))
183 })?
184 .as_secs();
185 let tolerance = tolerance.as_secs();
186
187 if now.saturating_sub(timestamp) > tolerance {
188 return Err(Error::WebhookVerification(WebhookVerificationError::new(
189 "Webhook timestamp is too old",
190 )));
191 }
192
193 if timestamp > now.saturating_add(tolerance) {
194 return Err(Error::WebhookVerification(WebhookVerificationError::new(
195 "Webhook timestamp is too new",
196 )));
197 }
198
199 Ok(())
200}
201
202fn compute_signature(secret: &str, payload: &[u8]) -> Result<Vec<u8>> {
203 let key = if let Some(secret) = secret.strip_prefix("whsec_") {
204 base64::engine::general_purpose::STANDARD
205 .decode(secret)
206 .map_err(|error| {
207 Error::WebhookVerification(WebhookVerificationError::new(format!(
208 "Webhook secret 非法: {error}"
209 )))
210 })?
211 } else {
212 secret.as_bytes().to_vec()
213 };
214
215 let mut mac = HmacSha256::new_from_slice(&key).map_err(|error| {
216 Error::WebhookVerification(WebhookVerificationError::new(format!(
217 "创建 HMAC 失败: {error}"
218 )))
219 })?;
220 mac.update(payload);
221 Ok(mac.finalize().into_bytes().to_vec())
222}
223
224#[cfg(test)]
225mod tests {
226 use super::WebhookVerifier;
227 use std::collections::BTreeMap;
228 use std::time::Duration;
229
230 fn test_payload() -> &'static str {
231 r#"{"id": "evt_685c059ae3a481909bdc86819b066fb6", "object": "event", "created_at": 1750861210, "type": "response.completed", "data": {"id": "resp_123"}}"#
232 }
233
234 fn test_headers() -> BTreeMap<String, String> {
235 BTreeMap::from([
236 (
237 "webhook-signature".into(),
238 "v1,gUAg4R2hWouRZqRQG4uJypNS8YK885G838+EHb4nKBY=".into(),
239 ),
240 ("webhook-timestamp".into(), "1750861210".into()),
241 (
242 "webhook-id".into(),
243 "wh_685c059ae39c8190af8c71ed1022a24d".into(),
244 ),
245 ])
246 }
247
248 fn test_secret() -> &'static str {
249 "whsec_RdvaYFYUXuIFuEbvZHwMfYFhUf7aMYjYcmM24+Aj40c="
250 }
251
252 #[test]
253 fn test_should_verify_valid_signature() {
254 let verifier = WebhookVerifier::new(None);
255 verifier
256 .verify_signature(
257 test_payload(),
258 &test_headers(),
259 Some(test_secret()),
260 Duration::from_secs(60 * 60 * 24 * 3650),
261 )
262 .unwrap();
263 }
264
265 #[test]
266 fn test_should_reject_invalid_signature() {
267 let verifier = WebhookVerifier::new(None);
268 let error = verifier
269 .verify_signature(
270 test_payload(),
271 &test_headers(),
272 Some("whsec_Zm9v"),
273 Duration::from_secs(60 * 60 * 24 * 3650),
274 )
275 .unwrap_err();
276 assert!(matches!(error, crate::Error::WebhookVerification(_)));
277 }
278
279 #[test]
280 fn test_should_unwrap_payload_after_verification() {
281 let verifier = WebhookVerifier::new(None);
282 let event: crate::webhooks::WebhookEvent = verifier
283 .unwrap(
284 test_payload(),
285 &test_headers(),
286 Some(test_secret()),
287 Duration::from_secs(60 * 60 * 24 * 3650),
288 )
289 .unwrap();
290 assert_eq!(event.id, "evt_685c059ae3a481909bdc86819b066fb6");
291 assert_eq!(event.event_type, "response.completed");
292 }
293}