use crate::layouts::{
Backend, DataView, HostBytesBackend, HostDataRef, MatZnx, ScalarZnx, ScalarZnxBackendRef, ScalarZnxToBackendRef, VecZnx,
VecZnxBackendMut, VecZnxBackendRef, VecZnxToBackendMut, VecZnxToBackendRef,
};
pub mod convolution;
pub mod serialization;
pub mod svp;
pub mod vec_znx;
pub mod vec_znx_big;
pub mod vec_znx_dft;
pub mod vmp;
#[derive(Clone, Copy, Debug)]
pub struct TestParams {
pub size: usize,
pub base2k: usize,
}
pub trait TestBackend: Backend {}
impl<BE: Backend> TestBackend for BE {}
pub fn vec_znx_backend_ref<'a, BE: Backend>(vec: &'a VecZnx<BE::OwnedBuf>) -> VecZnxBackendRef<'a, BE> {
<VecZnx<BE::OwnedBuf> as VecZnxToBackendRef<BE>>::to_backend_ref(vec)
}
pub fn vec_znx_backend_mut<'a, BE: Backend>(vec: &'a mut VecZnx<BE::OwnedBuf>) -> VecZnxBackendMut<'a, BE> {
<VecZnx<BE::OwnedBuf> as VecZnxToBackendMut<BE>>::to_backend_mut(vec)
}
pub fn scalar_znx_backend_ref<'a, BE: Backend>(scalar: &'a ScalarZnx<BE::OwnedBuf>) -> ScalarZnxBackendRef<'a, BE> {
<ScalarZnx<BE::OwnedBuf> as ScalarZnxToBackendRef<BE>>::to_backend_ref(scalar)
}
pub fn upload_scalar_znx<BE: Backend>(host: &ScalarZnx<impl HostDataRef>) -> ScalarZnx<BE::OwnedBuf> {
let shape = host.shape();
ScalarZnx::from_data(BE::from_host_bytes(host.data.as_ref()), shape.n(), shape.cols())
}
pub fn download_scalar_znx<BE: Backend>(backend: &ScalarZnx<BE::OwnedBuf>) -> ScalarZnx<Vec<u8>> {
let shape = backend.shape();
let host_bytes = BE::to_host_bytes(&backend.data);
ScalarZnx::from_data(HostBytesBackend::from_host_bytes(&host_bytes), shape.n(), shape.cols())
}
pub fn upload_vec_znx<BE: Backend>(host: &VecZnx<impl HostDataRef>) -> VecZnx<BE::OwnedBuf> {
let shape = host.shape();
VecZnx::from_data_with_max_size(
BE::from_host_bytes(host.data.as_ref()),
shape.n(),
shape.cols(),
shape.size(),
shape.max_size(),
)
}
pub fn download_vec_znx<BE: Backend>(backend: &VecZnx<BE::OwnedBuf>) -> VecZnx<Vec<u8>> {
let shape = backend.shape();
let host_bytes = BE::to_host_bytes(&backend.data);
VecZnx::from_data_with_max_size(
HostBytesBackend::from_host_bytes(&host_bytes),
shape.n(),
shape.cols(),
shape.size(),
shape.max_size(),
)
}
pub fn upload_mat_znx<BE: Backend>(host: &MatZnx<impl HostDataRef>) -> MatZnx<BE::OwnedBuf> {
let shape = host.shape();
MatZnx::from_data(
BE::from_host_bytes(host.data().as_ref()),
shape.n(),
shape.rows(),
shape.cols_in(),
shape.cols_out(),
shape.size(),
)
}
pub fn download_mat_znx<BE: Backend>(backend: &MatZnx<BE::OwnedBuf>) -> MatZnx<Vec<u8>> {
let shape = backend.shape();
let host_bytes = BE::to_host_bytes(backend.data());
MatZnx::from_data(
HostBytesBackend::from_host_bytes(&host_bytes),
shape.n(),
shape.rows(),
shape.cols_in(),
shape.cols_out(),
shape.size(),
)
}
#[macro_export]
macro_rules! backend_test_suite {
(
mod $modname:ident,
backend = $backend:ty,
params = $params:expr,
tests = {
$( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
}
) => {
mod $modname {
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::TestParams};
use once_cell::sync::Lazy;
static PARAMS: Lazy<TestParams> = Lazy::new(|| $params);
static MODULE: Lazy<Module<$backend>> =
Lazy::new(|| Module::<$backend>::new(PARAMS.size as u64));
$(
$(#[$attr])*
#[test]
fn $test_name() {
($impl)(&*PARAMS, &*MODULE);
}
)+
}
};
}
#[macro_export]
macro_rules! cross_backend_test_suite {
(
mod $modname:ident,
backend_ref = $backend_ref:ty,
backend_test = $backend_test:ty,
params = $params:expr,
tests = {
$( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
}
) => {
mod $modname {
use poulpy_hal::{api::ModuleNew, layouts::{HostBytesBackend, Module}, test_suite::TestParams};
use once_cell::sync::Lazy;
static PARAMS: Lazy<TestParams> = Lazy::new(|| $params);
static MODULE_HOST: Lazy<Module<HostBytesBackend>> =
Lazy::new(|| Module::<HostBytesBackend>::new(PARAMS.size as u64));
static MODULE_REF: Lazy<Module<$backend_ref>> =
Lazy::new(|| Module::<$backend_ref>::new(PARAMS.size as u64));
static MODULE_TEST: Lazy<Module<$backend_test>> =
Lazy::new(|| Module::<$backend_test>::new(PARAMS.size as u64));
$(
$(#[$attr])*
#[test]
fn $test_name() {
($impl)(&*PARAMS, &*MODULE_HOST, &*MODULE_REF, &*MODULE_TEST);
}
)+
}
};
}