use alloc::{vec, vec::Vec};
use miden_core::events::EventName;
use miden_crypto::aead::{
DataType, EncryptionError,
aead_poseidon2::{AuthTag, EncryptedData, Nonce, SecretKey},
};
use miden_processor::{ProcessorState, advice::AdviceMutation, event::EventError};
use crate::handlers::read_memory_region;
pub const AEAD_DECRYPT_EVENT_NAME: EventName = EventName::new("miden::core::crypto::aead::decrypt");
pub fn handle_aead_decrypt(process: &ProcessorState) -> Result<Vec<AdviceMutation>, EventError> {
let key_word = process.get_stack_word(1);
let nonce_word = process.get_stack_word(5);
let src_ptr = process.get_stack_item(9).as_canonical_u64();
let num_blocks = process.get_stack_item(11).as_canonical_u64();
let (num_ciphertext_elements, tag_ptr, data_blocks_count) = compute_sizes(num_blocks, src_ptr)?;
let ciphertext = read_memory_region(process, src_ptr, num_ciphertext_elements).ok_or(
AeadDecryptError::MemoryReadFailed {
addr: src_ptr,
len: num_ciphertext_elements,
},
)?;
let tag_addr: u32 = tag_ptr
.try_into()
.ok()
.ok_or(AeadDecryptError::MemoryReadFailed { addr: tag_ptr, len: 4 })?;
let ctx = process.ctx();
let tag_word = process
.get_mem_word(ctx, tag_addr)
.map_err(|_| AeadDecryptError::MemoryReadFailed { addr: tag_ptr, len: 4 })?
.ok_or(AeadDecryptError::MemoryReadFailed { addr: tag_ptr, len: 4 })?;
let tag_elements: [miden_core::Felt; 4] = tag_word.into();
let secret_key = SecretKey::from_elements(key_word.into());
let nonce = Nonce::from(nonce_word);
let auth_tag = AuthTag::new(tag_elements);
let encrypted_data =
EncryptedData::from_parts(DataType::Elements, ciphertext, auth_tag, nonce.clone());
let plaintext_with_padding = secret_key.decrypt_elements(&encrypted_data)?;
let mut plaintext_data = plaintext_with_padding;
plaintext_data.truncate(data_blocks_count);
let advice_stack_mutation = AdviceMutation::extend_stack(plaintext_data);
Ok(vec![advice_stack_mutation])
}
fn compute_sizes(num_blocks: u64, src_ptr: u64) -> Result<(u64, u64, usize), AeadDecryptError> {
let num_ciphertext_elements = num_blocks
.checked_add(1)
.and_then(|blocks| blocks.checked_mul(8))
.ok_or(AeadDecryptError::SizeOverflow)?;
let tag_ptr = src_ptr
.checked_add(num_ciphertext_elements)
.ok_or(AeadDecryptError::SizeOverflow)?;
let data_blocks_count = num_blocks
.checked_mul(8)
.and_then(|count| count.try_into().ok())
.ok_or(AeadDecryptError::SizeOverflow)?;
Ok((num_ciphertext_elements, tag_ptr, data_blocks_count))
}
#[derive(Debug, thiserror::Error)]
enum AeadDecryptError {
#[error("failed to read memory region at addr={addr}, len={len}")]
MemoryReadFailed { addr: u64, len: u64 },
#[error("size overflow in AEAD decrypt handler")]
SizeOverflow,
#[error(transparent)]
DecryptionFailed(#[from] EncryptionError),
}
#[cfg(test)]
mod tests {
use crate::handlers::aead_decrypt::{AEAD_DECRYPT_EVENT_NAME, AeadDecryptError, compute_sizes};
#[test]
fn test_event_name() {
assert_eq!(AEAD_DECRYPT_EVENT_NAME.as_str(), "miden::core::crypto::aead::decrypt");
}
#[test]
fn test_compute_sizes_happy_path() {
let (num_ciphertext_elements, tag_ptr, data_blocks_count) =
compute_sizes(1, 0).expect("sizes should fit");
assert_eq!(num_ciphertext_elements, 16);
assert_eq!(tag_ptr, 16);
assert_eq!(data_blocks_count, 8);
}
#[test]
fn test_compute_sizes_overflow_num_blocks() {
let err = compute_sizes(u64::MAX, 0).expect_err("should overflow");
assert!(matches!(err, AeadDecryptError::SizeOverflow));
}
#[test]
fn test_compute_sizes_overflow_tag_ptr() {
let err = compute_sizes(0, u64::MAX).expect_err("should overflow tag ptr");
assert!(matches!(err, AeadDecryptError::SizeOverflow));
}
#[cfg(target_pointer_width = "32")]
#[test]
fn test_compute_sizes_overflow_data_blocks_count() {
let num_blocks = (usize::MAX as u64 / 8) + 1;
let err = compute_sizes(num_blocks, 0).expect_err("should overflow usize");
assert!(matches!(err, AeadDecryptError::SizeOverflow));
}
}