abhedya_kem/
lib.rs

1use rand::Rng;
2use rand::distr::{Distribution, Uniform}; // Fix: rand 0.9 uses `distr` not `distributions`
3use abhedya_chhandas::{self as sanskrit, MatraWeight}; 
4
5// Define Encryption Configuration
6#[derive(Debug, Clone, Copy, PartialEq)]
7pub enum EncryptionMode {
8    Standard, // Raw Output, High Throughput
9    Metered,  // Chhandas Output, High Steganography
10}
11
12// LWE Parameters
13// Dimensions suitable for "toy" post-quantum or PoC
14pub const N: usize = 768; // Dimension
15pub const Q: i16 = 3329;   // Modulus (Standard Kyber-768 param)
16
17// Structures for Keys
18pub struct SecretKey {
19    pub s: Vec<i16>,
20}
21
22pub struct PublicKey {
23    pub A: Vec<Vec<i16>>, // N x N matrix for simplicity (Square LWE)
24    pub b: Vec<i16>,      // b = As + e
25}
26
27pub struct Ciphertext {
28    pub u: Vec<i16>,
29    pub v: Vec<i16>, // The encrypted message part
30}
31
32// Math helpers
33fn dot(v1: &[i16], v2: &[i16]) -> i16 {
34    let sum: i32 = v1.iter().zip(v2.iter())
35        .map(|(a, b)| (*a as i32) * (*b as i32))
36        .sum();
37    (sum.rem_euclid(Q as i32)) as i16
38}
39
40fn add_noise(val: i16, rng: &mut impl Rng) -> i16 {
41    // Simple centralized noise (hamming or gaussian approximation)
42    let e: i16 = rng.random_range(-1..=1); 
43    (val + e).rem_euclid(Q)
44}
45
46impl SecretKey {
47    pub fn new(rng: &mut impl Rng) -> Self {
48        // Ternary secret {-1, 0, 1} is common in LWE
49        // Uniform::new(low, high) -> [low, high)
50        let dist = Uniform::new(-1, 2).unwrap(); 
51        let s: Vec<i16> = (0..N).map(|_| dist.sample(rng) as i16).collect();
52        SecretKey { s }
53    }
54}
55
56impl PublicKey {
57    pub fn new(sk: &SecretKey, rng: &mut impl Rng) -> Self {
58        // Generate uniform A
59        let mut A = vec![vec![0; N]; N];
60        for row in A.iter_mut() {
61            for val in row.iter_mut() {
62                *val = rng.random_range(0..Q);
63            }
64        }
65
66        // Calculate b = As + e
67        let mut b = Vec::with_capacity(N);
68        for row in &A {
69            let dot_prod = dot(row, &sk.s);
70            let noise = add_noise(0, rng);
71            b.push((dot_prod + noise).rem_euclid(Q));
72        }
73
74        PublicKey { A, b }
75    }
76}
77
78// SIMD Implementation Module
79#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
80mod simd {
81    // Rust 2024 requires explicit unsafe blocks even in unsafe fns
82    #![allow(unsafe_op_in_unsafe_fn)] 
83    
84    use super::*;
85    use std::arch::x86_64::*;
86
87    pub unsafe fn dot_avx2(v1: &[i16], v2: &[i16]) -> i16 {
88        let mut sum_vec = _mm256_setzero_si256();
89        
90        for i in (0..N).step_by(16) {
91             let a = _mm256_loadu_si256(v1.as_ptr().add(i) as *const __m256i);
92             let b = _mm256_loadu_si256(v2.as_ptr().add(i) as *const __m256i);
93             let prod = _mm256_madd_epi16(a, b);
94             sum_vec = _mm256_add_epi32(sum_vec, prod);
95        }
96        
97        let mut temp = [0i32; 8];
98        _mm256_storeu_si256(temp.as_mut_ptr() as *mut __m256i, sum_vec);
99        
100        let total_sum: i32 = temp.iter().sum();
101        (total_sum.rem_euclid(Q as i32)) as i16
102    }
103    
104    pub unsafe fn acc_row_avx2(row: &[i16], r_val: i16, accum: &mut [i32]) {
105        let r_vec = _mm256_set1_epi16(r_val);
106        
107        for k in (0..N).step_by(16) {
108            let a_chunk = _mm256_loadu_si256(row.as_ptr().add(k) as *const __m256i);
109            let prod = _mm256_mullo_epi16(a_chunk, r_vec);
110            
111            let prod_lo = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(prod));
112            let prod_hi = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(prod, 1));
113            
114            let acc_ptr = accum.as_mut_ptr().add(k);
115            let mut acc_lo = _mm256_loadu_si256(acc_ptr as *const __m256i);
116            let mut acc_hi = _mm256_loadu_si256(acc_ptr.add(8) as *const __m256i);
117            
118            acc_lo = _mm256_add_epi32(acc_lo, prod_lo);
119            acc_hi = _mm256_add_epi32(acc_hi, prod_hi);
120            
121            _mm256_storeu_si256(acc_ptr as *mut __m256i, acc_lo);
122            _mm256_storeu_si256(acc_ptr.add(8) as *mut __m256i, acc_hi);
123        }
124    }
125}
126
127// Encryption: Encrypt a vector of messages m
128pub fn encrypt(pk: &PublicKey, message: &[i16], rng: &mut impl Rng, mode: EncryptionMode) -> (Vec<Vec<i16>>, Vec<i16>) {
129    let mut u_vecs = Vec::new();
130    let mut v_vec = Vec::new();
131
132    // Standard Anushtubh: 4 padas of 8 syllables.
133    // Pattern: 5th is light, 6th is heavy.
134    // For PoC: Let's use alternating Short-Long (S-L-S-L-S-L-S-L) for max entropy testing.
135    // S=Laghu, L=Guru
136    
137    // Check for AVX2 support at runtime
138    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
139    let use_avx2 = std::is_x86_feature_detected!("avx2");
140    #[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
141    let use_avx2 = false;
142
143    // Constant-Time Logic Helper
144    // Returns 1 if condition is true, 0 if false (using bitwise ops for primitive types would be better in C, here we use boolean logic)
145    // NOTE: In Rust, minimizing branches is key. 
146    
147    for (m_idx, &m) in message.iter().enumerate() {
148        let r: Vec<i16> = (0..N).map(|_| rng.random_range(-1..=1)).collect();
149        
150
151        let mut u_accum = vec![0i32; N]; // Use i32 to prevent overflow during accumulation
152        
153        for j in 0..N {
154            let r_val = r[j] as i32;
155            if r_val == 0 { continue; } // Optimization for sparse ternary r
156            
157            // SAXPY: u += r_val * A[j]
158            // We can SIMD this line easily?
159            // Actually, A[j] is Vec<i16>.
160            
161            if use_avx2 {
162                #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
163                unsafe {
164                     // Call the helper in the simd module
165                     simd::acc_row_avx2(&pk.A[j], r_val as i16, &mut u_accum);
166                }
167                #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
168                {
169                   // Fallback (detected but not enabled at compile time?)
170                   for k in 0..N {
171                       u_accum[k] += (pk.A[j][k] as i32) * r_val;
172                   }
173                }
174            } else {
175                for k in 0..N {
176                    u_accum[k] += (pk.A[j][k] as i32) * r_val;
177                }
178            }
179        }
180        
181        // Finalize u (Mod Q and add Noise)
182        let mut u: Vec<i16> = Vec::with_capacity(N);
183        for i in 0..N {
184            let col_dot = u_accum[i].rem_euclid(Q as i32) as i16;
185            
186            // Initial Noise
187            let mut noise = add_noise(0, rng);
188            let mut final_noise = noise;
189            
190            if mode == EncryptionMode::Metered {
191                // Address Timing Oracle: Fixed Iterations
192                // We ALWAYS iterate through the fixed window, regardless of when we find a match.
193                // We use a "found" flag to latch onto the first valid value.
194                
195                let target = if i % 2 == 0 { MatraWeight::Laghu } else { MatraWeight::Guru };
196                let mut found = false;
197
198                // Candidate offsets to try: -128 to +128 (257 iterations)
199                // This covers the entire "Modulus Gap" (size 161) ensuring we can always escape the forbidden range.
200                for offset in -128..=128 {
201                   let candidate_noise = noise + offset;
202                   let val = (col_dot + candidate_noise).rem_euclid(Q) as u16;
203                   
204                   // Check 1: Meter Constraints
205                   let weight_match = sanskrit::get_matra_weight((val % 16) as usize) == target;
206                   
207                   // Check 2: Modulus Truncation (Bias Elimination)
208                   // Reject values >= 3168 to ensure perfect u16->Sanskrit mapping
209                   let range_valid = val < 3168; 
210                   
211                   // Combined Check
212                   let valid = weight_match && range_valid;
213                   
214                   // Constant-Time Selection
215                   if valid && !found {
216                       final_noise = candidate_noise;
217                       found = true;
218                   }
219                }
220                
221                // If not found, we technically fail.
222                // In this hardened version, we fallback to 'noise' (which might be >= 3168 or wrong meter).
223                // But with 81 tries, probability is effectively zero.
224            } else {
225                final_noise = noise;
226            }
227            
228            u.push((col_dot + final_noise).rem_euclid(Q));
229        }
230        
231        // v calculation
232        let b_dot = dot(&pk.b, &r);
233        let mut noise = add_noise(0, rng);
234        let scaled_m = if m == 1 { Q/2 } else { 0 };
235        
236        let v_val = (b_dot + noise + scaled_m).rem_euclid(Q);
237        
238        u_vecs.push(u);
239        v_vec.push(v_val);
240    }
241    
242    (u_vecs, v_vec)
243}
244
245pub fn decrypt(sk: &SecretKey, u_vecs: &[Vec<i16>], v_vec: &[i16]) -> Vec<i16> {
246    let mut messages = Vec::new();
247    for (u, &v) in u_vecs.iter().zip(v_vec.iter()) {
248        // m_noisy = v - s^T u
249        let s_dot_u = dot(&sk.s, u);
250        let diff = (v - s_dot_u).rem_euclid(Q);
251        
252        // If close to Q/2 -> 1, if close to 0 -> 0
253        let m = if diff > Q/4 && diff < 3*Q/4 { 1 } else { 0 };
254        messages.push(m);
255    }
256    messages
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use abhedya_chhandas as sanskrit;
263
264    #[test]
265    fn test_encryption_correctness() {
266        // rand 0.9: use rng() for standard generator
267        let mut rng = rand::rng(); 
268        let sk = SecretKey::new(&mut rng);
269        let pk = PublicKey::new(&sk, &mut rng);
270        
271        let msg = vec![0, 1, 0, 1, 1, 0];
272        let (u, v) = encrypt(&pk, &msg, &mut rng, EncryptionMode::Standard);
273        let decrypted = decrypt(&sk, &u, &v);
274        
275        assert_eq!(msg, decrypted);
276    }
277
278    #[test]
279    fn test_metered_mode_constraints() {
280        let mut rng = rand::rng();
281        let sk = SecretKey::new(&mut rng);
282        let pk = PublicKey::new(&sk, &mut rng);
283        
284        // Encrypt meaningful amount of data to hit constraints
285        let msg = vec![1; 10]; 
286        let (u_vecs, _) = encrypt(&pk, &msg, &mut rng, EncryptionMode::Metered);
287        
288        for u in u_vecs {
289            for (i, val) in u.iter().enumerate() {
290                // Check 1: Modulus Truncation
291                assert!(*val < 3168, "Value {} exceeds truncated modulus 3168", val);
292                
293                // Check 2: Meter Pattern (Even=Laghu, Odd=Guru)
294                let target = if i % 2 == 0 { MatraWeight::Laghu } else { MatraWeight::Guru };
295                let current = sanskrit::get_matra_weight((val % 16) as usize);
296                
297                assert_eq!(current, target, "Meter mismatch at index {}", i);
298            }
299        }
300    }
301}