Skip to main content

poulpy_core/layouts/
glwe_tensor.rs

1use poulpy_hal::{
2    layouts::{Backend, Data, FillUniform, HostDataMut, HostDataRef, VecZnx, VecZnxToBackendMut, VecZnxToBackendRef},
3    source::Source,
4};
5
6use crate::layouts::{
7    Base2K, Degree, GLWE, GLWEBackendMut, GLWEBackendRef, GLWEInfos, GLWEToBackendMut, GLWEToBackendRef, LWEInfos, Rank,
8    SetLWEInfos, TorusPrecision,
9};
10use std::fmt;
11
12#[derive(PartialEq, Eq, Clone)]
13pub struct GLWETensor<D: Data> {
14    pub(crate) data: VecZnx<D>,
15    pub(crate) base2k: Base2K,
16    pub(crate) rank: Rank,
17}
18
19pub type GLWETensorBackendRef<'a, BE> = GLWETensor<<BE as Backend>::BufRef<'a>>;
20pub type GLWETensorBackendMut<'a, BE> = GLWETensor<<BE as Backend>::BufMut<'a>>;
21
22impl<D: HostDataMut> SetLWEInfos for GLWETensor<D> {
23    fn set_base2k(&mut self, base2k: Base2K) {
24        self.base2k = base2k
25    }
26}
27
28impl<D: HostDataRef> GLWETensor<D> {
29    pub fn data(&self) -> &VecZnx<D> {
30        &self.data
31    }
32}
33
34impl<D: HostDataMut> GLWETensor<D> {
35    pub fn data_mut(&mut self) -> &mut VecZnx<D> {
36        &mut self.data
37    }
38}
39
40impl<D: Data> LWEInfos for GLWETensor<D> {
41    fn base2k(&self) -> Base2K {
42        self.base2k
43    }
44
45    fn n(&self) -> Degree {
46        Degree(self.data.n() as u32)
47    }
48
49    fn size(&self) -> usize {
50        self.data.size()
51    }
52}
53
54impl<D: Data> LWEInfos for &mut GLWETensor<D> {
55    fn base2k(&self) -> Base2K {
56        self.base2k
57    }
58
59    fn n(&self) -> Degree {
60        Degree(self.data.n() as u32)
61    }
62
63    fn size(&self) -> usize {
64        self.data.size()
65    }
66}
67
68impl<D: Data> GLWEInfos for GLWETensor<D> {
69    ///NOTE: self.rank() != self.to_ref().rank() if self is of type [GLWETensor]
70    fn rank(&self) -> Rank {
71        self.rank
72    }
73}
74
75impl<D: Data> GLWEInfos for &mut GLWETensor<D> {
76    fn rank(&self) -> Rank {
77        self.rank
78    }
79}
80
81impl<D: HostDataRef> fmt::Debug for GLWETensor<D> {
82    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83        write!(f, "{self}")
84    }
85}
86
87impl<D: HostDataRef> fmt::Display for GLWETensor<D> {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        write!(
90            f,
91            "GLWETensor: base2k={} k={}: {}",
92            self.base2k().0,
93            self.max_k().0,
94            self.data
95        )
96    }
97}
98
99impl<D: HostDataMut> FillUniform for GLWETensor<D> {
100    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
101        self.data.fill_uniform(log_bound, source);
102    }
103}
104
105#[expect(
106    dead_code,
107    reason = "host-owned constructors are kept for serialization and host-only staging"
108)]
109impl GLWETensor<Vec<u8>> {
110    pub(crate) fn alloc_from_infos<A>(infos: &A) -> Self
111    where
112        A: GLWEInfos,
113    {
114        Self::alloc(infos.n(), infos.base2k(), infos.max_k(), infos.rank())
115    }
116
117    pub(crate) fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self {
118        let cols: usize = rank.as_usize() + 1;
119        let pairs: usize = (((cols + 1) * cols) >> 1).max(1);
120        let size: usize = k.0.div_ceil(base2k.0) as usize;
121        GLWETensor {
122            data: VecZnx::from_data(
123                poulpy_hal::layouts::HostBytesBackend::alloc_bytes(VecZnx::<Vec<u8>>::bytes_of(n.into(), pairs, size)),
124                n.into(),
125                pairs,
126                size,
127            ),
128            base2k,
129            rank,
130        }
131    }
132
133    pub fn bytes_of_from_infos<A>(infos: &A) -> usize
134    where
135        A: GLWEInfos,
136    {
137        Self::bytes_of(infos.n(), infos.base2k(), infos.max_k(), infos.rank())
138    }
139
140    pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize {
141        let cols: usize = rank.as_usize() + 1;
142        let pairs: usize = (((cols + 1) * cols) >> 1).max(1);
143        VecZnx::bytes_of(n.into(), pairs, k.0.div_ceil(base2k.0) as usize)
144    }
145}
146
147impl<BE: Backend, D: Data> GLWEToBackendRef<BE> for GLWETensor<D>
148where
149    VecZnx<D>: VecZnxToBackendRef<BE>,
150{
151    fn to_backend_ref(&self) -> GLWEBackendRef<'_, BE> {
152        GLWE {
153            base2k: self.base2k,
154            data: self.data.to_backend_ref(),
155        }
156    }
157}
158
159impl<BE: Backend, D: Data> GLWEToBackendRef<BE> for &GLWETensor<D>
160where
161    VecZnx<D>: VecZnxToBackendRef<BE>,
162{
163    fn to_backend_ref(&self) -> GLWEBackendRef<'_, BE> {
164        GLWE {
165            base2k: self.base2k,
166            data: self.data.to_backend_ref(),
167        }
168    }
169}
170
171impl<BE: Backend, D: Data> GLWEToBackendMut<BE> for GLWETensor<D>
172where
173    VecZnx<D>: VecZnxToBackendRef<BE> + VecZnxToBackendMut<BE>,
174{
175    fn to_backend_mut(&mut self) -> GLWEBackendMut<'_, BE> {
176        GLWE {
177            base2k: self.base2k,
178            data: self.data.to_backend_mut(),
179        }
180    }
181}
182
183impl<'b, BE: Backend + 'b> GLWEToBackendRef<BE> for &mut GLWETensor<BE::BufMut<'b>> {
184    fn to_backend_ref(&self) -> GLWEBackendRef<'_, BE> {
185        GLWE {
186            base2k: self.base2k,
187            data: poulpy_hal::layouts::vec_znx_backend_ref_from_mut::<BE>(&self.data),
188        }
189    }
190}
191
192impl<'b, BE: Backend + 'b> GLWEToBackendMut<BE> for &mut GLWETensor<BE::BufMut<'b>> {
193    fn to_backend_mut(&mut self) -> GLWEBackendMut<'_, BE> {
194        GLWE {
195            base2k: self.base2k,
196            data: poulpy_hal::layouts::vec_znx_backend_mut_from_mut::<BE>(&mut self.data),
197        }
198    }
199}