poem_middleware/
param_verify.rs

1use base64::{engine::general_purpose, Engine};
2use chrono::Utc;
3use hmac::{Hmac, Mac};
4use poem::{middleware::Middleware, Endpoint, IntoResponse, Request, Response, Result};
5
6use sha2::Sha256;
7
8type HmacSha256 = Hmac<Sha256>;
9
10#[derive(Default)]
11#[allow(clippy::type_complexity)]
12pub struct SignVerifyMiddleware {
13    secret_key: String,
14    allowed_time_window: i64,
15}
16
17impl SignVerifyMiddleware {
18    #[must_use]
19    pub fn new(secret: &str, allowed_time: i64) -> SignVerifyMiddleware {
20        Self {
21            secret_key: secret.to_string(),
22            allowed_time_window: allowed_time,
23        }
24    }
25}
26
27impl<E: Endpoint> Middleware<E> for SignVerifyMiddleware {
28    type Output = SignVerifyEndpoint<E>;
29
30    fn transform(&self, ep: E) -> Self::Output {
31        SignVerifyEndpoint {
32            ep,
33            secret_key: self.secret_key.clone(),
34            allowed_time_window: self.allowed_time_window,
35        }
36    }
37}
38
39#[allow(clippy::type_complexity)]
40pub struct SignVerifyEndpoint<E> {
41    ep: E,
42    secret_key: String,
43    allowed_time_window: i64,
44}
45
46impl<E: Endpoint> Endpoint for SignVerifyEndpoint<E> {
47    type Output = Response;
48
49    async fn call(&self, mut req: Request) -> Result<Self::Output> {
50        let sign = req
51            .header("apiSig")
52            .ok_or_else(|| {
53                poem::Error::from_string(
54                    "missing header apiSig",
55                    poem::http::StatusCode::BAD_REQUEST,
56                )
57            })?
58            .to_string();
59
60        let timestamp = req
61            .header("timestamp")
62            .ok_or_else(|| {
63                poem::Error::from_string(
64                    "missing header timestamp",
65                    poem::http::StatusCode::BAD_REQUEST,
66                )
67            })?
68            .parse::<i64>()
69            .map_err(|_| {
70                poem::Error::from_string(
71                    "timestamp parse error",
72                    poem::http::StatusCode::BAD_REQUEST,
73                )
74            })?;
75        let now = Utc::now().naive_utc().and_utc().timestamp();
76        if (timestamp - now).abs() > self.allowed_time_window {
77            return Err(poem::Error::from_string(
78                "request timeout",
79                poem::http::StatusCode::UNAUTHORIZED,
80            ));
81        }
82
83        let uri = req.uri().clone();
84
85        let method = req.method().clone();
86        let mut mac = HmacSha256::new_from_slice(self.secret_key.as_bytes())
87            .expect("HMAC can take key of any size");
88        let mut string_to_sign = String::new();
89        string_to_sign.push_str(&uri.to_string().split('?').last().unwrap());
90
91        let body = req.take_body().into_bytes().await?;
92        let body_str = String::from_utf8(body.to_vec())
93            .map_err(|_| {
94                poem::Error::from_string("body parse error", poem::http::StatusCode::BAD_REQUEST)
95            })?
96            .clone();
97
98        if method != poem::http::Method::GET {
99            string_to_sign.push_str(&body_str);
100        }
101
102        mac.update(string_to_sign.as_bytes());
103
104        let sign_decode = general_purpose::STANDARD
105            .decode(sign.as_bytes())
106            .map_err(|_| {
107                poem::Error::from_string(
108                    "base64 decode signature error",
109                    poem::http::StatusCode::BAD_REQUEST,
110                )
111            })
112            .unwrap();
113        let flag = mac.verify_slice(&sign_decode[..]).is_ok();
114        if !flag {
115            return Err(poem::Error::from_string(
116                "api signature verify error",
117                poem::http::StatusCode::UNAUTHORIZED,
118            ));
119        }
120        req.set_body(body);
121
122        let response = self.ep.call(req).await?.into_response();
123        Ok(response)
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use crate::param_verify::{HmacSha256, SignVerifyMiddleware};
130    use base64::{engine::general_purpose, Engine};
131    use chrono::Utc;
132    use hmac::{Hmac, Mac};
133    use poem::{endpoint::make_sync, test::TestClient, EndpointExt};
134
135    const SECRET_KEY: &[u8] = b"your_secret_key";
136
137    #[test]
138    fn test_encode() {
139        let mut mac =
140            HmacSha256::new_from_slice(SECRET_KEY).expect("HMAC can take key of any size");
141        mac.update(b"address=init&linkType=0");
142        let result = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
143        assert_eq!("kEU67gzX2pYgGlhsHXDxg0YtM7z8YYG6cQI8rl22eF4=", result);
144    }
145
146    #[test]
147    fn test_decode() {
148        let input = "OWvqzTbt3GhtPZUIQs9Z8g6KS/FroM7a4EUVWocFWP4=".to_string();
149        let decode_bytes = general_purpose::STANDARD.decode(input.as_bytes()).unwrap();
150        let mut mac =
151            HmacSha256::new_from_slice(b"your_secret_key").expect("HMAC can take key of any size");
152        mac.update(b"/api/available-code?address=init&linkType=0");
153        let result = mac.verify_slice(&decode_bytes[..]).is_ok();
154        assert_eq!(true, result)
155    }
156
157    #[tokio::test]
158    async fn test_check() {
159        let ep = make_sync(|_| "hello").with(SignVerifyMiddleware::new("your_secret_key", 20));
160        let cli = TestClient::new(ep);
161
162        let now = Utc::now().naive_utc().and_utc().timestamp();
163        let resp = cli
164            .get("/api/available-code")
165            .query("address", &"init")
166            .query("linkType", &0)
167            .header("apiSig", "kEU67gzX2pYgGlhsHXDxg0YtM7z8YYG6cQI8rl22eF4=")
168            .header("timestamp", now)
169            .send()
170            .await;
171
172        resp.assert_status_is_ok();
173    }
174}