poulpy_hal/api/
vmp_pmat.rs

1use crate::layouts::{Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef};
2
3pub trait VmpPMatAlloc<B: Backend> {
4    fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
5}
6
7pub trait VmpPMatAllocBytes {
8    fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
9}
10
11pub trait VmpPMatFromBytes<B: Backend> {
12    fn vmp_pmat_from_bytes(
13        &self,
14        n: usize,
15        rows: usize,
16        cols_in: usize,
17        cols_out: usize,
18        size: usize,
19        bytes: Vec<u8>,
20    ) -> VmpPMatOwned<B>;
21}
22
23pub trait VmpPrepareTmpBytes {
24    fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
25}
26
27pub trait VmpPrepare<B: Backend> {
28    fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
29    where
30        R: VmpPMatToMut<B>,
31        A: MatZnxToRef;
32}
33
34#[allow(clippy::too_many_arguments)]
35pub trait VmpApplyTmpBytes {
36    fn vmp_apply_tmp_bytes(
37        &self,
38        n: usize,
39        res_size: usize,
40        a_size: usize,
41        b_rows: usize,
42        b_cols_in: usize,
43        b_cols_out: usize,
44        b_size: usize,
45    ) -> usize;
46}
47
48pub trait VmpApply<B: Backend> {
49    /// Applies the vector matrix product [crate::layouts::VecZnxDft] x [crate::layouts::VmpPMat].
50    ///
51    /// A vector matrix product numerically equivalent to a sum of [crate::api::SvpApply],
52    /// where each [crate::layouts::SvpPPol] is a limb of the input [crate::layouts::VecZnx] in DFT,
53    /// and each vector a [crate::layouts::VecZnxDft] (row) of the [crate::layouts::VmpPMat].
54    ///
55    /// As such, given an input [crate::layouts::VecZnx] of `i` size and a [crate::layouts::VmpPMat] of `i` rows and
56    /// `j` size, the output is a [crate::layouts::VecZnx] of `j` size.
57    ///
58    /// If there is a mismatch between the dimensions the largest valid ones are used.
59    ///
60    /// ```text
61    /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
62    ///             |h i j|
63    ///             |k l m|
64    /// ```
65    /// where each element is a [crate::layouts::VecZnxDft].
66    ///
67    /// # Arguments
68    ///
69    /// * `c`: the output of the vector matrix product, as a [crate::layouts::VecZnxDft].
70    /// * `a`: the left operand [crate::layouts::VecZnxDft] of the vector matrix product.
71    /// * `b`: the right operand [crate::layouts::VmpPMat] of the vector matrix product.
72    /// * `buf`: scratch space, the size can be obtained with [VmpApplyTmpBytes::vmp_apply_tmp_bytes].
73    fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
74    where
75        R: VecZnxDftToMut<B>,
76        A: VecZnxDftToRef<B>,
77        C: VmpPMatToRef<B>;
78}
79
80#[allow(clippy::too_many_arguments)]
81pub trait VmpApplyAddTmpBytes {
82    fn vmp_apply_add_tmp_bytes(
83        &self,
84        n: usize,
85        res_size: usize,
86        a_size: usize,
87        b_rows: usize,
88        b_cols_in: usize,
89        b_cols_out: usize,
90        b_size: usize,
91    ) -> usize;
92}
93
94pub trait VmpApplyAdd<B: Backend> {
95    fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
96    where
97        R: VecZnxDftToMut<B>,
98        A: VecZnxDftToRef<B>,
99        C: VmpPMatToRef<B>;
100}