1use std::collections::HashMap;
45use std::sync::RwLock;
46
47use serde::{Deserialize, Serialize};
48use totp_rs::{Algorithm, TOTP};
49
50#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
56pub struct MfaSecret {
57 pub secret_base32: String,
60 pub serial: String,
65}
66
67#[derive(Debug, Default)]
73pub struct MfaDeleteManager {
74 default_secret: RwLock<Option<MfaSecret>>,
77 by_bucket: RwLock<HashMap<String, MfaSecret>>,
79 enabled: RwLock<HashMap<String, bool>>,
83}
84
85#[derive(Debug, Default, Serialize, Deserialize)]
88struct MfaSnapshot {
89 default_secret: Option<MfaSecret>,
90 by_bucket: HashMap<String, MfaSecret>,
91 enabled: HashMap<String, bool>,
92}
93
94impl MfaDeleteManager {
95 #[must_use]
98 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn set_default_secret(&self, secret: MfaSecret) {
106 *crate::lock_recovery::recover_write(&self.default_secret, "mfa.default_secret") =
107 Some(secret);
108 }
109
110 pub fn set_bucket_secret(&self, bucket: &str, secret: MfaSecret) {
112 crate::lock_recovery::recover_write(&self.by_bucket, "mfa.by_bucket")
113 .insert(bucket.to_owned(), secret);
114 }
115
116 pub fn set_bucket_state(&self, bucket: &str, enabled: bool) {
121 crate::lock_recovery::recover_write(&self.enabled, "mfa.enabled")
122 .insert(bucket.to_owned(), enabled);
123 }
124
125 #[must_use]
128 pub fn is_enabled(&self, bucket: &str) -> bool {
129 crate::lock_recovery::recover_read(&self.enabled, "mfa.enabled")
130 .get(bucket)
131 .copied()
132 .unwrap_or(false)
133 }
134
135 #[must_use]
139 pub fn lookup_secret(&self, bucket: &str) -> Option<MfaSecret> {
140 if let Some(s) = crate::lock_recovery::recover_read(&self.by_bucket, "mfa.by_bucket")
141 .get(bucket)
142 .cloned()
143 {
144 return Some(s);
145 }
146 crate::lock_recovery::recover_read(&self.default_secret, "mfa.default_secret").clone()
147 }
148
149 pub fn to_json(&self) -> Result<String, serde_json::Error> {
152 let snap = MfaSnapshot {
153 default_secret: crate::lock_recovery::recover_read(
154 &self.default_secret,
155 "mfa.default_secret",
156 )
157 .clone(),
158 by_bucket: crate::lock_recovery::recover_read(&self.by_bucket, "mfa.by_bucket").clone(),
159 enabled: crate::lock_recovery::recover_read(&self.enabled, "mfa.enabled").clone(),
160 };
161 serde_json::to_string(&snap)
162 }
163
164 pub fn from_json(s: &str) -> Result<Self, serde_json::Error> {
166 let snap: MfaSnapshot = serde_json::from_str(s)?;
167 Ok(Self {
168 default_secret: RwLock::new(snap.default_secret),
169 by_bucket: RwLock::new(snap.by_bucket),
170 enabled: RwLock::new(snap.enabled),
171 })
172 }
173}
174
175#[derive(Debug, thiserror::Error)]
177pub enum MfaError {
178 #[error("missing x-amz-mfa header (MFA Delete is Enabled on this bucket)")]
179 Missing,
180 #[error("malformed x-amz-mfa header")]
181 Malformed,
182 #[error("MFA serial does not match configured device")]
183 SerialMismatch,
184 #[error("invalid MFA code")]
185 InvalidCode,
186}
187
188pub fn parse_mfa_header(value: &str) -> Result<(String, String), MfaError> {
195 let mut parts = value.splitn(2, ' ');
196 let serial = parts.next().ok_or(MfaError::Malformed)?;
197 let code = parts.next().ok_or(MfaError::Malformed)?;
198 if serial.is_empty() || code.is_empty() {
199 return Err(MfaError::Malformed);
200 }
201 if value.split(' ').count() != 2 {
203 return Err(MfaError::Malformed);
204 }
205 if code.len() != 6 || !code.chars().all(|c| c.is_ascii_digit()) {
206 return Err(MfaError::Malformed);
207 }
208 Ok((serial.to_owned(), code.to_owned()))
209}
210
211#[must_use]
218pub fn verify_totp(secret_base32: &str, code: &str, now_unix_secs: u64) -> bool {
219 let Some(raw) = base32::decode(base32::Alphabet::Rfc4648 { padding: false }, secret_base32)
220 else {
221 return false;
222 };
223 let Ok(totp) = TOTP::new(Algorithm::SHA1, 6, 1, 30, raw) else {
224 return false;
225 };
226 totp.check(code, now_unix_secs)
227}
228
229pub fn check_mfa(
235 bucket: &str,
236 header_value: Option<&str>,
237 manager: &MfaDeleteManager,
238 now_unix_secs: u64,
239) -> Result<(), MfaError> {
240 if !manager.is_enabled(bucket) {
241 return Ok(());
242 }
243 let header = header_value.ok_or(MfaError::Missing)?;
244 let (serial, code) = parse_mfa_header(header)?;
245 let secret = manager.lookup_secret(bucket).ok_or(MfaError::InvalidCode)?;
246 if serial != secret.serial {
247 return Err(MfaError::SerialMismatch);
248 }
249 if !verify_totp(&secret.secret_base32, &code, now_unix_secs) {
250 return Err(MfaError::InvalidCode);
251 }
252 Ok(())
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 const TEST_SECRET_B32: &str = "JBSWY3DPEHPK3PXPJBSWY3DPEHPK3PXP";
265
266 fn raw_secret() -> Vec<u8> {
267 base32::decode(
268 base32::Alphabet::Rfc4648 { padding: false },
269 TEST_SECRET_B32,
270 )
271 .expect("decode test secret")
272 }
273
274 fn totp_at(time: u64) -> String {
275 let totp = TOTP::new(Algorithm::SHA1, 6, 1, 30, raw_secret()).expect("totp");
276 totp.generate(time)
277 }
278
279 #[test]
280 fn parse_mfa_header_happy_path() {
281 let (serial, code) = parse_mfa_header("SERIAL 123456").expect("parse");
282 assert_eq!(serial, "SERIAL");
283 assert_eq!(code, "123456");
284 }
285
286 #[test]
287 fn parse_mfa_header_rejects_no_space() {
288 let err = parse_mfa_header("SERIAL123456").expect_err("must fail");
289 assert!(matches!(err, MfaError::Malformed));
290 }
291
292 #[test]
293 fn parse_mfa_header_rejects_extra_token() {
294 let err = parse_mfa_header("SERIAL 123456 trailing").expect_err("must fail");
295 assert!(matches!(err, MfaError::Malformed));
296 }
297
298 #[test]
299 fn parse_mfa_header_rejects_non_digit_code() {
300 let err = parse_mfa_header("SERIAL 12345A").expect_err("must fail");
301 assert!(matches!(err, MfaError::Malformed));
302 }
303
304 #[test]
305 fn parse_mfa_header_rejects_wrong_length_code() {
306 for bad in ["SERIAL 12345", "SERIAL 1234567"] {
307 let err = parse_mfa_header(bad).expect_err("must fail");
308 assert!(matches!(err, MfaError::Malformed));
309 }
310 }
311
312 #[test]
313 fn parse_mfa_header_rejects_empty_serial_or_code() {
314 let err = parse_mfa_header(" 123456").expect_err("empty serial");
315 assert!(matches!(err, MfaError::Malformed));
316 let err = parse_mfa_header("SERIAL ").expect_err("empty code");
317 assert!(matches!(err, MfaError::Malformed));
318 }
319
320 #[test]
321 fn verify_totp_happy_path() {
322 let now = 1_700_000_000_u64;
323 let code = totp_at(now);
324 assert!(verify_totp(TEST_SECRET_B32, &code, now));
325 }
326
327 #[test]
328 fn verify_totp_clock_skew_within_one_step_ok() {
329 let now = 1_700_000_000_u64;
331 let code_prev = totp_at(now - 30);
332 assert!(
333 verify_totp(TEST_SECRET_B32, &code_prev, now),
334 "previous 30s window must validate"
335 );
336 let code_next = totp_at(now + 30);
337 assert!(
338 verify_totp(TEST_SECRET_B32, &code_next, now),
339 "next 30s window must validate"
340 );
341 }
342
343 #[test]
344 fn verify_totp_clock_skew_beyond_window_fails() {
345 let now = 1_700_000_000_u64;
348 let code_old = totp_at(now - 90);
349 assert!(!verify_totp(TEST_SECRET_B32, &code_old, now));
350 }
351
352 #[test]
353 fn verify_totp_wrong_code_fails() {
354 let now = 1_700_000_000_u64;
355 assert!(!verify_totp(TEST_SECRET_B32, "000000", now));
356 }
357
358 #[test]
359 fn verify_totp_short_secret_rejected() {
360 let short_b32 = "JBSWY3DP";
362 let now = 1_700_000_000_u64;
363 assert!(!verify_totp(short_b32, "000000", now));
364 }
365
366 #[test]
367 fn check_mfa_disabled_bucket_is_noop() {
368 let m = MfaDeleteManager::new();
369 assert!(check_mfa("b", None, &m, 0).is_ok());
372 assert!(check_mfa("b", Some("garbage"), &m, 0).is_ok());
373 }
374
375 #[test]
376 fn check_mfa_enabled_correct_code_ok() {
377 let m = MfaDeleteManager::new();
378 m.set_default_secret(MfaSecret {
379 secret_base32: TEST_SECRET_B32.to_owned(),
380 serial: "SERIAL-A".to_owned(),
381 });
382 m.set_bucket_state("b", true);
383 let now = 1_700_000_000_u64;
384 let code = totp_at(now);
385 let header = format!("SERIAL-A {code}");
386 assert!(check_mfa("b", Some(&header), &m, now).is_ok());
387 }
388
389 #[test]
390 fn check_mfa_enabled_wrong_code_fails() {
391 let m = MfaDeleteManager::new();
392 m.set_default_secret(MfaSecret {
393 secret_base32: TEST_SECRET_B32.to_owned(),
394 serial: "SERIAL-A".to_owned(),
395 });
396 m.set_bucket_state("b", true);
397 let now = 1_700_000_000_u64;
398 let err = check_mfa("b", Some("SERIAL-A 000000"), &m, now).expect_err("must fail");
399 assert!(matches!(err, MfaError::InvalidCode), "got {err:?}");
400 }
401
402 #[test]
403 fn check_mfa_enabled_missing_header_fails() {
404 let m = MfaDeleteManager::new();
405 m.set_default_secret(MfaSecret {
406 secret_base32: TEST_SECRET_B32.to_owned(),
407 serial: "SERIAL-A".to_owned(),
408 });
409 m.set_bucket_state("b", true);
410 let err = check_mfa("b", None, &m, 0).expect_err("must fail");
411 assert!(matches!(err, MfaError::Missing), "got {err:?}");
412 }
413
414 #[test]
415 fn check_mfa_enabled_serial_mismatch_fails() {
416 let m = MfaDeleteManager::new();
417 m.set_default_secret(MfaSecret {
418 secret_base32: TEST_SECRET_B32.to_owned(),
419 serial: "SERIAL-A".to_owned(),
420 });
421 m.set_bucket_state("b", true);
422 let now = 1_700_000_000_u64;
423 let code = totp_at(now);
424 let header = format!("SERIAL-OTHER {code}");
425 let err = check_mfa("b", Some(&header), &m, now).expect_err("must fail");
426 assert!(matches!(err, MfaError::SerialMismatch), "got {err:?}");
427 }
428
429 #[test]
430 fn check_mfa_per_bucket_override_takes_precedence() {
431 let m = MfaDeleteManager::new();
432 m.set_default_secret(MfaSecret {
433 secret_base32: TEST_SECRET_B32.to_owned(),
434 serial: "DEFAULT".to_owned(),
435 });
436 m.set_bucket_secret(
437 "b",
438 MfaSecret {
439 secret_base32: TEST_SECRET_B32.to_owned(),
440 serial: "BUCKET-OVERRIDE".to_owned(),
441 },
442 );
443 m.set_bucket_state("b", true);
444 let now = 1_700_000_000_u64;
445 let code = totp_at(now);
446 let header_default = format!("DEFAULT {code}");
448 assert!(matches!(
449 check_mfa("b", Some(&header_default), &m, now).expect_err("must fail"),
450 MfaError::SerialMismatch
451 ));
452 let header_override = format!("BUCKET-OVERRIDE {code}");
454 assert!(check_mfa("b", Some(&header_override), &m, now).is_ok());
455 }
456
457 #[test]
458 fn snapshot_roundtrip() {
459 let m = MfaDeleteManager::new();
460 m.set_default_secret(MfaSecret {
461 secret_base32: TEST_SECRET_B32.to_owned(),
462 serial: "DEFAULT".to_owned(),
463 });
464 m.set_bucket_secret(
465 "b1",
466 MfaSecret {
467 secret_base32: TEST_SECRET_B32.to_owned(),
468 serial: "B1-OVR".to_owned(),
469 },
470 );
471 m.set_bucket_state("b1", true);
472 m.set_bucket_state("b2", false);
473 let json = m.to_json().expect("to_json");
474 let m2 = MfaDeleteManager::from_json(&json).expect("from_json");
475 assert!(m2.is_enabled("b1"));
476 assert!(!m2.is_enabled("b2"));
477 let s = m2.lookup_secret("b1").expect("override survives");
478 assert_eq!(s.serial, "B1-OVR");
479 let s = m2.lookup_secret("other").expect("default survives");
481 assert_eq!(s.serial, "DEFAULT");
482 }
483
484 #[test]
489 fn mfa_to_json_after_panic_recovers_via_poison() {
490 let m = std::sync::Arc::new(MfaDeleteManager::new());
491 m.set_default_secret(MfaSecret {
492 secret_base32: TEST_SECRET_B32.to_owned(),
493 serial: "DEFAULT".to_owned(),
494 });
495 m.set_bucket_state("b", true);
496 let m_cl = std::sync::Arc::clone(&m);
497 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
498 let mut g = m_cl.enabled.write().expect("clean lock");
499 g.insert("b2".into(), true);
500 panic!("force-poison");
501 }));
502 assert!(
503 m.enabled.is_poisoned(),
504 "write panic must poison enabled lock"
505 );
506 let json = m.to_json().expect("to_json after poison must succeed");
507 let m2 = MfaDeleteManager::from_json(&json).expect("from_json");
508 assert!(m2.is_enabled("b"), "recovered snapshot keeps enabled flag");
509 let secret = m2.lookup_secret("b").expect("default secret survives");
510 assert_eq!(secret.serial, "DEFAULT");
511 }
512}