1use num_bigint::BigUint;
10use num_traits::{Num, Zero};
11
12use crate::error::{Error, Result};
13use rand::RngCore;
14use sha1::Sha1;
15use sha2::{Digest, Sha256};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum SrpHash {
20 Sha1,
22 Sha256,
24}
25
26impl SrpHash {
27 pub fn plugin_name(self) -> &'static str {
29 match self {
30 SrpHash::Sha1 => "Srp",
31 SrpHash::Sha256 => "Srp256",
32 }
33 }
34
35 fn digest(self, parts: &[&[u8]]) -> Vec<u8> {
36 match self {
37 SrpHash::Sha1 => sha1_digest(parts),
38 SrpHash::Sha256 => {
39 let mut h = Sha256::new();
40 for p in parts {
41 h.update(p);
42 }
43 h.finalize().to_vec()
44 }
45 }
46 }
47}
48
49fn sha1_digest(parts: &[&[u8]]) -> Vec<u8> {
50 let mut h = Sha1::new();
51 for p in parts {
52 h.update(p);
53 }
54 h.finalize().to_vec()
55}
56
57const N_HEX: &str = "E67D2E994B2F900C3F41F08F5BB2627ED0D49EE1FE767A52EFCD565CD6E768812C3E1E9CE8F0A8BEA6CB13CD29DDEBF7A96D4A93B55D488DF099A15C89DCB0640738EB2CBDD9A8F7BAB561AB1B0DC1C6CDABF303264A08D1BCA932D1F1EE428B619D970F342ABA9A65793B8B2F041AE5364350C16F735F56ECBCA87BD57B29E7";
59const K_DEC: &str = "1277432915985975349439481660349303019122249719989";
61
62fn n() -> BigUint {
63 BigUint::from_str_radix(N_HEX, 16).expect("valid N")
64}
65fn g() -> BigUint {
66 BigUint::from(2u32)
67}
68fn k() -> BigUint {
69 BigUint::from_str_radix(K_DEC, 10).expect("valid k")
70}
71
72#[inline]
73fn to_bytes(n: &BigUint) -> Vec<u8> {
74 n.to_bytes_be()
75}
76#[inline]
77fn from_bytes(b: &[u8]) -> BigUint {
78 BigUint::from_bytes_be(b)
79}
80
81fn scramble(a_pub: &BigUint, b_pub: &BigUint) -> BigUint {
83 from_bytes(&sha1_digest(&[&to_bytes(a_pub), &to_bytes(b_pub)]))
84}
85
86fn user_hash(user: &str, password: &str, salt: &[u8]) -> BigUint {
88 let inner = sha1_digest(&[user.as_bytes(), b":", password.as_bytes()]);
89 from_bytes(&sha1_digest(&[salt, &inner]))
90}
91
92#[derive(Debug, Clone)]
94pub struct SrpClient {
95 hash: SrpHash,
96 a: BigUint,
98 a_pub: BigUint,
100}
101
102impl SrpClient {
103 pub fn new(hash: SrpHash) -> Self {
105 let mut secret = [0u8; 32];
106 rand::thread_rng().fill_bytes(&mut secret);
107 Self::with_secret(hash, &secret)
108 }
109
110 pub fn with_secret(hash: SrpHash, secret: &[u8]) -> Self {
112 let n = n();
113 let a = from_bytes(secret) % &n;
114 let a_pub = g().modpow(&a, &n);
115 SrpClient { hash, a, a_pub }
116 }
117
118 pub fn hash(&self) -> SrpHash {
120 self.hash
121 }
122
123 pub fn set_hash(&mut self, hash: SrpHash) {
126 self.hash = hash;
127 }
128
129 pub fn public_key_hex(&self) -> String {
131 to_hex(&to_bytes(&self.a_pub))
132 }
133
134 fn session_key(&self, b_pub: &BigUint, x: &BigUint) -> Vec<u8> {
136 let n = n();
137 let u = scramble(&self.a_pub, b_pub);
138 let gx = g().modpow(x, &n);
139 let kgx = (k() * gx) % &n;
140 let diff = ((b_pub + &n) - kgx) % &n;
142 let aux = (&self.a + (u * x)) % &n;
143 let secret = diff.modpow(&aux, &n);
144 sha1_digest(&[&to_bytes(&secret)])
145 }
146
147 pub fn proof(
157 &self,
158 user: &str,
159 password: &str,
160 salt: &[u8],
161 b_pub: &BigUint,
162 ) -> Result<(Vec<u8>, Vec<u8>)> {
163 let n = n();
164 if (b_pub % &n).is_zero() {
165 return Err(Error::auth("invalid SRP server ephemeral: B mod N == 0"));
166 }
167 if scramble(&self.a_pub, b_pub).is_zero() {
168 return Err(Error::auth("invalid SRP scrambling parameter: u == 0"));
169 }
170 let x = user_hash(user, password, salt);
171 let key = self.session_key(b_pub, &x);
172
173 let hn = from_bytes(&sha1_digest(&[&to_bytes(&n)]));
175 let hg = from_bytes(&sha1_digest(&[&to_bytes(&g())]));
176 let hng = hn.modpow(&hg, &n);
177
178 let hu = from_bytes(&sha1_digest(&[user.as_bytes()]));
181
182 let proof = self.hash.digest(&[
183 &to_bytes(&hng),
184 &to_bytes(&hu),
185 salt,
186 &to_bytes(&self.a_pub),
187 &to_bytes(b_pub),
188 &key,
189 ]);
190 Ok((proof, key))
191 }
192}
193
194pub fn parse_server_data(data: &[u8]) -> crate::error::Result<(Vec<u8>, BigUint)> {
197 use crate::error::Error;
198 let rd = |buf: &[u8], at: usize| -> crate::error::Result<(usize, usize)> {
199 if at + 2 > buf.len() {
200 return Err(Error::auth("truncated SRP server data"));
201 }
202 let len = (buf[at] as usize) | ((buf[at + 1] as usize) << 8);
203 Ok((at + 2, len))
204 };
205
206 let (p, salt_len) = rd(data, 0)?;
207 if p + salt_len > data.len() {
208 return Err(Error::auth("truncated SRP salt"));
209 }
210 let salt = data[p..p + salt_len].to_vec();
211
212 let (p, key_len) = rd(data, p + salt_len)?;
213 if p + key_len > data.len() {
214 return Err(Error::auth("truncated SRP server key"));
215 }
216 let key_hex = &data[p..p + key_len];
217 let b_pub = BigUint::from_str_radix(
218 std::str::from_utf8(key_hex).map_err(|_| Error::auth("server key not valid hex"))?,
219 16,
220 )
221 .map_err(|_| Error::auth("server key not valid hex"))?;
222
223 Ok((salt, b_pub))
224}
225
226pub fn to_hex(bytes: &[u8]) -> String {
228 use std::fmt::Write;
229 let mut s = String::with_capacity(bytes.len() * 2);
230 for b in bytes {
231 let _ = write!(s, "{b:02x}");
232 }
233 s
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
244 fn client_server_session_keys_agree() {
245 let user = "SYSDBA";
246 let password = "masterkey";
247 let salt = [0x11u8; 32];
248
249 let n = n();
251 let x = user_hash(user, password, &salt);
252 let v = g().modpow(&x, &n);
253
254 let b_priv = BigUint::from_bytes_be(&[0x42u8; 32]) % &n;
256 let b_pub = (k() * &v + g().modpow(&b_priv, &n)) % &n;
257
258 let client = SrpClient::with_secret(SrpHash::Sha256, &[0x37u8; 32]);
260 let (_proof, client_key) = client.proof(user, password, &salt, &b_pub).unwrap();
261
262 let u = scramble(&client.a_pub, &b_pub);
265 let base = (&client.a_pub * v.modpow(&u, &n)) % &n;
266 let server_secret = base.modpow(&b_priv, &n);
267 let server_key = sha1_digest(&[&server_secret.to_bytes_be()]);
268
269 assert_eq!(
270 client_key, server_key,
271 "client and server session keys must match"
272 );
273 }
274
275 #[test]
276 fn server_data_roundtrip() {
277 let salt = [0xABu8; 32];
279 let b = BigUint::from(0x1234_5678u32);
280 let b_hex = format!("{b:x}");
281 let mut data = Vec::new();
282 data.extend_from_slice(&(salt.len() as u16).to_le_bytes());
283 data.extend_from_slice(&salt);
284 data.extend_from_slice(&(b_hex.len() as u16).to_le_bytes());
285 data.extend_from_slice(b_hex.as_bytes());
286
287 let (got_salt, got_b) = parse_server_data(&data).unwrap();
288 assert_eq!(got_salt, salt);
289 assert_eq!(got_b, b);
290 }
291
292 #[test]
293 fn hex_encoding() {
294 assert_eq!(to_hex(&[0x00, 0x0f, 0xff]), "000fff");
295 }
296}