poulpy_hal/test_suite/
mod.rs1use crate::layouts::{
10 Backend, DataView, HostBytesBackend, HostDataRef, MatZnx, ScalarZnx, ScalarZnxBackendRef, ScalarZnxToBackendRef, VecZnx,
11 VecZnxBackendMut, VecZnxBackendRef, VecZnxToBackendMut, VecZnxToBackendRef,
12};
13
14pub mod convolution;
15pub mod serialization;
16pub mod svp;
17pub mod vec_znx;
18pub mod vec_znx_big;
19pub mod vec_znx_dft;
20pub mod vmp;
21
22#[derive(Clone, Copy, Debug)]
30pub struct TestParams {
31 pub size: usize,
33 pub base2k: usize,
39}
40
41pub trait TestBackend: Backend {}
46
47impl<BE: Backend> TestBackend for BE {}
48
49pub fn vec_znx_backend_ref<'a, BE: Backend>(vec: &'a VecZnx<BE::OwnedBuf>) -> VecZnxBackendRef<'a, BE> {
50 <VecZnx<BE::OwnedBuf> as VecZnxToBackendRef<BE>>::to_backend_ref(vec)
51}
52
53pub fn vec_znx_backend_mut<'a, BE: Backend>(vec: &'a mut VecZnx<BE::OwnedBuf>) -> VecZnxBackendMut<'a, BE> {
54 <VecZnx<BE::OwnedBuf> as VecZnxToBackendMut<BE>>::to_backend_mut(vec)
55}
56
57pub fn scalar_znx_backend_ref<'a, BE: Backend>(scalar: &'a ScalarZnx<BE::OwnedBuf>) -> ScalarZnxBackendRef<'a, BE> {
58 <ScalarZnx<BE::OwnedBuf> as ScalarZnxToBackendRef<BE>>::to_backend_ref(scalar)
59}
60
61pub fn upload_scalar_znx<BE: Backend>(host: &ScalarZnx<impl HostDataRef>) -> ScalarZnx<BE::OwnedBuf> {
62 let shape = host.shape();
63 ScalarZnx::from_data(BE::from_host_bytes(host.data.as_ref()), shape.n(), shape.cols())
64}
65
66pub fn download_scalar_znx<BE: Backend>(backend: &ScalarZnx<BE::OwnedBuf>) -> ScalarZnx<Vec<u8>> {
67 let shape = backend.shape();
68 let host_bytes = BE::to_host_bytes(&backend.data);
69 ScalarZnx::from_data(HostBytesBackend::from_host_bytes(&host_bytes), shape.n(), shape.cols())
70}
71
72pub fn upload_vec_znx<BE: Backend>(host: &VecZnx<impl HostDataRef>) -> VecZnx<BE::OwnedBuf> {
73 let shape = host.shape();
74 VecZnx::from_data_with_max_size(
75 BE::from_host_bytes(host.data.as_ref()),
76 shape.n(),
77 shape.cols(),
78 shape.size(),
79 shape.max_size(),
80 )
81}
82
83pub fn download_vec_znx<BE: Backend>(backend: &VecZnx<BE::OwnedBuf>) -> VecZnx<Vec<u8>> {
84 let shape = backend.shape();
85 let host_bytes = BE::to_host_bytes(&backend.data);
86 VecZnx::from_data_with_max_size(
87 HostBytesBackend::from_host_bytes(&host_bytes),
88 shape.n(),
89 shape.cols(),
90 shape.size(),
91 shape.max_size(),
92 )
93}
94
95pub fn upload_mat_znx<BE: Backend>(host: &MatZnx<impl HostDataRef>) -> MatZnx<BE::OwnedBuf> {
96 let shape = host.shape();
97 MatZnx::from_data(
98 BE::from_host_bytes(host.data().as_ref()),
99 shape.n(),
100 shape.rows(),
101 shape.cols_in(),
102 shape.cols_out(),
103 shape.size(),
104 )
105}
106
107pub fn download_mat_znx<BE: Backend>(backend: &MatZnx<BE::OwnedBuf>) -> MatZnx<Vec<u8>> {
108 let shape = backend.shape();
109 let host_bytes = BE::to_host_bytes(backend.data());
110 MatZnx::from_data(
111 HostBytesBackend::from_host_bytes(&host_bytes),
112 shape.n(),
113 shape.rows(),
114 shape.cols_in(),
115 shape.cols_out(),
116 shape.size(),
117 )
118}
119
120#[macro_export]
121macro_rules! backend_test_suite {
122 (
123 mod $modname:ident,
124 backend = $backend:ty,
125 params = $params:expr,
126 tests = {
127 $( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
128 }
129 ) => {
130 mod $modname {
131 use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::TestParams};
132
133 use once_cell::sync::Lazy;
134
135 static PARAMS: Lazy<TestParams> = Lazy::new(|| $params);
136 static MODULE: Lazy<Module<$backend>> =
137 Lazy::new(|| Module::<$backend>::new(PARAMS.size as u64));
138
139 $(
140 $(#[$attr])*
141 #[test]
142 fn $test_name() {
143 ($impl)(&*PARAMS, &*MODULE);
144 }
145 )+
146 }
147 };
148}
149
150#[macro_export]
151macro_rules! cross_backend_test_suite {
152 (
153 mod $modname:ident,
154 backend_ref = $backend_ref:ty,
155 backend_test = $backend_test:ty,
156 params = $params:expr,
157 tests = {
158 $( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
159 }
160 ) => {
161 mod $modname {
162 use poulpy_hal::{api::ModuleNew, layouts::{HostBytesBackend, Module}, test_suite::TestParams};
163
164 use once_cell::sync::Lazy;
165
166 static PARAMS: Lazy<TestParams> = Lazy::new(|| $params);
167 static MODULE_HOST: Lazy<Module<HostBytesBackend>> =
168 Lazy::new(|| Module::<HostBytesBackend>::new(PARAMS.size as u64));
169 static MODULE_REF: Lazy<Module<$backend_ref>> =
170 Lazy::new(|| Module::<$backend_ref>::new(PARAMS.size as u64));
171 static MODULE_TEST: Lazy<Module<$backend_test>> =
172 Lazy::new(|| Module::<$backend_test>::new(PARAMS.size as u64));
173
174 $(
175 $(#[$attr])*
176 #[test]
177 fn $test_name() {
178 ($impl)(&*PARAMS, &*MODULE_HOST, &*MODULE_REF, &*MODULE_TEST);
179 }
180 )+
181 }
182 };
183}