use cryptography::vt::{AffinePoint, BigUint, CurveParams};
use super::Rng;
pub struct DualEcDrbg {
curve: CurveParams,
p: AffinePoint, q: AffinePoint, s: BigUint, outlen: usize, buf: Vec<u8>, pos: usize, }
impl DualEcDrbg {
pub fn new(
curve: CurveParams,
p: AffinePoint,
q: AffinePoint,
seed: &[u8],
outlen: usize,
) -> Self {
assert!(
outlen > 0 && outlen.is_multiple_of(8),
"outlen must be a positive multiple of 8"
);
let s = BigUint::from_be_bytes(seed);
Self {
curve,
p,
q,
s,
outlen,
buf: Vec::new(),
pos: 0,
}
}
pub fn p256(seed: &[u8]) -> Self {
let curve = cryptography::vt::p256();
let p = curve.base_point();
let q = point_from_hex(
"c97445f45cdef9f0d3e05e1e585fc297235b82b5be8ff3efca67c59852018192",
"b28ef557ba31dfcbdd21ac46e2a91e3c304f44cb87058ada2cb815151e610046",
);
Self::new(curve, p, q, seed, 240)
}
pub fn p384(seed: &[u8]) -> Self {
let curve = cryptography::vt::p384();
let p = curve.base_point();
let q = point_from_hex(
"8e722de3125bddb05580164bfe20b8b432216a62926c57502ceede31c47816ed\
d1e89769124179d0b695106428815065",
"023b1660dd701d0839fd45eec36f9ee7b32e13b315dc02610aa1b636e346df67\
1f790f84c5e09b05674dbb7e45c803dd",
);
Self::new(curve, p, q, seed, 368)
}
pub fn p521(seed: &[u8]) -> Self {
let curve = cryptography::vt::p521();
let p = curve.base_point();
let q = point_from_hex(
"01b9fa3e518d683c6b65763694ac8efbaec6fab44f2276171a42726507dd08ad\
d4c3b3f4c1ebc5b1222ddba077f722943b24c3edfa0f85fe24d0c8c01591f0be6f63",
"01f3bdba585295d9a1110d1df1f9430ef8442c5018976ff3437ef91b81dc0b81\
32c8d5c39c32d0e004a3092b7d327c0e7a4d26d2c7b69b58f9066652911e457779de",
);
Self::new(curve, p, q, seed, 504)
}
fn generate_block(&mut self) {
let t_point = self.curve.scalar_mul(&self.p, &self.s);
let t = t_point.x;
let s_point = self.curve.scalar_mul(&self.p, &t);
self.s = s_point.x;
let r_point = self.curve.scalar_mul(&self.q, &t);
let r_bytes = to_be_padded(&r_point.x, self.curve.coord_len);
let skip = r_bytes.len() - self.outlen / 8;
self.buf = r_bytes[skip..].to_vec();
self.pos = 0;
}
}
impl Rng for DualEcDrbg {
fn next_u32(&mut self) -> u32 {
let remaining = self.buf.len() - self.pos;
if remaining < 4 {
let mut spill = [0u8; 3];
spill[..remaining].copy_from_slice(&self.buf[self.pos..]);
self.generate_block(); let mut bytes = [0u8; 4];
bytes[..remaining].copy_from_slice(&spill[..remaining]);
bytes[remaining..].copy_from_slice(&self.buf[..4 - remaining]);
self.pos = 4 - remaining;
return u32::from_be_bytes(bytes);
}
let word = u32::from_be_bytes([
self.buf[self.pos],
self.buf[self.pos + 1],
self.buf[self.pos + 2],
self.buf[self.pos + 3],
]);
self.pos += 4;
word
}
}
fn to_be_padded(x: &BigUint, len: usize) -> Vec<u8> {
let raw = x.to_be_bytes();
if raw.len() >= len {
raw[raw.len() - len..].to_vec()
} else {
let mut out = vec![0u8; len];
out[len - raw.len()..].copy_from_slice(&raw);
out
}
}
fn point_from_hex(x_hex: &str, y_hex: &str) -> AffinePoint {
let x = BigUint::from_be_bytes(&decode_hex(x_hex));
let y = BigUint::from_be_bytes(&decode_hex(y_hex));
AffinePoint::new(x, y)
}
fn decode_hex(s: &str) -> Vec<u8> {
let cleaned: String = s.chars().filter(|c| !c.is_whitespace()).collect();
assert!(
cleaned.len().is_multiple_of(2),
"hex string must have even length"
);
(0..cleaned.len())
.step_by(2)
.map(|i| u8::from_str_radix(&cleaned[i..i + 2], 16).expect("valid hex digit"))
.collect()
}