ntrulp 0.2.4

Pure implementation of high-security prime-degree large-Galois-group inert-modulus ideal-lattice-based cryptography.
Documentation
use rand::Rng;

use crate::{
    encode::shuffle::{shuffle_array, unshuffle_array},
    params::params::{DIFFICULT, P, R3_BYTES, W},
    rng::random_sign,
};

use super::error::CompressError;

pub const BITS_SIZE: usize = 6;
const SYS_SIZE: usize = std::mem::size_of::<usize>();

fn byte_to_usize_vec(list: &[u8]) -> Vec<usize> {
    let num_elements = list.len() / SYS_SIZE;
    let mut vec = Vec::with_capacity(num_elements);
    for chunk in list.chunks_exact(SYS_SIZE) {
        let mut bytes = [0; SYS_SIZE];
        bytes.copy_from_slice(chunk);
        vec.push(usize::from_ne_bytes(bytes));
    }
    vec
}

pub fn pack_bytes(mut bytes: Vec<u8>, size: Vec<usize>, seed: u64) -> Vec<u8> {
    let size_bytes_len = size.len() * SYS_SIZE;
    let additional_size = size_bytes_len + SYS_SIZE + 8;
    bytes.reserve(additional_size);

    for &s in &size {
        bytes.extend_from_slice(&s.to_ne_bytes());
    }
    let size_len_bytes = size_bytes_len.to_ne_bytes();
    bytes.extend_from_slice(&size_len_bytes);
    bytes.extend_from_slice(&seed.to_ne_bytes());

    bytes
}

pub fn unpack_bytes(bytes: &[u8]) -> Result<(Vec<u8>, Vec<usize>, u64), CompressError> {
    const X2_SYS_SIZE: usize = SYS_SIZE * 2;

    let bytes_len = bytes.len();
    let seed_bytes: [u8; 8] = bytes[bytes_len - 8..]
        .try_into()
        .or(Err(CompressError::SeedSliceError))?;
    let size_bytes_len: [u8; SYS_SIZE] = bytes[bytes_len - X2_SYS_SIZE..bytes_len - SYS_SIZE]
        .try_into()
        .or(Err(CompressError::SizeSliceError))?;
    let size_len = usize::from_ne_bytes(size_bytes_len);
    let seed = u64::from_ne_bytes(seed_bytes);

    if bytes_len < size_len || (bytes_len / size_len) < R3_BYTES {
        return Err(CompressError::ByteslengthError);
    }

    let size_bytes = &bytes[bytes_len - size_len - X2_SYS_SIZE..(bytes_len - X2_SYS_SIZE)];
    let size = byte_to_usize_vec(size_bytes);

    let bytes_data = &bytes[..bytes_len - size_len - X2_SYS_SIZE];

    Ok((bytes_data.to_vec(), size, seed))
}

pub fn convert_to_ternary(num: u8) -> [i8; BITS_SIZE] {
    let mut result = [0i8; BITS_SIZE];
    let mut n = num;

    for i in (0..BITS_SIZE).rev() {
        let digit = n % 3;
        result[i] = match digit {
            0 => 0,
            1 => 1,
            2 => -1,
            _ => unreachable!(),
        };
        n /= 3;
    }

    result
}

pub fn convert_to_decimal(ternary: [i8; BITS_SIZE]) -> u8 {
    let mut result = 0i16;

    for &digit in &ternary {
        let x = match digit {
            0 => 0,
            1 => 1,
            -1 => 2,
            _ => unreachable!(),
        };

        result = result * 3 + x as i16;
    }

    result as u8
}

pub fn r3_encode_chunks(r3: &[i8]) -> Vec<u8> {
    let num_chunks = (r3.len() + BITS_SIZE - 1) / BITS_SIZE;
    let mut output = Vec::with_capacity(num_chunks);

    for start in (0..r3.len()).step_by(BITS_SIZE) {
        let end = (start + BITS_SIZE).min(r3.len());
        let mut bits = [0i8; BITS_SIZE];
        for (i, &val) in r3[start..end].iter().enumerate() {
            bits[i] = val;
        }
        let byte = convert_to_decimal(bits);
        output.push(byte);
    }

    output
}

pub fn r3_decode_chunks(bytes: &[u8]) -> Vec<i8> {
    let mut output = Vec::with_capacity(bytes.len() * BITS_SIZE);

    for &byte in bytes {
        let bits = convert_to_ternary(byte);
        output.extend_from_slice(&bits);
    }

    output
}

pub fn r3_merge_w_chunks(chunks: &[[i8; P]], size: &[usize], seed: u64) -> Vec<i8> {
    let total_size: usize = size.iter().sum();
    let mut out = Vec::with_capacity(total_size);

    for (index, chunk) in chunks.iter().enumerate() {
        let seed = seed + index as u64;
        let point = size[index];
        let mut part: [i8; P] = *chunk;

        unshuffle_array::<i8>(&mut part, seed);
        out.extend_from_slice(&part[..point]);
    }

    out
}

