Skip to main content

adele_ring/
batch.rs

1//! `RnsBatch` — the single flat buffer format shared by both backends.
2//!
3//! Having one canonical layout means there is **no reformatting** when switching
4//! between the CPU (rayon) and GPU (wgpu) backends: the GPU storage buffer *is*
5//! this struct's `data`, and the CPU just iterates it in place.
6//!
7//! Layout is row-major `[batch_size × n_channels]`: element
8//! `data[b * n_channels + c]` is the residue of item `b` in channel `c`.
9
10use crate::rns::{Channels, RnsInt};
11
12/// A batch of RNS values in a flat, backend-agnostic buffer.
13#[derive(Clone, Debug)]
14pub struct RnsBatch {
15    /// Flat row-major residues: length `batch_size * channels.len()`.
16    pub data: Vec<u64>,
17    /// Number of items `B`.
18    pub batch_size: usize,
19    /// The `K` channels shared by every item.
20    pub channels: Channels,
21}
22
23impl RnsBatch {
24    /// Allocate a zeroed batch (alias of [`RnsBatch::zeros`]).
25    pub fn new(batch_size: usize, channels: Channels) -> Self {
26        Self::zeros(batch_size, channels)
27    }
28
29    /// Allocate a batch with all residues set to zero.
30    pub fn zeros(batch_size: usize, channels: Channels) -> Self {
31        let k = channels.len();
32        RnsBatch {
33            data: vec![0; batch_size * k],
34            batch_size,
35            channels,
36        }
37    }
38
39    /// Number of channels `K`.
40    #[inline]
41    pub fn channels_len(&self) -> usize {
42        self.channels.len()
43    }
44
45    /// Residue of item `b` in channel `c`.
46    #[inline]
47    pub fn get(&self, b: usize, c: usize) -> u64 {
48        self.data[b * self.channels.len() + c]
49    }
50
51    /// Set the residue of item `b` in channel `c`.
52    #[inline]
53    pub fn set(&mut self, b: usize, c: usize, val: u64) {
54        let k = self.channels.len();
55        self.data[b * k + c] = val;
56    }
57
58    /// Pack a slice of [`RnsInt`] values into a batch.
59    ///
60    /// Panics if `items` is empty (no channels to infer) or if the items do not
61    /// all share the same channels.
62    pub fn from_rns_ints(items: &[RnsInt]) -> Self {
63        assert!(!items.is_empty(), "cannot build an RnsBatch from zero items");
64        let channels = items[0].channels.clone();
65        let k = channels.len();
66        let mut data = Vec::with_capacity(items.len() * k);
67        for item in items {
68            debug_assert_eq!(item.channels, channels, "all items must share channels");
69            debug_assert_eq!(item.residues.len(), k);
70            data.extend_from_slice(&item.residues);
71        }
72        RnsBatch {
73            data,
74            batch_size: items.len(),
75            channels,
76        }
77    }
78
79    /// Unpack the batch back into individual [`RnsInt`] values.
80    pub fn to_rns_ints(&self) -> Vec<RnsInt> {
81        let k = self.channels.len();
82        (0..self.batch_size)
83            .map(|b| {
84                let residues = self.data[b * k..(b + 1) * k].to_vec();
85                RnsInt::from_residues(residues, self.channels.clone())
86            })
87            .collect()
88    }
89
90    /// Pack for GPU upload: residues as little-endian `u32` bytes.
91    ///
92    /// Moduli are chosen `<= 2^31` so every residue fits in a `u32`. The cast to
93    /// bytes is zero-copy via bytemuck.
94    pub fn as_u32_bytes(&self) -> Vec<u8> {
95        let as_u32: Vec<u32> = self.data.iter().map(|&v| v as u32).collect();
96        bytemuck::cast_slice(&as_u32).to_vec()
97    }
98
99    /// Rebuild a batch from a flat `u32` slice downloaded from the GPU.
100    pub fn from_u32(values: &[u32], batch_size: usize, channels: Channels) -> Self {
101        let data = values.iter().map(|&v| v as u64).collect();
102        RnsBatch {
103            data,
104            batch_size,
105            channels,
106        }
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn get_set_roundtrip() {
116        let ch = Channels::standard(8);
117        let mut batch = RnsBatch::zeros(4, ch);
118        batch.set(2, 3, 7);
119        assert_eq!(batch.get(2, 3), 7);
120        assert_eq!(batch.get(0, 0), 0);
121    }
122
123    #[test]
124    fn pack_unpack_roundtrip() {
125        let ch = Channels::standard(16);
126        let items = vec![
127            RnsInt::from_i64(123, ch.clone()),
128            RnsInt::from_i64(456, ch.clone()),
129            RnsInt::from_i64(789, ch.clone()),
130        ];
131        let batch = RnsBatch::from_rns_ints(&items);
132        assert_eq!(batch.batch_size, 3);
133        let back = batch.to_rns_ints();
134        for (a, b) in items.iter().zip(back.iter()) {
135            assert_eq!(a.to_bigint(), b.to_bigint());
136        }
137    }
138
139    #[test]
140    fn u32_byte_layout() {
141        let ch = Channels::standard(4);
142        let mut batch = RnsBatch::zeros(1, ch);
143        batch.set(0, 0, 1);
144        batch.set(0, 1, 2);
145        let bytes = batch.as_u32_bytes();
146        assert_eq!(bytes.len(), 4 * 4); // 4 channels * 4 bytes
147        assert_eq!(bytes[0], 1);
148        assert_eq!(bytes[4], 2);
149    }
150}