poulpy_core/layouts/compressed/
gglwe_tsk.rs

1use poulpy_hal::{
2    api::{VecZnxCopy, VecZnxFillUniform},
3    layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo},
4    source::Source,
5};
6
7use crate::layouts::{
8    GGLWETensorKey, Infos,
9    compressed::{Decompress, GGLWESwitchingKeyCompressed},
10};
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12use std::fmt;
13
14#[derive(PartialEq, Eq, Clone)]
15pub struct GGLWETensorKeyCompressed<D: Data> {
16    pub(crate) keys: Vec<GGLWESwitchingKeyCompressed<D>>,
17}
18
19impl<D: DataRef> fmt::Debug for GGLWETensorKeyCompressed<D> {
20    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
21        write!(f, "{}", self)
22    }
23}
24
25impl<D: DataMut> FillUniform for GGLWETensorKeyCompressed<D> {
26    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
27        self.keys
28            .iter_mut()
29            .for_each(|key: &mut GGLWESwitchingKeyCompressed<D>| key.fill_uniform(log_bound, source))
30    }
31}
32
33impl<D: DataMut> Reset for GGLWETensorKeyCompressed<D>
34where
35    MatZnx<D>: Reset,
36{
37    fn reset(&mut self) {
38        self.keys
39            .iter_mut()
40            .for_each(|key: &mut GGLWESwitchingKeyCompressed<D>| key.reset())
41    }
42}
43
44impl<D: DataRef> fmt::Display for GGLWETensorKeyCompressed<D> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        writeln!(f, "(GLWETensorKeyCompressed)",)?;
47        for (i, key) in self.keys.iter().enumerate() {
48            write!(f, "{}: {}", i, key)?;
49        }
50        Ok(())
51    }
52}
53
54impl GGLWETensorKeyCompressed<Vec<u8>> {
55    pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
56        let mut keys: Vec<GGLWESwitchingKeyCompressed<Vec<u8>>> = Vec::new();
57        let pairs: usize = (((rank + 1) * rank) >> 1).max(1);
58        (0..pairs).for_each(|_| {
59            keys.push(GGLWESwitchingKeyCompressed::alloc(
60                n, basek, k, rows, digits, 1, rank,
61            ));
62        });
63        Self { keys }
64    }
65
66    pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
67        let pairs: usize = (((rank + 1) * rank) >> 1).max(1);
68        pairs * GGLWESwitchingKeyCompressed::bytes_of(n, basek, k, rows, digits, 1)
69    }
70}
71
72impl<D: Data> Infos for GGLWETensorKeyCompressed<D> {
73    type Inner = MatZnx<D>;
74
75    fn inner(&self) -> &Self::Inner {
76        self.keys[0].inner()
77    }
78
79    fn basek(&self) -> usize {
80        self.keys[0].basek()
81    }
82
83    fn k(&self) -> usize {
84        self.keys[0].k()
85    }
86}
87
88impl<D: Data> GGLWETensorKeyCompressed<D> {
89    pub fn rank(&self) -> usize {
90        self.keys[0].rank()
91    }
92
93    pub fn digits(&self) -> usize {
94        self.keys[0].digits()
95    }
96
97    pub fn rank_in(&self) -> usize {
98        self.keys[0].rank_in()
99    }
100
101    pub fn rank_out(&self) -> usize {
102        self.keys[0].rank_out()
103    }
104}
105
106impl<D: DataMut> ReaderFrom for GGLWETensorKeyCompressed<D> {
107    fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
108        let len: usize = reader.read_u64::<LittleEndian>()? as usize;
109        if self.keys.len() != len {
110            return Err(std::io::Error::new(
111                std::io::ErrorKind::InvalidData,
112                format!("self.keys.len()={} != read len={}", self.keys.len(), len),
113            ));
114        }
115        for key in &mut self.keys {
116            key.read_from(reader)?;
117        }
118        Ok(())
119    }
120}
121
122impl<D: DataRef> WriterTo for GGLWETensorKeyCompressed<D> {
123    fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
124        writer.write_u64::<LittleEndian>(self.keys.len() as u64)?;
125        for key in &self.keys {
126            key.write_to(writer)?;
127        }
128        Ok(())
129    }
130}
131
132impl<D: DataMut> GGLWETensorKeyCompressed<D> {
133    pub(crate) fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKeyCompressed<D> {
134        if i > j {
135            std::mem::swap(&mut i, &mut j);
136        };
137        let rank: usize = self.rank();
138        &mut self.keys[i * rank + j - (i * (i + 1) / 2)]
139    }
140}
141
142impl<D: DataMut, DR: DataRef, B: Backend> Decompress<B, GGLWETensorKeyCompressed<DR>> for GGLWETensorKey<D>
143where
144    Module<B>: VecZnxFillUniform + VecZnxCopy,
145{
146    fn decompress(&mut self, module: &Module<B>, other: &GGLWETensorKeyCompressed<DR>) {
147        #[cfg(debug_assertions)]
148        {
149            assert_eq!(
150                self.keys.len(),
151                other.keys.len(),
152                "invalid receiver: self.keys.len()={} != other.keys.len()={}",
153                self.keys.len(),
154                other.keys.len()
155            );
156        }
157
158        self.keys
159            .iter_mut()
160            .zip(other.keys.iter())
161            .for_each(|(a, b)| {
162                a.decompress(module, b);
163            });
164    }
165}