1use crate::rns::{Channels, RnsInt};
11
12#[derive(Clone, Debug)]
14pub struct RnsBatch {
15 pub data: Vec<u64>,
17 pub batch_size: usize,
19 pub channels: Channels,
21}
22
23impl RnsBatch {
24 pub fn new(batch_size: usize, channels: Channels) -> Self {
26 Self::zeros(batch_size, channels)
27 }
28
29 pub fn zeros(batch_size: usize, channels: Channels) -> Self {
31 let k = channels.len();
32 RnsBatch {
33 data: vec![0; batch_size * k],
34 batch_size,
35 channels,
36 }
37 }
38
39 #[inline]
41 pub fn channels_len(&self) -> usize {
42 self.channels.len()
43 }
44
45 #[inline]
47 pub fn get(&self, b: usize, c: usize) -> u64 {
48 self.data[b * self.channels.len() + c]
49 }
50
51 #[inline]
53 pub fn set(&mut self, b: usize, c: usize, val: u64) {
54 let k = self.channels.len();
55 self.data[b * k + c] = val;
56 }
57
58 pub fn from_rns_ints(items: &[RnsInt]) -> Self {
63 assert!(!items.is_empty(), "cannot build an RnsBatch from zero items");
64 let channels = items[0].channels.clone();
65 let k = channels.len();
66 let mut data = Vec::with_capacity(items.len() * k);
67 for item in items {
68 debug_assert_eq!(item.channels, channels, "all items must share channels");
69 debug_assert_eq!(item.residues.len(), k);
70 data.extend_from_slice(&item.residues);
71 }
72 RnsBatch {
73 data,
74 batch_size: items.len(),
75 channels,
76 }
77 }
78
79 pub fn to_rns_ints(&self) -> Vec<RnsInt> {
81 let k = self.channels.len();
82 (0..self.batch_size)
83 .map(|b| {
84 let residues = self.data[b * k..(b + 1) * k].to_vec();
85 RnsInt::from_residues(residues, self.channels.clone())
86 })
87 .collect()
88 }
89
90 pub fn as_u32_bytes(&self) -> Vec<u8> {
95 let as_u32: Vec<u32> = self.data.iter().map(|&v| v as u32).collect();
96 bytemuck::cast_slice(&as_u32).to_vec()
97 }
98
99 pub fn from_u32(values: &[u32], batch_size: usize, channels: Channels) -> Self {
101 let data = values.iter().map(|&v| v as u64).collect();
102 RnsBatch {
103 data,
104 batch_size,
105 channels,
106 }
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn get_set_roundtrip() {
116 let ch = Channels::standard(8);
117 let mut batch = RnsBatch::zeros(4, ch);
118 batch.set(2, 3, 7);
119 assert_eq!(batch.get(2, 3), 7);
120 assert_eq!(batch.get(0, 0), 0);
121 }
122
123 #[test]
124 fn pack_unpack_roundtrip() {
125 let ch = Channels::standard(16);
126 let items = vec![
127 RnsInt::from_i64(123, ch.clone()),
128 RnsInt::from_i64(456, ch.clone()),
129 RnsInt::from_i64(789, ch.clone()),
130 ];
131 let batch = RnsBatch::from_rns_ints(&items);
132 assert_eq!(batch.batch_size, 3);
133 let back = batch.to_rns_ints();
134 for (a, b) in items.iter().zip(back.iter()) {
135 assert_eq!(a.to_bigint(), b.to_bigint());
136 }
137 }
138
139 #[test]
140 fn u32_byte_layout() {
141 let ch = Channels::standard(4);
142 let mut batch = RnsBatch::zeros(1, ch);
143 batch.set(0, 0, 1);
144 batch.set(0, 1, 2);
145 let bytes = batch.as_u32_bytes();
146 assert_eq!(bytes.len(), 4 * 4); assert_eq!(bytes[0], 1);
148 assert_eq!(bytes[4], 2);
149 }
150}