tfhe 1.6.1

TFHE-rs is a fully homomorphic encryption (FHE) library that implements Zama's variant of TFHE.
Documentation
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey};
use crate::shortint::parameters::TestParameters;
use rand::Rng;
use std::sync::Arc;

fn encrypt_bits(cks: &RadixClientKey, bits: &[u64]) -> RadixCiphertext {
    RadixCiphertext::from(
        bits.iter()
            .map(|&bit| cks.encrypt_one_block(bit))
            .collect::<Vec<_>>(),
    )
}

fn decrypt_bits(cks: &RadixClientKey, ct: &RadixCiphertext) -> Vec<u8> {
    ct.blocks
        .iter()
        .map(|block| cks.decrypt_one_block(block) as u8)
        .collect()
}

struct TriviumRef {
    a: Vec<u8>,
    b: Vec<u8>,
    c: Vec<u8>,
}

impl TriviumRef {
    fn new(key: &[u8], iv: &[u8]) -> Self {
        let mut a = vec![0u8; 93];
        let mut b = vec![0u8; 84];
        let mut c = vec![0u8; 111];

        for i in 0..80 {
            a[i] = key[79 - i];
            b[i] = iv[79 - i];
        }

        c[108] = 1;
        c[109] = 1;
        c[110] = 1;

        let mut triv = Self { a, b, c };
        for _ in 0..(18 * 64) {
            triv.next();
        }
        triv
    }

    fn next(&mut self) -> u8 {
        let t1 = self.a[65] ^ self.a[92];
        let t2 = self.b[68] ^ self.b[83];
        let t3 = self.c[65] ^ self.c[110];

        let out = t1 ^ t2 ^ t3;

        let a_in = t3 ^ self.a[68] ^ (self.c[108] & self.c[109]);
        let b_in = t1 ^ self.b[77] ^ (self.a[90] & self.a[91]);
        let c_in = t2 ^ self.c[86] ^ (self.b[81] & self.b[82]);

        self.a.pop();
        self.a.insert(0, a_in);
        self.b.pop();
        self.b.insert(0, b_in);
        self.c.pop();
        self.c.insert(0, c_in);

        out
    }
}

#[test]
fn test_trivium_ref_consistency() {
    let key = vec![0u8; 80];
    let iv = vec![0u8; 80];

    let expected_hex = "FBE0BF265859051B";

    let mut trivium = TriviumRef::new(&key, &iv);
    let mut output_bits = Vec::new();
    for _ in 0..64 {
        output_bits.push(trivium.next());
    }

    let packed = get_hexadecimal_string_from_lsb_first_stream(&output_bits);

    assert_eq!(&packed[0..16], expected_hex);
}

fn get_hexadecimal_string_from_lsb_first_stream(a: &[u8]) -> String {
    assert!(a.len().is_multiple_of(8));
    let mut hexadecimal = String::new();
    for test in a.chunks(8) {
        let to_hex = |chunk: &[u8]| -> char {
            let mut val = 0u8;
            if chunk[0] == 1 {
                val |= 1;
            }
            if chunk[1] == 1 {
                val |= 2;
            }
            if chunk[2] == 1 {
                val |= 4;
            }
            if chunk[3] == 1 {
                val |= 8;
            }

            match val {
                0..=9 => (val + b'0') as char,
                10..=15 => (val - 10 + b'A') as char,
                _ => unreachable!(),
            }
        };

        hexadecimal.push(to_hex(&test[4..8]));
        hexadecimal.push(to_hex(&test[0..4]));
    }
    hexadecimal
}

pub fn trivium_test_vector_1_test<P, E>(param: P, mut executor: E)
where
    P: Into<TestParameters>,
    E: for<'a> FunctionExecutor<
        (&'a RadixCiphertext, &'a RadixCiphertext, usize),
        crate::Result<RadixCiphertext>,
    >,
{
    let param = param.into();
    let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
    let cks = RadixClientKey::from((cks, 1));
    let sks = Arc::new(sks);
    executor.setup(&cks, sks);

    let key = vec![0u64; 80];
    let iv = vec![0u64; 80];

    let expected_output_0_63 = "FBE0BF265859051B517A2E4E239FC97F563203161907CF2DE7A8790FA1B2E9CDF75292030268B7382B4C1A759AA2599A285549986E74805903801A4CB5A5D4F2";

    let ct_key = encrypt_bits(&cks, &key);
    let ct_iv = encrypt_bits(&cks, &iv);

    let num_steps = 512;
    let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();

    let decrypted_bits = decrypt_bits(&cks, &output_radix);
    let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);

    assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}

