use crate::common::{concat, i2osp};
use crate::elliptic_curve::{CurveType, GroupSpec, Ristretto255GroupSpec, WeierstrassGroupSpec};
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256, Sha384, Sha512};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CurveHashSuite {
P256Sha256,
P384Sha384,
P521Sha512,
Ristretto255Sha512,
}
impl CurveHashSuite {
pub fn from_name(name: &str) -> Self {
match name {
"P256_SHA256" => CurveHashSuite::P256Sha256,
"P384_SHA384" => CurveHashSuite::P384Sha384,
"P521_SHA512" => CurveHashSuite::P521Sha512,
"RISTRETTO255_SHA512" => CurveHashSuite::Ristretto255Sha512,
_ => panic!("Unknown suite: {}", name),
}
}
pub fn name(&self) -> &'static str {
match self {
CurveHashSuite::P256Sha256 => "P256_SHA256",
CurveHashSuite::P384Sha384 => "P384_SHA384",
CurveHashSuite::P521Sha512 => "P521_SHA512",
CurveHashSuite::Ristretto255Sha512 => "RISTRETTO255_SHA512",
}
}
}
#[derive(Debug, Clone, Copy)]
enum HashAlgorithm {
Sha256,
Sha384,
Sha512,
}
pub struct OprfCipherSuite {
identifier: String,
context_string: Vec<u8>,
hash_to_group_dst: Vec<u8>,
hash_to_scalar_dst: Vec<u8>,
derive_key_pair_dst: Vec<u8>,
group_spec: Arc<dyn GroupSpec>,
hash_algorithm: HashAlgorithm,
hash_output_length: usize,
}
impl OprfCipherSuite {
pub fn new(suite: CurveHashSuite) -> Self {
let (identifier, context_suffix, group_spec, hash_alg, hash_len): (
&str,
&str,
Arc<dyn GroupSpec>,
HashAlgorithm,
usize,
) = match suite {
CurveHashSuite::P256Sha256 => (
"P256-SHA256",
"P256-SHA256",
Arc::new(WeierstrassGroupSpec::new(CurveType::P256)),
HashAlgorithm::Sha256,
32,
),
CurveHashSuite::P384Sha384 => (
"P384-SHA384",
"P384-SHA384",
Arc::new(WeierstrassGroupSpec::new(CurveType::P384)),
HashAlgorithm::Sha384,
48,
),
CurveHashSuite::P521Sha512 => (
"P521-SHA512",
"P521-SHA512",
Arc::new(WeierstrassGroupSpec::new(CurveType::P521)),
HashAlgorithm::Sha512,
64,
),
CurveHashSuite::Ristretto255Sha512 => (
"ristretto255-SHA512",
"ristretto255-SHA512",
Arc::new(Ristretto255GroupSpec),
HashAlgorithm::Sha512,
64,
),
};
let context_string = build_context_string(context_suffix);
let hash_to_group_dst = concat(&[b"HashToGroup-", &context_string]);
let hash_to_scalar_dst = concat(&[b"HashToScalar-", &context_string]);
let derive_key_pair_dst = concat(&[b"DeriveKeyPair", &context_string]);
Self {
identifier: identifier.to_string(),
context_string,
hash_to_group_dst,
hash_to_scalar_dst,
derive_key_pair_dst,
group_spec,
hash_algorithm: hash_alg,
hash_output_length: hash_len,
}
}
pub fn identifier(&self) -> &str {
&self.identifier
}
pub fn context_string(&self) -> &[u8] {
&self.context_string
}
pub fn hash_to_group_dst(&self) -> &[u8] {
&self.hash_to_group_dst
}
pub fn hash_to_scalar_dst(&self) -> &[u8] {
&self.hash_to_scalar_dst
}
pub fn derive_key_pair_dst(&self) -> &[u8] {
&self.derive_key_pair_dst
}
pub fn group_spec(&self) -> &dyn GroupSpec {
self.group_spec.as_ref()
}
pub fn hash_output_length(&self) -> usize {
self.hash_output_length
}
pub fn element_size(&self) -> usize {
self.group_spec.element_size()
}
pub fn random_scalar(&self, rng: &mut dyn rand_core::CryptoRngCore) -> Vec<u8> {
self.group_spec.random_scalar(rng)
}
pub fn hash_to_scalar(&self, input: &[u8], dst: &[u8]) -> Vec<u8> {
self.group_spec.hash_to_scalar(input, dst)
}
pub fn derive_key_pair(&self, seed: &[u8], info: &[u8]) -> Vec<u8> {
let derive_input = concat(&[seed, &i2osp(info.len() as u32, 2), info]);
let mut counter: u16 = 0;
loop {
assert!(counter <= 255, "DeriveKeyPair: exceeded counter limit");
let counter_input = concat(&[&derive_input, &i2osp(counter as u32, 1)]);
let sk_s = self.hash_to_scalar(&counter_input, &self.derive_key_pair_dst);
if sk_s.iter().any(|&b| b != 0) {
return sk_s;
}
counter += 1;
}
}
pub fn finalize(&self, input: &[u8], blind: &[u8], evaluated_element: &[u8]) -> Vec<u8> {
let inverse_blind = self.group_spec.scalar_inverse(blind);
let unblinded_element = self
.group_spec
.scalar_multiply(&inverse_blind, evaluated_element);
let hash_input = concat(&[
&i2osp(input.len() as u32, 2),
input,
&i2osp(unblinded_element.len() as u32, 2),
&unblinded_element,
b"Finalize",
]);
self.hash(&hash_input)
}
pub fn hash(&self, data: &[u8]) -> Vec<u8> {
match self.hash_algorithm {
HashAlgorithm::Sha256 => Sha256::digest(data).to_vec(),
HashAlgorithm::Sha384 => Sha384::digest(data).to_vec(),
HashAlgorithm::Sha512 => Sha512::digest(data).to_vec(),
}
}
pub fn hmac(&self, key: &[u8], data: &[u8]) -> Vec<u8> {
match self.hash_algorithm {
HashAlgorithm::Sha256 => {
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC key length error");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
HashAlgorithm::Sha384 => {
let mut mac = Hmac::<Sha384>::new_from_slice(key).expect("HMAC key length error");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
HashAlgorithm::Sha512 => {
let mut mac = Hmac::<Sha512>::new_from_slice(key).expect("HMAC key length error");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
}
}
}
fn build_context_string(suffix: &str) -> Vec<u8> {
concat(&[b"OPRFV1-", &[0x00], format!("-{}", suffix).as_bytes()])
}