pub fn r3_split_w_chunks<R: Rng>(input: &[i8], rng: &mut R) -> (Vec<[i8; P]>, Vec<usize>, u64) {
    const LIMIT: usize = W - DIFFICULT;

    let origin_seed: u64 = rng.next_u64() - (input.len() / P) as u64;
    let mut seed = origin_seed;
    let mut chunks: Vec<[i8; P]> = Vec::new();
    let mut size: Vec<usize> = Vec::new();
    let mut part = [0i8; P];

    let mut sum: usize = 0;
    let mut input_ptr: usize = 0;
    let mut part_ptr: usize = 0;

    while input_ptr != input.len() {
        while sum != LIMIT {
            let value = match input.get(input_ptr) {
                Some(v) => *v,
                None => break,
            };

            sum += value.unsigned_abs() as usize;
            input_ptr += 1;
            part[part_ptr] = value;
            part_ptr += 1;
        }

        size.push(part_ptr);

        while sum != W {
            let value = random_sign(rng);

            part[part_ptr] = value;
            sum += 1;
            part_ptr += 1;
        }

        shuffle_array(&mut part, seed);
        chunks.push(part);

        part = [0i8; P];
        seed += 1;
        part_ptr = 0;
        sum = 0;
    }

    (chunks, size, origin_seed)
}

#[cfg(test)]
mod r3_compressro_test {
    use rand::{Rng, RngExt, SeedableRng};
    use rand_chacha::ChaCha20Rng;

    use crate::params::params1277::RQ_BYTES;

    use super::*;

    fn usize_vec_to_bytes(list: &[usize]) -> Vec<u8> {
        let mut bytes = Vec::with_capacity(list.len() * SYS_SIZE);
        for &x in list {
            bytes.extend_from_slice(&x.to_ne_bytes());
        }
        bytes
    }

    #[test]
    fn pack_unpack_bytes() {
        let mut rng = ChaCha20Rng::from_rng(&mut rand::rng());
        let bytes: Vec<u8> = (0..1000).map(|_| rng.random::<u8>()).collect();
        let unlimted_poly = r3_decode_chunks(&bytes);
        let (chunks, size, seed) = r3_split_w_chunks(&unlimted_poly, &mut rng);
        let mut bytes: Vec<u8> = Vec::with_capacity(P * size.len());

        for _ in chunks {
            let mut rq_bytes: [u8; RQ_BYTES] = [0u8; RQ_BYTES];
            rng.fill(&mut rq_bytes);
            bytes.extend(rq_bytes);
        }

        let packed = pack_bytes(bytes.clone(), size.clone(), seed);
        let unpack_bytes = unpack_bytes(&packed).unwrap();

        assert_eq!(unpack_bytes.0, bytes);
        assert_eq!(unpack_bytes.1, size);
        assert_eq!(unpack_bytes.2, seed);
    }

    #[test]
    fn test_u64_convert() {
        let mut rng = ChaCha20Rng::from_rng(&mut rand::rng());
        let usize_list: Vec<usize> = (0..1024).map(|_| rng.random::<usize>()).collect();
        let bytes = usize_vec_to_bytes(&usize_list);
        let out = byte_to_usize_vec(&bytes);

        assert_eq!(out, usize_list);
    }

    #[test]
    fn test_bit_convert() {
        for n in 0..u8::MAX {
            let bits = convert_to_ternary(n);
            let out = convert_to_decimal(bits);
            let bits0 = convert_to_ternary(out);

            assert_eq!(n, out);
            assert_eq!(bits0, bits);
        }
    }

    #[test]
    fn test_r3_encode_decode_chunks() {
        let mut rng = ChaCha20Rng::from_rng(&mut rand::rng());

        for _ in 0..10 {
            let bytes: Vec<u8> = (0..1000).map(|_| rng.random::<u8>()).collect();

            let r3 = r3_decode_chunks(&bytes);
            let out = r3_encode_chunks(&r3);

            assert_eq!(out, bytes);
        }
    }

    #[test]
    fn test_encode_decode_bytes_by_chunks_spliter_merge() {
        let mut rng = rand::rng();

        for _ in 0..100 {
            let rand_len = rng.random_range(5..1000);
            let bytes: Vec<u8> = (0..rand_len).map(|_| rng.random::<u8>()).collect();
            let r3 = r3_decode_chunks(&bytes);
            let (chunks, size, seed) = r3_split_w_chunks(&r3, &mut rng);
            let merged = r3_merge_w_chunks(&chunks, &size, seed);

            let mut r3_sum = 0usize;
            for el in &r3 {
                r3_sum += el.unsigned_abs() as usize;
            }

            let mut m_sum = 0usize;

            for el in &merged {
                m_sum += el.unsigned_abs() as usize;
            }

            assert_eq!(r3_sum, m_sum);
            assert_eq!(size.len(), chunks.len());
            assert_eq!(merged.len(), r3.len());
            assert_eq!(merged, r3);
        }
    }

    #[test]
    fn test_spliter() {
        let mut rng = rand::rng();

        for _ in 0..10 {
            let rand_len = rng.random_range(5..1000);
            let bytes: Vec<u8> = (0..rand_len).map(|_| rng.random::<u8>()).collect();
            let r3 = r3_decode_chunks(&bytes);
            let (chunks, size, _) = r3_split_w_chunks(&r3, &mut rng);

            for (chunk, index) in chunks.iter().zip(size) {
                let sum = chunk.iter().map(|&x| x.abs() as i32).sum::<i32>();

                assert_eq!(sum as usize, W);
                assert_eq!(chunk.len(), P);
                assert!(index <= P);
            }
        }
    }
}