1use poulpy_hal::{
2 layouts::{Backend, Data, FillUniform, HostDataMut, HostDataRef, VecZnx, VecZnxToBackendMut, VecZnxToBackendRef},
3 source::Source,
4};
5
6use crate::layouts::{
7 Base2K, Degree, GLWE, GLWEBackendMut, GLWEBackendRef, GLWEInfos, GLWEToBackendMut, GLWEToBackendRef, LWEInfos, Rank,
8 SetLWEInfos, TorusPrecision,
9};
10use std::fmt;
11
12#[derive(PartialEq, Eq, Clone)]
13pub struct GLWETensor<D: Data> {
14 pub(crate) data: VecZnx<D>,
15 pub(crate) base2k: Base2K,
16 pub(crate) rank: Rank,
17}
18
19pub type GLWETensorBackendRef<'a, BE> = GLWETensor<<BE as Backend>::BufRef<'a>>;
20pub type GLWETensorBackendMut<'a, BE> = GLWETensor<<BE as Backend>::BufMut<'a>>;
21
22impl<D: HostDataMut> SetLWEInfos for GLWETensor<D> {
23 fn set_base2k(&mut self, base2k: Base2K) {
24 self.base2k = base2k
25 }
26}
27
28impl<D: HostDataRef> GLWETensor<D> {
29 pub fn data(&self) -> &VecZnx<D> {
30 &self.data
31 }
32}
33
34impl<D: HostDataMut> GLWETensor<D> {
35 pub fn data_mut(&mut self) -> &mut VecZnx<D> {
36 &mut self.data
37 }
38}
39
40impl<D: Data> LWEInfos for GLWETensor<D> {
41 fn base2k(&self) -> Base2K {
42 self.base2k
43 }
44
45 fn n(&self) -> Degree {
46 Degree(self.data.n() as u32)
47 }
48
49 fn size(&self) -> usize {
50 self.data.size()
51 }
52}
53
54impl<D: Data> LWEInfos for &mut GLWETensor<D> {
55 fn base2k(&self) -> Base2K {
56 self.base2k
57 }
58
59 fn n(&self) -> Degree {
60 Degree(self.data.n() as u32)
61 }
62
63 fn size(&self) -> usize {
64 self.data.size()
65 }
66}
67
68impl<D: Data> GLWEInfos for GLWETensor<D> {
69 fn rank(&self) -> Rank {
71 self.rank
72 }
73}
74
75impl<D: Data> GLWEInfos for &mut GLWETensor<D> {
76 fn rank(&self) -> Rank {
77 self.rank
78 }
79}
80
81impl<D: HostDataRef> fmt::Debug for GLWETensor<D> {
82 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83 write!(f, "{self}")
84 }
85}
86
87impl<D: HostDataRef> fmt::Display for GLWETensor<D> {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 write!(
90 f,
91 "GLWETensor: base2k={} k={}: {}",
92 self.base2k().0,
93 self.max_k().0,
94 self.data
95 )
96 }
97}
98
99impl<D: HostDataMut> FillUniform for GLWETensor<D> {
100 fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
101 self.data.fill_uniform(log_bound, source);
102 }
103}
104
105#[expect(
106 dead_code,
107 reason = "host-owned constructors are kept for serialization and host-only staging"
108)]
109impl GLWETensor<Vec<u8>> {
110 pub(crate) fn alloc_from_infos<A>(infos: &A) -> Self
111 where
112 A: GLWEInfos,
113 {
114 Self::alloc(infos.n(), infos.base2k(), infos.max_k(), infos.rank())
115 }
116
117 pub(crate) fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self {
118 let cols: usize = rank.as_usize() + 1;
119 let pairs: usize = (((cols + 1) * cols) >> 1).max(1);
120 let size: usize = k.0.div_ceil(base2k.0) as usize;
121 GLWETensor {
122 data: VecZnx::from_data(
123 poulpy_hal::layouts::HostBytesBackend::alloc_bytes(VecZnx::<Vec<u8>>::bytes_of(n.into(), pairs, size)),
124 n.into(),
125 pairs,
126 size,
127 ),
128 base2k,
129 rank,
130 }
131 }
132
133 pub fn bytes_of_from_infos<A>(infos: &A) -> usize
134 where
135 A: GLWEInfos,
136 {
137 Self::bytes_of(infos.n(), infos.base2k(), infos.max_k(), infos.rank())
138 }
139
140 pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize {
141 let cols: usize = rank.as_usize() + 1;
142 let pairs: usize = (((cols + 1) * cols) >> 1).max(1);
143 VecZnx::bytes_of(n.into(), pairs, k.0.div_ceil(base2k.0) as usize)
144 }
145}
146
147impl<BE: Backend, D: Data> GLWEToBackendRef<BE> for GLWETensor<D>
148where
149 VecZnx<D>: VecZnxToBackendRef<BE>,
150{
151 fn to_backend_ref(&self) -> GLWEBackendRef<'_, BE> {
152 GLWE {
153 base2k: self.base2k,
154 data: self.data.to_backend_ref(),
155 }
156 }
157}
158
159impl<BE: Backend, D: Data> GLWEToBackendRef<BE> for &GLWETensor<D>
160where
161 VecZnx<D>: VecZnxToBackendRef<BE>,
162{
163 fn to_backend_ref(&self) -> GLWEBackendRef<'_, BE> {
164 GLWE {
165 base2k: self.base2k,
166 data: self.data.to_backend_ref(),
167 }
168 }
169}
170
171impl<BE: Backend, D: Data> GLWEToBackendMut<BE> for GLWETensor<D>
172where
173 VecZnx<D>: VecZnxToBackendRef<BE> + VecZnxToBackendMut<BE>,
174{
175 fn to_backend_mut(&mut self) -> GLWEBackendMut<'_, BE> {
176 GLWE {
177 base2k: self.base2k,
178 data: self.data.to_backend_mut(),
179 }
180 }
181}
182
183impl<'b, BE: Backend + 'b> GLWEToBackendRef<BE> for &mut GLWETensor<BE::BufMut<'b>> {
184 fn to_backend_ref(&self) -> GLWEBackendRef<'_, BE> {
185 GLWE {
186 base2k: self.base2k,
187 data: poulpy_hal::layouts::vec_znx_backend_ref_from_mut::<BE>(&self.data),
188 }
189 }
190}
191
192impl<'b, BE: Backend + 'b> GLWEToBackendMut<BE> for &mut GLWETensor<BE::BufMut<'b>> {
193 fn to_backend_mut(&mut self) -> GLWEBackendMut<'_, BE> {
194 GLWE {
195 base2k: self.base2k,
196 data: poulpy_hal::layouts::vec_znx_backend_mut_from_mut::<BE>(&mut self.data),
197 }
198 }
199}