use crate::error::Unspecified;
use native_ossl::params::ParamBuilder;
use native_ossl::pkey::{DeriveCtx, KeygenCtx, Pkey, Private, Public};
use std::ffi::CStr;
use crate::spki::{P256_SPKI_HEADER, P384_SPKI_HEADER, X25519_SPKI_HEADER};
#[derive(Debug)]
pub struct Algorithm {
keygen_name: &'static CStr,
group_param: Option<&'static CStr>,
spki_header: &'static [u8],
raw_key_len: usize,
}
pub static X25519: Algorithm = Algorithm {
keygen_name: c"X25519",
group_param: None,
spki_header: X25519_SPKI_HEADER,
raw_key_len: 32,
};
pub static ECDH_P256: Algorithm = Algorithm {
keygen_name: c"EC",
group_param: Some(c"P-256"),
spki_header: P256_SPKI_HEADER,
raw_key_len: 65,
};
pub static ECDH_P384: Algorithm = Algorithm {
keygen_name: c"EC",
group_param: Some(c"P-384"),
spki_header: P384_SPKI_HEADER,
raw_key_len: 97,
};
pub struct EphemeralPrivateKey {
alg: &'static Algorithm,
priv_key: Pkey<Private>,
pub_key_bytes: Vec<u8>,
}
impl EphemeralPrivateKey {
pub fn generate(
alg: &'static Algorithm,
_rng: &dyn crate::rand::SecureRandom,
) -> Result<Self, Unspecified> {
let mut ctx = KeygenCtx::new(alg.keygen_name).map_err(|_| Unspecified)?;
if let Some(group_name) = alg.group_param {
let params = ParamBuilder::new()
.and_then(|b| b.push_utf8_string(c"group", group_name))
.and_then(ParamBuilder::build)
.map_err(|_| Unspecified)?;
ctx.set_params(¶ms).map_err(|_| Unspecified)?;
}
let priv_key = ctx.generate().map_err(|_| Unspecified)?;
let spki = priv_key.public_key_to_der().map_err(|_| Unspecified)?;
let raw_pub = spki
.get(alg.spki_header.len()..)
.ok_or(Unspecified)?
.to_vec();
Ok(Self {
alg,
priv_key,
pub_key_bytes: raw_pub,
})
}
pub fn compute_public_key(&self) -> Result<PublicKey, Unspecified> {
Ok(PublicKey {
alg: self.alg,
bytes: self.pub_key_bytes.clone(),
})
}
#[must_use]
pub fn algorithm(&self) -> &'static Algorithm {
self.alg
}
}
impl std::fmt::Debug for EphemeralPrivateKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EphemeralPrivateKey")
.field("alg", &self.alg)
.finish_non_exhaustive()
}
}
#[derive(Debug)]
pub struct PublicKey {
alg: &'static Algorithm,
bytes: Vec<u8>,
}
impl PublicKey {
#[must_use]
pub fn algorithm(&self) -> &'static Algorithm {
self.alg
}
}
impl AsRef<[u8]> for PublicKey {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
#[derive(Debug)]
pub struct UnparsedPublicKey<B> {
alg: &'static Algorithm,
bytes: B,
}
impl<B: AsRef<[u8]>> UnparsedPublicKey<B> {
pub fn new(algorithm: &'static Algorithm, bytes: B) -> Self {
Self {
alg: algorithm,
bytes,
}
}
pub fn algorithm(&self) -> &'static Algorithm {
self.alg
}
}
impl<B: AsRef<[u8]>> AsRef<[u8]> for UnparsedPublicKey<B> {
fn as_ref(&self) -> &[u8] {
self.bytes.as_ref()
}
}
#[allow(clippy::needless_pass_by_value)]
pub fn agree_ephemeral<B, F, R, E>(
my_private_key: EphemeralPrivateKey,
peer_public_key: &UnparsedPublicKey<B>,
error_value: E,
kdf: F,
) -> Result<R, E>
where
B: AsRef<[u8]>,
F: FnOnce(&[u8]) -> Result<R, E>,
{
match do_ecdh(
&my_private_key.priv_key,
my_private_key.alg,
peer_public_key,
) {
Ok(secret) => kdf(&secret),
Err(()) => Err(error_value),
}
}
fn do_ecdh<B>(
priv_key: &Pkey<Private>,
alg: &'static Algorithm,
peer_public_key: &UnparsedPublicKey<B>,
) -> Result<Vec<u8>, ()>
where
B: AsRef<[u8]>,
{
let peer_bytes = peer_public_key.bytes.as_ref();
if peer_bytes.len() != alg.raw_key_len {
return Err(());
}
if alg.group_param.is_some() && peer_bytes.first() != Some(&0x04) {
return Err(());
}
let mut spki = alg.spki_header.to_vec();
spki.extend_from_slice(peer_bytes);
let peer_key = Pkey::<Public>::from_der(&spki).map_err(|_| ())?;
let mut derive = DeriveCtx::new(priv_key).map_err(|_| ())?;
derive.set_peer(&peer_key).map_err(|_| ())?;
let len = derive.derive_len().map_err(|_| ())?;
let mut secret = vec![0u8; len];
derive.derive(&mut secret).map_err(|_| ())?;
Ok(secret)
}