1extern crate group_math as group;
7
8#[cfg(feature = "prg")]
9pub mod prg;
10
11use bitvec::prelude::*;
12pub use fss_types::PointFn;
13use fss_types::{decl_prg_trait, Cw, Share};
14use group::byte::utils::{xor, xor_inplace};
15pub use group::Group;
16#[cfg(feature = "multithread")]
17use rayon::prelude::*;
18
19pub trait Dpf<const N: usize, const LAMBDA: usize, G>
25where
26 G: Group<LAMBDA>,
27{
28 fn gen(&self, f: &PointFn<N, LAMBDA, G>, s0s: [&[u8; LAMBDA]; 2]) -> Share<LAMBDA, G>;
30
31 fn eval(&self, b: bool, k: &Share<LAMBDA, G>, xs: &[&[u8; N]], ys: &mut [&mut G]);
33}
34
35decl_prg_trait!(([u8; LAMBDA], bool));
36
37pub struct DpfImpl<const N: usize, const LAMBDA: usize, PrgT>
41where
42 PrgT: Prg<LAMBDA>,
43{
44 prg: PrgT,
45}
46
47impl<const N: usize, const LAMBDA: usize, PrgT> DpfImpl<N, LAMBDA, PrgT>
48where
49 PrgT: Prg<LAMBDA>,
50{
51 pub fn new(prg: PrgT) -> Self {
52 Self { prg }
53 }
54}
55
56const IDX_L: usize = 0;
57const IDX_R: usize = 1;
58
59impl<const N: usize, const LAMBDA: usize, PrgT, G> Dpf<N, LAMBDA, G> for DpfImpl<N, LAMBDA, PrgT>
60where
61 PrgT: Prg<LAMBDA>,
62 G: Group<LAMBDA>,
63{
64 fn gen(&self, f: &PointFn<N, LAMBDA, G>, s0s: [&[u8; LAMBDA]; 2]) -> Share<LAMBDA, G> {
65 let n = 8 * N;
67 let mut ss = Vec::<[[u8; LAMBDA]; 2]>::with_capacity(n + 1);
69 ss.push([s0s[0].to_owned(), s0s[1].to_owned()]);
71 let mut ts = Vec::<[bool; 2]>::with_capacity(n + 1);
72 ts.push([false, true]);
74 let mut cws = Vec::<Cw<LAMBDA, G>>::with_capacity(n);
75 for i in 1..n + 1 {
76 let [(s0l, t0l), (s0r, t0r)] = self.prg.gen(&ss[i - 1][0]);
77 let [(s1l, t1l), (s1r, t1r)] = self.prg.gen(&ss[i - 1][1]);
78 let alpha_i = f.alpha.view_bits::<Msb0>()[i - 1];
80 let (keep, lose) = if alpha_i {
81 (IDX_R, IDX_L)
82 } else {
83 (IDX_L, IDX_R)
84 };
85 let s_cw = xor(&[[&s0l, &s0r][lose], [&s1l, &s1r][lose]]);
86 let tl_cw = t0l ^ t1l ^ alpha_i ^ true;
87 let tr_cw = t0r ^ t1r ^ alpha_i;
88 let cw = Cw {
89 s: s_cw,
90 v: G::zero(),
91 tl: tl_cw,
92 tr: tr_cw,
93 };
94 cws.push(cw);
95 ss.push([
96 xor(&[
97 [&s0l, &s0r][keep],
98 if ts[i - 1][0] { &s_cw } else { &[0; LAMBDA] },
99 ]),
100 xor(&[
101 [&s1l, &s1r][keep],
102 if ts[i - 1][1] { &s_cw } else { &[0; LAMBDA] },
103 ]),
104 ]);
105 ts.push([
106 [t0l, t0r][keep] ^ (ts[i - 1][0] & [tl_cw, tr_cw][keep]),
107 [t1l, t1r][keep] ^ (ts[i - 1][1] & [tl_cw, tr_cw][keep]),
108 ]);
109 }
110 assert_eq!((ss.len(), ts.len(), cws.len()), (n + 1, n + 1, n));
111 let cw_np1 = (f.beta.clone() + Into::<G>::into(ss[n][0]).add_inverse() + ss[n][1].into())
112 .add_inverse_if(ts[n][1]);
113 Share {
114 s0s: vec![s0s[0].to_owned(), s0s[1].to_owned()],
115 cws,
116 cw_np1,
117 }
118 }
119
120 fn eval(&self, b: bool, k: &Share<LAMBDA, G>, xs: &[&[u8; N]], ys: &mut [&mut G]) {
121 let n = k.cws.len();
122 assert_eq!(n, N * 8);
123 let f = |x: &[u8; N], v: &mut G| {
124 let mut ss = Vec::<[u8; LAMBDA]>::with_capacity(n + 1);
125 ss.push(k.s0s[0].to_owned());
126 let mut ts = Vec::<bool>::with_capacity(n + 1);
127 ts.push(b);
128 for i in 1..n + 1 {
129 let cw = &k.cws[i - 1];
130 let [(mut sl, mut tl), (mut sr, mut tr)] = self.prg.gen(&ss[i - 1]);
131 xor_inplace(&mut sl, &[if ts[i - 1] { &cw.s } else { &[0; LAMBDA] }]);
132 xor_inplace(&mut sr, &[if ts[i - 1] { &cw.s } else { &[0; LAMBDA] }]);
133 tl ^= ts[i - 1] & cw.tl;
134 tr ^= ts[i - 1] & cw.tr;
135 if x.view_bits::<Msb0>()[i - 1] {
136 ss.push(sr);
137 ts.push(tr);
138 } else {
139 ss.push(sl);
140 ts.push(tl);
141 }
142 }
143 assert_eq!((ss.len(), ts.len()), (n + 1, n + 1));
144 *v = (Into::<G>::into(ss[n]) + if ts[n] { k.cw_np1.clone() } else { G::zero() })
145 .add_inverse_if(b);
146 };
147 #[cfg(feature = "multithread")]
148 {
149 xs.par_iter()
150 .zip(ys.par_iter_mut())
151 .for_each(|(x, y)| f(x, y));
152 }
153 #[cfg(not(feature = "multithread"))]
154 {
155 xs.iter().zip(ys.iter_mut()).for_each(|(x, y)| f(x, y));
156 }
157 }
158}
159
160#[cfg(all(test, feature = "prg"))]
161mod tests {
162 use super::*;
163
164 use group::byte::ByteGroup;
165 use rand::{thread_rng, Rng};
166
167 use crate::prg::Aes256HirosePrg;
168
169 const KEYS: [&[u8; 32]; 2] = [
170 b"j9\x1b_\xb3X\xf33\xacW\x15\x1b\x0812K\xb3I\xb9\x90r\x1cN\xb5\xee9W\xd3\xbb@\xc6d",
171 b"\x9b\x15\xc8\x0f\xb7\xbc!q\x9e\x89\xb8\xf7\x0e\xa0S\x9dN\xfa\x0c;\x16\xe4\x98\x82b\xfcdy\xb5\x8c{\xc2",
172 ];
173 const ALPHAS: &[&[u8; 16]] = &[
174 b"K\xa9W\xf5\xdd\x05\xe9\xfc?\x04\xf6\xfbUo\xa8C",
175 b"\xc2GK\xda\xc6\xbb\x99\x98Fq\"f\xb7\x8csU",
176 b"\xc2GK\xda\xc6\xbb\x99\x98Fq\"f\xb7\x8csV",
177 b"\xc2GK\xda\xc6\xbb\x99\x98Fq\"f\xb7\x8csW",
178 b"\xef\x96\x97\xd7\x8f\x8a\xa4AP\n\xb35\xb5k\xff\x97",
179 ];
180 const BETA: &[u8; 16] = b"\x03\x11\x97\x12C\x8a\xe9#\x81\xa8\xde\xa8\x8f \xc0\xbb";
181
182 #[test]
183 fn test_dpf_gen_then_eval_ok() {
184 let prg = Aes256HirosePrg::new(KEYS);
185 let dpf = DpfImpl::<16, 16, _>::new(prg);
186 let s0s: [[u8; 16]; 2] = thread_rng().gen();
187 let f = PointFn {
188 alpha: ALPHAS[2].to_owned(),
189 beta: BETA.clone().into(),
190 };
191 let k = dpf.gen(&f, [&s0s[0], &s0s[1]]);
192 let mut k0 = k.clone();
193 k0.s0s = vec![k0.s0s[0]];
194 let mut k1 = k.clone();
195 k1.s0s = vec![k1.s0s[1]];
196 let mut ys0 = vec![ByteGroup::zero(); ALPHAS.len()];
197 let mut ys1 = vec![ByteGroup::zero(); ALPHAS.len()];
198 dpf.eval(false, &k0, ALPHAS, &mut ys0.iter_mut().collect::<Vec<_>>());
199 dpf.eval(true, &k1, ALPHAS, &mut ys1.iter_mut().collect::<Vec<_>>());
200 ys0.iter_mut()
201 .zip(ys1.iter())
202 .for_each(|(y0, y1)| *y0 += y1.clone());
203 ys1 = vec![
204 ByteGroup::zero(),
205 ByteGroup::zero(),
206 BETA.clone().into(),
207 ByteGroup::zero(),
208 ByteGroup::zero(),
209 ];
210 assert_eq!(ys0, ys1);
211 }
212
213 #[test]
214 fn test_dpf_gen_then_eval_not_zeros() {
215 let prg = Aes256HirosePrg::new(KEYS);
216 let dpf = DpfImpl::<16, 16, _>::new(prg);
217 let s0s: [[u8; 16]; 2] = thread_rng().gen();
218 let f = PointFn {
219 alpha: ALPHAS[2].to_owned(),
220 beta: BETA.clone().into(),
221 };
222 let k = dpf.gen(&f, [&s0s[0], &s0s[1]]);
223 let mut k0 = k.clone();
224 k0.s0s = vec![k0.s0s[0]];
225 let mut k1 = k.clone();
226 k1.s0s = vec![k1.s0s[1]];
227 let mut ys0 = vec![ByteGroup::zero(); ALPHAS.len()];
228 let mut ys1 = vec![ByteGroup::zero(); ALPHAS.len()];
229 dpf.eval(false, &k0, ALPHAS, &mut ys0.iter_mut().collect::<Vec<_>>());
230 dpf.eval(true, &k1, ALPHAS, &mut ys1.iter_mut().collect::<Vec<_>>());
231 assert_ne!(ys0[2], ByteGroup::zero());
232 assert_ne!(ys1[2], ByteGroup::zero());
233 }
234}