Skip to main content

poulpy_ckks/default/
plaintext.rs

1use anyhow::Result;
2use poulpy_core::layouts::{GLWEInfos, GLWEToBackendMut, LWEInfos};
3use poulpy_hal::{
4    api::{
5        VecZnxLshBackend, VecZnxLshTmpBytes, VecZnxRshAddCoeffIntoBackend, VecZnxRshAddIntoBackend, VecZnxRshBackend,
6        VecZnxRshSubBackend, VecZnxRshSubCoeffIntoBackend, VecZnxRshTmpBytes,
7    },
8    layouts::{Backend, ScratchArena},
9};
10
11use crate::GLWEToBackendRef;
12
13use crate::{
14    CKKSInfos, CKKSMeta, SetCKKSInfos, ensure_base2k_match, ensure_plaintext_alignment, ensure_plaintext_coeff_in_range,
15    ensure_plaintext_degree_match,
16};
17
18pub trait CKKSPlaintextDefault<BE: Backend> {
19    fn ckks_add_pt_vec_into_default<Dst, A>(&self, ct: &mut Dst, pt: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
20    where
21        Self: VecZnxRshAddIntoBackend<BE>,
22        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
23        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
24    {
25        const OP: &str = "ckks_add_pt_vec";
26        ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
27        ensure_plaintext_degree_match(OP, ct.n().as_usize(), pt.n().as_usize())?;
28        let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
29        let base2k = ct.base2k().as_usize();
30        let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
31        let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
32        self.vec_znx_rsh_add_into_backend(base2k, offset, ct_ref.data_mut(), 0, pt_ref.data(), 0, scratch);
33        Ok(())
34    }
35
36    fn ckks_add_pt_const_into_default<Dst, A>(
37        &self,
38        ct: &mut Dst,
39        coeff_ct: usize,
40        pt: &A,
41        coeff_pt: usize,
42        scratch: &mut ScratchArena<'_, BE>,
43    ) -> Result<()>
44    where
45        Self: VecZnxRshAddCoeffIntoBackend<BE>,
46        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
47        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
48    {
49        const OP: &str = "ckks_add_pt_const";
50        ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
51        ensure_plaintext_coeff_in_range(OP, "ciphertext", coeff_ct, ct.n().as_usize())?;
52        ensure_plaintext_coeff_in_range(OP, "plaintext", coeff_pt, pt.n().as_usize())?;
53        let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
54        let base2k = ct.base2k().as_usize();
55        let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
56        let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
57        self.vec_znx_rsh_add_coeff_into_backend(
58            base2k,
59            offset,
60            ct_ref.data_mut(),
61            0,
62            pt_ref.data(),
63            0,
64            coeff_pt,
65            coeff_ct,
66            scratch,
67        );
68
69        Ok(())
70    }
71
72    fn ckks_sub_pt_const_into_default<Dst, A>(
73        &self,
74        ct: &mut Dst,
75        coeff_ct: usize,
76        pt: &A,
77        coeff_pt: usize,
78        scratch: &mut ScratchArena<'_, BE>,
79    ) -> Result<()>
80    where
81        Self: VecZnxRshSubCoeffIntoBackend<BE>,
82        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
83        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
84    {
85        const OP: &str = "ckks_sub_pt_const";
86        ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
87        ensure_plaintext_coeff_in_range(OP, "ciphertext", coeff_ct, ct.n().as_usize())?;
88        ensure_plaintext_coeff_in_range(OP, "plaintext", coeff_pt, pt.n().as_usize())?;
89        let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
90        let base2k = ct.base2k().as_usize();
91        let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
92        let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
93        self.vec_znx_rsh_sub_coeff_into_backend(
94            base2k,
95            offset,
96            ct_ref.data_mut(),
97            0,
98            pt_ref.data(),
99            0,
100            coeff_pt,
101            coeff_ct,
102            scratch,
103        );
104
105        Ok(())
106    }
107
108    fn ckks_sub_pt_vec_into_default<Dst, A>(&self, ct: &mut Dst, pt: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
109    where
110        Self: VecZnxRshSubBackend<BE>,
111        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
112        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
113    {
114        const OP: &str = "ckks_sub_pt_vec";
115        ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
116        ensure_plaintext_degree_match(OP, ct.n().as_usize(), pt.n().as_usize())?;
117        let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
118        let base2k = ct.base2k().as_usize();
119        let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
120        let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
121        self.vec_znx_rsh_sub_backend(base2k, offset, ct_ref.data_mut(), 0, pt_ref.data(), 0, scratch);
122        Ok(())
123    }
124
125    fn ckks_extract_pt_tmp_bytes_default(&self) -> usize
126    where
127        Self: VecZnxLshTmpBytes + VecZnxRshTmpBytes,
128    {
129        self.vec_znx_rsh_tmp_bytes().max(self.vec_znx_lsh_tmp_bytes())
130    }
131
132    fn ckks_extract_pt_default<D, S>(&self, dst: &mut D, src: &S, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
133    where
134        D: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
135        S: GLWEToBackendRef<BE> + GLWEInfos + LWEInfos + CKKSInfos,
136        Self: VecZnxLshBackend<BE> + VecZnxRshBackend<BE>,
137    {
138        self.ckks_extract_pt_with_meta_default(dst, src, src.meta(), scratch)
139    }
140
141    fn ckks_extract_pt_with_meta_default<D, S>(
142        &self,
143        dst: &mut D,
144        src: &S,
145        src_meta: CKKSMeta,
146        scratch: &mut ScratchArena<'_, BE>,
147    ) -> Result<()>
148    where
149        D: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
150        S: GLWEToBackendRef<BE> + GLWEInfos + LWEInfos,
151        Self: VecZnxLshBackend<BE> + VecZnxRshBackend<BE>,
152    {
153        ensure_base2k_match("ckks_extract_pt", src.base2k().as_usize(), dst.base2k().as_usize())?;
154        let available = src_meta.log_budget() + dst.log_delta();
155        if available < dst.effective_k() {
156            return Err(crate::CKKSCompositionError::PlaintextAlignmentImpossible {
157                op: "ckks_extract_pt",
158                ct_log_budget: src_meta.log_budget(),
159                pt_log_delta: dst.log_delta(),
160                pt_k: dst.effective_k(),
161            }
162            .into());
163        }
164        let dst_k = dst.max_k().as_usize();
165        let dst_base2k: usize = dst.base2k().into();
166        let mut dst_ref = GLWEToBackendMut::to_backend_mut(dst);
167        let src_ref = GLWEToBackendRef::to_backend_ref(src);
168
169        if available < dst_k {
170            self.vec_znx_rsh_backend(
171                dst_base2k,
172                dst_k - available,
173                dst_ref.data_mut(),
174                0,
175                src_ref.data(),
176                0,
177                scratch,
178            );
179        } else if available > dst_k {
180            self.vec_znx_lsh_backend(
181                dst_base2k,
182                available - dst_k,
183                dst_ref.data_mut(),
184                0,
185                src_ref.data(),
186                0,
187                scratch,
188            );
189        } else {
190            self.vec_znx_rsh_backend(dst_base2k, 0, dst_ref.data_mut(), 0, src_ref.data(), 0, scratch);
191        }
192        Ok(())
193    }
194}