Skip to main content

snapcast_server/
auth.rs

1//! Streaming client authentication.
2//!
3//! Implement [`AuthValidator`] for custom authentication (database, LDAP, etc.)
4//! or use [`StaticAuthValidator`] for config-file-based users/roles.
5//!
6//! Implement [`ClientFilter`] for connection-level filtering (MAC allowlist, etc.)
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use subtle::ConstantTimeEq;
12
13use snapcast_proto::message::hello::Hello;
14
15/// Constant-time byte comparison to prevent timing attacks.
16fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
17    a.ct_eq(b).into()
18}
19
20/// Connection-level client filter — called after Hello, before authentication.
21///
22/// Use this to implement MAC allowlists, IP-based filtering, or rate limiting.
23/// Return `true` to accept the client, `false` to disconnect immediately.
24///
25/// # Example: MAC allowlist
26///
27/// ```
28/// use snapcast_server::auth::ClientFilter;
29/// use snapcast_server::Hello;
30///
31/// struct MacAllowlist(Vec<String>);
32///
33/// impl ClientFilter for MacAllowlist {
34///     fn accept(&self, hello: &Hello) -> bool {
35///         self.0.is_empty() || self.0.iter().any(|m| m.eq_ignore_ascii_case(&hello.mac))
36///     }
37/// }
38/// ```
39pub trait ClientFilter: Send + Sync {
40    /// Decide whether to accept a client based on its Hello message.
41    fn accept(&self, hello: &Hello) -> bool;
42}
43
44/// Result of successful authentication.
45#[derive(Debug, Clone)]
46pub struct AuthResult {
47    /// Authenticated username.
48    pub username: String,
49    /// Granted permissions (e.g. "Streaming", "Control").
50    pub permissions: Vec<String>,
51}
52
53/// Authentication error.
54#[derive(Debug, Clone)]
55pub enum AuthError {
56    /// 401 — invalid or missing credentials.
57    Unauthorized(String),
58    /// 403 — authenticated but lacking required permission.
59    Forbidden(String),
60}
61
62impl std::fmt::Display for AuthError {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            Self::Unauthorized(msg) => write!(f, "Unauthorized: {msg}"),
66            Self::Forbidden(msg) => write!(f, "Forbidden: {msg}"),
67        }
68    }
69}
70
71impl std::error::Error for AuthError {}
72
73impl AuthError {
74    /// HTTP-style error code.
75    pub fn code(&self) -> i32 {
76        match self {
77            Self::Unauthorized(_) => 401,
78            Self::Forbidden(_) => 403,
79        }
80    }
81
82    /// Error message.
83    pub fn message(&self) -> &str {
84        match self {
85            Self::Unauthorized(msg) | Self::Forbidden(msg) => msg,
86        }
87    }
88}
89
90/// Trait for validating streaming client credentials.
91///
92/// The server calls [`validate`](AuthValidator::validate) after receiving a Hello message.
93/// Return [`AuthResult`] on success or [`AuthError`] on failure.
94///
95/// # Example: Custom validator
96///
97/// ```
98/// use snapcast_server::auth::{AuthValidator, AuthResult, AuthError};
99///
100/// struct MyValidator;
101///
102/// impl AuthValidator for MyValidator {
103///     fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError> {
104///         // Look up in database, LDAP, etc.
105///         Ok(AuthResult {
106///             username: "user".into(),
107///             permissions: vec!["Streaming".into()],
108///         })
109///     }
110/// }
111/// ```
112/// Trait for validating streaming client credentials.
113pub trait AuthValidator: Send + Sync {
114    /// Validate credentials from the Hello message's auth field.
115    fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError>;
116}
117
118/// Permission required for streaming clients.
119pub const PERM_STREAMING: &str = "Streaming";
120
121/// A role with named permissions.
122#[derive(Debug, Clone)]
123pub struct Role {
124    /// Role name.
125    pub name: String,
126    /// Granted permissions.
127    pub permissions: Vec<String>,
128}
129
130/// A user with credentials and role assignment.
131#[derive(Debug, Clone)]
132pub struct User {
133    /// Username.
134    pub name: String,
135    /// Password (plaintext — hashing is the deployer's responsibility).
136    pub password: String,
137    /// Role name.
138    pub role: String,
139}
140
141/// Config-file-based authentication matching the C++ implementation.
142///
143/// Validates Basic auth (`base64(user:password)`) against a static user/role list.
144#[derive(Debug, Clone)]
145pub struct StaticAuthValidator {
146    users: HashMap<String, (String, Arc<Role>)>, // name → (password, role)
147}
148
149impl StaticAuthValidator {
150    /// Create from user and role lists.
151    pub fn new(users: Vec<User>, roles: Vec<Role>) -> Self {
152        let role_map: HashMap<String, Arc<Role>> = roles
153            .into_iter()
154            .map(|r| (r.name.clone(), Arc::new(r)))
155            .collect();
156        let empty_role = Arc::new(Role {
157            name: String::new(),
158            permissions: vec![],
159        });
160        let user_map = users
161            .into_iter()
162            .map(|u| {
163                let role = role_map
164                    .get(&u.role)
165                    .cloned()
166                    .unwrap_or_else(|| empty_role.clone());
167                (u.name, (u.password, role))
168            })
169            .collect();
170        Self { users: user_map }
171    }
172}
173
174impl AuthValidator for StaticAuthValidator {
175    fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError> {
176        if !scheme.eq_ignore_ascii_case("basic") {
177            return Err(AuthError::Unauthorized(format!(
178                "Unsupported auth scheme: {scheme}"
179            )));
180        }
181
182        // Decode base64(user:password)
183        use base64::Engine;
184        let decoded = base64::engine::general_purpose::STANDARD
185            .decode(param)
186            .map_err(|_| AuthError::Unauthorized("Invalid base64".into()))?;
187        let decoded = String::from_utf8(decoded)
188            .map_err(|_| AuthError::Unauthorized("Invalid UTF-8".into()))?;
189        let (username, password) = decoded
190            .split_once(':')
191            .ok_or_else(|| AuthError::Unauthorized("Expected user:password".into()))?;
192
193        let (stored_pw, role) = self
194            .users
195            .get(username)
196            .ok_or_else(|| AuthError::Unauthorized("Unknown user".into()))?;
197
198        if !constant_time_eq(stored_pw.as_bytes(), password.as_bytes()) {
199            return Err(AuthError::Unauthorized("Wrong password".into()));
200        }
201
202        Ok(AuthResult {
203            username: username.to_string(),
204            permissions: role.permissions.clone(),
205        })
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    fn test_validator() -> StaticAuthValidator {
214        StaticAuthValidator::new(
215            vec![
216                User {
217                    name: "admin".into(),
218                    password: "secret".into(),
219                    role: "full".into(),
220                },
221                User {
222                    name: "player".into(),
223                    password: "play".into(),
224                    role: "streaming".into(),
225                },
226            ],
227            vec![
228                Role {
229                    name: "full".into(),
230                    permissions: vec!["Streaming".into(), "Control".into()],
231                },
232                Role {
233                    name: "streaming".into(),
234                    permissions: vec!["Streaming".into()],
235                },
236            ],
237        )
238    }
239
240    fn basic(user: &str, pass: &str) -> String {
241        use base64::Engine;
242        base64::engine::general_purpose::STANDARD.encode(format!("{user}:{pass}"))
243    }
244
245    #[test]
246    fn valid_credentials() {
247        let v = test_validator();
248        let result = v.validate("Basic", &basic("admin", "secret")).unwrap();
249        assert_eq!(result.username, "admin");
250        assert!(result.permissions.contains(&"Streaming".into()));
251        assert!(result.permissions.contains(&"Control".into()));
252    }
253
254    #[test]
255    fn wrong_password() {
256        let v = test_validator();
257        let err = v.validate("Basic", &basic("admin", "wrong")).unwrap_err();
258        assert_eq!(err.code(), 401);
259    }
260
261    #[test]
262    fn unknown_user() {
263        let v = test_validator();
264        let err = v.validate("Basic", &basic("nobody", "x")).unwrap_err();
265        assert_eq!(err.code(), 401);
266    }
267
268    #[test]
269    fn unsupported_scheme() {
270        let v = test_validator();
271        let err = v.validate("Bearer", "token123").unwrap_err();
272        assert_eq!(err.code(), 401);
273    }
274
275    #[test]
276    fn streaming_only_user() {
277        let v = test_validator();
278        let result = v.validate("Basic", &basic("player", "play")).unwrap();
279        assert_eq!(result.username, "player");
280        assert!(result.permissions.contains(&"Streaming".into()));
281        assert!(!result.permissions.contains(&"Control".into()));
282    }
283}