poulpy_hal/api/
vmp_pmat.rs

1use crate::layouts::{
2    Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
3};
4
5pub trait VmpPMatAlloc<B: Backend> {
6    fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
7}
8
9pub trait VmpPMatBytesOf {
10    fn bytes_of_vmp_pmat(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
11}
12
13pub trait VmpPMatFromBytes<B: Backend> {
14    fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B>;
15}
16
17pub trait VmpPrepareTmpBytes {
18    fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
19}
20
21pub trait VmpPrepare<B: Backend> {
22    fn vmp_prepare<R, A>(&self, pmat: &mut R, mat: &A, scratch: &mut Scratch<B>)
23    where
24        R: VmpPMatToMut<B>,
25        A: MatZnxToRef;
26}
27
28#[allow(clippy::too_many_arguments)]
29pub trait VmpApplyDftTmpBytes {
30    fn vmp_apply_dft_tmp_bytes(
31        &self,
32        res_size: usize,
33        a_size: usize,
34        b_rows: usize,
35        b_cols_in: usize,
36        b_cols_out: usize,
37        b_size: usize,
38    ) -> usize;
39}
40
41pub trait VmpApplyDft<B: Backend> {
42    fn vmp_apply_dft<R, A, C>(&self, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch<B>)
43    where
44        R: VecZnxDftToMut<B>,
45        A: VecZnxToRef,
46        C: VmpPMatToRef<B>;
47}
48
49#[allow(clippy::too_many_arguments)]
50pub trait VmpApplyDftToDftTmpBytes {
51    fn vmp_apply_dft_to_dft_tmp_bytes(
52        &self,
53        res_size: usize,
54        a_size: usize,
55        b_rows: usize,
56        b_cols_in: usize,
57        b_cols_out: usize,
58        b_size: usize,
59    ) -> usize;
60}
61
62pub trait VmpApplyDftToDft<B: Backend> {
63    /// Applies the vector matrix product [crate::layouts::VecZnxDft] x [crate::layouts::VmpPMat].
64    ///
65    /// A vector matrix product numerically equivalent to a sum of [crate::api::SvpApply],
66    /// where each [crate::layouts::SvpPPol] is a limb of the input [crate::layouts::VecZnx] in DFT,
67    /// and each vector a [crate::layouts::VecZnxDft] (row) of the [crate::layouts::VmpPMat].
68    ///
69    /// As such, given an input [crate::layouts::VecZnx] of `i` size and a [crate::layouts::VmpPMat] of `i` rows and
70    /// `j` size, the output is a [crate::layouts::VecZnx] of `j` size.
71    ///
72    /// If there is a mismatch between the dimensions the largest valid ones are used.
73    ///
74    /// ```text
75    /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
76    ///             |h i j|
77    ///             |k l m|
78    /// ```
79    /// where each element is a [crate::layouts::VecZnxDft].
80    ///
81    /// # Arguments
82    ///
83    /// * `c`: the output of the vector matrix product, as a [crate::layouts::VecZnxDft].
84    /// * `a`: the left operand [crate::layouts::VecZnxDft] of the vector matrix product.
85    /// * `b`: the right operand [crate::layouts::VmpPMat] of the vector matrix product.
86    /// * `buf`: scratch space, the size can be obtained with [VmpApplyTmpBytes::vmp_apply_tmp_bytes].
87    fn vmp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch<B>)
88    where
89        R: VecZnxDftToMut<B>,
90        A: VecZnxDftToRef<B>,
91        C: VmpPMatToRef<B>;
92}
93
94#[allow(clippy::too_many_arguments)]
95pub trait VmpApplyDftToDftAddTmpBytes {
96    fn vmp_apply_dft_to_dft_add_tmp_bytes(
97        &self,
98        res_size: usize,
99        a_size: usize,
100        b_rows: usize,
101        b_cols_in: usize,
102        b_cols_out: usize,
103        b_size: usize,
104    ) -> usize;
105}
106
107pub trait VmpApplyDftToDftAdd<B: Backend> {
108    fn vmp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, limb_offset: usize, scratch: &mut Scratch<B>)
109    where
110        R: VecZnxDftToMut<B>,
111        A: VecZnxDftToRef<B>,
112        C: VmpPMatToRef<B>;
113}
114
115pub trait VmpZero<B: Backend> {
116    fn vmp_zero<R>(&self, res: &mut R)
117    where
118        R: VmpPMatToMut<B>;
119}