use crate::prelude::*;
use crate::{ClientKey, FheBool, FheInt16, FheInt32, FheInt64, FheInt8, FheUint64, FheUint8};
use rand::prelude::*;
mod cpu;
#[cfg(feature = "gpu")]
mod gpu;
#[allow(clippy::eq_op)]
fn test_case_int32_compare(cks: &ClientKey) {
let mut rng = thread_rng();
let clear_a = rng.gen::<i32>();
let clear_b = rng.gen::<i32>();
let a = FheInt32::encrypt(clear_a, cks);
let b = FheInt32::encrypt(clear_b, cks);
{
let result = &a.eq(&b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a == clear_b;
assert_eq!(decrypted_result, clear_result);
let result = &a.eq(&a);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a == clear_a;
assert_eq!(decrypted_result, clear_result);
let result = &a.ne(&b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a != clear_b;
assert_eq!(decrypted_result, clear_result);
let result = &a.ne(&a);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a != clear_a;
assert_eq!(decrypted_result, clear_result);
let result = &a.le(&b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a <= clear_b;
assert_eq!(decrypted_result, clear_result);
let result = &a.lt(&b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a < clear_b;
assert_eq!(decrypted_result, clear_result);
let result = &a.ge(&b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a >= clear_b;
assert_eq!(decrypted_result, clear_result);
let result = &a.gt(&b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a > clear_b;
assert_eq!(decrypted_result, clear_result);
}
{
let result = &a.eq(clear_b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a == clear_b;
assert_eq!(decrypted_result, clear_result);
let result = &a.eq(clear_a);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a == clear_a;
assert_eq!(decrypted_result, clear_result);
let result = &a.ne(clear_b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a != clear_b;
assert_eq!(decrypted_result, clear_result);
let result = &a.ne(clear_a);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a != clear_a;
assert_eq!(decrypted_result, clear_result);
let result = &a.le(clear_b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a <= clear_b;
assert_eq!(decrypted_result, clear_result);
let result = &a.lt(clear_b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a < clear_b;
assert_eq!(decrypted_result, clear_result);
let result = &a.ge(clear_b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a >= clear_b;
assert_eq!(decrypted_result, clear_result);
let result = &a.gt(clear_b);
let decrypted_result = result.decrypt(cks);
let clear_result = clear_a > clear_b;
assert_eq!(decrypted_result, clear_result);
}
}
fn test_case_int32_bitwise(cks: &ClientKey) {
let mut rng = rand::thread_rng();
let clear_a = rng.gen::<i32>();
let clear_b = rng.gen::<i32>();
let a = FheInt32::try_encrypt(clear_a, cks).unwrap();
let b = FheInt32::try_encrypt(clear_b, cks).unwrap();
{
let c = &a | &b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a | clear_b);
let c = &a & &b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a & clear_b);
let c = &a ^ &b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a ^ clear_b);
let mut c = a.clone();
c |= &b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a | clear_b);
let mut c = a.clone();
c &= &b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a & clear_b);
let mut c = a.clone();
c ^= &b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a ^ clear_b);
}
{
let c = &a | b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a | clear_b);
let c = &a & clear_b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a & clear_b);
let c = &a ^ clear_b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a ^ clear_b);
let mut c = a.clone();
c |= clear_b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a | clear_b);
let mut c = a.clone();
c &= clear_b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a & clear_b);
let mut c = a;
c ^= clear_b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a ^ clear_b);
}
}
fn test_case_int64_rotate(cks: &ClientKey) {
let mut rng = thread_rng();
let clear_a = rng.gen::<i64>();
let clear_b = rng.gen_range(0u32..64u32);
let a = FheInt64::try_encrypt(clear_a, cks).unwrap();
let b = FheUint64::try_encrypt(clear_b, cks).unwrap();
{
let c = (&a).rotate_left(&b);
let decrypted: i64 = c.decrypt(cks);
assert_eq!(decrypted, clear_a.rotate_left(clear_b));
let c = (&a).rotate_right(&b);
let decrypted: i64 = c.decrypt(cks);
assert_eq!(decrypted, clear_a.rotate_right(clear_b));
let mut c = a.clone();
c.rotate_right_assign(&b);
let decrypted: i64 = c.decrypt(cks);
assert_eq!(decrypted, clear_a.rotate_right(clear_b));
let mut c = a.clone();
c.rotate_left_assign(&b);
let decrypted: i64 = c.decrypt(cks);
assert_eq!(decrypted, clear_a.rotate_left(clear_b));
}
{
let c = (&a).rotate_left(clear_b);
let decrypted: i64 = c.decrypt(cks);
assert_eq!(decrypted, clear_a.rotate_left(clear_b));
let c = (&a).rotate_right(clear_b);
let decrypted: i64 = c.decrypt(cks);
assert_eq!(decrypted, clear_a.rotate_right(clear_b));
let mut c = a.clone();
c.rotate_right_assign(clear_b);
let decrypted: i64 = c.decrypt(cks);
assert_eq!(decrypted, clear_a.rotate_right(clear_b));
let mut c = a;
c.rotate_left_assign(clear_b);
let decrypted: i64 = c.decrypt(cks);
assert_eq!(decrypted, clear_a.rotate_left(clear_b));
}
}
fn test_case_int32_div_rem(cks: &ClientKey) {
let mut rng = rand::thread_rng();
let clear_a = rng.gen::<i32>();
let clear_b = loop {
let value = rng.gen::<i32>();
if value != 0 {
break value;
}
};
let a = FheInt32::try_encrypt(clear_a, cks).unwrap();
let b = FheInt32::try_encrypt(clear_b, cks).unwrap();
{
let c = &a / &b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a / clear_b);
let c = &a % &b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a % clear_b);
let (q, r) = (&a).div_rem(&b);
let decrypted_q: i32 = q.decrypt(cks);
let decrypted_r: i32 = r.decrypt(cks);
assert_eq!(decrypted_q, clear_a / clear_b);
assert_eq!(decrypted_r, clear_a % clear_b);
let mut c = a.clone();
c /= &b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a / clear_b);
let mut c = a.clone();
c %= &b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a % clear_b);
}
{
let c = &a / clear_b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a / clear_b);
let c = &a % clear_b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a % clear_b);
let (q, r) = (&a).div_rem(clear_b);
let decrypted_q: i32 = q.decrypt(cks);
let decrypted_r: i32 = r.decrypt(cks);
assert_eq!(decrypted_q, clear_a / clear_b);
assert_eq!(decrypted_r, clear_a % clear_b);
let mut c = a.clone();
c /= clear_b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a / clear_b);
let mut c = a;
c %= clear_b;
let decrypted: i32 = c.decrypt(cks);
assert_eq!(decrypted, clear_a % clear_b);
}
}
fn test_case_integer_casting(cks: &ClientKey) {
let mut rng = rand::thread_rng();
for clear in [rng.gen_range(i16::MIN..0), rng.gen_range(0..=i16::MAX)] {
{
let a = FheInt16::encrypt(clear, cks);
let a: FheInt8 = a.cast_into();
let da: i8 = a.decrypt(cks);
assert_eq!(da, clear as i8);
let a: FheInt32 = a.cast_into();
let da: i32 = a.decrypt(cks);
assert_eq!(da, (clear as i8) as i32);
}
{
let a = FheInt16::encrypt(clear, cks);
let a = FheInt32::cast_from(a);
let da: i32 = a.decrypt(cks);
assert_eq!(da, clear as i32);
let a = FheInt8::cast_from(a);
let da: i8 = a.decrypt(cks);
assert_eq!(da, (clear as i32) as i8);
}
{
let a = FheInt16::encrypt(clear, cks);
let a = FheInt16::cast_from(a);
let da: i16 = a.decrypt(cks);
assert_eq!(da, clear);
}
{
let a = FheInt16::encrypt(clear, cks);
let a: FheUint8 = a.cast_into();
let da: u8 = a.decrypt(cks);
assert_eq!(da, clear as u8);
let a: FheInt32 = a.cast_into();
let da: i32 = a.decrypt(cks);
assert_eq!(da, (clear as u8) as i32);
}
{
let a = FheInt16::encrypt(clear, cks);
let a: FheUint64 = a.cast_into();
let da: u64 = a.decrypt(cks);
assert_eq!(da, clear as u64);
let a: FheInt32 = a.cast_into();
let da: i32 = a.decrypt(cks);
assert_eq!(da, (clear as u64) as i32);
}
}
}
fn test_case_if_then_else(cks: &ClientKey) {
let mut rng = rand::thread_rng();
let clear_a = rng.gen::<i8>();
let clear_b = rng.gen::<i8>();
let a = FheInt8::encrypt(clear_a, cks);
let b = FheInt8::encrypt(clear_b, cks);
let result = a.le(&b).if_then_else(&a, &b);
let decrypted_result: i8 = result.decrypt(cks);
assert_eq!(
decrypted_result,
if clear_a <= clear_b { clear_a } else { clear_b }
);
let result = a.le(&b).if_then_else(&b, &a);
let decrypted_result: i8 = result.decrypt(cks);
assert_eq!(
decrypted_result,
if clear_a <= clear_b { clear_b } else { clear_a }
);
}
fn test_case_flip(client_key: &ClientKey) {
let clear_a = rand::random::<i32>();
let clear_b = rand::random::<i32>();
let a = FheInt32::encrypt(clear_a, client_key);
let b = FheInt32::encrypt(clear_b, client_key);
let c = FheBool::encrypt(true, client_key);
let (ra, rb) = c.flip(&a, &b);
let decrypted_a: i32 = ra.decrypt(client_key);
let decrypted_b: i32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_b, clear_a));
let c = FheBool::encrypt(false, client_key);
let (ra, rb) = c.flip(&a, &b);
let decrypted_a: i32 = ra.decrypt(client_key);
let decrypted_b: i32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_a, clear_b));
}
fn test_case_scalar_flip(client_key: &ClientKey) {
let clear_a = rand::random::<i32>();
let clear_b = rand::random::<i32>();
let a = FheInt32::encrypt(clear_a, client_key);
let b = FheInt32::encrypt(clear_b, client_key);
let c = FheBool::encrypt(true, client_key);
let (ra, rb) = c.flip(&a, clear_b);
let decrypted_a: i32 = ra.decrypt(client_key);
let decrypted_b: i32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_b, clear_a));
let c = FheBool::encrypt(false, client_key);
let (ra, rb) = c.flip(clear_a, &b);
let decrypted_a: i32 = ra.decrypt(client_key);
let decrypted_b: i32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_a, clear_b));
}
fn test_case_abs(cks: &ClientKey) {
let mut rng = rand::thread_rng();
for clear in [rng.gen_range(i64::MIN..0), rng.gen_range(0..=i64::MAX)] {
let a = FheInt64::encrypt(clear, cks);
let abs_a = a.abs();
let decrypted_result: i64 = abs_a.decrypt(cks);
assert_eq!(decrypted_result, clear.abs());
}
}
fn test_case_integer_compress_decompress(cks: &ClientKey) {
let a = FheInt8::try_encrypt(-83i8, cks).unwrap();
let clear: i8 = a.compress().decompress().decrypt(cks);
assert_eq!(clear, -83i8);
}
fn test_case_leading_trailing_zeros_ones(cks: &ClientKey) {
let mut rng = thread_rng();
for _ in 0..5 {
let clear_a = rng.gen::<i32>();
let a = FheInt32::try_encrypt(clear_a, cks).unwrap();
let leading_zeros: u32 = a.leading_zeros().decrypt(cks);
assert_eq!(leading_zeros, clear_a.leading_zeros());
let leading_ones: u32 = a.leading_ones().decrypt(cks);
assert_eq!(leading_ones, clear_a.leading_ones());
let trailing_zeros: u32 = a.trailing_zeros().decrypt(cks);
assert_eq!(trailing_zeros, clear_a.trailing_zeros());
let trailing_ones: u32 = a.trailing_ones().decrypt(cks);
assert_eq!(trailing_ones, clear_a.trailing_ones());
}
}
fn test_case_ilog2(cks: &ClientKey) {
let mut rng = thread_rng();
for _ in 0..5 {
let clear_a = rng.gen_range(1..=i32::MAX);
let a = FheInt32::try_encrypt(clear_a, cks).unwrap();
let ilog2: u32 = a.ilog2().decrypt(cks);
assert_eq!(ilog2, clear_a.ilog2());
let (ilog2, is_ok) = a.checked_ilog2();
let ilog2: u32 = ilog2.decrypt(cks);
let is_ok = is_ok.decrypt(cks);
assert!(is_ok);
assert_eq!(ilog2, clear_a.ilog2());
}
for _ in 0..5 {
let a = FheInt32::try_encrypt(rng.gen_range(i32::MIN..=0), cks).unwrap();
let (_ilog2, is_ok) = a.checked_ilog2();
let is_ok = is_ok.decrypt(cks);
assert!(!is_ok);
}
}
fn test_case_min_max(cks: &ClientKey) {
let mut rng = rand::thread_rng();
let a_val: i8 = rng.gen();
let b_val: i8 = rng.gen();
let a = FheInt8::encrypt(a_val, cks);
let b = FheInt8::encrypt(b_val, cks);
let encrypted_min = a.min(&b);
let encrypted_max = a.max(&b);
let decrypted_min: i8 = encrypted_min.decrypt(cks);
let decrypted_max: i8 = encrypted_max.decrypt(cks);
assert_eq!(decrypted_min, a_val.min(b_val));
assert_eq!(decrypted_max, a_val.max(b_val));
let encrypted_min = a.min(b.clone());
let encrypted_max = a.max(b);
let decrypted_min: i8 = encrypted_min.decrypt(cks);
let decrypted_max: i8 = encrypted_max.decrypt(cks);
assert_eq!(decrypted_min, a_val.min(b_val));
assert_eq!(decrypted_max, a_val.max(b_val));
}
fn test_case_int16_fused_mul_div(cks: &ClientKey) {
let mut rng = rand::thread_rng();
{
let input = -200i16;
let mul = 200i16;
let div = 100i16;
let expected = -400i16;
let a = FheInt16::try_encrypt(input, cks).unwrap();
let b = FheInt16::try_encrypt(mul, cks).unwrap();
let result = (&a).fused_mul_scalar_div(&b, div);
let decrypted: i16 = result.decrypt(cks);
assert_eq!(decrypted, expected);
let result = (&a).fused_scalar_mul_scalar_div(mul, div);
let decrypted: i16 = result.decrypt(cks);
assert_eq!(decrypted, expected);
}
for _ in 0..5 {
let clear_a: i16 = rng.gen();
let clear_b: i16 = rng.gen();
let clear_c: i16 = loop {
let v: i16 = rng.gen();
if v != 0 {
break v;
}
};
let a = FheInt16::try_encrypt(clear_a, cks).unwrap();
let b = FheInt16::try_encrypt(clear_b, cks).unwrap();
let expected = (i32::from(clear_a) * i32::from(clear_b) / i32::from(clear_c)) as i16;
{
let result = (&a).fused_mul_scalar_div(&b, clear_c);
let decrypted: i16 = result.decrypt(cks);
assert_eq!(decrypted, expected);
}
{
let result = (&a).fused_scalar_mul_scalar_div(clear_b, clear_c);
let decrypted: i16 = result.decrypt(cks);
assert_eq!(decrypted, expected);
}
{
let result = a.fused_mul_scalar_div(b, clear_c);
let decrypted: i16 = result.decrypt(cks);
assert_eq!(decrypted, expected);
let a = FheInt16::try_encrypt(clear_a, cks).unwrap();
let result = a.fused_scalar_mul_scalar_div(clear_b, clear_c);
let decrypted: i16 = result.decrypt(cks);
assert_eq!(decrypted, expected);
}
}
}