Skip to main content

poulpy_hal/test_suite/
mod.rs

1//! Backend-parametric test functions.
2//!
3//! Provides fully generic test functions that can be instantiated for any
4//! backend via the [`backend_test_suite!`](crate::backend_test_suite) and
5//! [`cross_backend_test_suite!`](crate::cross_backend_test_suite) macros.
6//! Tests validate correctness against the reference implementation in
7//! [`poulpy-cpu-ref`](https://docs.rs/poulpy-cpu-ref).
8
9use crate::layouts::{
10    Backend, DataView, HostBytesBackend, HostDataRef, MatZnx, ScalarZnx, ScalarZnxBackendRef, ScalarZnxToBackendRef, VecZnx,
11    VecZnxBackendMut, VecZnxBackendRef, VecZnxToBackendMut, VecZnxToBackendRef,
12};
13
14pub mod convolution;
15pub mod serialization;
16pub mod svp;
17pub mod vec_znx;
18pub mod vec_znx_big;
19pub mod vec_znx_dft;
20pub mod vmp;
21
22/// Parameters passed to every test function in a
23/// [`backend_test_suite!`](crate::backend_test_suite) or
24/// [`cross_backend_test_suite!`](crate::cross_backend_test_suite).
25///
26/// Centralising these values at the macro call-site makes it possible to
27/// instantiate the same test suite with backend-appropriate parameters
28/// (e.g. different `base2k` for FFT64 vs NTT120).
29#[derive(Clone, Copy, Debug)]
30pub struct TestParams {
31    /// Ring degree N (polynomial degree).
32    pub size: usize,
33    /// Primary decomposition base (limbs are base-2^`base2k`).
34    ///
35    /// Secondary base values used inside individual tests are derived from
36    /// this value via fixed offsets that preserve the original relative
37    /// relationships between bases.
38    pub base2k: usize,
39}
40
41/// Backend bound used by the generic test suites.
42///
43/// Tests upload only coefficient-domain host layouts (`ScalarZnx`, `VecZnx`,
44/// `MatZnx`) and keep all intermediate layouts backend-local.
45pub trait TestBackend: Backend {}
46
47impl<BE: Backend> TestBackend for BE {}
48
49pub fn vec_znx_backend_ref<'a, BE: Backend>(vec: &'a VecZnx<BE::OwnedBuf>) -> VecZnxBackendRef<'a, BE> {
50    <VecZnx<BE::OwnedBuf> as VecZnxToBackendRef<BE>>::to_backend_ref(vec)
51}
52
53pub fn vec_znx_backend_mut<'a, BE: Backend>(vec: &'a mut VecZnx<BE::OwnedBuf>) -> VecZnxBackendMut<'a, BE> {
54    <VecZnx<BE::OwnedBuf> as VecZnxToBackendMut<BE>>::to_backend_mut(vec)
55}
56
57pub fn scalar_znx_backend_ref<'a, BE: Backend>(scalar: &'a ScalarZnx<BE::OwnedBuf>) -> ScalarZnxBackendRef<'a, BE> {
58    <ScalarZnx<BE::OwnedBuf> as ScalarZnxToBackendRef<BE>>::to_backend_ref(scalar)
59}
60
61pub fn upload_scalar_znx<BE: Backend>(host: &ScalarZnx<impl HostDataRef>) -> ScalarZnx<BE::OwnedBuf> {
62    let shape = host.shape();
63    ScalarZnx::from_data(BE::from_host_bytes(host.data.as_ref()), shape.n(), shape.cols())
64}
65
66pub fn download_scalar_znx<BE: Backend>(backend: &ScalarZnx<BE::OwnedBuf>) -> ScalarZnx<Vec<u8>> {
67    let shape = backend.shape();
68    let host_bytes = BE::to_host_bytes(&backend.data);
69    ScalarZnx::from_data(HostBytesBackend::from_host_bytes(&host_bytes), shape.n(), shape.cols())
70}
71
72pub fn upload_vec_znx<BE: Backend>(host: &VecZnx<impl HostDataRef>) -> VecZnx<BE::OwnedBuf> {
73    let shape = host.shape();
74    VecZnx::from_data_with_max_size(
75        BE::from_host_bytes(host.data.as_ref()),
76        shape.n(),
77        shape.cols(),
78        shape.size(),
79        shape.max_size(),
80    )
81}
82
83pub fn download_vec_znx<BE: Backend>(backend: &VecZnx<BE::OwnedBuf>) -> VecZnx<Vec<u8>> {
84    let shape = backend.shape();
85    let host_bytes = BE::to_host_bytes(&backend.data);
86    VecZnx::from_data_with_max_size(
87        HostBytesBackend::from_host_bytes(&host_bytes),
88        shape.n(),
89        shape.cols(),
90        shape.size(),
91        shape.max_size(),
92    )
93}
94
95pub fn upload_mat_znx<BE: Backend>(host: &MatZnx<impl HostDataRef>) -> MatZnx<BE::OwnedBuf> {
96    let shape = host.shape();
97    MatZnx::from_data(
98        BE::from_host_bytes(host.data().as_ref()),
99        shape.n(),
100        shape.rows(),
101        shape.cols_in(),
102        shape.cols_out(),
103        shape.size(),
104    )
105}
106
107pub fn download_mat_znx<BE: Backend>(backend: &MatZnx<BE::OwnedBuf>) -> MatZnx<Vec<u8>> {
108    let shape = backend.shape();
109    let host_bytes = BE::to_host_bytes(backend.data());
110    MatZnx::from_data(
111        HostBytesBackend::from_host_bytes(&host_bytes),
112        shape.n(),
113        shape.rows(),
114        shape.cols_in(),
115        shape.cols_out(),
116        shape.size(),
117    )
118}
119
120#[macro_export]
121macro_rules! backend_test_suite {
122    (
123        mod $modname:ident,
124        backend = $backend:ty,
125        params = $params:expr,
126        tests = {
127            $( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
128        }
129    ) => {
130        mod $modname {
131            use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::TestParams};
132
133            use once_cell::sync::Lazy;
134
135            static PARAMS: Lazy<TestParams> = Lazy::new(|| $params);
136            static MODULE: Lazy<Module<$backend>> =
137                Lazy::new(|| Module::<$backend>::new(PARAMS.size as u64));
138
139            $(
140                $(#[$attr])*
141                #[test]
142                fn $test_name() {
143                    ($impl)(&*PARAMS, &*MODULE);
144                }
145            )+
146        }
147    };
148}
149
150#[macro_export]
151macro_rules! cross_backend_test_suite {
152    (
153        mod $modname:ident,
154        backend_ref = $backend_ref:ty,
155        backend_test = $backend_test:ty,
156        params = $params:expr,
157        tests = {
158            $( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
159        }
160    ) => {
161        mod $modname {
162            use poulpy_hal::{api::ModuleNew, layouts::{HostBytesBackend, Module}, test_suite::TestParams};
163
164            use once_cell::sync::Lazy;
165
166            static PARAMS: Lazy<TestParams> = Lazy::new(|| $params);
167            static MODULE_HOST: Lazy<Module<HostBytesBackend>> =
168                Lazy::new(|| Module::<HostBytesBackend>::new(PARAMS.size as u64));
169            static MODULE_REF: Lazy<Module<$backend_ref>> =
170                Lazy::new(|| Module::<$backend_ref>::new(PARAMS.size as u64));
171            static MODULE_TEST: Lazy<Module<$backend_test>> =
172                Lazy::new(|| Module::<$backend_test>::new(PARAMS.size as u64));
173
174            $(
175                $(#[$attr])*
176                #[test]
177                fn $test_name() {
178                    ($impl)(&*PARAMS, &*MODULE_HOST, &*MODULE_REF, &*MODULE_TEST);
179                }
180            )+
181        }
182    };
183}