use rand_core::CryptoRng;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::ciphersuite::{CipherSuite, Kem};
pub const ENCAPSULATION_KEY_SIZE: usize = x_wing::ENCAPSULATION_KEY_SIZE;
pub const CIPHERTEXT_SIZE: usize = x_wing::CIPHERTEXT_SIZE;
#[derive(Debug, Clone, Copy)]
pub struct XWingKem;
#[derive(Zeroize, ZeroizeOnDrop)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct XWingSharedSecret([u8; 32]);
impl AsRef<[u8]> for XWingSharedSecret {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
pub struct XWingDecapsulationKey {
inner: x_wing::DecapsulationKey,
}
impl Zeroize for XWingDecapsulationKey {
fn zeroize(&mut self) {
use x_wing::KeyInit as _;
self.inner = x_wing::DecapsulationKey::new(&[0u8; 32].into());
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for XWingDecapsulationKey {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_bytes(self.inner.as_bytes())
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for XWingDecapsulationKey {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct DkVisitor;
impl<'de> serde::de::Visitor<'de> for DkVisitor {
type Value = XWingDecapsulationKey;
fn expecting(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "32 bytes for X-Wing decapsulation key seed")
}
fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
let seed: [u8; 32] = v
.try_into()
.map_err(|_| E::invalid_length(v.len(), &self))?;
Ok(XWingDecapsulationKey::from_seed(seed))
}
}
deserializer.deserialize_bytes(DkVisitor)
}
}
impl XWingDecapsulationKey {
pub fn from_seed(seed: [u8; 32]) -> Self {
use x_wing::KeyInit as _;
Self {
inner: x_wing::DecapsulationKey::new(&seed.into()),
}
}
pub fn encapsulation_key(&self) -> XWingEncapsulationKey {
use x_wing::{Decapsulator as _, KeyExport as _};
let inner_ek = self.inner.encapsulation_key();
let ek_bytes = inner_ek.to_bytes();
let mut bytes = [0u8; ENCAPSULATION_KEY_SIZE];
bytes.copy_from_slice(ek_bytes.as_slice());
XWingEncapsulationKey(bytes)
}
}
macro_rules! byte_array_newtype {
(
$(#[$meta:meta])*
$vis:vis struct $name:ident([u8; $size:expr]);
label = $label:expr;
) => {
$(#[$meta])*
#[derive(Clone)]
$vis struct $name([u8; $size]);
#[cfg(feature = "serde")]
impl serde::Serialize for $name {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_bytes(&self.0)
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for $name {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = $name;
fn expecting(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{} bytes for {}", $size, $label)
}
fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
$name::from_bytes(v).ok_or_else(|| E::invalid_length(v.len(), &self))
}
}
deserializer.deserialize_bytes(Visitor)
}
}
impl $name {
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() != $size {
return None;
}
let mut arr = [0u8; $size];
arr.copy_from_slice(bytes);
Some(Self(arr))
}
pub fn as_bytes(&self) -> &[u8; $size] {
&self.0
}
}
impl AsRef<[u8]> for $name {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl Zeroize for $name {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
};
}
byte_array_newtype! {
pub struct XWingEncapsulationKey([u8; ENCAPSULATION_KEY_SIZE]);
label = "X-Wing encapsulation key";
}
byte_array_newtype! {
pub struct XWingCiphertext([u8; CIPHERTEXT_SIZE]);
label = "X-Wing ciphertext";
}
#[derive(Debug, Clone, Copy)]
pub struct XWingKemError;
impl core::fmt::Display for XWingKemError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "X-Wing KEM operation failed")
}
}
impl Kem for XWingKem {
type EncapsulationKey = XWingEncapsulationKey;
type DecapsulationKey = XWingDecapsulationKey;
type Ciphertext = XWingCiphertext;
type SharedSecret = XWingSharedSecret;
type Error = XWingKemError;
fn generate(rng: &mut impl CryptoRng) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
use x_wing::{Decapsulator as _, KeyExport as _};
let dk = <x_wing::DecapsulationKey as x_wing::Generate>::generate_from_rng(rng);
let inner_ek = dk.encapsulation_key();
let ek_bytes = inner_ek.to_bytes();
let mut ek_arr = [0u8; ENCAPSULATION_KEY_SIZE];
ek_arr.copy_from_slice(ek_bytes.as_slice());
(
XWingDecapsulationKey { inner: dk },
XWingEncapsulationKey(ek_arr),
)
}
fn encaps(
ek: &Self::EncapsulationKey,
rng: &mut impl CryptoRng,
) -> Result<(Self::Ciphertext, Self::SharedSecret), Self::Error> {
use x_wing::Encapsulate as _;
let inner_ek =
x_wing::EncapsulationKey::try_from(ek.0.as_slice()).map_err(|_| XWingKemError)?;
let (ct, ss) = inner_ek.encapsulate_with_rng(rng);
let mut ct_bytes = [0u8; CIPHERTEXT_SIZE];
ct_bytes.copy_from_slice(ct.as_slice());
let mut ss_bytes = [0u8; 32];
ss_bytes.copy_from_slice(ss.as_slice());
Ok((XWingCiphertext(ct_bytes), XWingSharedSecret(ss_bytes)))
}
fn decaps(
dk: &Self::DecapsulationKey,
ct: &Self::Ciphertext,
) -> Result<Self::SharedSecret, Self::Error> {
use x_wing::Decapsulate as _;
let mut inner_ct = x_wing::Ciphertext::default();
inner_ct.copy_from_slice(&ct.0);
let ss = dk.inner.decapsulate(&inner_ct);
let mut ss_bytes = [0u8; 32];
ss_bytes.copy_from_slice(ss.as_slice());
Ok(XWingSharedSecret(ss_bytes))
}
}
#[derive(Debug, Clone, Copy)]
pub struct XWingSha3;
impl CipherSuite for XWingSha3 {
type Kem = XWingKem;
type Hash = sha3::Sha3_256;
}
#[cfg(test)]
mod tests {
use super::*;
use rand_core::UnwrapErr;
fn test_rng() -> UnwrapErr<getrandom::SysRng> {
UnwrapErr(getrandom::SysRng)
}
#[test]
fn test_kem_roundtrip() {
let mut rng = test_rng();
let (dk, ek) = XWingKem::generate(&mut rng);
let (ct, ss1) = XWingKem::encaps(&ek, &mut rng).unwrap();
let ss2 = XWingKem::decaps(&dk, &ct).unwrap();
assert_eq!(ss1.as_ref(), ss2.as_ref());
}
#[test]
fn test_key_serialization_roundtrip() {
let mut rng = test_rng();
let (dk, ek) = XWingKem::generate(&mut rng);
let ek_bytes = ek.as_bytes();
let ek2 = XWingEncapsulationKey::from_bytes(ek_bytes).unwrap();
assert_eq!(ek.as_ref(), ek2.as_ref());
let (ct, ss1) = XWingKem::encaps(&ek2, &mut rng).unwrap();
let ss2 = XWingKem::decaps(&dk, &ct).unwrap();
assert_eq!(ss1.as_ref(), ss2.as_ref());
}
#[test]
fn test_ciphertext_serialization_roundtrip() {
let mut rng = test_rng();
let (dk, ek) = XWingKem::generate(&mut rng);
let (ct, ss1) = XWingKem::encaps(&ek, &mut rng).unwrap();
let ct_bytes = ct.as_bytes();
let ct2 = XWingCiphertext::from_bytes(ct_bytes).unwrap();
let ss2 = XWingKem::decaps(&dk, &ct2).unwrap();
assert_eq!(ss1.as_ref(), ss2.as_ref());
}
#[test]
fn test_wrong_length_rejected() {
assert!(XWingEncapsulationKey::from_bytes(&[0u8; 32]).is_none());
assert!(XWingCiphertext::from_bytes(&[0u8; 32]).is_none());
}
#[test]
fn test_dk_from_seed_deterministic() {
let seed = [42u8; 32];
let dk1 = XWingDecapsulationKey::from_seed(seed);
let dk2 = XWingDecapsulationKey::from_seed(seed);
assert_eq!(
dk1.encapsulation_key().as_ref(),
dk2.encapsulation_key().as_ref()
);
}
}