1use crate::{OpenADPError, Result};
13use sha2::{Sha256, Digest};
14use hkdf::Hkdf;
15use rand_core::{OsRng, RngCore};
16use rug::{Integer, Complete};
17use num_bigint::BigUint;
18use num_traits::{Zero, One};
19
20lazy_static::lazy_static! {
22 pub static ref P: BigUint = {
24 let p_str = "57896044618658097711785492504343953926634992332820282019728792003956564819949";
25 BigUint::parse_bytes(p_str.as_bytes(), 10).unwrap()
26 };
27
28 pub static ref D: BigUint = {
30 let inv121666 = mod_inverse(&BigUint::from(121666u32), &P);
31 let mut d = &*P - &BigUint::from(121665u32); d = (&d * &inv121666) % &*P;
33 d
34 };
35
36 pub static ref Q: BigUint = {
38 let q_str = "7237005577332262213973186563042994240857116359379907606001950938285454250989";
39 BigUint::parse_bytes(q_str.as_bytes(), 10).unwrap()
40 };
41
42 pub static ref MODP_SQRT_M1: BigUint = {
44 let exp = (&*P - 1u32) / 4u32;
45 BigUint::from(2u32).modpow(&exp, &P)
46 };
47
48 pub static ref G: Point4D = {
50 let gy = (&BigUint::from(4u32) * &mod_inverse(&BigUint::from(5u32), &P)) % &*P;
52 let gx = recover_x(&gy, 0).expect("Failed to recover base point X coordinate");
53 expand(&Point2D { x: gx, y: gy })
54 };
55
56 pub static ref ZERO_POINT: Point4D = Point4D {
58 x: BigUint::zero(),
59 y: BigUint::one(),
60 z: BigUint::one(),
61 t: BigUint::zero(),
62 };
63}
64
65#[derive(Debug, Clone, PartialEq)]
67pub struct Point2D {
68 pub x: BigUint,
69 pub y: BigUint,
70}
71
72impl Point2D {
73 pub fn new(x: BigUint, y: BigUint) -> Self {
74 Self { x, y }
75 }
76}
77
78#[derive(Debug, Clone, PartialEq)]
80pub struct Point4D {
81 pub x: BigUint,
82 pub y: BigUint,
83 pub z: BigUint,
84 pub t: BigUint,
85}
86
87impl Point4D {
88 pub fn new(x: BigUint, y: BigUint, z: BigUint, t: BigUint) -> Self {
89 Self { x, y, z, t }
90 }
91
92 pub fn identity() -> Self {
93 ZERO_POINT.clone()
94 }
95}
96
97pub fn expand(point: &Point2D) -> Point4D {
99 let xy = (&point.x * &point.y) % &*P;
100 Point4D {
101 x: point.x.clone(),
102 y: point.y.clone(),
103 z: BigUint::one(),
104 t: xy,
105 }
106}
107
108pub fn unexpand(point: &Point4D) -> Result<Point2D> {
110 let z_inv = mod_inverse(&point.z, &P);
111 let x = (&point.x * &z_inv) % &*P;
112 let y = (&point.y * &z_inv) % &*P;
113 Ok(Point2D { x, y })
114}
115
116pub fn sha256_hash(data: &[u8]) -> Vec<u8> {
118 let mut hasher = Sha256::new();
119 hasher.update(data);
120 hasher.finalize().to_vec()
121}
122
123pub fn mod_inverse(a: &BigUint, p: &BigUint) -> BigUint {
125 let exp = p - 2u32;
127 a.modpow(&exp, p)
128}
129
130pub fn point_add(p1: &Point4D, p2: &Point4D) -> Point4D {
132 let a = ((&p1.y + &*P - &p1.x) * (&p2.y + &*P - &p2.x)) % &*P;
134
135 let b = ((&p1.y + &p1.x) * (&p2.y + &p2.x)) % &*P;
137
138 let c = (2u32 * &p1.t * &p2.t % &*P * &*D) % &*P;
140
141 let d = (2u32 * &p1.z * &p2.z) % &*P;
143
144 let e = (&b + &*P - &a) % &*P;
146 let f = (&d + &*P - &c) % &*P;
147 let g = (&d + &c) % &*P;
148 let h = (&b + &a) % &*P;
149
150 Point4D {
152 x: (&e * &f) % &*P,
153 y: (&g * &h) % &*P,
154 z: (&f * &g) % &*P,
155 t: (&e * &h) % &*P,
156 }
157}
158
159pub fn point_mul(s: &BigUint, p: &Point4D) -> Point4D {
161 let mut q = ZERO_POINT.clone();
162 let mut p_copy = p.clone();
163 let mut s_copy = s.clone();
164
165 while s_copy > BigUint::zero() {
166 if s_copy.bit(0) {
167 q = point_add(&q, &p_copy);
168 }
169 p_copy = point_add(&p_copy, &p_copy);
170 s_copy >>= 1;
171 }
172
173 q
174}
175
176pub fn point_mul8(p: &Point4D) -> Point4D {
178 let mut result = point_add(p, p); result = point_add(&result, &result); result = point_add(&result, &result); result
183}
184
185pub fn point_equal(p1: &Point4D, p2: &Point4D) -> bool {
187 let left = (&p1.x * &p2.z) % &*P;
189 let right = (&p2.x * &p1.z) % &*P;
190 if left != right {
191 return false;
192 }
193
194 let left = (&p1.y * &p2.z) % &*P;
195 let right = (&p2.y * &p1.z) % &*P;
196 left == right
197}
198
199pub fn recover_x(y: &BigUint, sign: u8) -> Option<BigUint> {
201 if y >= &*P {
202 return None;
203 }
204
205 let y2 = (y * y) % &*P;
207
208 let numerator = (&y2 + &*P - 1u32) % &*P;
209 let denominator = ((&*D * &y2) + 1u32) % &*P;
210
211 let denominator_inv = mod_inverse(&denominator, &P);
212 let x2 = (&numerator * &denominator_inv) % &*P;
213
214 if x2.is_zero() {
215 return if sign != 0 { None } else { Some(BigUint::zero()) };
216 }
217
218 let exp = (&*P + 3u32) / 8u32;
220 let mut x = x2.modpow(&exp, &P);
221
222 let x_squared = (&x * &x) % &*P;
224 if x_squared != x2 {
225 x = (&x * &*MODP_SQRT_M1) % &*P;
226 }
227
228 let x_squared = (&x * &x) % &*P;
230 if x_squared != x2 {
231 return None;
232 }
233
234 if x.bit(0) != (sign != 0) {
236 x = &*P - &x;
237 }
238
239 Some(x)
240}
241
242pub fn point_compress(p: &Point4D) -> Result<Vec<u8>> {
244 let z_inv = mod_inverse(&p.z, &P);
245 let x = (&p.x * &z_inv) % &*P;
246 let mut y = (&p.y * &z_inv) % &*P;
247
248 if x.bit(0) {
250 y.set_bit(255, true);
251 }
252
253 let mut result = vec![0u8; 32];
255 let y_bytes = y.to_bytes_le();
256 let copy_len = std::cmp::min(y_bytes.len(), 32);
257 result[..copy_len].copy_from_slice(&y_bytes[..copy_len]);
258
259 Ok(result)
260}
261
262pub fn point_decompress(data: &[u8]) -> Result<Point4D> {
264 if data.len() != 32 {
265 return Err(OpenADPError::PointOperation("Invalid input length for decompression".to_string()));
266 }
267
268 let mut y = BigUint::zero();
270 for i in 0..32 {
271 for bit in 0..8 {
272 if (data[i] >> bit) & 1 == 1 {
273 y.set_bit((i * 8 + bit) as u64, true);
274 }
275 }
276 }
277
278 let sign = if y.bit(255) { 1 } else { 0 };
279 y.set_bit(255, false); let x = recover_x(&y, sign)
282 .ok_or_else(|| OpenADPError::PointOperation("Invalid point".to_string()))?;
283
284 let xy = (&x * &y) % &*P;
285 let point = Point4D {
286 x,
287 y,
288 z: BigUint::one(),
289 t: xy,
290 };
291
292 if !is_valid_point(&point) {
294 return Err(OpenADPError::PointOperation("Invalid point: failed validation".to_string()));
295 }
296
297 Ok(point)
298}
299
300pub fn is_valid_point(p: &Point4D) -> bool {
302 if point_equal(p, &ZERO_POINT) {
304 return false;
305 }
306
307 let eight_p = point_mul8(p);
310 !point_equal(&eight_p, &ZERO_POINT)
311}
312
313pub fn prefixed(data: &[u8]) -> Vec<u8> {
315 let l = data.len();
316 if l >= (1 << 16) {
317 panic!("Input string too long");
318 }
319 let mut result = Vec::with_capacity(data.len() + 2);
320 result.push(l as u8); result.push((l >> 8) as u8); result.extend_from_slice(data);
323 result
324}
325
326#[allow(non_snake_case)]
328pub fn H(uid: &[u8], did: &[u8], bid: &[u8], pin: &[u8]) -> Result<Point4D> {
329 let mut data = prefixed(uid);
331 data.extend_from_slice(&prefixed(did));
332 data.extend_from_slice(&prefixed(bid));
333 data.extend_from_slice(pin);
334
335 let hash = sha256_hash(&data);
337
338 let y_base_full = BigUint::from_bytes_le(&hash);
340
341 let sign = if y_base_full.bit(255) { 1 } else { 0 };
342 let mut y_base = y_base_full.clone();
343 y_base.set_bit(255, false); for counter in 0..1000 {
346 let y = &y_base ^ BigUint::from(counter as u32);
348
349 if let Some(x) = recover_x(&y, sign) {
350 let p = expand(&Point2D { x, y });
352 let p = point_mul8(&p);
353
354 if is_valid_point(&p) {
355 return Ok(p);
356 }
357 }
358 }
359
360 Ok(G.clone())
362}
363
364pub fn derive_enc_key(point: &Point4D) -> Result<Vec<u8>> {
366 let point_bytes = point_compress(point)?;
367
368 let salt = b"OpenADP-EncKey-v1";
370 let info = b"AES-256-GCM";
371
372 let hk = Hkdf::<Sha256>::new(Some(salt), &point_bytes);
373 let mut okm = [0u8; 32];
374 hk.expand(info, &mut okm)
375 .map_err(|e| OpenADPError::Crypto(format!("HKDF expansion failed: {}", e)))?;
376
377 Ok(okm.to_vec())
378}
379
380const Q_HEX: &str = "1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed";
383
384pub struct ShamirSecretSharing;
386
387impl ShamirSecretSharing {
388 pub fn get_q() -> Integer {
390 Integer::from_str_radix(Q_HEX, 16).unwrap()
391 }
392
393 fn random_mod_q() -> Integer {
395 let q = Self::get_q();
396 let mut rng = OsRng;
397
398 let bit_len = q.significant_bits();
400 let byte_len = (bit_len + 7) / 8;
401
402 loop {
403 let mut bytes = vec![0u8; byte_len as usize];
404 rng.fill_bytes(&mut bytes);
405
406 let random_int = Integer::from_digits(&bytes, rug::integer::Order::MsfBe);
407 if random_int < q {
408 return random_int;
409 }
410 }
411 }
412
413 pub fn split_secret(secret: &Integer, threshold: usize, num_shares: usize) -> Result<Vec<(usize, Integer)>> {
415 if threshold == 0 || threshold > num_shares {
416 return Err(OpenADPError::SecretSharing("Invalid threshold".to_string()));
417 }
418
419 let q = Self::get_q();
420
421 if *secret >= q {
422 return Err(OpenADPError::SecretSharing("Secret too large for field".to_string()));
423 }
424
425 let mut coefficients = vec![secret.clone()];
428
429 for _ in 1..threshold {
430 coefficients.push(Self::random_mod_q());
431 }
432
433 let mut shares = Vec::new();
435 for x in 1..=num_shares {
436 let x_int = Integer::from(x);
437 let mut y = Integer::new();
438 let mut x_power = Integer::from(1);
439
440 for coeff in &coefficients {
441 let term = (coeff * &x_power).complete() % &q;
443 y = (&y + &term).complete() % &q;
444 x_power = (&x_power * &x_int).complete() % &q;
446 }
447
448 shares.push((x, y));
449 }
450
451 Ok(shares)
452 }
453
454 pub fn recover_secret(shares: Vec<(usize, Integer)>) -> Result<Integer> {
456 if shares.is_empty() {
457 return Err(OpenADPError::SecretSharing("No shares provided".to_string()));
458 }
459
460 let q = Self::get_q();
461 let mut secret = Integer::new();
462
463 let int_shares: Vec<(Integer, Integer)> = shares.into_iter()
465 .map(|(x, y)| (Integer::from(x), y))
466 .collect();
467
468 for (i, (xi, yi)) in int_shares.iter().enumerate() {
470 let mut numerator = Integer::from(1);
472 let mut denominator = Integer::from(1);
473
474 for (j, (xj, _)) in int_shares.iter().enumerate() {
475 if i != j {
476 let neg_xj = Integer::from(&q - xj);
478 numerator = Integer::from(&numerator * &neg_xj) % &q;
479 let xi_clone = xi.clone();
481 let xj_clone = xj.clone();
482 let diff = if &xi_clone >= &xj_clone {
483 Integer::from(&xi_clone - &xj_clone)
484 } else {
485 let xj_minus_xi = Integer::from(&xj_clone - &xi_clone);
486 Integer::from(&q - &xj_minus_xi)
487 };
488 denominator = Integer::from(&denominator * &diff) % &q;
489 }
490 }
491
492 let denominator_inv = denominator.invert(&q)
494 .map_err(|_| OpenADPError::SecretSharing("Cannot invert denominator".to_string()))?;
495 let li_0 = (&numerator * &denominator_inv).complete() % &q;
496
497 let term = (yi * &li_0).complete() % &q;
499 secret = (&secret + &term).complete() % &q;
500 }
501
502 Ok(secret)
503 }
504
505 pub fn split_secret_bytes(secret: &[u8], threshold: usize, num_shares: usize) -> Result<Vec<(usize, Vec<u8>)>> {
507 if threshold == 0 || threshold > num_shares {
508 return Err(OpenADPError::SecretSharing("Invalid threshold".to_string()));
509 }
510
511 let secret_int = Integer::from_digits(secret, rug::integer::Order::MsfBe);
513
514 let int_shares = Self::split_secret(&secret_int, threshold, num_shares)?;
516
517 let mut byte_shares = Vec::new();
519 for (x, y) in int_shares {
520 let y_bytes = y.to_digits::<u8>(rug::integer::Order::MsfBe);
521 byte_shares.push((x, y_bytes));
522 }
523
524 Ok(byte_shares)
525 }
526
527 pub fn recover_secret_bytes(shares: Vec<(usize, Vec<u8>)>) -> Result<Vec<u8>> {
529 if shares.is_empty() {
530 return Err(OpenADPError::SecretSharing("No shares provided".to_string()));
531 }
532
533 let int_shares = shares.into_iter()
535 .map(|(x, y_bytes)| {
536 let y = Integer::from_digits(&y_bytes, rug::integer::Order::MsfBe);
537 (x, y)
538 })
539 .collect();
540
541 let secret_int = Self::recover_secret(int_shares)?;
543
544 let secret_bytes = secret_int.to_digits::<u8>(rug::integer::Order::MsfBe);
546
547 Ok(secret_bytes)
548 }
549}
550
551pub struct PointShare {
553 pub x: usize,
554 pub point: Point4D,
555}
556
557impl PointShare {
558 pub fn new(x: usize, point: Point4D) -> Self {
559 Self { x, point }
560 }
561}
562
563pub fn recover_point_secret(point_shares: Vec<PointShare>) -> Result<Point4D> {
565 if point_shares.is_empty() {
566 return Err(OpenADPError::SecretSharing("No point shares provided".to_string()));
567 }
568
569 let q = &*Q;
570 let mut secret_point = ZERO_POINT.clone();
571
572 for (i, share_i) in point_shares.iter().enumerate() {
574 let xi = BigUint::from(share_i.x);
575
576 let mut numerator = BigUint::one();
578 let mut denominator = BigUint::one();
579
580 for (j, share_j) in point_shares.iter().enumerate() {
581 if i != j {
582 let xj = BigUint::from(share_j.x);
583
584 let neg_xj = (q + q - &xj) % q; numerator = (&numerator * &neg_xj) % q;
587
588 let diff = if &xi >= &xj {
590 (&xi - &xj) % q
591 } else {
592 (q + &xi - &xj) % q
593 };
594 denominator = (&denominator * &diff) % q;
595 }
596 }
597
598 let denominator_inv = mod_inverse(&denominator, q);
600 let li_0 = (&numerator * &denominator_inv) % q;
601
602 let term_point = point_mul(&li_0, &share_i.point);
604 secret_point = point_add(&secret_point, &term_point);
605 }
606
607 Ok(secret_point)
608}
609
610pub struct Ed25519;
612
613impl Ed25519 {
614 #[allow(non_snake_case)]
615 pub fn H(uid: &[u8], did: &[u8], bid: &[u8], pin: &[u8]) -> Result<Point4D> {
616 H(uid, did, bid, pin)
617 }
618
619 pub fn scalar_mult(scalar: &[u8], point: &Point4D) -> Result<Point4D> {
620 let scalar_bigint = BigUint::from_bytes_le(scalar);
621 Ok(point_mul(&scalar_bigint, point))
622 }
623
624 pub fn point_add(p1: &Point4D, p2: &Point4D) -> Result<Point4D> {
625 Ok(point_add(p1, p2))
626 }
627
628 pub fn compress(point: &Point4D) -> Result<Vec<u8>> {
629 point_compress(point)
630 }
631
632 pub fn decompress(data: &[u8]) -> Result<Point4D> {
633 point_decompress(data)
634 }
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640
641 #[test]
642 fn test_point_operations() {
643 let p1 = G.clone();
645 let p2 = point_add(&p1, &p1);
646
647 assert!(!point_equal(&p1, &p2));
649
650 let scalar = BigUint::from(2u32);
652 let p3 = point_mul(&scalar, &p1);
653 assert!(point_equal(&p2, &p3));
654 }
655
656 #[test]
657 fn test_hash_functions() {
658 let data = b"test data";
659 let hash = sha256_hash(data);
660 assert_eq!(hash.len(), 32);
661 }
662
663 #[test]
664 fn test_H() {
665 let uid = b"test-user";
666 let did = b"test-device";
667 let bid = b"test-backup";
668 let pin = b"12";
669
670 let point = H(uid, did, bid, pin).unwrap();
671 assert!(is_valid_point(&point));
672 }
673
674 #[test]
675 fn test_shamir_secret_sharing() {
676 let secret = Integer::from(12345);
677 let threshold = 3;
678 let num_shares = 5;
679
680 let shares = ShamirSecretSharing::split_secret(&secret, threshold, num_shares).unwrap();
681 assert_eq!(shares.len(), num_shares);
682
683 let recovery_shares = shares.into_iter().take(threshold).collect();
685 let recovered = ShamirSecretSharing::recover_secret(recovery_shares).unwrap();
686 assert_eq!(recovered, secret);
687 }
688
689 #[test]
690 fn test_key_derivation() {
691 let point = G.clone();
692 let key = derive_enc_key(&point).unwrap();
693 assert_eq!(key.len(), 32);
694 }
695}