1use std::collections::HashMap;
9use std::sync::Arc;
10
11use subtle::ConstantTimeEq;
12
13use snapcast_proto::message::hello::Hello;
14
15fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
17 a.ct_eq(b).into()
18}
19
20pub trait ClientFilter: Send + Sync {
40 fn accept(&self, hello: &Hello) -> bool;
42}
43
44#[derive(Debug, Clone)]
46pub struct AuthResult {
47 pub username: String,
49 pub permissions: Vec<String>,
51}
52
53#[derive(Debug, Clone)]
55pub enum AuthError {
56 Unauthorized(String),
58 Forbidden(String),
60}
61
62impl std::fmt::Display for AuthError {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 Self::Unauthorized(msg) => write!(f, "Unauthorized: {msg}"),
66 Self::Forbidden(msg) => write!(f, "Forbidden: {msg}"),
67 }
68 }
69}
70
71impl std::error::Error for AuthError {}
72
73impl AuthError {
74 pub fn code(&self) -> i32 {
76 match self {
77 Self::Unauthorized(_) => 401,
78 Self::Forbidden(_) => 403,
79 }
80 }
81
82 pub fn message(&self) -> &str {
84 match self {
85 Self::Unauthorized(msg) | Self::Forbidden(msg) => msg,
86 }
87 }
88}
89
90pub trait AuthValidator: Send + Sync {
114 fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError>;
116}
117
118pub const PERM_STREAMING: &str = "Streaming";
120
121#[derive(Debug, Clone)]
123pub struct Role {
124 pub name: String,
126 pub permissions: Vec<String>,
128}
129
130#[derive(Debug, Clone)]
132pub struct User {
133 pub name: String,
135 pub password: String,
137 pub role: String,
139}
140
141#[derive(Debug, Clone)]
145pub struct StaticAuthValidator {
146 users: HashMap<String, (String, Arc<Role>)>, }
148
149impl StaticAuthValidator {
150 pub fn new(users: Vec<User>, roles: Vec<Role>) -> Self {
152 let role_map: HashMap<String, Arc<Role>> = roles
153 .into_iter()
154 .map(|r| (r.name.clone(), Arc::new(r)))
155 .collect();
156 let empty_role = Arc::new(Role {
157 name: String::new(),
158 permissions: vec![],
159 });
160 let user_map = users
161 .into_iter()
162 .map(|u| {
163 let role = role_map
164 .get(&u.role)
165 .cloned()
166 .unwrap_or_else(|| empty_role.clone());
167 (u.name, (u.password, role))
168 })
169 .collect();
170 Self { users: user_map }
171 }
172}
173
174impl AuthValidator for StaticAuthValidator {
175 fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError> {
176 if !scheme.eq_ignore_ascii_case("basic") {
177 return Err(AuthError::Unauthorized(format!(
178 "Unsupported auth scheme: {scheme}"
179 )));
180 }
181
182 use base64::Engine;
184 let decoded = base64::engine::general_purpose::STANDARD
185 .decode(param)
186 .map_err(|_| AuthError::Unauthorized("Invalid base64".into()))?;
187 let decoded = String::from_utf8(decoded)
188 .map_err(|_| AuthError::Unauthorized("Invalid UTF-8".into()))?;
189 let (username, password) = decoded
190 .split_once(':')
191 .ok_or_else(|| AuthError::Unauthorized("Expected user:password".into()))?;
192
193 let (stored_pw, role) = self
194 .users
195 .get(username)
196 .ok_or_else(|| AuthError::Unauthorized("Unknown user".into()))?;
197
198 if !constant_time_eq(stored_pw.as_bytes(), password.as_bytes()) {
199 return Err(AuthError::Unauthorized("Wrong password".into()));
200 }
201
202 Ok(AuthResult {
203 username: username.to_string(),
204 permissions: role.permissions.clone(),
205 })
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 fn test_validator() -> StaticAuthValidator {
214 StaticAuthValidator::new(
215 vec![
216 User {
217 name: "admin".into(),
218 password: "secret".into(),
219 role: "full".into(),
220 },
221 User {
222 name: "player".into(),
223 password: "play".into(),
224 role: "streaming".into(),
225 },
226 ],
227 vec![
228 Role {
229 name: "full".into(),
230 permissions: vec!["Streaming".into(), "Control".into()],
231 },
232 Role {
233 name: "streaming".into(),
234 permissions: vec!["Streaming".into()],
235 },
236 ],
237 )
238 }
239
240 fn basic(user: &str, pass: &str) -> String {
241 use base64::Engine;
242 base64::engine::general_purpose::STANDARD.encode(format!("{user}:{pass}"))
243 }
244
245 #[test]
246 fn valid_credentials() {
247 let v = test_validator();
248 let result = v.validate("Basic", &basic("admin", "secret")).unwrap();
249 assert_eq!(result.username, "admin");
250 assert!(result.permissions.contains(&"Streaming".into()));
251 assert!(result.permissions.contains(&"Control".into()));
252 }
253
254 #[test]
255 fn wrong_password() {
256 let v = test_validator();
257 let err = v.validate("Basic", &basic("admin", "wrong")).unwrap_err();
258 assert_eq!(err.code(), 401);
259 }
260
261 #[test]
262 fn unknown_user() {
263 let v = test_validator();
264 let err = v.validate("Basic", &basic("nobody", "x")).unwrap_err();
265 assert_eq!(err.code(), 401);
266 }
267
268 #[test]
269 fn unsupported_scheme() {
270 let v = test_validator();
271 let err = v.validate("Bearer", "token123").unwrap_err();
272 assert_eq!(err.code(), 401);
273 }
274
275 #[test]
276 fn streaming_only_user() {
277 let v = test_validator();
278 let result = v.validate("Basic", &basic("player", "play")).unwrap();
279 assert_eq!(result.username, "player");
280 assert!(result.permissions.contains(&"Streaming".into()));
281 assert!(!result.permissions.contains(&"Control".into()));
282 }
283}