1use core::str;
2use std::{
3 collections::HashMap,
4 fs::File,
5 io::{BufRead, BufReader},
6 sync::Arc,
7};
8
9use super::RequestHandler;
10use crate::{HttpRequest, Result, err};
11
12pub struct AuthConfig {
30 users: Arc<HashMap<String, String>>,
31 required_users: Arc<Vec<String>>,
32}
33
34pub struct AuthConfigBuilder {
35 users: HashMap<String, String>,
36 required_users: Vec<String>,
37}
38
39impl AuthConfigBuilder {
40 pub fn require_user(mut self, user: &str) -> Self {
41 self.required_users.push(user.to_owned());
42 self
43 }
44 pub fn build(self) -> AuthConfig {
45 AuthConfig {
46 users: Arc::new(self.users),
47 required_users: Arc::new(self.required_users),
48 }
49 }
50}
51
52impl AuthConfig {
53 #[must_use]
54 pub fn builder() -> AuthConfigBuilder {
55 AuthConfigBuilder {
56 users: HashMap::new(),
57 required_users: Vec::new(),
58 }
59 }
60 pub fn of_file(filename: &str) -> crate::Result<Self> {
61 let f = File::open(filename)?;
62 let f = BufReader::new(f);
63 let mut users = HashMap::new();
64 let mut lines = f.lines();
65 while let Some(Ok(l)) = lines.next() {
66 let mut l = l.split_whitespace();
67 let u = l.next().ok_or("Malformatted file")?.to_owned();
68 let p = l.next().ok_or("Malformatted file")?.to_owned();
69 users.insert(u, p);
70 }
71 users.shrink_to_fit();
72 Ok(Self {
73 users: Arc::new(users),
74 required_users: Arc::new(Vec::new()),
75 })
76 }
77 #[must_use]
78 pub fn of_list(list: &[(&str, &str)]) -> Self {
79 let mut users = HashMap::new();
80 for e in list {
81 users.insert(e.0.to_owned(), e.1.to_owned());
82 }
83 Self {
84 users: Arc::new(users),
85 required_users: Arc::new(Vec::new()),
86 }
87 }
88 pub fn apply<H: RequestHandler>(&self, f: H) -> AuthedRequest<H> {
89 AuthedRequest {
90 f,
91 users: Arc::clone(&self.users),
92 required_users: Arc::clone(&self.required_users),
93 }
94 }
95}
96
97pub struct AuthedRequest<H: RequestHandler> {
98 f: H,
99 users: Arc<HashMap<String, String>>,
100 required_users: Arc<Vec<String>>,
101}
102
103impl<H: RequestHandler> RequestHandler for AuthedRequest<H> {
104 fn handle(&self, req: &mut HttpRequest) -> Result<()> {
105 let Some(auth) = req.header("Authorization") else {
106 req.set_header("WWW-Authenticate", "Basic");
107 return req.unauthorized();
108 };
109 let auth = HttpAuth::parse(auth)?;
110 if auth.check(&self.required_users, &self.users) {
111 self.f.handle(req)
112 } else {
113 req.unauthorized()
114 }
115 }
116}
117
118#[derive(Clone, PartialEq, Debug)]
119enum HttpAuth {
120 Basic(String, String),
121}
122
123impl HttpAuth {
124 fn parse(header: &str) -> Result<Self> {
125 let mut auth = header.split_whitespace();
126 let t = auth.next().ok_or("Malfromatted Authentication header")?;
127 let payload = auth.next().ok_or("Malfromatted Authentication header")?;
128
129 match t {
130 "Basic" => parse_basic(payload),
131 _ => err!("Unknown authentication method {t}"),
132 }
133 }
134 fn check(&self, users: &[String], passwds: &HashMap<String, String>) -> bool {
135 match self {
136 HttpAuth::Basic(user, pass) => {
137 if users.is_empty() || users.contains(user) {
138 if let Some(p) = passwds.get(user).as_ref() {
139 *p == pass
140 } else {
141 false
142 }
143 } else {
144 false
145 }
146 }
147 }
148 }
149}
150
151fn parse_basic(payload: &str) -> Result<HttpAuth> {
152 let decoded = base64::decode(payload)?;
153 let decoded = str::from_utf8(&decoded)?;
154 let mut decoded = decoded.splitn(2, ':');
155 let user = decoded.next().unwrap_or("");
156 let passwd = decoded.next().unwrap_or("");
157 let user = url::decode(user)?.into_owned();
158 let passwd = url::decode(passwd)?.into_owned();
159 Ok(HttpAuth::Basic(user, passwd))
160}
161
162#[cfg(test)]
163mod test {
164 #![allow(clippy::expect_used)]
165 use super::*;
166
167 #[test]
168 fn test() {
169 let auth = HttpAuth::parse("Basic dXNlcjpwYXNzd2Q=").expect("Expected correct parsing");
170 match auth {
171 HttpAuth::Basic(user, passwd) => {
172 assert_eq!(user, "user");
173 assert_eq!(passwd, "passwd");
174 }
175 }
176 }
177}