spa_rs/
auth.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use axum::{
4    extract::Request,
5    http::StatusCode,
6    response::{IntoResponse, Response},
7};
8use axum_help::filter::{drain_body, AsyncPredicate};
9use headers::{authorization::Basic, Authorization, HeaderMapExt};
10use parking_lot::Mutex;
11use std::{collections::VecDeque, fmt::Display, future::Future, pin::Pin, sync::Arc};
12
13use self::digest::unauthorized;
14
15#[async_trait]
16pub trait AuthCheckPredicate {
17    type CheckInfo: Clone + Send + Sync + 'static;
18
19    async fn check(
20        &self,
21        username: impl Into<String> + Send,
22        password: impl Into<String> + Send,
23    ) -> Result<Self::CheckInfo>;
24
25    fn username(&self) -> &str;
26    fn password(&self) -> &str;
27}
28
29#[derive(Clone)]
30pub struct AsyncBasicAuth<T>(T, String)
31where
32    T: AuthCheckPredicate + Clone + Send;
33
34impl<T> AsyncBasicAuth<T>
35where
36    T: AuthCheckPredicate + Clone + Send,
37{
38    pub fn new(p: T) -> Self {
39        Self(p, "Need basic authenticate".to_string())
40    }
41
42    pub fn err_msg(mut self, msg: impl Into<String>) -> Self {
43        self.1 = msg.into();
44        self
45    }
46}
47
48impl<T> AsyncPredicate<Request> for AsyncBasicAuth<T>
49where
50    T: AuthCheckPredicate + Clone + Send + Sync + 'static,
51{
52    type Request = Request;
53    type Response = Response;
54    type Future = Pin<Box<dyn Future<Output = Result<Self::Request, Self::Response>> + Send>>;
55
56    fn check(&mut self, mut request: Request) -> Self::Future {
57        let mut err = self.1.clone();
58        let auth = self.0.clone();
59        Box::pin(async move {
60            if let Some(authorization) = request.headers().typed_get::<Authorization<Basic>>() {
61                match auth
62                    .check(authorization.username(), authorization.password())
63                    .await
64                {
65                    Err(e) => err = format!("check authorization error: {:?}", e),
66                    Ok(ci) => {
67                        request.extensions_mut().insert(ci);
68                        return Ok(request);
69                    }
70                }
71            }
72
73            drain_body(request).await;
74            Err((
75                StatusCode::UNAUTHORIZED,
76                [("WWW-Authenticate", "Basic"); 1],
77                err,
78            )
79                .into_response())
80        })
81    }
82}
83
84#[derive(Clone)]
85pub struct AsyncDigestAuth<T>
86where
87    T: AuthCheckPredicate + Clone + Send,
88{
89    inner: T,
90    err: String,
91    srv_name: String,
92    nonces: Arc<Mutex<VecDeque<(String, String)>>>,
93}
94
95impl<T> AsyncDigestAuth<T>
96where
97    T: AuthCheckPredicate + Clone + Send,
98{
99    pub fn new(p: T) -> Self {
100        Self {
101            inner: p,
102            srv_name: env!("CARGO_PKG_NAME").to_owned(),
103            err: "Need digest authenticate".to_string(),
104            nonces: Arc::new(Mutex::new(VecDeque::new())),
105        }
106    }
107
108    pub fn srv_name(mut self, name: impl Into<String>) -> Self {
109        self.srv_name = name.into();
110        self
111    }
112
113    pub fn err_msg(mut self, msg: impl Into<String>) -> Self {
114        self.err = msg.into();
115        self
116    }
117}
118
119impl<T> AsyncPredicate<Request> for AsyncDigestAuth<T>
120where
121    T: AuthCheckPredicate + Clone + Send + Sync + 'static,
122{
123    type Request = Request;
124    type Response = Response;
125    type Future = Pin<Box<dyn Future<Output = Result<Self::Request, Self::Response>> + Send>>;
126
127    fn check(&mut self, request: Request) -> Self::Future {
128        let err = self.err.clone();
129        let inner = self.inner.clone();
130        let srv_name = self.srv_name.clone();
131        let nonces = self.nonces.clone();
132        Box::pin(async move {
133            if let Some(auth_header) = request.headers().get("Authorization") {
134                let auth = digest::Authorization::from_header(
135                    auth_header.to_str().map_err(bad_request)?,
136                )
137                .map_err(bad_request)?;
138
139                return auth.check(
140                    inner.username(),
141                    inner.password(),
142                    nonces,
143                    request,
144                    srv_name,
145                );
146            }
147
148            drain_body(request).await;
149            Err(unauthorized(nonces, err, srv_name))
150        })
151    }
152}
153
154fn bad_request(e: impl Display) -> Response {
155    (
156        StatusCode::BAD_REQUEST,
157        format!("Bad request in header Authorization: {}", e),
158    )
159        .into_response()
160}
161
162mod digest {
163    use anyhow::{anyhow, bail, Result};
164    use axum::{
165        extract::Request,
166        http::StatusCode,
167        response::{IntoResponse, Response},
168    };
169    use parking_lot::Mutex;
170    use rand::{distributions::Alphanumeric, thread_rng, Rng};
171    use std::{collections::VecDeque, fmt::Debug, sync::Arc};
172
173    #[derive(Default, Debug)]
174    pub(super) struct Authorization {
175        pub(super) username: String,
176        pub(super) realm: String,
177        pub(super) nonce: String,
178        pub(super) uri: String,
179        pub(super) qop: String,
180        pub(super) nc: String,
181        pub(super) cnonce: String,
182        pub(super) response: String,
183        pub(super) opaque: String,
184    }
185
186    impl Authorization {
187        pub(super) fn check(
188            &self,
189            username: impl AsRef<str>,
190            password: impl AsRef<str>,
191            nonces: Arc<Mutex<VecDeque<(String, String)>>>,
192            request: Request,
193            srv_name: impl AsRef<str>,
194        ) -> Result<Request, Response> {
195            let mut found_nonce = false;
196            {
197                let mut nonce_list = nonces.lock();
198                let mut index = nonce_list.len().saturating_sub(1);
199
200                for (nonce, opaque) in nonce_list.iter().rev() {
201                    if nonce == &self.nonce || opaque == &self.opaque {
202                        found_nonce = true;
203                        nonce_list.remove(index);
204                        break;
205                    }
206
207                    index = index.saturating_sub(1);
208                }
209            }
210
211            if !found_nonce {
212                return Err(unauthorized(nonces, "invalid nonce or opaque", srv_name));
213            }
214
215            log::debug!("digest request: {:?}", request);
216            let ha1 = md5::compute(format!(
217                "{}:{}:{}",
218                username.as_ref(),
219                self.realm,
220                password.as_ref()
221            ));
222            let ha2 = md5::compute(format!("{}:{}", request.method(), self.uri));
223            let password = md5::compute(format!(
224                "{:x}:{}:{}:{}:{}:{:x}",
225                ha1, self.nonce, self.nc, self.cnonce, self.qop, ha2
226            ));
227
228            if format!("{:x}", password) != self.response {
229                return Err(unauthorized(
230                    nonces,
231                    "invalid username or password",
232                    srv_name,
233                ));
234            }
235
236            Ok(request)
237        }
238
239        const DIGEST_MARK: &'static str = "Digest";
240        pub(super) fn from_header(auth: impl AsRef<str>) -> Result<Self> {
241            let auth = auth.as_ref();
242            let (mark, content) = auth.split_at(Self::DIGEST_MARK.len());
243            let content = content.trim();
244            if mark != Self::DIGEST_MARK {
245                bail!("only support digest authorization");
246            }
247
248            let mut result = Authorization::default();
249            for c in content.split(',') {
250                let c = c.trim();
251                let (k, v) = c
252                    .split_once('=')
253                    .ok_or_else(|| anyhow!("invalid part of authorization: {}", c))?;
254                let v = v.trim_matches('"');
255                match k {
256                    "username" => result.username = v.to_string(),
257                    "realm" => result.realm = v.to_string(),
258                    "nonce" => result.nonce = v.to_string(),
259                    "uri" => result.uri = v.to_string(),
260                    "qop" => result.qop = v.to_string(),
261                    "nc" => result.nc = v.to_string(),
262                    "cnonce" => result.cnonce = v.to_string(),
263                    "response" => result.response = v.to_string(),
264                    "opaque" => result.opaque = v.to_string(),
265                    _ => {
266                        log::warn!("unknown authorization part: {}", c);
267                        continue;
268                    }
269                }
270            }
271
272            log::debug!("digest auth: {:?}", result);
273            Ok(result)
274        }
275    }
276
277    pub(super) fn unauthorized(
278        nonces: Arc<Mutex<VecDeque<(String, String)>>>,
279        msg: impl Into<String>,
280        srv_name: impl AsRef<str>,
281    ) -> Response {
282        let realm = format!("Login to {}", srv_name.as_ref());
283        let nonce = rand_string(32);
284        let opaque = rand_string(32);
285
286        let www_authenticate = format!(
287            r#"Digest realm="{}",qop="auth",nonce="{}",opaque="{}""#,
288            realm, nonce, opaque
289        );
290
291        {
292            let mut nonce_list = nonces.lock();
293            while nonce_list.len() >= 256 {
294                nonce_list.pop_front();
295            }
296
297            nonce_list.push_back((nonce, opaque));
298        }
299
300        (
301            StatusCode::UNAUTHORIZED,
302            [("WWW-Authenticate", www_authenticate); 1],
303            msg.into(),
304        )
305            .into_response()
306    }
307
308    fn rand_string(count: usize) -> String {
309        thread_rng()
310            .sample_iter(Alphanumeric)
311            .take(count)
312            .map(char::from)
313            .collect()
314    }
315}