1use rand::Rng;
4
5use crate::gadget::{bit_decomp, flatten_matrix, powers_of_2};
6use crate::lwe::{keygen, PublicKey, SecretKey};
7use crate::modular::mod_q;
8use crate::params::Params;
9
10pub type Ciphertext = Vec<Vec<u64>>;
12
13pub type GswSecretKey = SecretKey;
15
16pub type GswPublicKey = PublicKey;
18
19pub fn gsw_keygen<R: Rng>(rng: &mut R, params: &Params) -> (GswSecretKey, GswPublicKey) {
21 keygen(rng, params)
22}
23
24pub fn encrypt<R: Rng>(rng: &mut R, pk: &GswPublicKey, bit: u8) -> Ciphertext {
29 let params = pk.params();
30 let n_expanded = params.n_expanded;
31 let m = params.m;
32 let q = params.q;
33
34 let r: Vec<Vec<u64>> = (0..n_expanded)
36 .map(|_| (0..m).map(|_| rng.gen_range(0..=1) as u64).collect())
37 .collect();
38
39 let mut ra = vec![vec![0u64; params.n + 1]; n_expanded];
41 for i in 0..n_expanded {
42 for j in 0..(params.n + 1) {
43 let mut sum: i64 = 0;
44 for k in 0..m {
45 sum += (r[i][k] as i64) * (pk.a[k][j] as i64);
46 }
47 ra[i][j] = mod_q(sum, q);
48 }
49 }
50
51 let bit_decomp_ra: Vec<Vec<u64>> = ra
53 .iter()
54 .map(|row| bit_decomp(row, params))
55 .collect();
56
57 let mut sum = bit_decomp_ra;
59 for i in 0..n_expanded {
60 sum[i][i] = mod_q((sum[i][i] as i64) + (bit as i64), q);
61 }
62
63 flatten_matrix(&sum, params)
65}
66
67pub fn decrypt(sk: &GswSecretKey, ct: &Ciphertext) -> u8 {
71 let params = sk.params();
72 let q = params.q;
73 let l = params.l;
74 let n_expanded = params.n_expanded;
75
76 let v = powers_of_2(&sk.s, params);
77 let row_idx = l - 1;
78
79 let mut dot: i64 = 0;
80 for j in 0..n_expanded {
81 dot += (ct[row_idx][j] as i64) * (v[j] as i64);
82 }
83 let val = mod_q(dot, q) as i64;
84
85 let scale = v[l - 1] as i64;
86 if scale == 0 {
87 return 0;
88 }
89
90 let msg = ((val as f64) / (scale as f64)).round() as i64;
91 (msg.rem_euclid(2)).abs() as u8
92}
93
94pub fn homomorphic_add(params: &Params, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
96 let q = params.q;
97 let n_expanded = params.n_expanded;
98 let mut sum = vec![vec![0u64; n_expanded]; n_expanded];
99 for i in 0..n_expanded {
100 for j in 0..n_expanded {
101 sum[i][j] = mod_q(
102 (ct1[i][j] as i64) + (ct2[i][j] as i64),
103 q,
104 );
105 }
106 }
107 flatten_matrix(&sum, params)
108}
109
110pub fn homomorphic_mult(params: &Params, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
114 let q = params.q;
115 let n_expanded = params.n_expanded;
116
117 let mut prod = vec![vec![0u64; n_expanded]; n_expanded];
118 for i in 0..n_expanded {
119 for j in 0..n_expanded {
120 let mut sum: i64 = 0;
121 for k in 0..n_expanded {
122 sum += (ct1[i][k] as i64) * (ct2[k][j] as i64);
123 }
124 prod[i][j] = mod_q(sum, q);
125 }
126 }
127 flatten_matrix(&prod, params)
128}
129
130pub fn homomorphic_nand(params: &Params, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
132 let q = params.q;
133 let n_expanded = params.n_expanded;
134
135 let mut prod = vec![vec![0u64; n_expanded]; n_expanded];
136 for i in 0..n_expanded {
137 for j in 0..n_expanded {
138 let mut sum: i64 = 0;
139 for k in 0..n_expanded {
140 sum += (ct1[i][k] as i64) * (ct2[k][j] as i64);
141 }
142 prod[i][j] = mod_q(sum, q);
143 }
144 }
145
146 let mut result = vec![vec![0u64; n_expanded]; n_expanded];
147 for i in 0..n_expanded {
148 for j in 0..n_expanded {
149 let val = if i == j {
150 mod_q(1 - (prod[i][j] as i64), q)
151 } else {
152 mod_q(-(prod[i][j] as i64), q)
153 };
154 result[i][j] = val;
155 }
156 }
157 flatten_matrix(&result, params)
158}