use alloc::{string::ToString, vec::Vec};
use core::ops::Range;
use miden_crypto_derive::{SilentDebug, SilentDisplay};
use num::Integer;
use rand::{
Rng,
distr::{Distribution, StandardUniform, Uniform},
};
use subtle::ConstantTimeEq;
use crate::{
Felt, ONE, Word, ZERO,
aead::{AeadScheme, DataType, EncryptionError},
hash::poseidon2::Poseidon2,
utils::{
ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
bytes_to_elements_exact, bytes_to_elements_with_padding, elements_to_bytes,
padded_elements_to_bytes,
zeroize::{Zeroize, ZeroizeOnDrop},
},
};
#[cfg(test)]
mod test;
pub const SECRET_KEY_SIZE: usize = 4;
pub const SK_SIZE_BYTES: usize = SECRET_KEY_SIZE * Felt::NUM_BYTES;
pub const NONCE_SIZE: usize = 4;
pub const NONCE_SIZE_BYTES: usize = NONCE_SIZE * Felt::NUM_BYTES;
pub const AUTH_TAG_SIZE: usize = 4;
const STATE_WIDTH: usize = Poseidon2::STATE_WIDTH;
const CAPACITY_RANGE: Range<usize> = Poseidon2::CAPACITY_RANGE;
const RATE_RANGE: Range<usize> = Poseidon2::RATE_RANGE;
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
const HALF_RATE_WIDTH: usize = (Poseidon2::RATE_RANGE.end - Poseidon2::RATE_RANGE.start) / 2;
const RATE_RANGE_FIRST_HALF: Range<usize> =
Poseidon2::RATE_RANGE.start..Poseidon2::RATE_RANGE.start + HALF_RATE_WIDTH;
const RATE_RANGE_SECOND_HALF: Range<usize> =
Poseidon2::RATE_RANGE.start + HALF_RATE_WIDTH..Poseidon2::RATE_RANGE.end;
const RATE_START: usize = Poseidon2::RATE_RANGE.start;
const PADDING_BLOCK: [Felt; RATE_WIDTH] = [ONE, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO];
#[derive(Debug, PartialEq, Eq)]
pub struct EncryptedData {
data_type: DataType,
ciphertext: Vec<Felt>,
auth_tag: AuthTag,
nonce: Nonce,
}
impl EncryptedData {
pub fn from_parts(
data_type: DataType,
ciphertext: Vec<Felt>,
auth_tag: AuthTag,
nonce: Nonce,
) -> Self {
Self { data_type, ciphertext, auth_tag, nonce }
}
pub fn data_type(&self) -> DataType {
self.data_type
}
pub fn ciphertext(&self) -> &[Felt] {
&self.ciphertext
}
pub fn auth_tag(&self) -> &AuthTag {
&self.auth_tag
}
pub fn nonce(&self) -> &Nonce {
&self.nonce
}
}
#[derive(Debug, Default, Clone)]
pub struct AuthTag([Felt; AUTH_TAG_SIZE]);
impl AuthTag {
pub fn new(elements: [Felt; AUTH_TAG_SIZE]) -> Self {
Self(elements)
}
pub fn to_elements(&self) -> [Felt; AUTH_TAG_SIZE] {
self.0
}
}
impl PartialEq for AuthTag {
fn eq(&self, other: &Self) -> bool {
let mut result = true;
for (a, b) in self.0.iter().zip(other.0.iter()) {
result &= bool::from(a.as_canonical_u64_ct().ct_eq(&b.as_canonical_u64_ct()));
}
result
}
}
impl Eq for AuthTag {}
#[derive(Clone, SilentDebug, SilentDisplay)]
pub struct SecretKey([Felt; SECRET_KEY_SIZE]);
impl SecretKey {
#[cfg(feature = "std")]
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let mut rng = rand::rng();
Self::with_rng(&mut rng)
}
pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
rng.sample(StandardUniform)
}
pub fn from_elements(elements: [Felt; SECRET_KEY_SIZE]) -> Self {
Self(elements)
}
pub fn to_elements(&self) -> [Felt; SECRET_KEY_SIZE] {
self.0
}
#[cfg(feature = "std")]
pub fn encrypt_elements(&self, data: &[Felt]) -> Result<EncryptedData, EncryptionError> {
self.encrypt_elements_with_associated_data(data, &[])
}
#[cfg(feature = "std")]
pub fn encrypt_elements_with_associated_data(
&self,
data: &[Felt],
associated_data: &[Felt],
) -> Result<EncryptedData, EncryptionError> {
let mut rng = rand::rng();
let nonce = Nonce::with_rng(&mut rng);
self.encrypt_elements_with_nonce(data, associated_data, nonce)
}
pub fn encrypt_elements_with_nonce(
&self,
data: &[Felt],
associated_data: &[Felt],
nonce: Nonce,
) -> Result<EncryptedData, EncryptionError> {
let mut sponge = SpongeState::new(self, &nonce);
let padded_associated_data = pad(associated_data);
padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
sponge.duplex_overwrite(chunk);
});
let mut ciphertext = Vec::with_capacity(data.len() + RATE_WIDTH);
let data = pad(data);
let mut data_block_iterator = data.chunks_exact(RATE_WIDTH);
data_block_iterator.by_ref().for_each(|data_block| {
let keystream = sponge.duplex_add(data_block);
for (i, &plaintext_felt) in data_block.iter().enumerate() {
ciphertext.push(plaintext_felt + keystream[i]);
}
});
let auth_tag = sponge.squeeze_tag();
Ok(EncryptedData {
data_type: DataType::Elements,
ciphertext,
auth_tag,
nonce,
})
}
#[cfg(feature = "std")]
pub fn encrypt_bytes(&self, data: &[u8]) -> Result<EncryptedData, EncryptionError> {
self.encrypt_bytes_with_associated_data(data, &[])
}
#[cfg(feature = "std")]
pub fn encrypt_bytes_with_associated_data(
&self,
data: &[u8],
associated_data: &[u8],
) -> Result<EncryptedData, EncryptionError> {
let mut rng = rand::rng();
let nonce = Nonce::with_rng(&mut rng);
self.encrypt_bytes_with_nonce(data, associated_data, nonce)
}
pub fn encrypt_bytes_with_nonce(
&self,
data: &[u8],
associated_data: &[u8],
nonce: Nonce,
) -> Result<EncryptedData, EncryptionError> {
let data_felt = bytes_to_elements_with_padding(data);
let ad_felt = bytes_to_elements_with_padding(associated_data);
let mut encrypted_data = self.encrypt_elements_with_nonce(&data_felt, &ad_felt, nonce)?;
encrypted_data.data_type = DataType::Bytes;
Ok(encrypted_data)
}
pub fn decrypt_elements(
&self,
encrypted_data: &EncryptedData,
) -> Result<Vec<Felt>, EncryptionError> {
self.decrypt_elements_with_associated_data(encrypted_data, &[])
}
pub fn decrypt_elements_with_associated_data(
&self,
encrypted_data: &EncryptedData,
associated_data: &[Felt],
) -> Result<Vec<Felt>, EncryptionError> {
if encrypted_data.data_type != DataType::Elements {
return Err(EncryptionError::InvalidDataType {
expected: DataType::Elements,
found: encrypted_data.data_type,
});
}
self.decrypt_elements_with_associated_data_unchecked(encrypted_data, associated_data)
}
fn decrypt_elements_with_associated_data_unchecked(
&self,
encrypted_data: &EncryptedData,
associated_data: &[Felt],
) -> Result<Vec<Felt>, EncryptionError> {
if !encrypted_data.ciphertext.len().is_multiple_of(RATE_WIDTH) {
return Err(EncryptionError::CiphertextLenNotMultipleRate);
}
let mut sponge = SpongeState::new(self, &encrypted_data.nonce);
let padded_associated_data = pad(associated_data);
padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
sponge.duplex_overwrite(chunk);
});
let mut plaintext = Vec::with_capacity(encrypted_data.ciphertext.len());
let mut ciphertext_block_iterator = encrypted_data.ciphertext.chunks_exact(RATE_WIDTH);
ciphertext_block_iterator.by_ref().for_each(|ciphertext_data_block| {
let keystream = sponge.duplex_add(&[]);
for (i, &ciphertext_felt) in ciphertext_data_block.iter().enumerate() {
let plaintext_felt = ciphertext_felt - keystream[i];
plaintext.push(plaintext_felt);
}
sponge.state[RATE_RANGE].copy_from_slice(ciphertext_data_block);
});
let computed_tag = sponge.squeeze_tag();
if computed_tag != encrypted_data.auth_tag {
return Err(EncryptionError::InvalidAuthTag);
}
unpad(plaintext)
}
pub fn decrypt_bytes(
&self,
encrypted_data: &EncryptedData,
) -> Result<Vec<u8>, EncryptionError> {
self.decrypt_bytes_with_associated_data(encrypted_data, &[])
}
pub fn decrypt_bytes_with_associated_data(
&self,
encrypted_data: &EncryptedData,
associated_data: &[u8],
) -> Result<Vec<u8>, EncryptionError> {
if encrypted_data.data_type != DataType::Bytes {
return Err(EncryptionError::InvalidDataType {
expected: DataType::Bytes,
found: encrypted_data.data_type,
});
}
let ad_felt = bytes_to_elements_with_padding(associated_data);
let data_felts =
self.decrypt_elements_with_associated_data_unchecked(encrypted_data, &ad_felt)?;
match padded_elements_to_bytes(&data_felts) {
Some(bytes) => Ok(bytes),
None => Err(EncryptionError::MalformedPadding),
}
}
}
impl Distribution<SecretKey> for StandardUniform {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> SecretKey {
let mut res = [ZERO; SECRET_KEY_SIZE];
let uni_dist =
Uniform::new(0, Felt::ORDER).expect("should not fail given the size of the field");
for r in res.iter_mut() {
let sampled_integer = uni_dist.sample(rng);
*r = Felt::new_unchecked(sampled_integer);
}
SecretKey(res)
}
}
#[cfg(any(test, feature = "testing"))]
impl PartialEq for SecretKey {
fn eq(&self, other: &Self) -> bool {
let mut result = true;
for (a, b) in self.0.iter().zip(other.0.iter()) {
result &= bool::from(a.as_canonical_u64_ct().ct_eq(&b.as_canonical_u64_ct()));
}
result
}
}
#[cfg(any(test, feature = "testing"))]
impl Eq for SecretKey {}
impl Zeroize for SecretKey {
fn zeroize(&mut self) {
for element in self.0.iter_mut() {
unsafe {
core::ptr::write_volatile(element, ZERO);
}
}
core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
}
}
impl Drop for SecretKey {
fn drop(&mut self) {
self.zeroize();
}
}
impl ZeroizeOnDrop for SecretKey {}
struct SpongeState {
state: [Felt; STATE_WIDTH],
}
impl SpongeState {
fn new(sk: &SecretKey, nonce: &Nonce) -> Self {
let mut state = [ZERO; STATE_WIDTH];
state[RATE_RANGE_FIRST_HALF].copy_from_slice(&sk.0);
state[RATE_RANGE_SECOND_HALF].copy_from_slice(&nonce.0);
Self { state }
}
fn duplex_overwrite(&mut self, data: &[Felt]) {
self.permute();
self.state[CAPACITY_RANGE.start] += ONE;
self.state[RATE_RANGE].copy_from_slice(data);
}
fn duplex_add(&mut self, data: &[Felt]) -> [Felt; RATE_WIDTH] {
self.permute();
let squeezed_data = self.squeeze_rate();
for (idx, &element) in data.iter().enumerate() {
self.state[RATE_START + idx] += element;
}
squeezed_data
}
fn squeeze_tag(&mut self) -> AuthTag {
self.permute();
AuthTag(
self.state[RATE_RANGE_FIRST_HALF]
.try_into()
.expect("rate first half is exactly AUTH_TAG_SIZE elements"),
)
}
fn permute(&mut self) {
Poseidon2::apply_permutation(&mut self.state);
}
fn squeeze_rate(&self) -> [Felt; RATE_WIDTH] {
self.state[RATE_RANGE]
.try_into()
.expect("rate range is exactly RATE_WIDTH elements")
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Nonce([Felt; NONCE_SIZE]);
impl Nonce {
pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
rng.sample(StandardUniform)
}
}
impl From<Word> for Nonce {
fn from(word: Word) -> Self {
Nonce(word.into())
}
}
impl From<[Felt; NONCE_SIZE]> for Nonce {
fn from(elements: [Felt; NONCE_SIZE]) -> Self {
Nonce(elements)
}
}
impl From<Nonce> for Word {
fn from(nonce: Nonce) -> Self {
nonce.0.into()
}
}
impl From<Nonce> for [Felt; NONCE_SIZE] {
fn from(nonce: Nonce) -> Self {
nonce.0
}
}
impl Distribution<Nonce> for StandardUniform {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Nonce {
let mut res = [ZERO; NONCE_SIZE];
let uni_dist =
Uniform::new(0, Felt::ORDER).expect("should not fail given the size of the field");
for r in res.iter_mut() {
let sampled_integer = uni_dist.sample(rng);
*r = Felt::new_unchecked(sampled_integer);
}
Nonce(res)
}
}
impl Serializable for SecretKey {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
let bytes = elements_to_bytes(&self.0);
target.write_bytes(&bytes);
}
}
impl Deserializable for SecretKey {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let bytes: [u8; SK_SIZE_BYTES] = source.read_array()?;
match bytes_to_elements_exact(&bytes) {
Some(inner) => {
let inner: [Felt; 4] = inner.try_into().map_err(|_| {
DeserializationError::InvalidValue("malformed secret key".to_string())
})?;
Ok(Self(inner))
},
None => Err(DeserializationError::InvalidValue("malformed secret key".to_string())),
}
}
}
impl Serializable for Nonce {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
let bytes = elements_to_bytes(&self.0);
target.write_bytes(&bytes);
}
}
impl Deserializable for Nonce {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let bytes: [u8; NONCE_SIZE_BYTES] = source.read_array()?;
match bytes_to_elements_exact(&bytes) {
Some(inner) => {
let inner: [Felt; 4] = inner.try_into().map_err(|_| {
DeserializationError::InvalidValue("malformed nonce".to_string())
})?;
Ok(Self(inner))
},
None => Err(DeserializationError::InvalidValue("malformed nonce".to_string())),
}
}
}
impl Serializable for EncryptedData {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_u8(self.data_type as u8);
self.ciphertext.write_into(target);
target.write_many(self.nonce.0);
target.write_many(self.auth_tag.0);
}
}
impl Deserializable for EncryptedData {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let data_type_value: u8 = source.read_u8()?;
let data_type = data_type_value.try_into().map_err(|_| {
DeserializationError::InvalidValue("invalid data type value".to_string())
})?;
let ciphertext = Vec::<Felt>::read_from(source)?;
let nonce: [Felt; NONCE_SIZE] = source.read()?;
let auth_tag: [Felt; AUTH_TAG_SIZE] = source.read()?;
Ok(Self {
ciphertext,
nonce: Nonce(nonce),
auth_tag: AuthTag(auth_tag),
data_type,
})
}
}
fn pad(data: &[Felt]) -> Vec<Felt> {
let num_elem_final_block = data.len() % RATE_WIDTH;
let padding_elements = RATE_WIDTH - num_elem_final_block;
let mut result = data.to_vec();
result.extend_from_slice(&PADDING_BLOCK[..padding_elements]);
result
}
fn unpad(mut plaintext: Vec<Felt>) -> Result<Vec<Felt>, EncryptionError> {
let (num_blocks, remainder) = plaintext.len().div_rem(&RATE_WIDTH);
assert_eq!(remainder, 0);
let final_block: &[Felt; RATE_WIDTH] =
plaintext.last_chunk().ok_or(EncryptionError::MalformedPadding)?;
let pos = match final_block.iter().rposition(|entry| *entry == ONE) {
Some(pos) => pos,
None => return Err(EncryptionError::MalformedPadding),
};
plaintext.truncate((num_blocks - 1) * RATE_WIDTH + pos);
Ok(plaintext)
}
pub struct AeadPoseidon2;
impl AeadScheme for AeadPoseidon2 {
const KEY_SIZE: usize = SK_SIZE_BYTES;
type Key = SecretKey;
fn key_from_bytes(bytes: &[u8]) -> Result<Self::Key, EncryptionError> {
if bytes.len() != SK_SIZE_BYTES {
return Err(EncryptionError::FailedOperation);
}
SecretKey::read_from_bytes_with_budget(bytes, SK_SIZE_BYTES)
.map_err(|_| EncryptionError::FailedOperation)
}
fn encrypt_bytes<R: rand::CryptoRng + rand::RngCore>(
key: &Self::Key,
rng: &mut R,
plaintext: &[u8],
associated_data: &[u8],
) -> Result<Vec<u8>, EncryptionError> {
let nonce = Nonce::with_rng(rng);
let encrypted_data = key
.encrypt_bytes_with_nonce(plaintext, associated_data, nonce)
.map_err(|_| EncryptionError::FailedOperation)?;
Ok(encrypted_data.to_bytes())
}
fn decrypt_bytes_with_associated_data(
key: &Self::Key,
ciphertext: &[u8],
associated_data: &[u8],
) -> Result<Vec<u8>, EncryptionError> {
let encrypted_data =
EncryptedData::read_from_bytes_with_budget(ciphertext, ciphertext.len())
.map_err(|_| EncryptionError::FailedOperation)?;
key.decrypt_bytes_with_associated_data(&encrypted_data, associated_data)
}
fn encrypt_elements<R: rand::CryptoRng + rand::RngCore>(
key: &Self::Key,
rng: &mut R,
plaintext: &[Felt],
associated_data: &[Felt],
) -> Result<Vec<u8>, EncryptionError> {
let nonce = Nonce::with_rng(rng);
let encrypted_data = key
.encrypt_elements_with_nonce(plaintext, associated_data, nonce)
.map_err(|_| EncryptionError::FailedOperation)?;
Ok(encrypted_data.to_bytes())
}
fn decrypt_elements_with_associated_data(
key: &Self::Key,
ciphertext: &[u8],
associated_data: &[Felt],
) -> Result<Vec<Felt>, EncryptionError> {
let encrypted_data =
EncryptedData::read_from_bytes_with_budget(ciphertext, ciphertext.len())
.map_err(|_| EncryptionError::FailedOperation)?;
key.decrypt_elements_with_associated_data(&encrypted_data, associated_data)
}
}