Skip to main content

srv_session/
lib.rs

1use argon2::{Algorithm, Argon2, Params, Version};
2use chrono::{DateTime, Duration, Utc};
3use log::info;
4use serde::{Deserialize, Serialize};
5use std::time::{Duration as StdDuration, Instant};
6
7pub const DEFAULT_HASH_LEN: usize = 32;
8pub const DEFAULT_PEPPER_LEN: usize = 16;
9pub const DEFAULT_SALT_LEN: usize = 16;
10pub const DEFAULT_SESSION_LEN: usize = 32;
11
12/// Serde helpers: hex array
13#[derive(Clone, Serialize, Deserialize)]
14pub struct HashConfig<const PEPPER_LEN: usize = DEFAULT_PEPPER_LEN> {
15    #[serde(with = "serde_hex_array")]
16    pub pepper: [u8; PEPPER_LEN],
17    pub memory_kib: u32,
18    pub time_cost: u32,
19    pub lanes: u32,
20}
21
22pub mod serde_hex_array {
23    use serde::{Deserialize, Deserializer, Serializer};
24
25    #[inline]
26    pub fn bytes_to_hex<const N: usize>(bytes: &[u8; N]) -> String {
27        let mut out = String::with_capacity(N * 2);
28        for b in bytes {
29            use core::fmt::Write;
30            let _ = write!(&mut out, "{:02x}", b);
31        }
32        out
33    }
34
35    #[inline]
36    pub fn hex_to_bytes<const N: usize>(s: &str) -> Result<[u8; N], String> {
37        if s.len() != N * 2 {
38            return Err(format!("expected {} bytes hex, got {}", N, s.len() / 2));
39        }
40        let mut out = [0u8; N];
41        for i in 0..N {
42            let idx = i * 2;
43            out[i] = u8::from_str_radix(&s[idx..idx + 2], 16)
44                .map_err(|e| format!("invalid hex: {}", e))?;
45        }
46        Ok(out)
47    }
48
49    pub fn serialize<S, const N: usize>(bytes: &[u8; N], s: S) -> Result<S::Ok, S::Error>
50    where
51        S: Serializer,
52    {
53        let out = bytes_to_hex(bytes);
54        s.serialize_str(&out)
55    }
56
57    pub fn deserialize<'de, D, const N: usize>(d: D) -> Result<[u8; N], D::Error>
58    where
59        D: Deserializer<'de>,
60    {
61        let s = String::deserialize(d)?;
62        hex_to_bytes::<N>(&s).map_err(serde::de::Error::custom)
63    }
64}
65
66pub mod serde_hex_array_vec {
67    use super::serde_hex_array::{bytes_to_hex, hex_to_bytes};
68    use serde::ser::SerializeSeq;
69    use serde::{Deserialize, Deserializer, Serializer};
70
71    // Vec<[u8; N]> <-> Vec<String(hex)>
72    pub fn serialize<S, const N: usize>(items: &Vec<[u8; N]>, s: S) -> Result<S::Ok, S::Error>
73    where
74        S: Serializer,
75    {
76        let mut seq = s.serialize_seq(Some(items.len()))?;
77        for it in items {
78            let hex = bytes_to_hex(it);
79            seq.serialize_element(&hex)?;
80        }
81        seq.end()
82    }
83
84    pub fn deserialize<'de, D, const N: usize>(d: D) -> Result<Vec<[u8; N]>, D::Error>
85    where
86        D: Deserializer<'de>,
87    {
88        let hex_strings: Vec<String> = Deserialize::deserialize(d)?;
89        let mut out = Vec::with_capacity(hex_strings.len());
90        for hex_string in hex_strings {
91            let bytes = hex_to_bytes::<N>(&hex_string).map_err(serde::de::Error::custom)?;
92            out.push(bytes);
93        }
94        Ok(out)
95    }
96}
97
98/// Hash benchmarking and config generation
99impl<const PEPPER_LEN: usize> HashConfig<PEPPER_LEN> {
100    pub fn benchmark(target_ms: u64) -> Self {
101        info!("Benchmarking HashConfig parameters...");
102        let test_password = "benchmark_password";
103        let salt = [0u8; 16];
104        let target_duration = StdDuration::from_millis(target_ms);
105
106        info!(
107            "Benchmark assumptions: target_duration={:?}, test_password='{}', salt={:?}",
108            target_duration, test_password, salt
109        );
110
111        let pepper = Self::generate_random_pepper();
112        info!("Generated random pepper for benchmark");
113
114        let best_memory = Self::binary_search_param(
115            target_duration,
116            |memory| {
117                let params = Params::new(memory, 3, 1, Some(32)).expect("argon2 params for memory");
118                let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
119                let start = Instant::now();
120                let mut out = [0u8; 32];
121                let mut adv = Vec::new();
122                adv.extend_from_slice(&salt);
123                adv.extend_from_slice(&pepper);
124                hasher
125                    .hash_password_into(test_password.as_bytes(), &adv, &mut out)
126                    .expect("hash during memory benchmark");
127                start.elapsed()
128            },
129            32768,
130            1048576,
131        );
132
133        let best_time = Self::binary_search_param(
134            target_duration,
135            |time| {
136                let params =
137                    Params::new(best_memory, time, 1, Some(32)).expect("argon2 params for time");
138                let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
139                let start = Instant::now();
140                let mut out = [0u8; 32];
141                let mut adv = Vec::new();
142                adv.extend_from_slice(&salt);
143                adv.extend_from_slice(&pepper);
144                hasher
145                    .hash_password_into(test_password.as_bytes(), &adv, &mut out)
146                    .expect("hash during time benchmark");
147                start.elapsed()
148            },
149            1,
150            10,
151        );
152
153        let best_lanes = Self::binary_search_param(
154            target_duration,
155            |lanes| {
156                let params = Params::new(best_memory, best_time, lanes, Some(32))
157                    .expect("argon2 params for lanes");
158                let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
159                let start = Instant::now();
160                let mut out = [0u8; 32];
161                let mut adv = Vec::new();
162                adv.extend_from_slice(&salt);
163                adv.extend_from_slice(&pepper);
164                hasher
165                    .hash_password_into(test_password.as_bytes(), &adv, &mut out)
166                    .expect("hash during lanes benchmark");
167                start.elapsed()
168            },
169            1,
170            8,
171        );
172
173        let best_config = Self {
174            pepper,
175            memory_kib: best_memory,
176            time_cost: best_time,
177            lanes: best_lanes,
178        };
179
180        let params = Params::new(best_memory, best_time, best_lanes, Some(32))
181            .expect("argon2 params for final measurement");
182        let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
183        let start = Instant::now();
184        let mut out = [0u8; 32];
185        let mut adv = Vec::new();
186        adv.extend_from_slice(&salt);
187        adv.extend_from_slice(&best_config.pepper);
188        hasher
189            .hash_password_into(test_password.as_bytes(), &adv, &mut out)
190            .expect("hash during final benchmark");
191        let final_duration = start.elapsed();
192
193        info!(
194            "Best HashConfig: memory={} KiB, time={}, lanes={}, duration={:?}",
195            best_config.memory_kib, best_config.time_cost, best_config.lanes, final_duration
196        );
197        best_config
198    }
199
200    fn generate_random_pepper() -> [u8; PEPPER_LEN] {
201        let mut bytes = [0u8; PEPPER_LEN];
202        getrandom::fill(&mut bytes).expect("generate random pepper");
203        bytes
204    }
205
206    fn binary_search_param<F>(target: StdDuration, measure: F, min: u32, max: u32) -> u32
207    where
208        F: Fn(u32) -> StdDuration,
209    {
210        let mut low = min;
211        let mut high = max;
212        let mut best = min;
213        let mut best_diff = StdDuration::from_secs(1000);
214
215        while low <= high {
216            let mid = low + (high - low) / 2;
217            let duration = measure(mid);
218            let diff = if duration > target {
219                duration - target
220            } else {
221                target - duration
222            };
223
224            if diff < best_diff {
225                best = mid;
226                best_diff = diff;
227            }
228
229            if duration < target {
230                low = mid + 1;
231            } else {
232                if mid == 0 {
233                    break;
234                }
235                high = mid - 1;
236            }
237        }
238
239        best
240    }
241}
242
243/// KV trait
244pub trait KVTrait<K, V>
245where
246    K: ?Sized,
247{
248    fn get(&self, key: &K) -> Option<V>;
249    fn set(&self, key: &K, value: V);
250    fn contains(&self, key: &K) -> bool;
251    fn delete(&self, key: &K) -> bool;
252}
253
254#[derive(Clone, Serialize, Deserialize)]
255pub struct SessionValue<const SESSION_LEN: usize> {
256    #[serde(with = "serde_hex_array")]
257    pub session_key: [u8; SESSION_LEN],
258    pub linked_accounts_cache: Vec<Box<str>>,
259    pub last_time: DateTime<Utc>,
260    pub created_time: DateTime<Utc>,
261    pub primary_account: Option<Box<str>>,
262}
263
264#[derive(Clone, Serialize, Deserialize)]
265pub struct AccountValue<const SALT_LEN: usize, const HASH_LEN: usize, const SESSION_LEN: usize> {
266    #[serde(with = "serde_hex_array")]
267    pub password_hash: [u8; HASH_LEN],
268    #[serde(with = "serde_hex_array")]
269    pub salt: [u8; SALT_LEN],
270    pub last_time: DateTime<Utc>,
271    #[serde(with = "serde_hex_array_vec")]
272    pub authed_linked_sessions: Vec<[u8; SESSION_LEN]>,
273}
274
275/// Fast Lock
276mod account_lock {
277    use ahash::AHasher;
278    use parking_lot::{Mutex, MutexGuard};
279    use std::hash::{Hash, Hasher};
280
281    // SHARDS must be power of two for bitmask
282    pub struct AccountLocks<const SHARDS: usize> {
283        locks: [Mutex<()>; SHARDS],
284    }
285
286    impl<const SHARDS: usize> AccountLocks<SHARDS> {
287        pub fn new() -> Self {
288            debug_assert!(SHARDS.is_power_of_two());
289            Self {
290                locks: std::array::from_fn(|_| Mutex::new(())),
291            }
292        }
293
294        #[inline]
295        fn shard_for_username(username: &str) -> usize {
296            let mut h = AHasher::default();
297            username.hash(&mut h);
298            (h.finish() as usize) & (SHARDS - 1)
299        }
300
301        #[inline]
302        pub fn lock_account<'a>(&'a self, username: &str) -> MutexGuard<'a, ()> {
303            let idx = Self::shard_for_username(username);
304            self.locks[idx].lock()
305        }
306    }
307}
308
309
310/// Main AuthManager
311/// - sessions はキャッシュ扱い。必要なら verify して primary_account を落とす。
312/// - account 側が正。session はキャッシュ更新
313pub struct AuthManager<
314    S,
315    A,
316    const SESSION_LEN: usize = DEFAULT_SESSION_LEN,
317    const HASH_LEN: usize = DEFAULT_HASH_LEN,
318    const PEPPER_LEN: usize = DEFAULT_PEPPER_LEN,
319    const SALT_LEN: usize = DEFAULT_SALT_LEN,
320    const ACCOUNT_LOCK_SHARDS: usize = 4096,
321> where
322    S: KVTrait<[u8; SESSION_LEN], SessionValue<SESSION_LEN>> + Send + Sync,
323    A: KVTrait<str, AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> + Send + Sync,
324{
325    // NOTE: これらを pub にすると、AuthManager を介さない更新で整合が壊れる可能性が上がる
326    pub sessions: S,
327    pub accounts: A,
328
329    pub session_timeout: Duration,
330    pub account_timeout: Duration,
331    pub password_hasher: Argon2<'static>,
332    pub pepper: [u8; PEPPER_LEN],
333
334    account_locks: account_lock::AccountLocks<ACCOUNT_LOCK_SHARDS>,
335}
336
337impl<
338    S,
339    A,
340    const SESSION_LEN: usize,
341    const HASH_LEN: usize,
342    const PEPPER_LEN: usize,
343    const SALT_LEN: usize,
344    const ACCOUNT_LOCK_SHARDS: usize,
345> AuthManager<S, A, SESSION_LEN, HASH_LEN, PEPPER_LEN, SALT_LEN, ACCOUNT_LOCK_SHARDS>
346where
347    S: KVTrait<[u8; SESSION_LEN], SessionValue<SESSION_LEN>> + Send + Sync,
348    A: KVTrait<str, AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> + Send + Sync,
349{
350    pub fn new(
351        sessions: S,
352        accounts: A,
353        session_timeout: Duration,
354        account_timeout: Duration,
355        hash_config: HashConfig<PEPPER_LEN>,
356    ) -> Self {
357        Self {
358            sessions,
359            accounts,
360            session_timeout,
361            account_timeout,
362            password_hasher: Argon2::new(
363                Algorithm::Argon2id,
364                Version::V0x13,
365                Params::new(
366                    hash_config.memory_kib,
367                    hash_config.time_cost,
368                    hash_config.lanes,
369                    Some(HASH_LEN),
370                )
371                .expect("argon2 hash params"),
372            ),
373            pepper: hash_config.pepper,
374            account_locks: account_lock::AccountLocks::new(),
375        }
376    }
377
378    /// session_id はランダム生成。衝突したらリトライ。
379    pub fn create_session(&self) -> [u8; SESSION_LEN] {
380        let session_id = Self::generate_session();
381        if self.sessions.contains(&session_id) {
382            return self.create_session();
383        }
384        let session_value = SessionValue::<SESSION_LEN> {
385            session_key: session_id,
386            linked_accounts_cache: Vec::new(),
387            last_time: Utc::now(),
388            created_time: Utc::now(),
389            primary_account: None,
390        };
391        self.sessions.set(&session_id, session_value);
392        session_id
393    }
394
395    pub fn delete_session(&self, session_id: &[u8; SESSION_LEN]) -> bool {
396        self.sessions.delete(session_id)
397    }
398
399    /// ガード入り
400    /// Noneならcreate sessionするなりするとよい
401    pub fn get_and_verify_session(
402        &self,
403        session_id: &[u8; SESSION_LEN],
404    ) -> Option<SessionValue<SESSION_LEN>> {
405        if let Some(mut session) = self.update_or_gc_session(session_id) {
406            if let Some(primary) = session.primary_account.clone() {
407                // auth_verify の正は account 側
408                if !self.auth_verify(session_id, &primary) {
409                    session.primary_account = None;
410                    // sessionはキャッシュ。ここで set するのはベストエフォート
411                    self.sessions.set(session_id, session.clone());
412                }
413            }
414            return Some(session);
415        }
416        None
417    }
418
419    pub fn update_or_gc_session(
420        &self,
421        session_id: &[u8; SESSION_LEN],
422    ) -> Option<SessionValue<SESSION_LEN>> {
423        if let Some(mut session) = self.gc_sessions(session_id) {
424            session.last_time = Utc::now();
425            self.sessions.set(session_id, session.clone());
426            return Some(session);
427        }
428        None
429    }
430
431    pub fn gc_sessions(&self, session_id: &[u8; SESSION_LEN]) -> Option<SessionValue<SESSION_LEN>> {
432        if let Some(session) = self.sessions.get(session_id) {
433            let now = Utc::now();
434            if now - session.last_time > self.session_timeout {
435                let _ = self.delete_session(session_id);
436                return None;
437            }
438            Some(session)
439        } else {
440            None
441        }
442    }
443
444    pub fn set_primary_account(&self, session_id: &[u8; SESSION_LEN], username: &str) -> bool {
445        // primary_account はヒントなので verify 必須
446        if self.auth_verify(session_id, username) {
447            if let Some(mut session) = self.sessions.get(session_id) {
448                session.primary_account = Some(username.into());
449                self.sessions.set(session_id, session);
450                return true;
451            }
452        }
453        false
454    }
455
456    pub fn add_account(&self, username: &str, password: &str) {
457        let _g = self.account_locks.lock_account(username);
458
459        let salt = Self::generate_random_salt();
460        let password_hash = self.hash_password(password, &salt);
461        let account_value = AccountValue::<SALT_LEN, HASH_LEN, SESSION_LEN> {
462            password_hash,
463            salt,
464            last_time: Utc::now(),
465            authed_linked_sessions: Vec::new(),
466        };
467        self.accounts.set(username, account_value);
468    }
469
470    pub fn delete_account(&self, username: &str) -> bool {
471        let _g = self.account_locks.lock_account(username);
472
473        // session側はキャッシュ扱いなので、この掃除はベストエフォート
474        if let Some(account) = self.accounts.get(username) {
475            for session_id in &account.authed_linked_sessions {
476                if let Some(mut session_value) = self.sessions.get(session_id) {
477                    session_value
478                        .linked_accounts_cache
479                        .retain(|a| a.as_ref() != username);
480                    if session_value.primary_account.as_deref() == Some(username) {
481                        session_value.primary_account = None;
482                    }
483                    self.sessions.set(session_id, session_value);
484                }
485            }
486            self.accounts.delete(username)
487        } else {
488            false
489        }
490    }
491
492    pub fn get_account(
493        &self,
494        username: &str,
495    ) -> Option<AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> {
496        self.accounts.get(username)
497    }
498
499    pub fn auth_login(
500        &self,
501        session_id: &[u8; SESSION_LEN],
502        username: &str,
503        password: &str,
504    ) -> bool {
505        let _g = self.account_locks.lock_account(username);
506
507        if let Some(mut account) = self.accounts.get(username) {
508            let expected_hash = self.hash_password(password, &account.salt);
509            if expected_hash != account.password_hash {
510                return false;
511            }
512
513            // account側が正。ここだけ確実に更新する
514            if !account.authed_linked_sessions.contains(session_id) {
515                account.authed_linked_sessions.push(*session_id);
516                account.last_time = Utc::now();
517                self.accounts.set(username, account);
518            }
519
520            // session側はキャッシュ更新(ベストエフォート)
521            if let Some(mut session) = self.sessions.get(session_id) {
522                if !session
523                    .linked_accounts_cache
524                    .iter()
525                    .any(|a| a.as_ref() == username)
526                {
527                    session.linked_accounts_cache.push(username.into());
528                }
529                self.sessions.set(session_id, session);
530            }
531
532            return true;
533        }
534        false
535    }
536
537    /// 認可
538    pub fn auth_verify(&self, session_id: &[u8; SESSION_LEN], username: &str) -> bool {
539        if let Some(account) = self.accounts.get(username) {
540            account.authed_linked_sessions.contains(session_id)
541        } else {
542            false
543        }
544    }
545
546    /// logout は account 側だけ更新
547    pub fn auth_logout(&self, session_id: &[u8; SESSION_LEN], username: &str) -> bool {
548        let _g = self.account_locks.lock_account(username);
549
550        if let Some(mut account) = self.accounts.get(username) {
551            let before = account.authed_linked_sessions.len();
552            account.authed_linked_sessions.retain(|s| s != session_id);
553            let changed = account.authed_linked_sessions.len() != before;
554            if changed {
555                account.last_time = Utc::now();
556                self.accounts.set(username, account);
557            }
558            return changed;
559        }
560        false
561    }
562
563    fn hash_password(&self, password: &str, salt: &[u8; SALT_LEN]) -> [u8; HASH_LEN] {
564        let mut out = [0u8; HASH_LEN];
565        let mut adv = Vec::with_capacity(SALT_LEN + PEPPER_LEN);
566        adv.extend_from_slice(salt);
567        adv.extend_from_slice(&self.pepper);
568        self.password_hasher
569            .hash_password_into(password.as_bytes(), &adv, &mut out)
570            .expect("argon2 hash_password_into");
571        out
572    }
573
574    fn generate_random_salt() -> [u8; SALT_LEN] {
575        let mut salt = [0u8; SALT_LEN];
576        getrandom::fill(&mut salt).expect("generate random salt");
577        salt
578    }
579
580    fn generate_session() -> [u8; SESSION_LEN] {
581        let mut session_id = [0u8; SESSION_LEN];
582        getrandom::fill(&mut session_id).expect("generate random session ID");
583        session_id
584    }
585}