1use hmac::Hmac;
16use num_bigint::{BigInt, Sign};
17use num_traits::ops::euclid::Euclid;
18use sha2::{Digest, Sha256, Sha512};
19
20fn sha256(parts: &[&[u8]]) -> [u8; 32] {
21 let mut h = Sha256::new();
22 for p in parts {
23 h.update(p);
24 }
25 h.finalize().into()
26}
27
28fn sh(data: &[u8], salt: &[u8]) -> [u8; 32] {
29 sha256(&[salt, data, salt])
30}
31
32fn ph1(password: &[u8], salt1: &[u8], salt2: &[u8]) -> [u8; 32] {
33 sh(&sh(password, salt1), salt2)
34}
35
36fn ph2(password: &[u8], salt1: &[u8], salt2: &[u8]) -> [u8; 32] {
37 let hash1 = ph1(password, salt1, salt2);
38 let mut dk = [0u8; 64];
39 pbkdf2::pbkdf2::<Hmac<Sha512>>(&hash1, salt1, 100_000, &mut dk).unwrap();
40 sh(&dk, salt2)
41}
42
43fn pad256(data: &[u8]) -> [u8; 256] {
44 let mut out = [0u8; 256];
45 let start = 256usize.saturating_sub(data.len());
46 out[start..].copy_from_slice(&data[data.len().saturating_sub(256)..]);
47 out
48}
49
50fn xor32(a: &[u8; 32], b: &[u8; 32]) -> [u8; 32] {
51 let mut out = [0u8; 32];
52 for i in 0..32 {
53 out[i] = a[i] ^ b[i];
54 }
55 out
56}
57
58#[derive(Debug)]
60pub enum SrpError {
61 GbOutOfRange,
63 GaOutOfRange,
65}
66
67impl std::fmt::Display for SrpError {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 match self {
70 SrpError::GbOutOfRange => write!(f, "SRP: server g_b out of safe range"),
71 SrpError::GaOutOfRange => write!(f, "SRP: client g_a out of safe range"),
72 }
73 }
74}
75
76impl std::error::Error for SrpError {}
77
78pub fn calculate_2fa(
83 salt1: &[u8],
84 salt2: &[u8],
85 p: &[u8],
86 g: i32,
87 g_b: &[u8],
88 a: &[u8],
89 password: impl AsRef<[u8]>,
90) -> Result<([u8; 32], [u8; 256]), SrpError> {
91 let big_p = BigInt::from_bytes_be(Sign::Plus, p);
92 let g_b = pad256(g_b);
93 let a = pad256(a);
94 let g_hash = pad256(&[g as u8]);
95
96 let big_g_b = BigInt::from_bytes_be(Sign::Plus, &g_b);
97 let big_g = BigInt::from(g as u32);
98 let big_a = BigInt::from_bytes_be(Sign::Plus, &a);
99
100 {
102 let one = BigInt::from(1u32);
103 let p_minus_one = &big_p - &one;
104 if big_g_b <= one || big_g_b >= p_minus_one {
105 return Err(SrpError::GbOutOfRange);
106 }
107 }
108
109 let k = sha256(&[p, &g_hash]);
110 let big_k = BigInt::from_bytes_be(Sign::Plus, &k);
111
112 let g_a = big_g.modpow(&big_a, &big_p);
113 let g_a = pad256(&g_a.to_bytes_be().1);
114
115 {
117 let big_g_a = BigInt::from_bytes_be(Sign::Plus, &g_a);
118 let one = BigInt::from(1u32);
119 let p_minus_one = &big_p - &one;
120 if big_g_a <= one || big_g_a >= p_minus_one {
121 return Err(SrpError::GaOutOfRange);
122 }
123 }
124
125 let u = sha256(&[&g_a, &g_b]);
126 let big_u = BigInt::from_bytes_be(Sign::Plus, &u);
127
128 let x = ph2(password.as_ref(), salt1, salt2);
129 let big_x = BigInt::from_bytes_be(Sign::Plus, &x);
130
131 let big_v = big_g.modpow(&big_x, &big_p);
132 let big_kv = (big_k * big_v) % &big_p;
133
134 let big_t = (big_g_b - big_kv).rem_euclid(&big_p);
135
136 let exp = big_a + big_u * big_x;
137 let big_sa = big_t.modpow(&exp, &big_p);
138
139 let k_a = sha256(&[&pad256(&big_sa.to_bytes_be().1)]);
140
141 let h_p = sha256(&[p]);
142 let h_g = sha256(&[&g_hash]);
143 let p_xg = xor32(&h_p, &h_g);
144 let m1 = sha256(&[
145 &p_xg,
146 &sha256(&[salt1]),
147 &sha256(&[salt2]),
148 &g_a,
149 &g_b,
150 &k_a,
151 ]);
152
153 Ok((m1, g_a))
154}