use std::convert::TryFrom;
use aes_kw::KekAes256 as Kek;
use serde::{Deserialize, Serialize};
use super::keys::{KeyError, PublicKey, SecretKey, PUBLIC_KEY_SIZE};
use super::secret::{Secret, SecretError, SECRET_SIZE};
pub const KW_NONCE_SIZE: usize = 8;
pub const SECRET_SHARE_SIZE: usize = PUBLIC_KEY_SIZE + SECRET_SIZE + KW_NONCE_SIZE;
#[derive(Debug, thiserror::Error)]
pub enum SecretShareError {
#[error("share error: {0}")]
Default(#[from] anyhow::Error),
#[error("key error: {0}")]
Key(#[from] KeyError),
#[error("secret error: {0}")]
Secret(#[from] SecretError),
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct SecretShare(pub(crate) [u8; SECRET_SHARE_SIZE]);
impl Serialize for SecretShare {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_bytes(&self.0)
}
}
impl<'de> Deserialize<'de> for SecretShare {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{Error, Visitor};
use std::fmt;
struct ShareVisitor;
impl<'de> Visitor<'de> for ShareVisitor {
type Value = SecretShare;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a byte array or sequence of SHARE_SIZE")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: Error,
{
if v.len() != SECRET_SHARE_SIZE {
return Err(E::invalid_length(
v.len(),
&format!("expected {} bytes", SECRET_SHARE_SIZE).as_str(),
));
}
let mut array = [0u8; SECRET_SHARE_SIZE];
array.copy_from_slice(v);
Ok(SecretShare(array))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut bytes = Vec::new();
while let Some(byte) = seq.next_element::<u8>()? {
bytes.push(byte);
}
if bytes.len() != SECRET_SHARE_SIZE {
return Err(A::Error::invalid_length(
bytes.len(),
&format!("expected {} bytes", SECRET_SHARE_SIZE).as_str(),
));
}
let mut array = [0u8; SECRET_SHARE_SIZE];
array.copy_from_slice(&bytes);
Ok(SecretShare(array))
}
}
deserializer.deserialize_byte_buf(ShareVisitor)
}
}
impl Default for SecretShare {
fn default() -> Self {
SecretShare([0; SECRET_SHARE_SIZE])
}
}
impl From<[u8; SECRET_SHARE_SIZE]> for SecretShare {
fn from(bytes: [u8; SECRET_SHARE_SIZE]) -> Self {
SecretShare(bytes)
}
}
impl From<SecretShare> for [u8; SECRET_SHARE_SIZE] {
fn from(share: SecretShare) -> Self {
share.0
}
}
impl TryFrom<&[u8]> for SecretShare {
type Error = SecretShareError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() != SECRET_SHARE_SIZE {
return Err(anyhow::anyhow!(
"invalid share size, expected {}, got {}",
SECRET_SHARE_SIZE,
bytes.len()
)
.into());
}
let mut share = SecretShare::default();
share.0.copy_from_slice(bytes);
Ok(share)
}
}
impl SecretShare {
pub fn from_hex(hex: &str) -> Result<Self, SecretShareError> {
let hex = hex.strip_prefix("0x").unwrap_or(hex);
let mut buff = [0; SECRET_SHARE_SIZE];
hex::decode_to_slice(hex, &mut buff).map_err(|_| anyhow::anyhow!("hex decode error"))?;
Ok(SecretShare::from(buff))
}
#[allow(clippy::wrong_self_convention)]
pub fn to_hex(&self) -> String {
hex::encode(self.0)
}
pub fn new(secret: &Secret, recipient: &PublicKey) -> Result<Self, SecretShareError> {
let ephemeral_private = SecretKey::generate();
let ephemeral_public = ephemeral_private.public();
let ephemeral_x25519_private = ephemeral_private.to_x25519();
let recipient_x25519_public = recipient.to_x25519()?;
let shared_secret = ephemeral_x25519_private.diffie_hellman(&recipient_x25519_public);
let mut shared_secret_bytes = [0; SECRET_SIZE];
shared_secret_bytes.copy_from_slice(shared_secret.as_bytes());
let kek = Kek::from(shared_secret_bytes);
let wrapped = kek
.wrap_vec(secret.bytes())
.map_err(|_| anyhow::anyhow!("AES-KW wrap error"))?;
let mut share = SecretShare::default();
let ephemeral_bytes = ephemeral_public.to_bytes();
if ephemeral_bytes.len() + wrapped.len() != SECRET_SHARE_SIZE {
return Err(anyhow::anyhow!("expected share size is incorrect").into());
};
share.0[..PUBLIC_KEY_SIZE].copy_from_slice(&ephemeral_bytes);
share.0[PUBLIC_KEY_SIZE..PUBLIC_KEY_SIZE + wrapped.len()].copy_from_slice(&wrapped);
Ok(share)
}
pub fn recover(&self, recipient_secret: &SecretKey) -> Result<Secret, SecretShareError> {
let ephemeral_public_bytes = &self.0[..PUBLIC_KEY_SIZE];
let ephemeral_public = PublicKey::try_from(ephemeral_public_bytes)?;
let recipient_x25519_private = recipient_secret.to_x25519();
let ephemeral_x25519_public = ephemeral_public.to_x25519()?;
let shared_secret = recipient_x25519_private.diffie_hellman(&ephemeral_x25519_public);
let shared_secret_bytes = *shared_secret.as_bytes();
let kek = Kek::from(shared_secret_bytes);
let wrapped_data = &self.0[PUBLIC_KEY_SIZE..];
let unwrapped = kek
.unwrap_vec(wrapped_data)
.map_err(|_| anyhow::anyhow!("AES-KW unwrap error"))?;
if unwrapped.len() != SECRET_SIZE {
return Err(anyhow::anyhow!("unwrapped secret has wrong size").into());
}
let mut secret_bytes = [0; SECRET_SIZE];
secret_bytes.copy_from_slice(&unwrapped);
Ok(Secret::from(secret_bytes))
}
pub fn bytes(&self) -> &[u8] {
&self.0
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_share_secret() {
let secret = Secret::from_slice(&[42u8; SECRET_SIZE]).unwrap();
let private_key = SecretKey::generate();
let public_key = private_key.public();
let share = SecretShare::new(&secret, &public_key).unwrap();
let recovered_secret = share.recover(&private_key).unwrap();
assert_eq!(secret, recovered_secret);
}
#[test]
fn test_share_different_keys() {
let secret = Secret::generate();
let alice_private = SecretKey::generate();
let alice_public = alice_private.public();
let bob_private = SecretKey::generate();
let share = SecretShare::new(&secret, &alice_public).unwrap();
let recovered_by_alice = share.recover(&alice_private).unwrap();
assert_eq!(secret, recovered_by_alice);
let result = share.recover(&bob_private);
assert!(result.is_err());
}
#[test]
fn test_share_hex_roundtrip() {
let secret = Secret::generate();
let private_key = SecretKey::generate();
let public_key = private_key.public();
let share = SecretShare::new(&secret, &public_key).unwrap();
let hex = share.to_hex();
let recovered_share = SecretShare::from_hex(&hex).unwrap();
assert_eq!(share, recovered_share);
let recovered_secret = recovered_share.recover(&private_key).unwrap();
assert_eq!(secret, recovered_secret);
}
#[test]
fn test_share_serde_json_roundtrip() {
let secret = Secret::generate();
let private_key = SecretKey::generate();
let public_key = private_key.public();
let share = SecretShare::new(&secret, &public_key).unwrap();
let json = serde_json::to_string(&share).unwrap();
let recovered_share: SecretShare = serde_json::from_str(&json).unwrap();
assert_eq!(share, recovered_share);
let recovered_secret = recovered_share.recover(&private_key).unwrap();
assert_eq!(secret, recovered_secret);
}
#[test]
fn test_share_serde_bincode_roundtrip() {
let secret = Secret::generate();
let private_key = SecretKey::generate();
let public_key = private_key.public();
let share = SecretShare::new(&secret, &public_key).unwrap();
let binary = bincode::serialize(&share).unwrap();
let recovered_share: SecretShare = bincode::deserialize(&binary).unwrap();
assert_eq!(share, recovered_share);
let recovered_secret = recovered_share.recover(&private_key).unwrap();
assert_eq!(secret, recovered_secret);
}
#[test]
fn test_share_deserialize_invalid_length() {
let short_data = vec![0u8; SECRET_SHARE_SIZE - 1];
let result: Result<SecretShare, _> =
bincode::deserialize(&bincode::serialize(&short_data).unwrap());
assert!(result.is_err());
let long_data = vec![0u8; SECRET_SHARE_SIZE + 1];
let result: Result<SecretShare, _> =
bincode::deserialize(&bincode::serialize(&long_data).unwrap());
assert!(result.is_err());
}
#[test]
fn test_share_deserialize_exact_size() {
let exact_data = vec![0u8; SECRET_SHARE_SIZE];
let serialized = bincode::serialize(&exact_data).unwrap();
let result: Result<SecretShare, _> = bincode::deserialize(&serialized);
assert!(result.is_ok());
let share = result.unwrap();
assert_eq!(share.0, [0u8; SECRET_SHARE_SIZE]);
}
#[test]
fn test_share_serde_multiple_formats() {
let secret = Secret::generate();
let private_key = SecretKey::generate();
let public_key = private_key.public();
let original_share = SecretShare::new(&secret, &public_key).unwrap();
let json = serde_json::to_string(&original_share).unwrap();
let json_share: SecretShare = serde_json::from_str(&json).unwrap();
assert_eq!(original_share, json_share);
let binary = bincode::serialize(&original_share).unwrap();
let binary_share: SecretShare = bincode::deserialize(&binary).unwrap();
assert_eq!(original_share, binary_share);
assert_eq!(json_share, binary_share);
let secret1 = json_share.recover(&private_key).unwrap();
let secret2 = binary_share.recover(&private_key).unwrap();
assert_eq!(secret, secret1);
assert_eq!(secret, secret2);
assert_eq!(secret1, secret2);
}
}