use std::sync::Arc;
use itertools::Itertools;
use rand::{RngCore, rng};
use crate::{
ArgsBuilder, Byte, Memory, ToArg,
proc::IsaOp,
register_names::*,
test_utils::{Bits, BitsUnsigned, MaybeEncryptedUInt, make_computer_128},
};
use parasol_runtime::{Encryption, SecretKey, test_utils::get_secret_keys_128};
enum CastType {
ZeroExtension,
SignExtension,
Truncation,
}
fn casting(cast_type: CastType, encrypted_computation: bool) {
let supported_sizes = [8u32, 16, 32];
let combinations = supported_sizes
.iter()
.cartesian_product(supported_sizes.iter())
.map(|(input_width, output_width)| {
(
(*input_width, *output_width),
match cast_type {
CastType::SignExtension | CastType::ZeroExtension => {
input_width <= output_width
}
CastType::Truncation => input_width >= output_width,
},
)
});
for ((input_width, output_width), valid) in combinations {
let (mut proc, enc) = make_computer_128();
let sk = get_secret_keys_128();
let enc = &enc;
let value = rng().next_u32();
let expected = match cast_type {
CastType::SignExtension => {
let mut sign_bit = value & (1 << (input_width - 1));
let mut extended = value & (((1u64 << input_width) - 1) as u32);
for _ in 0..output_width.saturating_sub(input_width) {
sign_bit <<= 1;
extended |= sign_bit;
}
extended
}
CastType::ZeroExtension => value & (((1u64 << input_width) - 1) as u32),
CastType::Truncation => value & (((1u64 << output_width) - 1) as u32),
};
let memory = Arc::new(Memory::new_default_stack());
let program = memory.allocate_program(&[
IsaOp::Load(T0, SP, 32, 0),
IsaOp::Load(T0, T0, input_width, 0),
IsaOp::Load(T1, SP, 32, 4),
match cast_type {
CastType::SignExtension => IsaOp::Sext(T0, T0, output_width),
CastType::ZeroExtension => IsaOp::Zext(T0, T0, output_width),
CastType::Truncation => IsaOp::Trunc(T0, T0, output_width),
},
IsaOp::Store(T1, T0, output_width, 0),
IsaOp::Ret(),
]);
let input = MaybeEncryptedUInt::<32>::new(value as u128, enc, &sk, encrypted_computation);
let input_ptr = memory.try_allocate(64).unwrap();
let output_ptr = memory.try_allocate(64).unwrap();
for (i, b) in input.to_bytes().into_iter().enumerate() {
memory
.try_store(input_ptr.try_offset(i as u32).unwrap(), b)
.unwrap();
}
let args = ArgsBuilder::new()
.arg(input_ptr)
.arg(output_ptr)
.no_return_value();
let result = proc.run_program(program, &memory, args);
match (valid, result) {
(true, Ok(())) => {
let ans_bytes = (0..output_width / 8)
.map(|x| memory.try_load(output_ptr.try_offset(x).unwrap()).unwrap())
.collect::<Vec<_>>();
fn get_ans<const N: usize>(
ans_bytes: Vec<Byte>,
enc: &Encryption,
sk: &SecretKey,
) -> u32
where
BitsUnsigned: Bits<N>,
<BitsUnsigned as Bits<N>>::PlaintextType: Into<u64>,
{
let ans = MaybeEncryptedUInt::<N>::try_from_bytes(ans_bytes).unwrap();
let ans: u64 = ans.get(enc, sk).into();
ans as u32
}
let ans = match output_width {
8 => get_ans::<8>(ans_bytes, enc, &sk),
16 => get_ans::<16>(ans_bytes, enc, &sk),
32 => get_ans::<32>(ans_bytes, enc, &sk),
_ => unreachable!(),
};
assert_eq!(
expected, ans,
"input_width: {input_width}, output_width: {output_width}"
);
}
(false, Err(_)) => {
continue;
}
(true, Err(e)) => panic!("Unexpected error: {e:?}"),
(false, Ok(())) => panic!("Expected error"),
}
}
}
#[test]
fn can_cast_zero_extend_plaintext() {
for _ in 0..5 {
casting(CastType::ZeroExtension, false);
}
}
#[test]
fn can_cast_zero_extend_ciphertext() {
for _ in 0..3 {
casting(CastType::ZeroExtension, true);
}
}
#[test]
fn can_cast_sign_extend_plaintext() {
for _ in 0..5 {
casting(CastType::SignExtension, false);
}
}
#[test]
fn can_cast_sign_extend_ciphertext() {
for _ in 0..3 {
casting(CastType::SignExtension, true);
}
}
#[test]
fn can_cast_truncate_plaintext() {
for _ in 0..5 {
casting(CastType::Truncation, false);
}
}
#[test]
fn can_cast_truncate_ciphertext() {
for _ in 3..5 {
casting(CastType::Truncation, true);
}
}