http_srv/handler/
auth.rs

1use core::str;
2use std::{
3    collections::HashMap,
4    fs::File,
5    io::{BufRead, BufReader},
6    sync::Arc,
7};
8
9use super::RequestHandler;
10use crate::{HttpRequest, Result, err};
11
12/// Authentication Config
13///
14/// Helps wrapping a [request handler](RequestHandler) behing authentication
15///
16/// # Example
17/// ```
18/// use http::*;
19/// use http_srv::handler::*;
20///
21/// let auth = AuthConfig::of_list(&[("user", "passwd")]);
22///
23/// let mut handler = Handler::default();
24/// let func = |req: &mut HttpRequest| {
25///     req.respond_str("Super secret message")
26/// };
27/// handler.get("/secret", auth.apply(func));
28/// ```
29pub struct AuthConfig {
30    users: Arc<HashMap<String, String>>,
31    required_users: Arc<Vec<String>>,
32}
33
34pub struct AuthConfigBuilder {
35    users: HashMap<String, String>,
36    required_users: Vec<String>,
37}
38
39impl AuthConfigBuilder {
40    pub fn require_user(mut self, user: &str) -> Self {
41        self.required_users.push(user.to_owned());
42        self
43    }
44    pub fn build(self) -> AuthConfig {
45        AuthConfig {
46            users: Arc::new(self.users),
47            required_users: Arc::new(self.required_users),
48        }
49    }
50}
51
52impl AuthConfig {
53    #[must_use]
54    pub fn builder() -> AuthConfigBuilder {
55        AuthConfigBuilder {
56            users: HashMap::new(),
57            required_users: Vec::new(),
58        }
59    }
60    pub fn of_file(filename: &str) -> crate::Result<Self> {
61        let f = File::open(filename)?;
62        let f = BufReader::new(f);
63        let mut users = HashMap::new();
64        let mut lines = f.lines();
65        while let Some(Ok(l)) = lines.next() {
66            let mut l = l.split_whitespace();
67            let u = l.next().ok_or("Malformatted file")?.to_owned();
68            let p = l.next().ok_or("Malformatted file")?.to_owned();
69            users.insert(u, p);
70        }
71        users.shrink_to_fit();
72        Ok(Self {
73            users: Arc::new(users),
74            required_users: Arc::new(Vec::new()),
75        })
76    }
77    #[must_use]
78    pub fn of_list(list: &[(&str, &str)]) -> Self {
79        let mut users = HashMap::new();
80        for e in list {
81            users.insert(e.0.to_owned(), e.1.to_owned());
82        }
83        Self {
84            users: Arc::new(users),
85            required_users: Arc::new(Vec::new()),
86        }
87    }
88    pub fn apply<H: RequestHandler>(&self, f: H) -> AuthedRequest<H> {
89        AuthedRequest {
90            f,
91            users: Arc::clone(&self.users),
92            required_users: Arc::clone(&self.required_users),
93        }
94    }
95}
96
97pub struct AuthedRequest<H: RequestHandler> {
98    f: H,
99    users: Arc<HashMap<String, String>>,
100    required_users: Arc<Vec<String>>,
101}
102
103impl<H: RequestHandler> RequestHandler for AuthedRequest<H> {
104    fn handle(&self, req: &mut HttpRequest) -> Result<()> {
105        let Some(auth) = req.header("Authorization") else {
106            req.set_header("WWW-Authenticate", "Basic");
107            return req.unauthorized();
108        };
109        let auth = HttpAuth::parse(auth)?;
110        if auth.check(&self.required_users, &self.users) {
111            self.f.handle(req)
112        } else {
113            req.unauthorized()
114        }
115    }
116}
117
118#[derive(Clone, PartialEq, Debug)]
119enum HttpAuth {
120    Basic(String, String),
121}
122
123impl HttpAuth {
124    fn parse(header: &str) -> Result<Self> {
125        let mut auth = header.split_whitespace();
126        let t = auth.next().ok_or("Malfromatted Authentication header")?;
127        let payload = auth.next().ok_or("Malfromatted Authentication header")?;
128
129        match t {
130            "Basic" => parse_basic(payload),
131            _ => err!("Unknown authentication method {t}"),
132        }
133    }
134    fn check(&self, users: &[String], passwds: &HashMap<String, String>) -> bool {
135        match self {
136            HttpAuth::Basic(user, pass) => {
137                if users.is_empty() || users.contains(user) {
138                    if let Some(p) = passwds.get(user).as_ref() {
139                        *p == pass
140                    } else {
141                        false
142                    }
143                } else {
144                    false
145                }
146            }
147        }
148    }
149}
150
151fn parse_basic(payload: &str) -> Result<HttpAuth> {
152    let decoded = base64::decode(payload)?;
153    let decoded = str::from_utf8(&decoded)?;
154    let mut decoded = decoded.splitn(2, ':');
155    let user = decoded.next().unwrap_or("");
156    let passwd = decoded.next().unwrap_or("");
157    let user = url::decode(user)?.into_owned();
158    let passwd = url::decode(passwd)?.into_owned();
159    Ok(HttpAuth::Basic(user, passwd))
160}
161
162#[cfg(test)]
163mod test {
164    #![allow(clippy::expect_used)]
165    use super::*;
166
167    #[test]
168    fn test() {
169        let auth = HttpAuth::parse("Basic dXNlcjpwYXNzd2Q=").expect("Expected correct parsing");
170        match auth {
171            HttpAuth::Basic(user, passwd) => {
172                assert_eq!(user, "user");
173                assert_eq!(passwd, "passwd");
174            }
175        }
176    }
177}