1use bitvec::prelude::*;
7#[cfg(feature = "multi-thread")]
8use rayon::prelude::*;
9
10use crate::group::Group;
11use crate::utils::{xor, xor_inplace};
12pub use crate::PointFn;
13use crate::{Cw, Prg, Share};
14
15pub trait Dpf<const IN_BLEN: usize, const OUT_BLEN: usize, G>
22where
23 G: Group<OUT_BLEN>,
24{
25 fn gen(
27 &self,
28 f: &PointFn<IN_BLEN, OUT_BLEN, G>,
29 s0s: [&[u8; OUT_BLEN]; 2],
30 ) -> Share<OUT_BLEN, G>;
31
32 fn eval(&self, b: bool, k: &Share<OUT_BLEN, G>, xs: &[&[u8; IN_BLEN]], ys: &mut [&mut G]);
34
35 fn full_eval(&self, b: bool, k: &Share<OUT_BLEN, G>, ys: &mut [&mut G]);
39}
40
41pub struct DpfImpl<const IN_BLEN: usize, const OUT_BLEN: usize, P>
43where
44 P: Prg<OUT_BLEN, 1>,
45{
46 prg: P,
47 filter_bitn: usize,
48}
49
50impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DpfImpl<IN_BLEN, OUT_BLEN, P>
51where
52 P: Prg<OUT_BLEN, 1>,
53{
54 pub fn new(prg: P) -> Self {
55 Self {
56 prg,
57 filter_bitn: 8 * IN_BLEN,
58 }
59 }
60
61 pub fn new_with_filter(prg: P, filter_bitn: usize) -> Self {
62 assert!(filter_bitn <= 8 * IN_BLEN && filter_bitn > 1);
63 Self { prg, filter_bitn }
64 }
65}
66
67const IDX_L: usize = 0;
68const IDX_R: usize = 1;
69
70impl<const IN_BLEN: usize, const OUT_BLEN: usize, P, G> Dpf<IN_BLEN, OUT_BLEN, G>
71 for DpfImpl<IN_BLEN, OUT_BLEN, P>
72where
73 P: Prg<OUT_BLEN, 1>,
74 G: Group<OUT_BLEN>,
75{
76 fn gen(
77 &self,
78 f: &PointFn<IN_BLEN, OUT_BLEN, G>,
79 s0s: [&[u8; OUT_BLEN]; 2],
80 ) -> Share<OUT_BLEN, G> {
81 let n = self.filter_bitn;
83 let mut ss_prev = [s0s[0].to_owned(), s0s[1].to_owned()];
85 let mut ts_prev = [false, true];
87 let mut cws = Vec::<Cw<OUT_BLEN, G>>::with_capacity(n);
88 for i in 0..n {
89 let alpha_i = f.alpha.view_bits::<Msb0>()[i];
91 let [([s0l], t0l), ([s0r], t0r)] = self.prg.gen(&ss_prev[0]);
92 let [([s1l], t1l), ([s1r], t1r)] = self.prg.gen(&ss_prev[1]);
93 let (keep, lose) = if alpha_i {
94 (IDX_R, IDX_L)
95 } else {
96 (IDX_L, IDX_R)
97 };
98 let s_cw = xor(&[[&s0l, &s0r][lose], [&s1l, &s1r][lose]]);
99 let tl_cw = t0l ^ t1l ^ alpha_i ^ true;
100 let tr_cw = t0r ^ t1r ^ alpha_i;
101 let cw = Cw {
102 s: s_cw,
103 v: G::zero(),
104 tl: tl_cw,
105 tr: tr_cw,
106 };
107 cws.push(cw);
108 ss_prev = [
109 xor(&[
110 [&s0l, &s0r][keep],
111 if ts_prev[0] { &s_cw } else { &[0; OUT_BLEN] },
112 ]),
113 xor(&[
114 [&s1l, &s1r][keep],
115 if ts_prev[1] { &s_cw } else { &[0; OUT_BLEN] },
116 ]),
117 ];
118 ts_prev = [
119 [t0l, t0r][keep] ^ (ts_prev[0] & [tl_cw, tr_cw][keep]),
120 [t1l, t1r][keep] ^ (ts_prev[1] & [tl_cw, tr_cw][keep]),
121 ];
122 }
123 let cw_np1 =
124 (f.beta.clone() + -Into::<G>::into(ss_prev[0]) + ss_prev[1].into()).neg_if(ts_prev[1]);
125 Share {
126 s0s: vec![s0s[0].to_owned(), s0s[1].to_owned()],
127 cws,
128 cw_np1,
129 }
130 }
131
132 fn eval(&self, b: bool, k: &Share<OUT_BLEN, G>, xs: &[&[u8; IN_BLEN]], ys: &mut [&mut G]) {
133 #[cfg(feature = "multi-thread")]
134 self.eval_mt(b, k, xs, ys);
135 #[cfg(not(feature = "multi-thread"))]
136 self.eval_st(b, k, xs, ys);
137 }
138
139 fn full_eval(&self, b: bool, k: &Share<OUT_BLEN, G>, ys: &mut [&mut G]) {
140 let n = k.cws.len();
141 assert_eq!(n, self.filter_bitn);
142
143 let s = k.s0s[0];
144 let t = b;
145 self.full_eval_layer(b, k, ys, 0, (s, t));
146 }
147}
148
149impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DpfImpl<IN_BLEN, OUT_BLEN, P>
150where
151 P: Prg<OUT_BLEN, 1>,
152{
153 pub fn eval_st<G>(
156 &self,
157 b: bool,
158 k: &Share<OUT_BLEN, G>,
159 xs: &[&[u8; IN_BLEN]],
160 ys: &mut [&mut G],
161 ) where
162 G: Group<OUT_BLEN>,
163 {
164 xs.iter()
165 .zip(ys.iter_mut())
166 .for_each(|(x, y)| self.eval_point(b, k, x, y));
167 }
168
169 #[cfg(feature = "multi-thread")]
170 pub fn eval_mt<G>(
173 &self,
174 b: bool,
175 k: &Share<OUT_BLEN, G>,
176 xs: &[&[u8; IN_BLEN]],
177 ys: &mut [&mut G],
178 ) where
179 G: Group<OUT_BLEN>,
180 {
181 xs.par_iter()
182 .zip(ys.par_iter_mut())
183 .for_each(|(x, y)| self.eval_point(b, k, x, y));
184 }
185
186 fn full_eval_layer<G>(
187 &self,
188 b: bool,
189 k: &Share<OUT_BLEN, G>,
190 ys: &mut [&mut G],
191 layer_i: usize,
192 (s, t): ([u8; OUT_BLEN], bool),
193 ) where
194 G: Group<OUT_BLEN>,
195 {
196 assert_eq!(ys.len(), 1 << (self.filter_bitn - layer_i));
197 if ys.len() == 1 {
198 *ys[0] = (Into::<G>::into(s) + if t { k.cw_np1.clone() } else { G::zero() }).neg_if(b);
199 return;
200 }
201
202 let cw = &k.cws[layer_i];
203 let [([mut sl], mut tl), ([mut sr], mut tr)] = self.prg.gen(&s);
204 xor_inplace(&mut sl, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
205 xor_inplace(&mut sr, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
206 tl ^= t & cw.tl;
207 tr ^= t & cw.tr;
208
209 let (ys_l, ys_r) = ys.split_at_mut(ys.len() / 2);
210 #[cfg(feature = "multi-thread")]
211 rayon::join(
212 || self.full_eval_layer(b, k, ys_l, layer_i + 1, (sl, tl)),
213 || self.full_eval_layer(b, k, ys_r, layer_i + 1, (sr, tr)),
214 );
215 #[cfg(not(feature = "multi-thread"))]
216 {
217 self.full_eval_layer(b, k, ys_l, layer_i + 1, (sl, tl));
218 self.full_eval_layer(b, k, ys_r, layer_i + 1, (sr, tr));
219 }
220 }
221
222 pub fn eval_point<G>(&self, b: bool, k: &Share<OUT_BLEN, G>, x: &[u8; IN_BLEN], y: &mut G)
223 where
224 G: Group<OUT_BLEN>,
225 {
226 let n = k.cws.len();
227 assert_eq!(n, self.filter_bitn);
228 let v = y;
229
230 let mut s_prev = k.s0s[0];
231 let mut t_prev = b;
232 for i in 0..n {
233 let cw = &k.cws[i];
234 let [([mut sl], mut tl), ([mut sr], mut tr)] = self.prg.gen(&s_prev);
235 xor_inplace(&mut sl, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
236 xor_inplace(&mut sr, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
237 tl ^= t_prev & cw.tl;
238 tr ^= t_prev & cw.tr;
239 if x.view_bits::<Msb0>()[i] {
240 s_prev = sr;
241 t_prev = tr;
242 } else {
243 s_prev = sl;
244 t_prev = tl;
245 }
246 }
247 *v =
248 (Into::<G>::into(s_prev) + if t_prev { k.cw_np1.clone() } else { G::zero() }).neg_if(b);
249 }
250}
251
252#[cfg(all(test, feature = "prg"))]
253mod tests {
254 use arbtest::arbtest;
255 use std::iter;
256
257 use super::*;
258 use crate::group::byte::ByteGroup;
259 use crate::prg::Aes128MatyasMeyerOseasPrg;
260
261 type GroupImpl = ByteGroup<16>;
262 type PrgImpl = Aes128MatyasMeyerOseasPrg<16, 1, 2>;
263 type DpfImplImpl = DpfImpl<2, 16, PrgImpl>;
264
265 #[test]
266 fn test_correctness() {
267 arbtest(|u| {
268 let keys: [[u8; 16]; 2] = u.arbitrary()?;
269 let prg = PrgImpl::new(&std::array::from_fn(|i| &keys[i]));
270 let filter_bitn = u.arbitrary::<usize>()? % 15 + 2; let dpf = DpfImplImpl::new_with_filter(prg, filter_bitn);
272 let s0s: [[u8; 16]; 2] = u.arbitrary()?;
273 let alpha_i: u16 = u.arbitrary::<u16>()? >> (16 - filter_bitn);
274 let alpha: [u8; 2] = (alpha_i << (16 - filter_bitn)).to_be_bytes();
275 let beta: [u8; 16] = u.arbitrary()?;
276 let f = PointFn {
277 alpha,
278 beta: beta.into(),
279 };
280 let k = dpf.gen(&f, [&s0s[0], &s0s[1]]);
281 let mut k0 = k.clone();
282 k0.s0s = vec![k0.s0s[0]];
283 let mut k1 = k;
284 k1.s0s = vec![k1.s0s[1]];
285
286 let xs: Vec<_> = (0u16..=u16::MAX >> (16 - filter_bitn))
287 .map(|i| (i << (16 - filter_bitn)).to_be_bytes())
288 .collect();
289 assert_eq!(xs.len(), 1 << filter_bitn);
290 let xs_lt_num = alpha_i;
291 let xs_gt_num = (u16::MAX >> (16 - filter_bitn)) - alpha_i;
292 let ys_expected: Vec<_> = iter::repeat(beta.into())
293 .take(xs_lt_num as usize)
294 .chain([beta.into()])
295 .chain(iter::repeat(GroupImpl::zero()).take(xs_gt_num as usize))
296 .collect();
297
298 let mut ys0 = vec![GroupImpl::zero(); xs.len()];
299 let mut ys1 = vec![GroupImpl::zero(); xs.len()];
300 dpf.eval(
301 false,
302 &k0,
303 &xs.iter().collect::<Vec<_>>(),
304 &mut ys0.iter_mut().collect::<Vec<_>>(),
305 );
306 ys0.iter().for_each(|y0| {
307 assert_ne!(*y0, GroupImpl::zero());
308 assert_ne!(*y0, [0xff; 16].into());
309 });
310 dpf.eval(
311 true,
312 &k1,
313 &xs.iter().collect::<Vec<_>>(),
314 &mut ys1.iter_mut().collect::<Vec<_>>(),
315 );
316 ys1.iter().for_each(|y1| {
317 assert_ne!(*y1, GroupImpl::zero());
318 assert_ne!(*y1, [0xff; 16].into());
319 });
320 let ys: Vec<_> = ys0
321 .iter()
322 .zip(ys1.iter())
323 .map(|(y0, y1)| y0.clone() + y1.clone())
324 .collect();
325 assert_ys_eq(&ys, &ys_expected, &xs, &alpha);
326
327 let mut ys0_full_eval = vec![ByteGroup::zero(); 1 << filter_bitn];
328 dpf.full_eval(
329 false,
330 &k0,
331 &mut ys0_full_eval.iter_mut().collect::<Vec<_>>(),
332 );
333 assert_ys_eq(&ys0_full_eval, &ys0, &xs, &alpha);
334 let mut ys1_full_eval = vec![ByteGroup::zero(); 1 << filter_bitn];
335 dpf.full_eval(true, &k1, &mut ys1_full_eval.iter_mut().collect::<Vec<_>>());
336 assert_ys_eq(&ys1_full_eval, &ys1, &xs, &alpha);
337
338 Ok(())
339 });
340 }
341
342 fn assert_ys_eq(ys: &[GroupImpl], ys_expected: &[GroupImpl], xs: &[[u8; 2]], alpha: &[u8; 2]) {
343 let alpha_int = u16::from_be_bytes(*alpha);
344 for (i, (x, (y, y_expected))) in
345 xs.iter().zip(ys.iter().zip(ys_expected.iter())).enumerate()
346 {
347 let x_int = u16::from_be_bytes(*x);
348 let cmp = if x_int < alpha_int {
349 "<"
350 } else if x_int > alpha_int {
351 ">"
352 } else {
353 "="
354 };
355 assert_eq!(y, y_expected, "where i={}, x={:?}, x{}alpha", i, *x, cmp);
356 }
357 }
358}