1use std::collections::HashMap;
16
17#[derive(Debug, Clone, Default, PartialEq, Eq)]
20pub struct AuthSubject {
21 pub name: String,
23 pub groups: Vec<String>,
25 pub claims: HashMap<String, String>,
27}
28
29impl AuthSubject {
30 #[must_use]
32 pub fn anonymous() -> Self {
33 Self {
34 name: "anonymous".to_string(),
35 groups: Vec::new(),
36 claims: HashMap::new(),
37 }
38 }
39
40 #[must_use]
42 pub fn new(name: impl Into<String>) -> Self {
43 Self {
44 name: name.into(),
45 groups: Vec::new(),
46 claims: HashMap::new(),
47 }
48 }
49
50 #[must_use]
52 pub fn with_group(mut self, g: impl Into<String>) -> Self {
53 self.groups.push(g.into());
54 self
55 }
56
57 #[must_use]
59 pub fn with_claim(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
60 self.claims.insert(k.into(), v.into());
61 self
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
67pub enum AuthError {
68 MissingCredentials,
70 MalformedCredentials(String),
72 Rejected(String),
74 Misconfigured(String),
76}
77
78impl core::fmt::Display for AuthError {
79 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
80 match self {
81 Self::MissingCredentials => f.write_str("missing credentials"),
82 Self::MalformedCredentials(m) => write!(f, "malformed credentials: {m}"),
83 Self::Rejected(m) => write!(f, "credentials rejected: {m}"),
84 Self::Misconfigured(m) => write!(f, "auth misconfigured: {m}"),
85 }
86 }
87}
88
89impl std::error::Error for AuthError {}
90
91#[derive(Debug, Clone)]
93pub enum AuthMode {
94 None,
96 Bearer {
99 tokens: HashMap<String, AuthSubject>,
101 },
102 Jwt {
105 pkcs1_pubkey_der: Vec<u8>,
107 expected_issuer: Option<String>,
109 },
110 Mtls,
114 SaslPlain {
117 users: HashMap<String, String>,
119 },
120}
121
122#[derive(Debug, Clone, Default)]
124pub struct AuthInput<'a> {
125 pub authorization_header: Option<&'a str>,
127 pub sasl_plain_blob: Option<&'a [u8]>,
129 pub mtls_subject: Option<AuthSubject>,
132}
133
134impl AuthMode {
135 pub fn validate(&self, input: &AuthInput<'_>) -> Result<AuthSubject, AuthError> {
140 match self {
141 Self::None => Ok(AuthSubject::anonymous()),
142 Self::Bearer { tokens } => {
143 let hdr = input
144 .authorization_header
145 .ok_or(AuthError::MissingCredentials)?;
146 let token = strip_bearer(hdr)?;
147 tokens
148 .get(token)
149 .cloned()
150 .ok_or_else(|| AuthError::Rejected("unknown bearer token".to_string()))
151 }
152 Self::Jwt {
153 pkcs1_pubkey_der,
154 expected_issuer,
155 } => {
156 let hdr = input
157 .authorization_header
158 .ok_or(AuthError::MissingCredentials)?;
159 let token = strip_bearer(hdr)?;
160 validate_jwt_rs256(token, pkcs1_pubkey_der, expected_issuer.as_deref())
161 }
162 Self::Mtls => input
163 .mtls_subject
164 .clone()
165 .ok_or_else(|| AuthError::Rejected("mTLS expected client cert".to_string())),
166 Self::SaslPlain { users } => {
167 let blob = input.sasl_plain_blob.ok_or(AuthError::MissingCredentials)?;
168 let (user, pass) = parse_sasl_plain(blob)?;
169 let stored = users
170 .get(user)
171 .ok_or_else(|| AuthError::Rejected("unknown user".to_string()))?;
172 if stored == pass {
173 Ok(AuthSubject::new(user))
174 } else {
175 Err(AuthError::Rejected("password mismatch".to_string()))
176 }
177 }
178 }
179 }
180}
181
182fn strip_bearer(hdr: &str) -> Result<&str, AuthError> {
183 let trimmed = hdr.trim();
184 let prefix = "Bearer ";
185 if trimmed.len() < prefix.len()
186 || !trimmed
187 .get(..prefix.len())
188 .is_some_and(|p| p.eq_ignore_ascii_case(prefix))
189 {
190 return Err(AuthError::MalformedCredentials(
191 "expected `Bearer …`".to_string(),
192 ));
193 }
194 Ok(trimmed[prefix.len()..].trim())
195}
196
197fn parse_sasl_plain(blob: &[u8]) -> Result<(&str, &str), AuthError> {
198 let mut parts = blob.splitn(3, |b| *b == 0);
200 let _authzid = parts
201 .next()
202 .ok_or(AuthError::MalformedCredentials("sasl-plain empty".into()))?;
203 let authcid = parts
204 .next()
205 .ok_or(AuthError::MalformedCredentials("sasl-plain no user".into()))?;
206 let passwd = parts
207 .next()
208 .ok_or(AuthError::MalformedCredentials("sasl-plain no pass".into()))?;
209 let user = core::str::from_utf8(authcid)
210 .map_err(|_| AuthError::MalformedCredentials("sasl-plain user utf8".into()))?;
211 let pass = core::str::from_utf8(passwd)
212 .map_err(|_| AuthError::MalformedCredentials("sasl-plain pass utf8".into()))?;
213 if user.is_empty() {
214 return Err(AuthError::MalformedCredentials(
215 "sasl-plain empty user".into(),
216 ));
217 }
218 Ok((user, pass))
219}
220
221fn validate_jwt_rs256(
224 token: &str,
225 pkcs1_pubkey_der: &[u8],
226 expected_issuer: Option<&str>,
227) -> Result<AuthSubject, AuthError> {
228 use base64::Engine as _;
229 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
230
231 let mut parts = token.split('.');
232 let h_b64 = parts
233 .next()
234 .ok_or_else(|| AuthError::MalformedCredentials("jwt: no header".into()))?;
235 let p_b64 = parts
236 .next()
237 .ok_or_else(|| AuthError::MalformedCredentials("jwt: no payload".into()))?;
238 let s_b64 = parts
239 .next()
240 .ok_or_else(|| AuthError::MalformedCredentials("jwt: no sig".into()))?;
241 if parts.next().is_some() {
242 return Err(AuthError::MalformedCredentials(
243 "jwt: too many segments".into(),
244 ));
245 }
246
247 let header_bytes = engine
248 .decode(h_b64)
249 .map_err(|e| AuthError::MalformedCredentials(format!("jwt header b64: {e}")))?;
250 let payload_bytes = engine
251 .decode(p_b64)
252 .map_err(|e| AuthError::MalformedCredentials(format!("jwt payload b64: {e}")))?;
253 let sig_bytes = engine
254 .decode(s_b64)
255 .map_err(|e| AuthError::MalformedCredentials(format!("jwt sig b64: {e}")))?;
256
257 let header_str = core::str::from_utf8(&header_bytes)
259 .map_err(|_| AuthError::MalformedCredentials("jwt header utf8".into()))?;
260 if !json_field_eq(header_str, "alg", "RS256") {
261 return Err(AuthError::Rejected("jwt: alg must be RS256".into()));
262 }
263
264 let signed = {
266 let mut v = Vec::with_capacity(h_b64.len() + 1 + p_b64.len());
267 v.extend_from_slice(h_b64.as_bytes());
268 v.push(b'.');
269 v.extend_from_slice(p_b64.as_bytes());
270 v
271 };
272 let pubkey = ring::signature::UnparsedPublicKey::new(
273 &ring::signature::RSA_PKCS1_2048_8192_SHA256,
274 pkcs1_pubkey_der,
275 );
276 pubkey
277 .verify(&signed, &sig_bytes)
278 .map_err(|_| AuthError::Rejected("jwt: signature invalid".into()))?;
279
280 let payload_str = core::str::from_utf8(&payload_bytes)
282 .map_err(|_| AuthError::MalformedCredentials("jwt payload utf8".into()))?;
283 let sub = json_field(payload_str, "sub")
284 .ok_or_else(|| AuthError::Rejected("jwt: no sub claim".into()))?;
285
286 if let Some(expected) = expected_issuer {
287 let iss = json_field(payload_str, "iss")
288 .ok_or_else(|| AuthError::Rejected("jwt: no iss claim".into()))?;
289 if iss != expected {
290 return Err(AuthError::Rejected(format!("jwt: iss != {expected}")));
291 }
292 }
293
294 let mut subj = AuthSubject::new(sub);
295 if let Some(groups_raw) = json_array(payload_str, "groups") {
296 for g in groups_raw {
297 subj.groups.push(g);
298 }
299 }
300 Ok(subj)
301}
302
303fn json_field(src: &str, key: &str) -> Option<String> {
308 let pat = format!("\"{key}\"");
309 let pos = src.find(&pat)?;
310 let after = &src[pos + pat.len()..];
311 let colon = after.find(':')?;
312 let rest = after[colon + 1..].trim_start();
313 if let Some(stripped) = rest.strip_prefix('"') {
314 let end = stripped.find('"')?;
315 Some(stripped[..end].to_string())
316 } else {
317 let end = rest
319 .find(|c: char| c == ',' || c == '}' || c.is_whitespace())
320 .unwrap_or(rest.len());
321 Some(rest[..end].to_string())
322 }
323}
324
325fn json_field_eq(src: &str, key: &str, expected: &str) -> bool {
326 json_field(src, key).is_some_and(|v| v == expected)
327}
328
329fn json_array(src: &str, key: &str) -> Option<Vec<String>> {
330 let pat = format!("\"{key}\"");
331 let pos = src.find(&pat)?;
332 let after = &src[pos + pat.len()..];
333 let colon = after.find(':')?;
334 let rest = after[colon + 1..].trim_start();
335 let stripped = rest.strip_prefix('[')?;
336 let end = stripped.find(']')?;
337 let inside = &stripped[..end];
338 let mut out = Vec::new();
339 for piece in inside.split(',') {
340 let p = piece.trim().trim_matches('"');
341 if !p.is_empty() {
342 out.push(p.to_string());
343 }
344 }
345 Some(out)
346}
347
348#[cfg(test)]
349#[allow(clippy::expect_used, clippy::unwrap_used)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn none_mode_yields_anonymous() {
355 let m = AuthMode::None;
356 let s = m.validate(&AuthInput::default()).unwrap();
357 assert_eq!(s.name, "anonymous");
358 }
359
360 #[test]
361 fn bearer_valid_token_accepted() {
362 let mut tokens = HashMap::new();
363 tokens.insert("secret123".to_string(), AuthSubject::new("alice"));
364 let m = AuthMode::Bearer { tokens };
365 let s = m
366 .validate(&AuthInput {
367 authorization_header: Some("Bearer secret123"),
368 ..Default::default()
369 })
370 .unwrap();
371 assert_eq!(s.name, "alice");
372 }
373
374 #[test]
375 fn bearer_invalid_token_rejected() {
376 let m = AuthMode::Bearer {
377 tokens: HashMap::new(),
378 };
379 let err = m
380 .validate(&AuthInput {
381 authorization_header: Some("Bearer wrong"),
382 ..Default::default()
383 })
384 .unwrap_err();
385 assert!(matches!(err, AuthError::Rejected(_)));
386 }
387
388 #[test]
389 fn bearer_missing_header_returns_missing() {
390 let m = AuthMode::Bearer {
391 tokens: HashMap::new(),
392 };
393 let err = m.validate(&AuthInput::default()).unwrap_err();
394 assert!(matches!(err, AuthError::MissingCredentials));
395 }
396
397 #[test]
398 fn bearer_malformed_header_returns_malformed() {
399 let m = AuthMode::Bearer {
400 tokens: HashMap::new(),
401 };
402 let err = m
403 .validate(&AuthInput {
404 authorization_header: Some("Basic xx"),
405 ..Default::default()
406 })
407 .unwrap_err();
408 assert!(matches!(err, AuthError::MalformedCredentials(_)));
409 }
410
411 #[test]
412 fn mtls_with_subject_accepted() {
413 let m = AuthMode::Mtls;
414 let s = m
415 .validate(&AuthInput {
416 mtls_subject: Some(AuthSubject::new("CN=alice")),
417 ..Default::default()
418 })
419 .unwrap();
420 assert_eq!(s.name, "CN=alice");
421 }
422
423 #[test]
424 fn mtls_without_subject_rejected() {
425 let m = AuthMode::Mtls;
426 let err = m.validate(&AuthInput::default()).unwrap_err();
427 assert!(matches!(err, AuthError::Rejected(_)));
428 }
429
430 #[test]
431 fn sasl_plain_valid_pair_accepted() {
432 let mut users = HashMap::new();
433 users.insert("alice".to_string(), "wonderland".to_string());
434 let m = AuthMode::SaslPlain { users };
435 let blob = b"\0alice\0wonderland";
436 let s = m
437 .validate(&AuthInput {
438 sasl_plain_blob: Some(blob),
439 ..Default::default()
440 })
441 .unwrap();
442 assert_eq!(s.name, "alice");
443 }
444
445 #[test]
446 fn sasl_plain_wrong_password_rejected() {
447 let mut users = HashMap::new();
448 users.insert("alice".to_string(), "wonderland".to_string());
449 let m = AuthMode::SaslPlain { users };
450 let blob = b"\0alice\0wrong";
451 let err = m
452 .validate(&AuthInput {
453 sasl_plain_blob: Some(blob),
454 ..Default::default()
455 })
456 .unwrap_err();
457 assert!(matches!(err, AuthError::Rejected(_)));
458 }
459
460 #[test]
461 fn json_field_extracts_string() {
462 let s = r#"{"alg":"RS256","typ":"JWT"}"#;
463 assert_eq!(json_field(s, "alg").as_deref(), Some("RS256"));
464 }
465
466 #[test]
467 fn json_array_extracts_groups() {
468 let s = r#"{"sub":"a","groups":["eng","ops"]}"#;
469 let g = json_array(s, "groups").unwrap();
470 assert_eq!(g, vec!["eng".to_string(), "ops".to_string()]);
471 }
472
473 #[test]
474 fn jwt_invalid_signature_rejected() {
475 let m = AuthMode::Jwt {
480 pkcs1_pubkey_der: vec![0u8; 32],
481 expected_issuer: None,
482 };
483 let token = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhIn0.AAAA";
486 let err = m
487 .validate(&AuthInput {
488 authorization_header: Some(&format!("Bearer {token}")),
489 ..Default::default()
490 })
491 .unwrap_err();
492 assert!(matches!(err, AuthError::Rejected(_)));
493 }
494
495 #[test]
496 fn jwt_wrong_alg_rejected() {
497 let m = AuthMode::Jwt {
498 pkcs1_pubkey_der: vec![0u8; 32],
499 expected_issuer: None,
500 };
501 let token = "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJhIn0.AAAA";
503 let err = m
504 .validate(&AuthInput {
505 authorization_header: Some(&format!("Bearer {token}")),
506 ..Default::default()
507 })
508 .unwrap_err();
509 assert!(matches!(err, AuthError::Rejected(_)));
510 }
511}