1use rand::Rng;
2use rand::distr::{Distribution, Uniform}; use abhedya_chhandas::{self as sanskrit, MatraWeight};
4
5#[derive(Debug, Clone, Copy, PartialEq)]
7pub enum EncryptionMode {
8 Standard, Metered, }
11
12pub const N: usize = 768; pub const Q: i16 = 3329; pub struct SecretKey {
19 pub s: Vec<i16>,
20}
21
22pub struct PublicKey {
23 pub A: Vec<Vec<i16>>, pub b: Vec<i16>, }
26
27pub struct Ciphertext {
28 pub u: Vec<i16>,
29 pub v: Vec<i16>, }
31
32fn dot(v1: &[i16], v2: &[i16]) -> i16 {
34 let sum: i32 = v1.iter().zip(v2.iter())
35 .map(|(a, b)| (*a as i32) * (*b as i32))
36 .sum();
37 (sum.rem_euclid(Q as i32)) as i16
38}
39
40fn add_noise(val: i16, rng: &mut impl Rng) -> i16 {
41 let e: i16 = rng.random_range(-1..=1);
43 (val + e).rem_euclid(Q)
44}
45
46impl SecretKey {
47 pub fn new(rng: &mut impl Rng) -> Self {
48 let dist = Uniform::new(-1, 2).unwrap();
51 let s: Vec<i16> = (0..N).map(|_| dist.sample(rng) as i16).collect();
52 SecretKey { s }
53 }
54}
55
56impl PublicKey {
57 pub fn new(sk: &SecretKey, rng: &mut impl Rng) -> Self {
58 let mut A = vec![vec![0; N]; N];
60 for row in A.iter_mut() {
61 for val in row.iter_mut() {
62 *val = rng.random_range(0..Q);
63 }
64 }
65
66 let mut b = Vec::with_capacity(N);
68 for row in &A {
69 let dot_prod = dot(row, &sk.s);
70 let noise = add_noise(0, rng);
71 b.push((dot_prod + noise).rem_euclid(Q));
72 }
73
74 PublicKey { A, b }
75 }
76}
77
78#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
80mod simd {
81 #![allow(unsafe_op_in_unsafe_fn)]
83
84 use super::*;
85 use std::arch::x86_64::*;
86
87 pub unsafe fn dot_avx2(v1: &[i16], v2: &[i16]) -> i16 {
88 let mut sum_vec = _mm256_setzero_si256();
89
90 for i in (0..N).step_by(16) {
91 let a = _mm256_loadu_si256(v1.as_ptr().add(i) as *const __m256i);
92 let b = _mm256_loadu_si256(v2.as_ptr().add(i) as *const __m256i);
93 let prod = _mm256_madd_epi16(a, b);
94 sum_vec = _mm256_add_epi32(sum_vec, prod);
95 }
96
97 let mut temp = [0i32; 8];
98 _mm256_storeu_si256(temp.as_mut_ptr() as *mut __m256i, sum_vec);
99
100 let total_sum: i32 = temp.iter().sum();
101 (total_sum.rem_euclid(Q as i32)) as i16
102 }
103
104 pub unsafe fn acc_row_avx2(row: &[i16], r_val: i16, accum: &mut [i32]) {
105 let r_vec = _mm256_set1_epi16(r_val);
106
107 for k in (0..N).step_by(16) {
108 let a_chunk = _mm256_loadu_si256(row.as_ptr().add(k) as *const __m256i);
109 let prod = _mm256_mullo_epi16(a_chunk, r_vec);
110
111 let prod_lo = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(prod));
112 let prod_hi = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(prod, 1));
113
114 let acc_ptr = accum.as_mut_ptr().add(k);
115 let mut acc_lo = _mm256_loadu_si256(acc_ptr as *const __m256i);
116 let mut acc_hi = _mm256_loadu_si256(acc_ptr.add(8) as *const __m256i);
117
118 acc_lo = _mm256_add_epi32(acc_lo, prod_lo);
119 acc_hi = _mm256_add_epi32(acc_hi, prod_hi);
120
121 _mm256_storeu_si256(acc_ptr as *mut __m256i, acc_lo);
122 _mm256_storeu_si256(acc_ptr.add(8) as *mut __m256i, acc_hi);
123 }
124 }
125}
126
127pub fn encrypt(pk: &PublicKey, message: &[i16], rng: &mut impl Rng, mode: EncryptionMode) -> (Vec<Vec<i16>>, Vec<i16>) {
129 let mut u_vecs = Vec::new();
130 let mut v_vec = Vec::new();
131
132 #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
139 let use_avx2 = std::is_x86_feature_detected!("avx2");
140 #[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
141 let use_avx2 = false;
142
143 for (m_idx, &m) in message.iter().enumerate() {
148 let r: Vec<i16> = (0..N).map(|_| rng.random_range(-1..=1)).collect();
149
150
151 let mut u_accum = vec![0i32; N]; for j in 0..N {
154 let r_val = r[j] as i32;
155 if r_val == 0 { continue; } if use_avx2 {
162 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
163 unsafe {
164 simd::acc_row_avx2(&pk.A[j], r_val as i16, &mut u_accum);
166 }
167 #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
168 {
169 for k in 0..N {
171 u_accum[k] += (pk.A[j][k] as i32) * r_val;
172 }
173 }
174 } else {
175 for k in 0..N {
176 u_accum[k] += (pk.A[j][k] as i32) * r_val;
177 }
178 }
179 }
180
181 let mut u: Vec<i16> = Vec::with_capacity(N);
183 for i in 0..N {
184 let col_dot = u_accum[i].rem_euclid(Q as i32) as i16;
185
186 let mut noise = add_noise(0, rng);
188 let mut final_noise = noise;
189
190 if mode == EncryptionMode::Metered {
191 let target = if i % 2 == 0 { MatraWeight::Laghu } else { MatraWeight::Guru };
196 let mut found = false;
197
198 for offset in -128..=128 {
201 let candidate_noise = noise + offset;
202 let val = (col_dot + candidate_noise).rem_euclid(Q) as u16;
203
204 let weight_match = sanskrit::get_matra_weight((val % 16) as usize) == target;
206
207 let range_valid = val < 3168;
210
211 let valid = weight_match && range_valid;
213
214 if valid && !found {
216 final_noise = candidate_noise;
217 found = true;
218 }
219 }
220
221 } else {
225 final_noise = noise;
226 }
227
228 u.push((col_dot + final_noise).rem_euclid(Q));
229 }
230
231 let b_dot = dot(&pk.b, &r);
233 let mut noise = add_noise(0, rng);
234 let scaled_m = if m == 1 { Q/2 } else { 0 };
235
236 let v_val = (b_dot + noise + scaled_m).rem_euclid(Q);
237
238 u_vecs.push(u);
239 v_vec.push(v_val);
240 }
241
242 (u_vecs, v_vec)
243}
244
245pub fn decrypt(sk: &SecretKey, u_vecs: &[Vec<i16>], v_vec: &[i16]) -> Vec<i16> {
246 let mut messages = Vec::new();
247 for (u, &v) in u_vecs.iter().zip(v_vec.iter()) {
248 let s_dot_u = dot(&sk.s, u);
250 let diff = (v - s_dot_u).rem_euclid(Q);
251
252 let m = if diff > Q/4 && diff < 3*Q/4 { 1 } else { 0 };
254 messages.push(m);
255 }
256 messages
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use abhedya_chhandas as sanskrit;
263
264 #[test]
265 fn test_encryption_correctness() {
266 let mut rng = rand::rng();
268 let sk = SecretKey::new(&mut rng);
269 let pk = PublicKey::new(&sk, &mut rng);
270
271 let msg = vec![0, 1, 0, 1, 1, 0];
272 let (u, v) = encrypt(&pk, &msg, &mut rng, EncryptionMode::Standard);
273 let decrypted = decrypt(&sk, &u, &v);
274
275 assert_eq!(msg, decrypted);
276 }
277
278 #[test]
279 fn test_metered_mode_constraints() {
280 let mut rng = rand::rng();
281 let sk = SecretKey::new(&mut rng);
282 let pk = PublicKey::new(&sk, &mut rng);
283
284 let msg = vec![1; 10];
286 let (u_vecs, _) = encrypt(&pk, &msg, &mut rng, EncryptionMode::Metered);
287
288 for u in u_vecs {
289 for (i, val) in u.iter().enumerate() {
290 assert!(*val < 3168, "Value {} exceeds truncated modulus 3168", val);
292
293 let target = if i % 2 == 0 { MatraWeight::Laghu } else { MatraWeight::Guru };
295 let current = sanskrit::get_matra_weight((val % 16) as usize);
296
297 assert_eq!(current, target, "Meter mismatch at index {}", i);
298 }
299 }
300 }
301}