use crate::{c2_Element, c2_Scalar, Group};
use alloc::vec::Vec;
use curve25519_dalek::{constants::ED25519_BASEPOINT_POINT, edwards::CompressedEdwardsY};
use hkdf::Hkdf;
use rand_core::{CryptoRng, RngCore};
use sha2::{Digest, Sha256};
#[derive(Debug, PartialEq, Eq)]
pub struct Ed25519Group;
impl Group for Ed25519Group {
type Scalar = c2_Scalar;
type Element = c2_Element;
type TranscriptHash = Sha256;
fn name() -> &'static str {
"Ed25519"
}
fn const_m() -> c2_Element {
CompressedEdwardsY([
0x15, 0xcf, 0xd1, 0x8e, 0x38, 0x59, 0x52, 0x98, 0x2b, 0x6a, 0x8f, 0x8c, 0x78, 0x54,
0x96, 0x3b, 0x58, 0xe3, 0x43, 0x88, 0xc8, 0xe6, 0xda, 0xe8, 0x91, 0xdb, 0x75, 0x64,
0x81, 0xa0, 0x23, 0x12,
])
.decompress()
.unwrap()
}
fn const_n() -> c2_Element {
CompressedEdwardsY([
0xf0, 0x4f, 0x2e, 0x7e, 0xb7, 0x34, 0xb2, 0xa8, 0xf8, 0xb4, 0x72, 0xea, 0xf9, 0xc3,
0xc6, 0x32, 0x57, 0x6a, 0xc6, 0x4a, 0xea, 0x65, 0x0b, 0x49, 0x6a, 0x8a, 0x20, 0xff,
0x00, 0xe5, 0x83, 0xc3,
])
.decompress()
.unwrap()
}
fn const_s() -> c2_Element {
CompressedEdwardsY([
0x6f, 0x00, 0xda, 0xe8, 0x7c, 0x1b, 0xe1, 0xa7, 0x3b, 0x59, 0x22, 0xef, 0x43, 0x1c,
0xd8, 0xf5, 0x78, 0x79, 0x56, 0x9c, 0x22, 0x2d, 0x22, 0xb1, 0xcd, 0x71, 0xe8, 0x54,
0x6a, 0xb8, 0xe6, 0xf1,
])
.decompress()
.unwrap()
}
fn hash_to_scalar(s: &[u8]) -> c2_Scalar {
ed25519_hash_to_scalar(s)
}
fn random_scalar<T>(cspring: &mut T) -> c2_Scalar
where
T: RngCore + CryptoRng,
{
c2_Scalar::random(cspring)
}
fn scalar_neg(s: &c2_Scalar) -> c2_Scalar {
-s
}
fn element_to_bytes(s: &c2_Element) -> Vec<u8> {
s.compress().as_bytes().to_vec()
}
fn element_length() -> usize {
32
}
fn bytes_to_element(b: &[u8]) -> Option<c2_Element> {
if b.len() != 32 {
return None;
}
let mut bytes = [0u8; 32];
bytes.copy_from_slice(b);
let cey = CompressedEdwardsY(bytes);
cey.decompress()
}
fn basepoint_mult(s: &c2_Scalar) -> c2_Element {
ED25519_BASEPOINT_POINT * s
}
fn scalarmult(e: &c2_Element, s: &c2_Scalar) -> c2_Element {
e * s
}
fn add(a: &c2_Element, b: &c2_Element) -> c2_Element {
a + b
}
}
fn ed25519_hash_to_scalar(s: &[u8]) -> c2_Scalar {
let mut okm = [0u8; 32 + 16];
Hkdf::<Sha256>::new(Some(b""), s)
.expand(b"SPAKE2 pw", &mut okm)
.unwrap();
let mut reducible = [0u8; 64]; for (i, x) in okm.iter().enumerate().take(32 + 16) {
reducible[32 + 16 - 1 - i] = *x;
}
c2_Scalar::from_bytes_mod_order_wide(&reducible)
}
pub(crate) fn hash_ab(
password_vec: &[u8],
id_a: &[u8],
id_b: &[u8],
first_msg: &[u8],
second_msg: &[u8],
key_bytes: &[u8],
) -> Vec<u8> {
assert_eq!(first_msg.len(), 32);
assert_eq!(second_msg.len(), 32);
let mut transcript = [0u8; 6 * 32];
let mut pw_hash = Sha256::new();
pw_hash.update(password_vec);
transcript[0..32].copy_from_slice(&pw_hash.finalize());
let mut ida_hash = Sha256::new();
ida_hash.update(id_a);
transcript[32..64].copy_from_slice(&ida_hash.finalize());
let mut idb_hash = Sha256::new();
idb_hash.update(id_b);
transcript[64..96].copy_from_slice(&idb_hash.finalize());
transcript[96..128].copy_from_slice(first_msg);
transcript[128..160].copy_from_slice(second_msg);
transcript[160..192].copy_from_slice(key_bytes);
let mut hash = Sha256::new();
hash.update(transcript);
hash.finalize().to_vec()
}
pub(crate) fn hash_symmetric(
password_vec: &[u8],
id_s: &[u8],
msg_u: &[u8],
msg_v: &[u8],
key_bytes: &[u8],
) -> Vec<u8> {
assert_eq!(msg_u.len(), 32);
assert_eq!(msg_v.len(), 32);
let mut transcript = [0u8; 5 * 32];
let mut pw_hash = Sha256::new();
pw_hash.update(password_vec);
transcript[0..32].copy_from_slice(&pw_hash.finalize());
let mut ids_hash = Sha256::new();
ids_hash.update(id_s);
transcript[32..64].copy_from_slice(&ids_hash.finalize());
if msg_u < msg_v {
transcript[64..96].copy_from_slice(msg_u);
transcript[96..128].copy_from_slice(msg_v);
} else {
transcript[64..96].copy_from_slice(msg_v);
transcript[96..128].copy_from_slice(msg_u);
}
transcript[128..160].copy_from_slice(key_bytes);
let mut hash = Sha256::new();
hash.update(transcript);
hash.finalize().to_vec()
}