poulpy-core 0.6.0

A backend-agnostic crate implementing Module-LWE-based encryption and arithmetic
Documentation
use std::{marker::PhantomData, ptr::NonNull};

use poulpy_hal::{
    api::ScratchArenaTakeBasic,
    layouts::{Backend, DataViewMut, Host, Module, ScratchOwned, VecZnx},
    oep::HalModuleImpl,
};

use crate::{
    layouts::{Base2K, Degree, Dnum, Dsize, GGSWLayout, GLWELayout, Rank, TorusPrecision},
    scratch::ScratchArenaTakeCore,
};

#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
struct TestBackend;

impl Backend for TestBackend {
    type ScalarBig = i64;
    type ScalarPrep = f64;
    type OwnedBuf = Vec<u8>;
    type BufRef<'a> = &'a [u8];
    type BufMut<'a> = &'a mut [u8];
    type Handle = ();
    type Location = Host;

    fn alloc_bytes(len: usize) -> Self::OwnedBuf {
        vec![0; len]
    }

    fn from_host_bytes(bytes: &[u8]) -> Self::OwnedBuf {
        bytes.to_vec()
    }

    fn from_bytes(bytes: Vec<u8>) -> Self::OwnedBuf {
        bytes
    }

    fn to_host_bytes(buf: &Self::OwnedBuf) -> Vec<u8> {
        buf.clone()
    }

    fn copy_to_host(buf: &Self::OwnedBuf, dst: &mut [u8]) {
        dst.copy_from_slice(buf);
    }

    fn copy_from_host(buf: &mut Self::OwnedBuf, src: &[u8]) {
        buf.copy_from_slice(src);
    }

    fn len_bytes(buf: &Self::OwnedBuf) -> usize {
        buf.len()
    }

    fn view(buf: &Self::OwnedBuf) -> Self::BufRef<'_> {
        buf.as_slice()
    }

    fn view_ref<'a, 'b>(buf: &'a Self::BufRef<'b>) -> Self::BufRef<'a>
    where
        Self: 'b,
    {
        buf
    }

    fn view_ref_mut<'a, 'b>(buf: &'a Self::BufMut<'b>) -> Self::BufRef<'a>
    where
        Self: 'b,
    {
        buf
    }

    fn view_mut_ref<'a, 'b>(buf: &'a mut Self::BufMut<'b>) -> Self::BufMut<'a>
    where
        Self: 'b,
    {
        buf
    }

    fn view_mut(buf: &mut Self::OwnedBuf) -> Self::BufMut<'_> {
        buf.as_mut_slice()
    }

    fn region(buf: &Self::OwnedBuf, offset: usize, len: usize) -> Self::BufRef<'_> {
        &buf[offset..offset + len]
    }

    fn region_mut(buf: &mut Self::OwnedBuf, offset: usize, len: usize) -> Self::BufMut<'_> {
        &mut buf[offset..offset + len]
    }

    fn region_ref<'a, 'b>(buf: &'a Self::BufRef<'b>, offset: usize, len: usize) -> Self::BufRef<'a>
    where
        Self: 'b,
    {
        &buf[offset..offset + len]
    }

    fn region_ref_mut<'a, 'b>(buf: &'a Self::BufMut<'b>, offset: usize, len: usize) -> Self::BufRef<'a>
    where
        Self: 'b,
    {
        &buf[offset..offset + len]
    }

    fn region_mut_ref<'a, 'b>(buf: &'a mut Self::BufMut<'b>, offset: usize, len: usize) -> Self::BufMut<'a>
    where
        Self: 'b,
    {
        &mut buf[offset..offset + len]
    }

    unsafe fn destroy(_: NonNull<Self::Handle>) {}
}

unsafe impl HalModuleImpl<TestBackend> for TestBackend {
    fn new(n: u64) -> Module<TestBackend> {
        assert!(n.is_power_of_two(), "n must be a power of two, got {n}");
        unsafe { Module::from_nonnull(NonNull::dangling(), n) }
    }
}

#[test]
fn scratch_arena_take_core_returns_disjoint_backend_regions() {
    let module: Module<TestBackend> = Module::new(64);
    let glwe_infos = GLWELayout {
        n: Degree(64),
        base2k: Base2K(8),
        k: TorusPrecision(24),
        rank: Rank(2),
    };
    let ggsw_infos = GGSWLayout {
        n: Degree(64),
        base2k: Base2K(8),
        k: TorusPrecision(24),
        rank: Rank(1),
        dnum: Dnum(2),
        dsize: Dsize(1),
    };

    let mut scratch: ScratchOwned<TestBackend> = ScratchOwned {
        data: TestBackend::alloc_bytes(1 << 15),
        _phantom: PhantomData,
    };
    let arena = scratch.arena();
    let available_before = arena.available();

    let (mut glwe, arena) = arena.take_glwe_scratch(&glwe_infos);
    let glwe_bytes = glwe.data_mut().data_mut();
    glwe_bytes[0] = 0x5a;
    let glwe_ptr = glwe_bytes.as_mut_ptr() as usize;
    let glwe_len = glwe_bytes.len();

    let (mut prepared, arena) = arena.take_ggsw_prepared_scratch(&module, &ggsw_infos);
    let prepared_bytes = prepared.data.data_mut();
    prepared_bytes[0] = 0xa5;
    let prepared_ptr = prepared_bytes.as_mut_ptr() as usize;
    let prepared_len = prepared_bytes.len();

    assert!(arena.available() < available_before);
    assert!(glwe_ptr + glwe_len <= prepared_ptr || prepared_ptr + prepared_len <= glwe_ptr);
}

#[test]
fn scratch_arena_split_yields_independent_chunks() {
    let mut scratch: ScratchOwned<TestBackend> = ScratchOwned {
        data: TestBackend::alloc_bytes(1 << 12),
        _phantom: PhantomData,
    };
    let arena = scratch.arena();
    let chunk_len = VecZnx::bytes_of(64, 2, 2);
    let (chunks, rem) = arena.split(2, chunk_len);

    assert!(rem.available() < (1 << 12));
    assert_eq!(chunks.len(), 2);

    let mut it = chunks.into_iter();
    let (mut lhs, _) = it.next().expect("missing first scratch chunk").take_vec_znx_scratch(64, 2, 2);
    let (mut rhs, _) = it
        .next()
        .expect("missing second scratch chunk")
        .take_vec_znx_scratch(64, 2, 2);

    let lhs_bytes = lhs.data_mut();
    let rhs_bytes = rhs.data_mut();
    lhs_bytes[0] = 1;
    rhs_bytes[0] = 2;

    let lhs_ptr = lhs_bytes.as_mut_ptr() as usize;
    let rhs_ptr = rhs_bytes.as_mut_ptr() as usize;
    assert!(lhs_ptr + lhs_bytes.len() <= rhs_ptr || rhs_ptr + rhs_bytes.len() <= lhs_ptr);
}