1use argon2::{Algorithm, Argon2, Params, Version};
2use std::time::{Duration, SystemTime};
3use log::info;
4use serde::{Deserialize, Serialize};
5use std::time::{Duration as StdDuration, Instant};
6
7pub const DEFAULT_HASH_LEN: usize = 32;
8pub const DEFAULT_PEPPER_LEN: usize = 16;
9pub const DEFAULT_SALT_LEN: usize = 16;
10pub const DEFAULT_SESSION_LEN: usize = 32;
11
12#[derive(Clone, Serialize, Deserialize)]
14pub struct HashConfig<const PEPPER_LEN: usize = DEFAULT_PEPPER_LEN> {
15 #[serde(with = "serde_hex_array")]
16 pub pepper: [u8; PEPPER_LEN],
17 pub memory_kib: u32,
18 pub time_cost: u32,
19 pub lanes: u32,
20}
21
22pub mod serde_hex_array {
23 use serde::{Deserialize, Deserializer, Serializer};
24
25 #[inline]
26 pub fn bytes_to_hex<const N: usize>(bytes: &[u8; N]) -> String {
27 let mut out = String::with_capacity(N * 2);
28 for b in bytes {
29 use core::fmt::Write;
30 let _ = write!(&mut out, "{:02x}", b);
31 }
32 out
33 }
34
35 #[inline]
36 pub fn hex_to_bytes<const N: usize>(s: &str) -> Result<[u8; N], String> {
37 if s.len() != N * 2 {
38 return Err(format!("expected {} bytes hex, got {}", N, s.len() / 2));
39 }
40 let mut out = [0u8; N];
41 for i in 0..N {
42 let idx = i * 2;
43 out[i] = u8::from_str_radix(&s[idx..idx + 2], 16)
44 .map_err(|e| format!("invalid hex: {}", e))?;
45 }
46 Ok(out)
47 }
48
49 pub fn serialize<S, const N: usize>(bytes: &[u8; N], s: S) -> Result<S::Ok, S::Error>
50 where
51 S: Serializer,
52 {
53 let out = bytes_to_hex(bytes);
54 s.serialize_str(&out)
55 }
56
57 pub fn deserialize<'de, D, const N: usize>(d: D) -> Result<[u8; N], D::Error>
58 where
59 D: Deserializer<'de>,
60 {
61 let s = String::deserialize(d)?;
62 hex_to_bytes::<N>(&s).map_err(serde::de::Error::custom)
63 }
64}
65
66pub mod serde_hex_array_vec {
67 use super::serde_hex_array::{bytes_to_hex, hex_to_bytes};
68 use serde::ser::SerializeSeq;
69 use serde::{Deserialize, Deserializer, Serializer};
70
71 pub fn serialize<S, const N: usize>(items: &Vec<[u8; N]>, s: S) -> Result<S::Ok, S::Error>
73 where
74 S: Serializer,
75 {
76 let mut seq = s.serialize_seq(Some(items.len()))?;
77 for it in items {
78 let hex = bytes_to_hex(it);
79 seq.serialize_element(&hex)?;
80 }
81 seq.end()
82 }
83
84 pub fn deserialize<'de, D, const N: usize>(d: D) -> Result<Vec<[u8; N]>, D::Error>
85 where
86 D: Deserializer<'de>,
87 {
88 let hex_strings: Vec<String> = Deserialize::deserialize(d)?;
89 let mut out = Vec::with_capacity(hex_strings.len());
90 for hex_string in hex_strings {
91 let bytes = hex_to_bytes::<N>(&hex_string).map_err(serde::de::Error::custom)?;
92 out.push(bytes);
93 }
94 Ok(out)
95 }
96}
97
98impl<const PEPPER_LEN: usize> HashConfig<PEPPER_LEN> {
100 pub fn benchmark(target_ms: u64) -> Self {
101 info!("Benchmarking HashConfig parameters...");
102 let test_password = "benchmark_password";
103 let salt = [0u8; 16];
104 let target_duration = StdDuration::from_millis(target_ms);
105
106 info!(
107 "Benchmark assumptions: target_duration={:?}, test_password='{}', salt={:?}",
108 target_duration, test_password, salt
109 );
110
111 let pepper = Self::generate_random_pepper();
112 info!("Generated random pepper for benchmark");
113
114 let best_memory = Self::binary_search_param(
115 target_duration,
116 |memory| {
117 let params = Params::new(memory, 3, 1, Some(32)).expect("argon2 params for memory");
118 let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
119 let start = Instant::now();
120 let mut out = [0u8; 32];
121 let mut adv = Vec::new();
122 adv.extend_from_slice(&salt);
123 adv.extend_from_slice(&pepper);
124 hasher
125 .hash_password_into(test_password.as_bytes(), &adv, &mut out)
126 .expect("hash during memory benchmark");
127 start.elapsed()
128 },
129 32768,
130 1048576,
131 );
132
133 let best_time = Self::binary_search_param(
134 target_duration,
135 |time| {
136 let params =
137 Params::new(best_memory, time, 1, Some(32)).expect("argon2 params for time");
138 let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
139 let start = Instant::now();
140 let mut out = [0u8; 32];
141 let mut adv = Vec::new();
142 adv.extend_from_slice(&salt);
143 adv.extend_from_slice(&pepper);
144 hasher
145 .hash_password_into(test_password.as_bytes(), &adv, &mut out)
146 .expect("hash during time benchmark");
147 start.elapsed()
148 },
149 1,
150 10,
151 );
152
153 let best_lanes = Self::binary_search_param(
154 target_duration,
155 |lanes| {
156 let params = Params::new(best_memory, best_time, lanes, Some(32))
157 .expect("argon2 params for lanes");
158 let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
159 let start = Instant::now();
160 let mut out = [0u8; 32];
161 let mut adv = Vec::new();
162 adv.extend_from_slice(&salt);
163 adv.extend_from_slice(&pepper);
164 hasher
165 .hash_password_into(test_password.as_bytes(), &adv, &mut out)
166 .expect("hash during lanes benchmark");
167 start.elapsed()
168 },
169 1,
170 8,
171 );
172
173 let best_config = Self {
174 pepper,
175 memory_kib: best_memory,
176 time_cost: best_time,
177 lanes: best_lanes,
178 };
179
180 let params = Params::new(best_memory, best_time, best_lanes, Some(32))
181 .expect("argon2 params for final measurement");
182 let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
183 let start = Instant::now();
184 let mut out = [0u8; 32];
185 let mut adv = Vec::new();
186 adv.extend_from_slice(&salt);
187 adv.extend_from_slice(&best_config.pepper);
188 hasher
189 .hash_password_into(test_password.as_bytes(), &adv, &mut out)
190 .expect("hash during final benchmark");
191 let final_duration = start.elapsed();
192
193 info!(
194 "Best HashConfig: memory={} KiB, time={}, lanes={}, duration={:?}",
195 best_config.memory_kib, best_config.time_cost, best_config.lanes, final_duration
196 );
197 best_config
198 }
199
200 fn generate_random_pepper() -> [u8; PEPPER_LEN] {
201 let mut bytes = [0u8; PEPPER_LEN];
202 getrandom::fill(&mut bytes).expect("generate random pepper");
203 bytes
204 }
205
206 fn binary_search_param<F>(target: StdDuration, measure: F, min: u32, max: u32) -> u32
207 where
208 F: Fn(u32) -> StdDuration,
209 {
210 let mut low = min;
211 let mut high = max;
212 let mut best = min;
213 let mut best_diff = StdDuration::from_secs(1000);
214
215 while low <= high {
216 let mid = low + (high - low) / 2;
217 let duration = measure(mid);
218 let diff = if duration > target {
219 duration - target
220 } else {
221 target - duration
222 };
223
224 if diff < best_diff {
225 best = mid;
226 best_diff = diff;
227 }
228
229 if duration < target {
230 low = mid + 1;
231 } else {
232 if mid == 0 {
233 break;
234 }
235 high = mid - 1;
236 }
237 }
238
239 best
240 }
241}
242
243pub trait KVTrait<K, V>
245where
246 K: ?Sized,
247{
248 fn get(&self, key: &K) -> Option<V>;
249 fn set(&self, key: &K, value: V);
250 fn contains(&self, key: &K) -> bool;
251 fn delete(&self, key: &K) -> bool;
252}
253
254#[derive(Clone, Serialize, Deserialize)]
255pub struct SessionValue<const SESSION_LEN: usize> {
256 #[serde(with = "serde_hex_array")]
257 pub session_key: [u8; SESSION_LEN],
258 pub linked_accounts_cache: Vec<Box<str>>,
259 pub last_time: SystemTime,
260 pub created_time: SystemTime,
261 pub primary_account: Option<Box<str>>,
262}
263
264#[derive(Clone, Serialize, Deserialize)]
265pub struct AccountValue<const SALT_LEN: usize, const HASH_LEN: usize, const SESSION_LEN: usize> {
266 #[serde(with = "serde_hex_array")]
267 pub password_hash: [u8; HASH_LEN],
268 #[serde(with = "serde_hex_array")]
269 pub salt: [u8; SALT_LEN],
270 pub last_time: SystemTime,
271 #[serde(with = "serde_hex_array_vec")]
272 pub authed_linked_sessions: Vec<[u8; SESSION_LEN]>,
273}
274
275mod account_lock {
277 use ahash::AHasher;
278 use parking_lot::{Mutex, MutexGuard};
279 use std::hash::{Hash, Hasher};
280
281 pub struct AccountLocks<const SHARDS: usize> {
283 locks: [Mutex<()>; SHARDS],
284 }
285
286 impl<const SHARDS: usize> AccountLocks<SHARDS> {
287 pub fn new() -> Self {
288 debug_assert!(SHARDS.is_power_of_two());
289 Self {
290 locks: std::array::from_fn(|_| Mutex::new(())),
291 }
292 }
293
294 #[inline]
295 fn shard_for_username(username: &str) -> usize {
296 let mut h = AHasher::default();
297 username.hash(&mut h);
298 (h.finish() as usize) & (SHARDS - 1)
299 }
300
301 #[inline]
302 pub fn lock_account<'a>(&'a self, username: &str) -> MutexGuard<'a, ()> {
303 let idx = Self::shard_for_username(username);
304 self.locks[idx].lock()
305 }
306 }
307}
308
309
310pub struct AuthManager<
314 S,
315 A,
316 const SESSION_LEN: usize = DEFAULT_SESSION_LEN,
317 const HASH_LEN: usize = DEFAULT_HASH_LEN,
318 const PEPPER_LEN: usize = DEFAULT_PEPPER_LEN,
319 const SALT_LEN: usize = DEFAULT_SALT_LEN,
320 const ACCOUNT_LOCK_SHARDS: usize = 4096,
321> where
322 S: KVTrait<[u8; SESSION_LEN], SessionValue<SESSION_LEN>> + Send + Sync,
323 A: KVTrait<str, AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> + Send + Sync,
324{
325 pub sessions: S,
327 pub accounts: A,
328
329 pub session_timeout: Duration,
330 pub account_timeout: Duration,
331 pub password_hasher: Argon2<'static>,
332 pub pepper: [u8; PEPPER_LEN],
333
334 account_locks: account_lock::AccountLocks<ACCOUNT_LOCK_SHARDS>,
335}
336
337impl<
338 S,
339 A,
340 const SESSION_LEN: usize,
341 const HASH_LEN: usize,
342 const PEPPER_LEN: usize,
343 const SALT_LEN: usize,
344 const ACCOUNT_LOCK_SHARDS: usize,
345> AuthManager<S, A, SESSION_LEN, HASH_LEN, PEPPER_LEN, SALT_LEN, ACCOUNT_LOCK_SHARDS>
346where
347 S: KVTrait<[u8; SESSION_LEN], SessionValue<SESSION_LEN>> + Send + Sync,
348 A: KVTrait<str, AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> + Send + Sync,
349{
350 pub fn new(
351 sessions: S,
352 accounts: A,
353 session_timeout: Duration,
354 account_timeout: Duration,
355 hash_config: HashConfig<PEPPER_LEN>,
356 ) -> Self {
357 Self {
358 sessions,
359 accounts,
360 session_timeout,
361 account_timeout,
362 password_hasher: Argon2::new(
363 Algorithm::Argon2id,
364 Version::V0x13,
365 Params::new(
366 hash_config.memory_kib,
367 hash_config.time_cost,
368 hash_config.lanes,
369 Some(HASH_LEN),
370 )
371 .expect("argon2 hash params"),
372 ),
373 pepper: hash_config.pepper,
374 account_locks: account_lock::AccountLocks::new(),
375 }
376 }
377
378 pub fn create_session(&self) -> [u8; SESSION_LEN] {
380 let session_id = Self::generate_session();
381 if self.sessions.contains(&session_id) {
382 return self.create_session();
383 }
384 let session_value = SessionValue::<SESSION_LEN> {
385 session_key: session_id,
386 linked_accounts_cache: Vec::new(),
387 last_time: SystemTime::now(),
388 created_time: SystemTime::now(),
389 primary_account: None,
390 };
391 self.sessions.set(&session_id, session_value);
392 session_id
393 }
394
395 pub fn delete_session(&self, session_id: &[u8; SESSION_LEN]) -> bool {
396 self.sessions.delete(session_id)
397 }
398
399 pub fn get_and_verify_session(
402 &self,
403 session_id: &[u8; SESSION_LEN],
404 ) -> Option<SessionValue<SESSION_LEN>> {
405 if let Some(mut session) = self.update_or_gc_session(session_id) {
406 if let Some(primary) = session.primary_account.clone() {
407 if !self.auth_verify(session_id, &primary) {
409 session.primary_account = None;
410 self.sessions.set(session_id, session.clone());
412 }
413 }
414 return Some(session);
415 }
416 None
417 }
418
419 pub fn update_or_gc_session(
420 &self,
421 session_id: &[u8; SESSION_LEN],
422 ) -> Option<SessionValue<SESSION_LEN>> {
423 if let Some(mut session) = self.gc_sessions(session_id) {
424 session.last_time = SystemTime::now();
425 self.sessions.set(session_id, session.clone());
426 return Some(session);
427 }
428 None
429 }
430
431 pub fn gc_sessions(&self, session_id: &[u8; SESSION_LEN]) -> Option<SessionValue<SESSION_LEN>> {
432 if let Some(session) = self.sessions.get(session_id) {
433 let now = SystemTime::now();
434 if now.duration_since(session.last_time).unwrap_or(Duration::from_secs(0)) > self.session_timeout {
435 let _ = self.delete_session(session_id);
436 return None;
437 }
438 Some(session)
439 } else {
440 None
441 }
442 }
443
444 pub fn set_primary_account(&self, session_id: &[u8; SESSION_LEN], username: &str) -> bool {
445 if self.auth_verify(session_id, username) {
447 if let Some(mut session) = self.sessions.get(session_id) {
448 session.primary_account = Some(username.into());
449 self.sessions.set(session_id, session);
450 return true;
451 }
452 }
453 false
454 }
455
456 pub fn add_account(&self, username: &str, password: &str) {
457 let _g = self.account_locks.lock_account(username);
458
459 let salt = Self::generate_random_salt();
460 let password_hash = self.hash_password(password, &salt);
461 let account_value = AccountValue::<SALT_LEN, HASH_LEN, SESSION_LEN> {
462 password_hash,
463 salt,
464 last_time: SystemTime::now(),
465 authed_linked_sessions: Vec::new(),
466 };
467 self.accounts.set(username, account_value);
468 }
469
470 pub fn delete_account(&self, username: &str) -> bool {
471 let _g = self.account_locks.lock_account(username);
472
473 if let Some(account) = self.accounts.get(username) {
475 for session_id in &account.authed_linked_sessions {
476 if let Some(mut session_value) = self.sessions.get(session_id) {
477 session_value
478 .linked_accounts_cache
479 .retain(|a| a.as_ref() != username);
480 if session_value.primary_account.as_deref() == Some(username) {
481 session_value.primary_account = None;
482 }
483 self.sessions.set(session_id, session_value);
484 }
485 }
486 self.accounts.delete(username)
487 } else {
488 false
489 }
490 }
491
492 pub fn get_account(
493 &self,
494 username: &str,
495 ) -> Option<AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> {
496 self.accounts.get(username)
497 }
498
499 pub fn auth_login(
500 &self,
501 session_id: &[u8; SESSION_LEN],
502 username: &str,
503 password: &str,
504 ) -> bool {
505 let _g = self.account_locks.lock_account(username);
506
507 if let Some(mut account) = self.accounts.get(username) {
508 let expected_hash = self.hash_password(password, &account.salt);
509 if expected_hash != account.password_hash {
510 return false;
511 }
512
513 if !account.authed_linked_sessions.contains(session_id) {
515 account.authed_linked_sessions.push(*session_id);
516 account.last_time = SystemTime::now();
517 self.accounts.set(username, account);
518 }
519
520 if let Some(mut session) = self.sessions.get(session_id) {
522 if !session
523 .linked_accounts_cache
524 .iter()
525 .any(|a| a.as_ref() == username)
526 {
527 session.linked_accounts_cache.push(username.into());
528 }
529 self.sessions.set(session_id, session);
530 }
531
532 return true;
533 }
534 false
535 }
536
537 pub fn auth_verify(&self, session_id: &[u8; SESSION_LEN], username: &str) -> bool {
539 if let Some(account) = self.accounts.get(username) {
540 account.authed_linked_sessions.contains(session_id)
541 } else {
542 false
543 }
544 }
545
546 pub fn auth_logout(&self, session_id: &[u8; SESSION_LEN], username: &str) -> bool {
548 let _g = self.account_locks.lock_account(username);
549
550 if let Some(mut account) = self.accounts.get(username) {
551 let before = account.authed_linked_sessions.len();
552 account.authed_linked_sessions.retain(|s| s != session_id);
553 let changed = account.authed_linked_sessions.len() != before;
554 if changed {
555 account.last_time = SystemTime::now();
556 self.accounts.set(username, account);
557 }
558 return changed;
559 }
560 false
561 }
562
563 fn hash_password(&self, password: &str, salt: &[u8; SALT_LEN]) -> [u8; HASH_LEN] {
564 let mut out = [0u8; HASH_LEN];
565 let mut adv = Vec::with_capacity(SALT_LEN + PEPPER_LEN);
566 adv.extend_from_slice(salt);
567 adv.extend_from_slice(&self.pepper);
568 self.password_hasher
569 .hash_password_into(password.as_bytes(), &adv, &mut out)
570 .expect("argon2 hash_password_into");
571 out
572 }
573
574 fn generate_random_salt() -> [u8; SALT_LEN] {
575 let mut salt = [0u8; SALT_LEN];
576 getrandom::fill(&mut salt).expect("generate random salt");
577 salt
578 }
579
580 fn generate_session() -> [u8; SESSION_LEN] {
581 let mut session_id = [0u8; SESSION_LEN];
582 getrandom::fill(&mut session_id).expect("generate random session ID");
583 session_id
584 }
585}