1use bitvec::prelude::*;
7#[cfg(feature = "multi-thread")]
8use rayon::prelude::*;
9
10use crate::group::Group;
11use crate::utils::{xor, xor_inplace};
12use crate::{Cw, PointFn, Prg, Share};
13
14pub trait Dcf<const IN_BLEN: usize, const OUT_BLEN: usize, G>
19where
20 G: Group<OUT_BLEN>,
21{
22 fn gen(&self, f: &CmpFn<IN_BLEN, OUT_BLEN, G>, s0s: [&[u8; OUT_BLEN]; 2])
24 -> Share<OUT_BLEN, G>;
25
26 fn eval(&self, b: bool, k: &Share<OUT_BLEN, G>, xs: &[&[u8; IN_BLEN]], ys: &mut [&mut G]);
28
29 fn full_eval(&self, b: bool, k: &Share<OUT_BLEN, G>, ys: &mut [&mut G]);
33}
34
35pub struct CmpFn<const IN_BLEN: usize, const OUT_BLEN: usize, G>
40where
41 G: Group<OUT_BLEN>,
42{
43 pub alpha: [u8; IN_BLEN],
45 pub beta: G,
47 pub bound: BoundState,
49}
50
51impl<const IN_BLEN: usize, const OUT_BLEN: usize, G> CmpFn<IN_BLEN, OUT_BLEN, G>
52where
53 G: Group<OUT_BLEN>,
54{
55 pub fn from_point(point: PointFn<IN_BLEN, OUT_BLEN, G>, bound: BoundState) -> Self {
56 Self {
57 alpha: point.alpha,
58 beta: point.beta,
59 bound,
60 }
61 }
62}
63
64pub struct DcfImpl<const IN_BLEN: usize, const OUT_BLEN: usize, P>
68where
69 P: Prg<OUT_BLEN, 2>,
70{
71 prg: P,
72 filter_bitn: usize,
73}
74
75impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DcfImpl<IN_BLEN, OUT_BLEN, P>
76where
77 P: Prg<OUT_BLEN, 2>,
78{
79 pub fn new(prg: P) -> Self {
80 Self {
81 prg,
82 filter_bitn: 8 * IN_BLEN,
83 }
84 }
85
86 pub fn new_with_filter(prg: P, filter_bitn: usize) -> Self {
87 assert!(filter_bitn <= 8 * IN_BLEN && filter_bitn > 1);
88 Self { prg, filter_bitn }
89 }
90}
91
92const IDX_L: usize = 0;
93const IDX_R: usize = 1;
94
95impl<const IN_BLEN: usize, const OUT_BLEN: usize, P, G> Dcf<IN_BLEN, OUT_BLEN, G>
96 for DcfImpl<IN_BLEN, OUT_BLEN, P>
97where
98 P: Prg<OUT_BLEN, 2>,
99 G: Group<OUT_BLEN>,
100{
101 fn gen(
102 &self,
103 f: &CmpFn<IN_BLEN, OUT_BLEN, G>,
104 s0s: [&[u8; OUT_BLEN]; 2],
105 ) -> Share<OUT_BLEN, G> {
106 let n = self.filter_bitn;
108 let mut v_alpha = G::zero();
109 let mut ss_prev = [*s0s[0], *s0s[1]];
111 let mut ts_prev = [false, true];
113 let mut cws = Vec::<Cw<OUT_BLEN, G>>::with_capacity(n);
114 for i in 0..n {
115 let alpha_i = f.alpha.view_bits::<Msb0>()[i];
117 let [([s0l, v0l], t0l), ([s0r, v0r], t0r)] = self.prg.gen(&ss_prev[0]);
118 let [([s1l, v1l], t1l), ([s1r, v1r], t1r)] = self.prg.gen(&ss_prev[1]);
119 let (keep, lose) = if alpha_i {
121 (IDX_R, IDX_L)
122 } else {
123 (IDX_L, IDX_R)
124 };
125 let s_cw = xor(&[[&s0l, &s0r][lose], [&s1l, &s1r][lose]]);
126 let mut v_cw =
127 (G::from(*[&v0l, &v0r][lose]) + -G::from(*[&v1l, &v1r][lose]) + -v_alpha.clone())
128 .neg_if(ts_prev[1]);
129 match f.bound {
130 BoundState::LtAlpha => {
131 if lose == IDX_L {
132 v_cw += f.beta.clone()
133 }
134 }
135 BoundState::GtAlpha => {
136 if lose == IDX_R {
137 v_cw += f.beta.clone()
138 }
139 }
140 }
141 v_alpha += -G::from(*[&v0l, &v0r][keep])
142 + (*[&v1l, &v1r][keep]).into()
143 + v_cw.clone().neg_if(ts_prev[1]);
144 let tl_cw = t0l ^ t1l ^ alpha_i ^ true;
145 let tr_cw = t0r ^ t1r ^ alpha_i;
146 let cw = Cw {
147 s: s_cw,
148 v: v_cw,
149 tl: tl_cw,
150 tr: tr_cw,
151 };
152 cws.push(cw);
153 ss_prev = [
154 xor(&[
155 [&s0l, &s0r][keep],
156 if ts_prev[0] { &s_cw } else { &[0; OUT_BLEN] },
157 ]),
158 xor(&[
159 [&s1l, &s1r][keep],
160 if ts_prev[1] { &s_cw } else { &[0; OUT_BLEN] },
161 ]),
162 ];
163 ts_prev = [
164 [t0l, t0r][keep] ^ (ts_prev[0] & [tl_cw, tr_cw][keep]),
165 [t1l, t1r][keep] ^ (ts_prev[1] & [tl_cw, tr_cw][keep]),
166 ];
167 }
168 let cw_np1 = (G::from(ss_prev[1]) + -G::from(ss_prev[0]) + -v_alpha).neg_if(ts_prev[1]);
169 Share {
170 s0s: vec![s0s[0].to_owned(), s0s[1].to_owned()],
171 cws,
172 cw_np1,
173 }
174 }
175
176 fn eval(&self, b: bool, k: &Share<OUT_BLEN, G>, xs: &[&[u8; IN_BLEN]], ys: &mut [&mut G]) {
177 #[cfg(feature = "multi-thread")]
178 self.eval_mt(b, k, xs, ys);
179 #[cfg(not(feature = "multi-thread"))]
180 self.eval_st(b, k, xs, ys);
181 }
182
183 fn full_eval(&self, b: bool, k: &Share<OUT_BLEN, G>, ys: &mut [&mut G]) {
184 let n = k.cws.len();
185 assert_eq!(n, self.filter_bitn);
186
187 let s = k.s0s[0];
188 let v = G::zero();
189 let t = b;
190 self.full_eval_layer(b, k, ys, 0, (s, v, t));
191 }
192}
193
194impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DcfImpl<IN_BLEN, OUT_BLEN, P>
195where
196 P: Prg<OUT_BLEN, 2>,
197{
198 pub fn eval_st<G>(
201 &self,
202 b: bool,
203 k: &Share<OUT_BLEN, G>,
204 xs: &[&[u8; IN_BLEN]],
205 ys: &mut [&mut G],
206 ) where
207 G: Group<OUT_BLEN>,
208 {
209 xs.iter()
210 .zip(ys.iter_mut())
211 .for_each(|(x, y)| self.eval_point(b, k, x, y));
212 }
213
214 #[cfg(feature = "multi-thread")]
215 pub fn eval_mt<G>(
218 &self,
219 b: bool,
220 k: &Share<OUT_BLEN, G>,
221 xs: &[&[u8; IN_BLEN]],
222 ys: &mut [&mut G],
223 ) where
224 G: Group<OUT_BLEN>,
225 {
226 xs.par_iter()
227 .zip(ys.par_iter_mut())
228 .for_each(|(x, y)| self.eval_point(b, k, x, y));
229 }
230
231 fn full_eval_layer<G>(
232 &self,
233 b: bool,
234 k: &Share<OUT_BLEN, G>,
235 ys: &mut [&mut G],
236 layer_i: usize,
237 (s, v, t): ([u8; OUT_BLEN], G, bool),
238 ) where
239 G: Group<OUT_BLEN>,
240 {
241 assert_eq!(ys.len(), 1 << (self.filter_bitn - layer_i));
242 if ys.len() == 1 {
243 *ys[0] = v + (G::from(s) + if t { k.cw_np1.clone() } else { G::zero() }).neg_if(b);
244 return;
245 }
246
247 let cw = &k.cws[layer_i];
248 let [([mut sl, vl_hat], mut tl), ([mut sr, vr_hat], mut tr)] = self.prg.gen(&s);
250 xor_inplace(&mut sl, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
251 xor_inplace(&mut sr, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
252 tl ^= t & cw.tl;
253 tr ^= t & cw.tr;
254 let vl = v.clone() + (G::from(vl_hat) + if t { cw.v.clone() } else { G::zero() }).neg_if(b);
255 let vr = v + (G::from(vr_hat) + if t { cw.v.clone() } else { G::zero() }).neg_if(b);
256
257 let (ys_l, ys_r) = ys.split_at_mut(ys.len() / 2);
258 #[cfg(feature = "multi-thread")]
259 rayon::join(
260 || self.full_eval_layer(b, k, ys_l, layer_i + 1, (sl, vl, tl)),
261 || self.full_eval_layer(b, k, ys_r, layer_i + 1, (sr, vr, tr)),
262 );
263 #[cfg(not(feature = "multi-thread"))]
264 {
265 self.full_eval_layer(b, k, ys_l, layer_i + 1, (sl, vl, tl));
266 self.full_eval_layer(b, k, ys_r, layer_i + 1, (sr, vr, tr));
267 }
268 }
269
270 pub fn eval_point<G>(&self, b: bool, k: &Share<OUT_BLEN, G>, x: &[u8; IN_BLEN], y: &mut G)
271 where
272 G: Group<OUT_BLEN>,
273 {
274 let n = k.cws.len();
275 assert_eq!(n, self.filter_bitn);
276 let v = y;
277
278 let mut s_prev = k.s0s[0];
279 let mut t_prev = b;
280 *v = G::zero();
281 for i in 0..n {
282 let cw = &k.cws[i];
283 let [([mut sl, vl_hat], mut tl), ([mut sr, vr_hat], mut tr)] = self.prg.gen(&s_prev);
285 xor_inplace(&mut sl, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
286 xor_inplace(&mut sr, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
287 tl ^= t_prev & cw.tl;
288 tr ^= t_prev & cw.tr;
289 if x.view_bits::<Msb0>()[i] {
290 *v += (G::from(vr_hat) + if t_prev { cw.v.clone() } else { G::zero() }).neg_if(b);
291 s_prev = sr;
292 t_prev = tr;
293 } else {
294 *v += (G::from(vl_hat) + if t_prev { cw.v.clone() } else { G::zero() }).neg_if(b);
295 s_prev = sl;
296 t_prev = tl;
297 }
298 }
299 *v += (G::from(s_prev) + if t_prev { k.cw_np1.clone() } else { G::zero() }).neg_if(b);
300 }
301}
302
303pub enum BoundState {
304 LtAlpha,
308 GtAlpha,
310}
311
312#[cfg(all(test, feature = "prg"))]
313mod tests {
314 use std::iter;
315
316 use arbtest::arbtest;
317
318 use super::*;
319 use crate::group::byte::ByteGroup;
320 use crate::prg::Aes128MatyasMeyerOseasPrg;
321
322 type GroupImpl = ByteGroup<16>;
323 type PrgImpl = Aes128MatyasMeyerOseasPrg<16, 2, 4>;
324 type DcfImplImpl = DcfImpl<2, 16, PrgImpl>;
325
326 #[test]
327 fn test_correctness() {
328 arbtest(|u| {
329 let keys: [[u8; 16]; 4] = u.arbitrary()?;
330 let prg = PrgImpl::new(&std::array::from_fn(|i| &keys[i]));
331 let filter_bitn = u.arbitrary::<usize>()? % 15 + 2; let dcf = DcfImplImpl::new_with_filter(prg, filter_bitn);
333 let s0s: [[u8; 16]; 2] = u.arbitrary()?;
334 let alpha_i: u16 = u.arbitrary::<u16>()? >> (16 - filter_bitn);
335 let alpha: [u8; 2] = (alpha_i << (16 - filter_bitn)).to_be_bytes();
336 let beta: [u8; 16] = u.arbitrary()?;
337 let bound_is_gt: bool = u.arbitrary()?;
338 let f = CmpFn {
339 alpha,
340 beta: beta.into(),
341 bound: if bound_is_gt {
342 BoundState::GtAlpha
343 } else {
344 BoundState::LtAlpha
345 },
346 };
347 let k = dcf.gen(&f, [&s0s[0], &s0s[1]]);
348 let mut k0 = k.clone();
349 k0.s0s = vec![k0.s0s[0]];
350 let mut k1 = k;
351 k1.s0s = vec![k1.s0s[1]];
352
353 let xs: Vec<_> = (0u16..=u16::MAX >> (16 - filter_bitn))
354 .map(|i| (i << (16 - filter_bitn)).to_be_bytes())
355 .collect();
356 assert_eq!(xs.len(), 1 << filter_bitn);
357 let xs_lt_num = alpha_i;
358 let xs_gt_num = (u16::MAX >> (16 - filter_bitn)) - alpha_i;
359 let ys_expected: Vec<_> = iter::repeat(if bound_is_gt {
360 GroupImpl::zero()
361 } else {
362 beta.into()
363 })
364 .take(xs_lt_num as usize)
365 .chain([GroupImpl::zero()])
366 .chain(
367 iter::repeat({
368 if bound_is_gt {
369 beta.into()
370 } else {
371 GroupImpl::zero()
372 }
373 })
374 .take(xs_gt_num as usize),
375 )
376 .collect();
377
378 let mut ys0 = vec![GroupImpl::zero(); xs.len()];
379 let mut ys1 = vec![GroupImpl::zero(); xs.len()];
380 dcf.eval(
381 false,
382 &k0,
383 &xs.iter().collect::<Vec<_>>(),
384 &mut ys0.iter_mut().collect::<Vec<_>>(),
385 );
386 ys0.iter().for_each(|y0| {
387 assert_ne!(*y0, GroupImpl::zero());
388 assert_ne!(*y0, [0xff; 16].into());
389 });
390 dcf.eval(
391 true,
392 &k1,
393 &xs.iter().collect::<Vec<_>>(),
394 &mut ys1.iter_mut().collect::<Vec<_>>(),
395 );
396 ys1.iter().for_each(|y1| {
397 assert_ne!(*y1, GroupImpl::zero());
398 assert_ne!(*y1, [0xff; 16].into());
399 });
400 let ys: Vec<_> = ys0
401 .iter()
402 .zip(ys1.iter())
403 .map(|(y0, y1)| y0.clone() + y1.clone())
404 .collect();
405 assert_ys_eq(&ys, &ys_expected, &xs, &alpha);
406
407 let mut ys0_full_eval = vec![ByteGroup::zero(); 1 << filter_bitn];
408 dcf.full_eval(
409 false,
410 &k0,
411 &mut ys0_full_eval.iter_mut().collect::<Vec<_>>(),
412 );
413 assert_ys_eq(&ys0_full_eval, &ys0, &xs, &alpha);
414 let mut ys1_full_eval = vec![ByteGroup::zero(); 1 << filter_bitn];
415 dcf.full_eval(true, &k1, &mut ys1_full_eval.iter_mut().collect::<Vec<_>>());
416 assert_ys_eq(&ys1_full_eval, &ys1, &xs, &alpha);
417
418 Ok(())
419 });
420 }
421
422 fn assert_ys_eq(ys: &[GroupImpl], ys_expected: &[GroupImpl], xs: &[[u8; 2]], alpha: &[u8; 2]) {
423 let alpha_int = u16::from_be_bytes(*alpha);
424 for (i, (x, (y, y_expected))) in
425 xs.iter().zip(ys.iter().zip(ys_expected.iter())).enumerate()
426 {
427 let x_int = u16::from_be_bytes(*x);
428 let cmp = if x_int < alpha_int {
429 "<"
430 } else if x_int > alpha_int {
431 ">"
432 } else {
433 "="
434 };
435 assert_eq!(y, y_expected, "where i={}, x={:?}, x{}alpha", i, *x, cmp);
436 }
437 }
438}