1use anyhow::Result;
2use poulpy_core::layouts::{GLWEInfos, GLWEToBackendMut, LWEInfos};
3use poulpy_hal::{
4 api::{
5 VecZnxLshBackend, VecZnxLshTmpBytes, VecZnxRshAddCoeffIntoBackend, VecZnxRshAddIntoBackend, VecZnxRshBackend,
6 VecZnxRshSubBackend, VecZnxRshSubCoeffIntoBackend, VecZnxRshTmpBytes,
7 },
8 layouts::{Backend, ScratchArena},
9};
10
11use crate::GLWEToBackendRef;
12
13use crate::{
14 CKKSInfos, CKKSMeta, SetCKKSInfos, ensure_base2k_match, ensure_plaintext_alignment, ensure_plaintext_coeff_in_range,
15 ensure_plaintext_degree_match,
16};
17
18pub trait CKKSPlaintextDefault<BE: Backend> {
19 fn ckks_add_pt_vec_into_default<Dst, A>(&self, ct: &mut Dst, pt: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
20 where
21 Self: VecZnxRshAddIntoBackend<BE>,
22 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
23 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
24 {
25 const OP: &str = "ckks_add_pt_vec";
26 ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
27 ensure_plaintext_degree_match(OP, ct.n().as_usize(), pt.n().as_usize())?;
28 let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
29 let base2k = ct.base2k().as_usize();
30 let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
31 let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
32 self.vec_znx_rsh_add_into_backend(base2k, offset, ct_ref.data_mut(), 0, pt_ref.data(), 0, scratch);
33 Ok(())
34 }
35
36 fn ckks_add_pt_const_into_default<Dst, A>(
37 &self,
38 ct: &mut Dst,
39 coeff_ct: usize,
40 pt: &A,
41 coeff_pt: usize,
42 scratch: &mut ScratchArena<'_, BE>,
43 ) -> Result<()>
44 where
45 Self: VecZnxRshAddCoeffIntoBackend<BE>,
46 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
47 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
48 {
49 const OP: &str = "ckks_add_pt_const";
50 ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
51 ensure_plaintext_coeff_in_range(OP, "ciphertext", coeff_ct, ct.n().as_usize())?;
52 ensure_plaintext_coeff_in_range(OP, "plaintext", coeff_pt, pt.n().as_usize())?;
53 let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
54 let base2k = ct.base2k().as_usize();
55 let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
56 let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
57 self.vec_znx_rsh_add_coeff_into_backend(
58 base2k,
59 offset,
60 ct_ref.data_mut(),
61 0,
62 pt_ref.data(),
63 0,
64 coeff_pt,
65 coeff_ct,
66 scratch,
67 );
68
69 Ok(())
70 }
71
72 fn ckks_sub_pt_const_into_default<Dst, A>(
73 &self,
74 ct: &mut Dst,
75 coeff_ct: usize,
76 pt: &A,
77 coeff_pt: usize,
78 scratch: &mut ScratchArena<'_, BE>,
79 ) -> Result<()>
80 where
81 Self: VecZnxRshSubCoeffIntoBackend<BE>,
82 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
83 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
84 {
85 const OP: &str = "ckks_sub_pt_const";
86 ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
87 ensure_plaintext_coeff_in_range(OP, "ciphertext", coeff_ct, ct.n().as_usize())?;
88 ensure_plaintext_coeff_in_range(OP, "plaintext", coeff_pt, pt.n().as_usize())?;
89 let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
90 let base2k = ct.base2k().as_usize();
91 let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
92 let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
93 self.vec_znx_rsh_sub_coeff_into_backend(
94 base2k,
95 offset,
96 ct_ref.data_mut(),
97 0,
98 pt_ref.data(),
99 0,
100 coeff_pt,
101 coeff_ct,
102 scratch,
103 );
104
105 Ok(())
106 }
107
108 fn ckks_sub_pt_vec_into_default<Dst, A>(&self, ct: &mut Dst, pt: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
109 where
110 Self: VecZnxRshSubBackend<BE>,
111 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
112 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
113 {
114 const OP: &str = "ckks_sub_pt_vec";
115 ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
116 ensure_plaintext_degree_match(OP, ct.n().as_usize(), pt.n().as_usize())?;
117 let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
118 let base2k = ct.base2k().as_usize();
119 let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
120 let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
121 self.vec_znx_rsh_sub_backend(base2k, offset, ct_ref.data_mut(), 0, pt_ref.data(), 0, scratch);
122 Ok(())
123 }
124
125 fn ckks_extract_pt_tmp_bytes_default(&self) -> usize
126 where
127 Self: VecZnxLshTmpBytes + VecZnxRshTmpBytes,
128 {
129 self.vec_znx_rsh_tmp_bytes().max(self.vec_znx_lsh_tmp_bytes())
130 }
131
132 fn ckks_extract_pt_default<D, S>(&self, dst: &mut D, src: &S, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
133 where
134 D: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
135 S: GLWEToBackendRef<BE> + GLWEInfos + LWEInfos + CKKSInfos,
136 Self: VecZnxLshBackend<BE> + VecZnxRshBackend<BE>,
137 {
138 self.ckks_extract_pt_with_meta_default(dst, src, src.meta(), scratch)
139 }
140
141 fn ckks_extract_pt_with_meta_default<D, S>(
142 &self,
143 dst: &mut D,
144 src: &S,
145 src_meta: CKKSMeta,
146 scratch: &mut ScratchArena<'_, BE>,
147 ) -> Result<()>
148 where
149 D: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
150 S: GLWEToBackendRef<BE> + GLWEInfos + LWEInfos,
151 Self: VecZnxLshBackend<BE> + VecZnxRshBackend<BE>,
152 {
153 ensure_base2k_match("ckks_extract_pt", src.base2k().as_usize(), dst.base2k().as_usize())?;
154 let available = src_meta.log_budget() + dst.log_delta();
155 if available < dst.effective_k() {
156 return Err(crate::CKKSCompositionError::PlaintextAlignmentImpossible {
157 op: "ckks_extract_pt",
158 ct_log_budget: src_meta.log_budget(),
159 pt_log_delta: dst.log_delta(),
160 pt_k: dst.effective_k(),
161 }
162 .into());
163 }
164 let dst_k = dst.max_k().as_usize();
165 let dst_base2k: usize = dst.base2k().into();
166 let mut dst_ref = GLWEToBackendMut::to_backend_mut(dst);
167 let src_ref = GLWEToBackendRef::to_backend_ref(src);
168
169 if available < dst_k {
170 self.vec_znx_rsh_backend(
171 dst_base2k,
172 dst_k - available,
173 dst_ref.data_mut(),
174 0,
175 src_ref.data(),
176 0,
177 scratch,
178 );
179 } else if available > dst_k {
180 self.vec_znx_lsh_backend(
181 dst_base2k,
182 available - dst_k,
183 dst_ref.data_mut(),
184 0,
185 src_ref.data(),
186 0,
187 scratch,
188 );
189 } else {
190 self.vec_znx_rsh_backend(dst_base2k, 0, dst_ref.data_mut(), 0, src_ref.data(), 0, scratch);
191 }
192 Ok(())
193 }
194}