1use std::fmt;
17use std::time::{SystemTime, UNIX_EPOCH};
18
19use layer_crypto::{AuthKey, aes, check_p_and_g, factorize, generate_key_data_from_nonce, rsa};
20use layer_tl_types::{Cursor, Deserializable, Serializable};
21use num_bigint::{BigUint, ToBigUint};
22use sha1::{Digest, Sha1};
23
24#[allow(missing_docs)]
28#[derive(Clone, Debug, PartialEq)]
29pub enum Error {
30 InvalidNonce {
31 got: [u8; 16],
32 expected: [u8; 16],
33 },
34 InvalidPqSize {
35 size: usize,
36 },
37 UnknownFingerprints {
38 fingerprints: Vec<i64>,
39 },
40 DhParamsFail,
41 InvalidServerNonce {
42 got: [u8; 16],
43 expected: [u8; 16],
44 },
45 EncryptedResponseNotPadded {
46 len: usize,
47 },
48 InvalidDhInnerData {
49 error: layer_tl_types::deserialize::Error,
50 },
51 InvalidDhPrime {
52 source: layer_crypto::DhError,
53 },
54 GParameterOutOfRange {
55 value: BigUint,
56 low: BigUint,
57 high: BigUint,
58 },
59 DhGenRetry,
60 DhGenFail,
61 InvalidAnswerHash {
62 got: [u8; 20],
63 expected: [u8; 20],
64 },
65 InvalidNewNonceHash {
66 got: [u8; 16],
67 expected: [u8; 16],
68 },
69}
70
71impl std::error::Error for Error {}
72
73impl fmt::Display for Error {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 match self {
76 Self::InvalidNonce { got, expected } => {
77 write!(f, "nonce mismatch: got {got:?}, expected {expected:?}")
78 }
79 Self::InvalidPqSize { size } => write!(f, "pq size {size} invalid (expected 8)"),
80 Self::UnknownFingerprints { fingerprints } => {
81 write!(f, "no known fingerprint in {fingerprints:?}")
82 }
83 Self::DhParamsFail => write!(f, "server returned DH params failure"),
84 Self::InvalidServerNonce { got, expected } => write!(
85 f,
86 "server_nonce mismatch: got {got:?}, expected {expected:?}"
87 ),
88 Self::EncryptedResponseNotPadded { len } => {
89 write!(f, "encrypted answer len {len} is not 16-byte aligned")
90 }
91 Self::InvalidDhInnerData { error } => {
92 write!(f, "DH inner data deserialization error: {error}")
93 }
94 Self::InvalidDhPrime { source } => {
95 write!(f, "DH prime/generator validation failed: {source}")
96 }
97 Self::GParameterOutOfRange { value, low, high } => {
98 write!(f, "g={value} not in range ({low}, {high})")
99 }
100 Self::DhGenRetry => write!(f, "DH gen retry requested"),
101 Self::DhGenFail => write!(f, "DH gen failed"),
102 Self::InvalidAnswerHash { got, expected } => write!(
103 f,
104 "answer hash mismatch: got {got:?}, expected {expected:?}"
105 ),
106 Self::InvalidNewNonceHash { got, expected } => write!(
107 f,
108 "new nonce hash mismatch: got {got:?}, expected {expected:?}"
109 ),
110 }
111 }
112}
113
114pub struct Step1 {
118 nonce: [u8; 16],
119}
120
121pub struct Step2 {
123 nonce: [u8; 16],
124 server_nonce: [u8; 16],
125 new_nonce: [u8; 32],
126}
127
128pub struct Step3 {
130 nonce: [u8; 16],
131 server_nonce: [u8; 16],
132 new_nonce: [u8; 32],
133 gab: BigUint,
134 time_offset: i32,
135}
136
137#[derive(Clone, Debug, PartialEq)]
139pub struct Finished {
140 pub auth_key: [u8; 256],
142 pub time_offset: i32,
144 pub first_salt: i64,
146}
147
148pub fn step1() -> Result<(layer_tl_types::functions::ReqPqMulti, Step1), Error> {
152 let mut buf = [0u8; 16];
153 getrandom::getrandom(&mut buf).expect("getrandom");
154 do_step1(&buf)
155}
156
157fn do_step1(random: &[u8; 16]) -> Result<(layer_tl_types::functions::ReqPqMulti, Step1), Error> {
158 let nonce = *random;
159 Ok((
160 layer_tl_types::functions::ReqPqMulti { nonce },
161 Step1 { nonce },
162 ))
163}
164
165pub fn step2(
169 data: Step1,
170 response: layer_tl_types::enums::ResPq,
171) -> Result<(layer_tl_types::functions::ReqDhParams, Step2), Error> {
172 let mut rnd = [0u8; 256];
173 getrandom::getrandom(&mut rnd).expect("getrandom");
174 do_step2(data, response, &rnd)
175}
176
177fn do_step2(
178 data: Step1,
179 response: layer_tl_types::enums::ResPq,
180 random: &[u8; 256],
181) -> Result<(layer_tl_types::functions::ReqDhParams, Step2), Error> {
182 let Step1 { nonce } = data;
183
184 let layer_tl_types::enums::ResPq::ResPq(res_pq) = response;
186
187 check_nonce(&res_pq.nonce, &nonce)?;
188
189 if res_pq.pq.len() != 8 {
190 return Err(Error::InvalidPqSize {
191 size: res_pq.pq.len(),
192 });
193 }
194
195 let pq = u64::from_be_bytes(res_pq.pq.as_slice().try_into().unwrap());
196 let (p, q) = factorize(pq);
197
198 let mut new_nonce = [0u8; 32];
199 new_nonce.copy_from_slice(&random[..32]);
200
201 let rnd224: &[u8; 224] = random[32..].try_into().unwrap();
203
204 fn trim_be(v: u64) -> Vec<u8> {
205 let b = v.to_be_bytes();
206 let skip = b.iter().position(|&x| x != 0).unwrap_or(7);
207 b[skip..].to_vec()
208 }
209
210 let p_bytes = trim_be(p);
211 let q_bytes = trim_be(q);
212
213 let pq_inner =
216 layer_tl_types::enums::PQInnerData::PQInnerData(layer_tl_types::types::PQInnerData {
217 pq: pq.to_be_bytes().to_vec(),
218 p: p_bytes.clone(),
219 q: q_bytes.clone(),
220 nonce,
221 server_nonce: res_pq.server_nonce,
222 new_nonce,
223 })
224 .to_bytes();
225
226 let fingerprint = res_pq
227 .server_public_key_fingerprints
228 .iter()
229 .copied()
230 .find(|&fp| key_for_fingerprint(fp).is_some())
231 .ok_or_else(|| Error::UnknownFingerprints {
232 fingerprints: res_pq.server_public_key_fingerprints.clone(),
233 })?;
234
235 let key = key_for_fingerprint(fingerprint).unwrap();
236 let ciphertext = rsa::encrypt_hashed(&pq_inner, &key, rnd224);
237
238 Ok((
239 layer_tl_types::functions::ReqDhParams {
240 nonce,
241 server_nonce: res_pq.server_nonce,
242 p: p_bytes,
243 q: q_bytes,
244 public_key_fingerprint: fingerprint,
245 encrypted_data: ciphertext,
246 },
247 Step2 {
248 nonce,
249 server_nonce: res_pq.server_nonce,
250 new_nonce,
251 },
252 ))
253}
254
255pub fn step3(
259 data: Step2,
260 response: layer_tl_types::enums::ServerDhParams,
261) -> Result<(layer_tl_types::functions::SetClientDhParams, Step3), Error> {
262 let mut rnd = [0u8; 272]; getrandom::getrandom(&mut rnd).expect("getrandom");
264 let now = SystemTime::now()
265 .duration_since(UNIX_EPOCH)
266 .unwrap()
267 .as_secs() as i32;
268 do_step3(data, response, &rnd, now)
269}
270
271fn do_step3(
272 data: Step2,
273 response: layer_tl_types::enums::ServerDhParams,
274 random: &[u8; 272],
275 now: i32,
276) -> Result<(layer_tl_types::functions::SetClientDhParams, Step3), Error> {
277 let Step2 {
278 nonce,
279 server_nonce,
280 new_nonce,
281 } = data;
282
283 let mut server_dh_ok = match response {
284 layer_tl_types::enums::ServerDhParams::Fail(f) => {
285 check_nonce(&f.nonce, &nonce)?;
286 check_server_nonce(&f.server_nonce, &server_nonce)?;
287 let digest: [u8; 20] = {
289 let mut sha = Sha1::new();
290 sha.update(new_nonce);
291 sha.finalize().into()
292 };
293 let mut expected_hash = [0u8; 16];
294 expected_hash.copy_from_slice(&digest[4..]);
295 check_new_nonce_hash(&f.new_nonce_hash, &expected_hash)?;
296 return Err(Error::DhParamsFail);
297 }
298 layer_tl_types::enums::ServerDhParams::Ok(x) => x,
299 };
300
301 check_nonce(&server_dh_ok.nonce, &nonce)?;
302 check_server_nonce(&server_dh_ok.server_nonce, &server_nonce)?;
303
304 if server_dh_ok.encrypted_answer.len() % 16 != 0 {
305 return Err(Error::EncryptedResponseNotPadded {
306 len: server_dh_ok.encrypted_answer.len(),
307 });
308 }
309
310 let (key, iv) = generate_key_data_from_nonce(&server_nonce, &new_nonce);
311 aes::ige_decrypt(&mut server_dh_ok.encrypted_answer, &key, &iv);
312 let plain = server_dh_ok.encrypted_answer;
313
314 let got_hash: [u8; 20] = plain[..20].try_into().unwrap();
315 let mut cursor = Cursor::from_slice(&plain[20..]);
316
317 let inner = match layer_tl_types::enums::ServerDhInnerData::deserialize(&mut cursor) {
320 Ok(layer_tl_types::enums::ServerDhInnerData::ServerDhInnerData(x)) => x,
321 Err(e) => return Err(Error::InvalidDhInnerData { error: e }),
322 };
323
324 let expected_hash: [u8; 20] = {
325 let mut sha = Sha1::new();
326 sha.update(&plain[20..20 + cursor.pos()]);
327 sha.finalize().into()
328 };
329 if got_hash != expected_hash {
330 return Err(Error::InvalidAnswerHash {
331 got: got_hash,
332 expected: expected_hash,
333 });
334 }
335
336 check_nonce(&inner.nonce, &nonce)?;
337 check_server_nonce(&inner.server_nonce, &server_nonce)?;
338
339 check_p_and_g(&inner.dh_prime, inner.g as u32)
341 .map_err(|source| Error::InvalidDhPrime { source })?;
342
343 let dh_prime = BigUint::from_bytes_be(&inner.dh_prime);
344 let g = inner.g.to_biguint().unwrap();
345 let g_a = BigUint::from_bytes_be(&inner.g_a);
346 let time_offset = inner.server_time - now;
347
348 let b = BigUint::from_bytes_be(&random[..256]);
349 let g_b = g.modpow(&b, &dh_prime);
350 let gab = g_a.modpow(&b, &dh_prime);
351
352 let one = BigUint::from(1u32);
354 check_g_in_range(&g, &one, &(&dh_prime - &one))?;
355 check_g_in_range(&g_a, &one, &(&dh_prime - &one))?;
356 check_g_in_range(&g_b, &one, &(&dh_prime - &one))?;
357 let safety = one.clone() << (2048 - 64);
358 check_g_in_range(&g_a, &safety, &(&dh_prime - &safety))?;
359 check_g_in_range(&g_b, &safety, &(&dh_prime - &safety))?;
360
361 let client_dh_inner = layer_tl_types::enums::ClientDhInnerData::ClientDhInnerData(
364 layer_tl_types::types::ClientDhInnerData {
365 nonce,
366 server_nonce,
367 retry_id: 0,
368 g_b: g_b.to_bytes_be(),
369 },
370 )
371 .to_bytes();
372
373 let digest: [u8; 20] = {
374 let mut sha = Sha1::new();
375 sha.update(&client_dh_inner);
376 sha.finalize().into()
377 };
378
379 let pad_len = (16 - ((20 + client_dh_inner.len()) % 16)) % 16;
380 let rnd16 = &random[256..256 + pad_len.min(16)];
381
382 let mut hashed = Vec::with_capacity(20 + client_dh_inner.len() + pad_len);
383 hashed.extend_from_slice(&digest);
384 hashed.extend_from_slice(&client_dh_inner);
385 hashed.extend_from_slice(&rnd16[..pad_len]);
386
387 aes::ige_encrypt(&mut hashed, &key, &iv);
388
389 Ok((
390 layer_tl_types::functions::SetClientDhParams {
391 nonce,
392 server_nonce,
393 encrypted_data: hashed,
394 },
395 Step3 {
396 nonce,
397 server_nonce,
398 new_nonce,
399 gab,
400 time_offset,
401 },
402 ))
403}
404
405pub fn finish(
409 data: Step3,
410 response: layer_tl_types::enums::SetClientDhParamsAnswer,
411) -> Result<Finished, Error> {
412 let Step3 {
413 nonce,
414 server_nonce,
415 new_nonce,
416 gab,
417 time_offset,
418 } = data;
419
420 struct DhData {
421 nonce: [u8; 16],
422 server_nonce: [u8; 16],
423 hash: [u8; 16],
424 num: u8,
425 }
426
427 let dh = match response {
428 layer_tl_types::enums::SetClientDhParamsAnswer::DhGenOk(x) => DhData {
430 nonce: x.nonce,
431 server_nonce: x.server_nonce,
432 hash: x.new_nonce_hash1,
433 num: 1,
434 },
435 layer_tl_types::enums::SetClientDhParamsAnswer::DhGenRetry(x) => DhData {
436 nonce: x.nonce,
437 server_nonce: x.server_nonce,
438 hash: x.new_nonce_hash2,
439 num: 2,
440 },
441 layer_tl_types::enums::SetClientDhParamsAnswer::DhGenFail(x) => DhData {
442 nonce: x.nonce,
443 server_nonce: x.server_nonce,
444 hash: x.new_nonce_hash3,
445 num: 3,
446 },
447 };
448
449 check_nonce(&dh.nonce, &nonce)?;
450 check_server_nonce(&dh.server_nonce, &server_nonce)?;
451
452 let mut key_bytes = [0u8; 256];
453 let gab_bytes = gab.to_bytes_be();
454 let skip = 256 - gab_bytes.len();
455 key_bytes[skip..].copy_from_slice(&gab_bytes);
456
457 let auth_key = AuthKey::from_bytes(key_bytes);
458 let expected_hash = auth_key.calc_new_nonce_hash(&new_nonce, dh.num);
459 check_new_nonce_hash(&dh.hash, &expected_hash)?;
460
461 let first_salt = {
462 let mut buf = [0u8; 8];
463 for ((dst, a), b) in buf.iter_mut().zip(&new_nonce[..8]).zip(&server_nonce[..8]) {
464 *dst = a ^ b;
465 }
466 i64::from_le_bytes(buf)
467 };
468
469 match dh.num {
470 1 => Ok(Finished {
471 auth_key: auth_key.to_bytes(),
472 time_offset,
473 first_salt,
474 }),
475 2 => Err(Error::DhGenRetry),
476 _ => Err(Error::DhGenFail),
477 }
478}
479
480fn check_nonce(got: &[u8; 16], expected: &[u8; 16]) -> Result<(), Error> {
483 if got == expected {
484 Ok(())
485 } else {
486 Err(Error::InvalidNonce {
487 got: *got,
488 expected: *expected,
489 })
490 }
491}
492fn check_server_nonce(got: &[u8; 16], expected: &[u8; 16]) -> Result<(), Error> {
493 if got == expected {
494 Ok(())
495 } else {
496 Err(Error::InvalidServerNonce {
497 got: *got,
498 expected: *expected,
499 })
500 }
501}
502fn check_new_nonce_hash(got: &[u8; 16], expected: &[u8; 16]) -> Result<(), Error> {
503 if got == expected {
504 Ok(())
505 } else {
506 Err(Error::InvalidNewNonceHash {
507 got: *got,
508 expected: *expected,
509 })
510 }
511}
512fn check_g_in_range(val: &BigUint, lo: &BigUint, hi: &BigUint) -> Result<(), Error> {
513 if lo < val && val < hi {
514 Ok(())
515 } else {
516 Err(Error::GParameterOutOfRange {
517 value: val.clone(),
518 low: lo.clone(),
519 high: hi.clone(),
520 })
521 }
522}
523
524#[allow(clippy::unreadable_literal)]
526pub fn key_for_fingerprint(fp: i64) -> Option<rsa::Key> {
527 Some(match fp {
528 -3414540481677951611 => rsa::Key::new(
530 "29379598170669337022986177149456128565388431120058863768162556424047512191330847455146576344487764408661701890505066208632169112269581063774293102577308490531282748465986139880977280302242772832972539403531316010870401287642763009136156734339538042419388722777357134487746169093539093850251243897188928735903389451772730245253062963384108812842079887538976360465290946139638691491496062099570836476454855996319192747663615955633778034897140982517446405334423701359108810182097749467210509584293428076654573384828809574217079944388301239431309115013843331317877374435868468779972014486325557807783825502498215169806323",
531 "65537",
532 )?,
533 -5595554452916591101 => rsa::Key::new(
535 "25342889448840415564971689590713473206898847759084779052582026594546022463853940585885215951168491965708222649399180603818074200620463776135424884632162512403163793083921641631564740959529419359595852941166848940585952337613333022396096584117954892216031229237302943701877588456738335398602461675225081791820393153757504952636234951323237820036543581047826906120927972487366805292115792231423684261262330394324750785450942589751755390156647751460719351439969059949569615302809050721500330239005077889855323917509948255722081644689442127297605422579707142646660768825302832201908302295573257427896031830742328565032949",
536 "65537",
537 )?,
538 _ => return None,
539 })
540}