pub fn trivium_test_vector_2_test<P, E>(param: P, mut executor: E)
where
    P: Into<TestParameters>,
    E: for<'a> FunctionExecutor<
        (&'a RadixCiphertext, &'a RadixCiphertext, usize),
        crate::Result<RadixCiphertext>,
    >,
{
    let param = param.into();
    let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
    let cks = RadixClientKey::from((cks, 1));
    let sks = Arc::new(sks);
    executor.setup(&cks, sks);

    let mut key = vec![0u64; 80];
    let iv = vec![0u64; 80];
    key[7] = 1;

    let expected_output_0_63 = "38EB86FF730D7A9CAF8DF13A4420540DBB7B651464C87501552041C249F29A64D2FBF515610921EBE06C8F92CECF7F8098FF20CCCC6A62B97BE8EF7454FC80F9";

    let ct_key = encrypt_bits(&cks, &key);
    let ct_iv = encrypt_bits(&cks, &iv);

    let num_steps = 512;
    let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();

    let decrypted_bits = decrypt_bits(&cks, &output_radix);
    let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);

    assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}

pub fn trivium_test_vector_3_test<P, E>(param: P, mut executor: E)
where
    P: Into<TestParameters>,
    E: for<'a> FunctionExecutor<
        (&'a RadixCiphertext, &'a RadixCiphertext, usize),
        crate::Result<RadixCiphertext>,
    >,
{
    let param = param.into();
    let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
    let cks = RadixClientKey::from((cks, 1));
    let sks = Arc::new(sks);
    executor.setup(&cks, sks);

    let key = vec![0u64; 80];
    let mut iv = vec![0u64; 80];
    iv[7] = 1;

    let expected_output_0_63 = "F8901736640549E3BA7D42EA2D07B9F49233C18D773008BD755585B1A8CBAB86C1E9A9B91F1AD33483FD6EE3696D659C9374260456A36AAE11F033A519CBD5D7";

    let ct_key = encrypt_bits(&cks, &key);
    let ct_iv = encrypt_bits(&cks, &iv);

    let num_steps = 512;
    let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();

    let decrypted_bits = decrypt_bits(&cks, &output_radix);
    let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);

    assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}

pub fn trivium_test_vector_4_test<P, E>(param: P, mut executor: E)
where
    P: Into<TestParameters>,
    E: for<'a> FunctionExecutor<
        (&'a RadixCiphertext, &'a RadixCiphertext, usize),
        crate::Result<RadixCiphertext>,
    >,
{
    let param = param.into();
    let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
    let cks = RadixClientKey::from((cks, 1));
    let sks = Arc::new(sks);
    executor.setup(&cks, sks);

    let key_string = "0053A6F94C9FF24598EB";
    let mut key = vec![0u64; 80];
    for i in (0..key_string.len()).step_by(2) {
        let mut val = u8::from_str_radix(&key_string[i..i + 2], 16).unwrap();
        for j in 0..8 {
            key[8 * (i >> 1) + j] = (val % 2) as u64;
            val >>= 1;
        }
    }

    let iv_string = "0D74DB42A91077DE45AC";
    let mut iv = vec![0u64; 80];
    for i in (0..iv_string.len()).step_by(2) {
        let mut val = u8::from_str_radix(&iv_string[i..i + 2], 16).unwrap();
        for j in 0..8 {
            iv[8 * (i >> 1) + j] = (val % 2) as u64;
            val >>= 1;
        }
    }

    let expected_output_0_63 = "F4CD954A717F26A7D6930830C4E7CF0819F80E03F25F342C64ADC66ABA7F8A8E6EAA49F23632AE3CD41A7BD290A0132F81C6D4043B6E397D7388F3A03B5FE358";

    let ct_key = encrypt_bits(&cks, &key);
    let ct_iv = encrypt_bits(&cks, &iv);

    let num_steps = 512;
    let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();

    let decrypted_bits = decrypt_bits(&cks, &output_radix);
    let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);

    assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}

pub fn trivium_comparison_test<P, E>(param: P, mut executor: E)
where
    P: Into<TestParameters>,
    E: for<'a> FunctionExecutor<
        (&'a RadixCiphertext, &'a RadixCiphertext, usize),
        crate::Result<RadixCiphertext>,
    >,
{
    let param = param.into();
    let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
    let cks = RadixClientKey::from((cks, 1));
    let sks = Arc::new(sks);
    executor.setup(&cks, sks);

    let num_runs = 5;
    let num_steps = 512;

    for i in 0..num_runs {
        let mut rng = rand::thread_rng();
        let plain_key: Vec<u8> = (0..80).map(|_| rng.gen_range(0..=1)).collect();
        let plain_iv: Vec<u8> = (0..80).map(|_| rng.gen_range(0..=1)).collect();

        let key_bits_u64: Vec<u64> = plain_key.iter().map(|&x| x as u64).collect();
        let iv_bits_u64: Vec<u64> = plain_iv.iter().map(|&x| x as u64).collect();

        let ct_key = encrypt_bits(&cks, &key_bits_u64);
        let ct_iv = encrypt_bits(&cks, &iv_bits_u64);

        let mut cpu_trivium = TriviumRef::new(&plain_key, &plain_iv);
        let mut cpu_output = Vec::with_capacity(num_steps);
        for _ in 0..num_steps {
            cpu_output.push(cpu_trivium.next());
        }

        let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
        let fhe_output = decrypt_bits(&cks, &output_radix);

        assert_eq!(cpu_output.len(), fhe_output.len());
        assert_eq!(cpu_output, fhe_output, "Mismatch at iteration {i}");
    }
}