fss_rs/
dcf.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (C) 2023 Yulong Ming (myl7)
3
4//! See [`Dcf`].
5
6use bitvec::prelude::*;
7#[cfg(feature = "multi-thread")]
8use rayon::prelude::*;
9
10use crate::group::Group;
11use crate::utils::{xor, xor_inplace};
12use crate::{Cw, PointFn, Prg, Share};
13
14/// API of distributed comparison functions (DCFs).
15///
16/// - See [`CmpFn`] for `IN_BLEN` and `OUT_BLEN`.
17/// - See [`DcfImpl`] for the implementation.
18pub trait Dcf<const IN_BLEN: usize, const OUT_BLEN: usize, G>
19where
20    G: Group<OUT_BLEN>,
21{
22    /// `s0s` is `$s^{(0)}_0$` and `$s^{(0)}_1$` which should be randomly sampled.
23    fn gen(&self, f: &CmpFn<IN_BLEN, OUT_BLEN, G>, s0s: [&[u8; OUT_BLEN]; 2])
24        -> Share<OUT_BLEN, G>;
25
26    /// `b` is the party. `false` is 0 and `true` is 1.
27    fn eval(&self, b: bool, k: &Share<OUT_BLEN, G>, xs: &[&[u8; IN_BLEN]], ys: &mut [&mut G]);
28
29    /// Full domain eval.
30    /// See [`Dcf::eval`] for `b`.
31    /// The corresponding `xs` to `ys` is the big endian representation of `0..=u*::MAX`.
32    fn full_eval(&self, b: bool, k: &Share<OUT_BLEN, G>, ys: &mut [&mut G]);
33}
34
35/// Comparison function.
36///
37/// - See [`BoundState`] for available `bound` values.
38/// - See [`PointFn`] for `IN_BLEN`, `OUT_BLEN`, `alpha`, and `beta`.
39pub struct CmpFn<const IN_BLEN: usize, const OUT_BLEN: usize, G>
40where
41    G: Group<OUT_BLEN>,
42{
43    /// `$\alpha$`.
44    pub alpha: [u8; IN_BLEN],
45    /// `$\beta$`.
46    pub beta: G,
47    /// See [`BoundState`].
48    pub bound: BoundState,
49}
50
51impl<const IN_BLEN: usize, const OUT_BLEN: usize, G> CmpFn<IN_BLEN, OUT_BLEN, G>
52where
53    G: Group<OUT_BLEN>,
54{
55    pub fn from_point(point: PointFn<IN_BLEN, OUT_BLEN, G>, bound: BoundState) -> Self {
56        Self {
57            alpha: point.alpha,
58            beta: point.beta,
59            bound,
60        }
61    }
62}
63
64/// Implementation of [`Dcf`].
65///
66/// `$\alpha$` itself is not included (or say exclusive endpoint), which means `$f(\alpha)$ = 0`.
67pub struct DcfImpl<const IN_BLEN: usize, const OUT_BLEN: usize, P>
68where
69    P: Prg<OUT_BLEN, 2>,
70{
71    prg: P,
72    filter_bitn: usize,
73}
74
75impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DcfImpl<IN_BLEN, OUT_BLEN, P>
76where
77    P: Prg<OUT_BLEN, 2>,
78{
79    pub fn new(prg: P) -> Self {
80        Self {
81            prg,
82            filter_bitn: 8 * IN_BLEN,
83        }
84    }
85
86    pub fn new_with_filter(prg: P, filter_bitn: usize) -> Self {
87        assert!(filter_bitn <= 8 * IN_BLEN && filter_bitn > 1);
88        Self { prg, filter_bitn }
89    }
90}
91
92const IDX_L: usize = 0;
93const IDX_R: usize = 1;
94
95impl<const IN_BLEN: usize, const OUT_BLEN: usize, P, G> Dcf<IN_BLEN, OUT_BLEN, G>
96    for DcfImpl<IN_BLEN, OUT_BLEN, P>
97where
98    P: Prg<OUT_BLEN, 2>,
99    G: Group<OUT_BLEN>,
100{
101    fn gen(
102        &self,
103        f: &CmpFn<IN_BLEN, OUT_BLEN, G>,
104        s0s: [&[u8; OUT_BLEN]; 2],
105    ) -> Share<OUT_BLEN, G> {
106        // The bit size of `$\alpha$`.
107        let n = self.filter_bitn;
108        let mut v_alpha = G::zero();
109        // Set `$s^{(1)}_0$` and `$s^{(1)}_1$`.
110        let mut ss_prev = [*s0s[0], *s0s[1]];
111        // Set `$t^{(0)}_0$` and `$t^{(0)}_1$`.
112        let mut ts_prev = [false, true];
113        let mut cws = Vec::<Cw<OUT_BLEN, G>>::with_capacity(n);
114        for i in 0..n {
115            // MSB is required since we index from high to low in arrays.
116            let alpha_i = f.alpha.view_bits::<Msb0>()[i];
117            let [([s0l, v0l], t0l), ([s0r, v0r], t0r)] = self.prg.gen(&ss_prev[0]);
118            let [([s1l, v1l], t1l), ([s1r, v1r], t1r)] = self.prg.gen(&ss_prev[1]);
119            // MSB is required since we index from high to low in arrays.
120            let (keep, lose) = if alpha_i {
121                (IDX_R, IDX_L)
122            } else {
123                (IDX_L, IDX_R)
124            };
125            let s_cw = xor(&[[&s0l, &s0r][lose], [&s1l, &s1r][lose]]);
126            let mut v_cw =
127                (G::from(*[&v0l, &v0r][lose]) + -G::from(*[&v1l, &v1r][lose]) + -v_alpha.clone())
128                    .neg_if(ts_prev[1]);
129            match f.bound {
130                BoundState::LtAlpha => {
131                    if lose == IDX_L {
132                        v_cw += f.beta.clone()
133                    }
134                }
135                BoundState::GtAlpha => {
136                    if lose == IDX_R {
137                        v_cw += f.beta.clone()
138                    }
139                }
140            }
141            v_alpha += -G::from(*[&v0l, &v0r][keep])
142                + (*[&v1l, &v1r][keep]).into()
143                + v_cw.clone().neg_if(ts_prev[1]);
144            let tl_cw = t0l ^ t1l ^ alpha_i ^ true;
145            let tr_cw = t0r ^ t1r ^ alpha_i;
146            let cw = Cw {
147                s: s_cw,
148                v: v_cw,
149                tl: tl_cw,
150                tr: tr_cw,
151            };
152            cws.push(cw);
153            ss_prev = [
154                xor(&[
155                    [&s0l, &s0r][keep],
156                    if ts_prev[0] { &s_cw } else { &[0; OUT_BLEN] },
157                ]),
158                xor(&[
159                    [&s1l, &s1r][keep],
160                    if ts_prev[1] { &s_cw } else { &[0; OUT_BLEN] },
161                ]),
162            ];
163            ts_prev = [
164                [t0l, t0r][keep] ^ (ts_prev[0] & [tl_cw, tr_cw][keep]),
165                [t1l, t1r][keep] ^ (ts_prev[1] & [tl_cw, tr_cw][keep]),
166            ];
167        }
168        let cw_np1 = (G::from(ss_prev[1]) + -G::from(ss_prev[0]) + -v_alpha).neg_if(ts_prev[1]);
169        Share {
170            s0s: vec![s0s[0].to_owned(), s0s[1].to_owned()],
171            cws,
172            cw_np1,
173        }
174    }
175
176    fn eval(&self, b: bool, k: &Share<OUT_BLEN, G>, xs: &[&[u8; IN_BLEN]], ys: &mut [&mut G]) {
177        #[cfg(feature = "multi-thread")]
178        self.eval_mt(b, k, xs, ys);
179        #[cfg(not(feature = "multi-thread"))]
180        self.eval_st(b, k, xs, ys);
181    }
182
183    fn full_eval(&self, b: bool, k: &Share<OUT_BLEN, G>, ys: &mut [&mut G]) {
184        let n = k.cws.len();
185        assert_eq!(n, self.filter_bitn);
186
187        let s = k.s0s[0];
188        let v = G::zero();
189        let t = b;
190        self.full_eval_layer(b, k, ys, 0, (s, v, t));
191    }
192}
193
194impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DcfImpl<IN_BLEN, OUT_BLEN, P>
195where
196    P: Prg<OUT_BLEN, 2>,
197{
198    /// Eval with single-threading.
199    /// See [`Dcf::eval`].
200    pub fn eval_st<G>(
201        &self,
202        b: bool,
203        k: &Share<OUT_BLEN, G>,
204        xs: &[&[u8; IN_BLEN]],
205        ys: &mut [&mut G],
206    ) where
207        G: Group<OUT_BLEN>,
208    {
209        xs.iter()
210            .zip(ys.iter_mut())
211            .for_each(|(x, y)| self.eval_point(b, k, x, y));
212    }
213
214    #[cfg(feature = "multi-thread")]
215    /// Eval with multi-threading.
216    /// See [`Dcf::eval`].
217    pub fn eval_mt<G>(
218        &self,
219        b: bool,
220        k: &Share<OUT_BLEN, G>,
221        xs: &[&[u8; IN_BLEN]],
222        ys: &mut [&mut G],
223    ) where
224        G: Group<OUT_BLEN>,
225    {
226        xs.par_iter()
227            .zip(ys.par_iter_mut())
228            .for_each(|(x, y)| self.eval_point(b, k, x, y));
229    }
230
231    fn full_eval_layer<G>(
232        &self,
233        b: bool,
234        k: &Share<OUT_BLEN, G>,
235        ys: &mut [&mut G],
236        layer_i: usize,
237        (s, v, t): ([u8; OUT_BLEN], G, bool),
238    ) where
239        G: Group<OUT_BLEN>,
240    {
241        assert_eq!(ys.len(), 1 << (self.filter_bitn - layer_i));
242        if ys.len() == 1 {
243            *ys[0] = v + (G::from(s) + if t { k.cw_np1.clone() } else { G::zero() }).neg_if(b);
244            return;
245        }
246
247        let cw = &k.cws[layer_i];
248        // `*_hat` before in-place XOR.
249        let [([mut sl, vl_hat], mut tl), ([mut sr, vr_hat], mut tr)] = self.prg.gen(&s);
250        xor_inplace(&mut sl, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
251        xor_inplace(&mut sr, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
252        tl ^= t & cw.tl;
253        tr ^= t & cw.tr;
254        let vl = v.clone() + (G::from(vl_hat) + if t { cw.v.clone() } else { G::zero() }).neg_if(b);
255        let vr = v + (G::from(vr_hat) + if t { cw.v.clone() } else { G::zero() }).neg_if(b);
256
257        let (ys_l, ys_r) = ys.split_at_mut(ys.len() / 2);
258        #[cfg(feature = "multi-thread")]
259        rayon::join(
260            || self.full_eval_layer(b, k, ys_l, layer_i + 1, (sl, vl, tl)),
261            || self.full_eval_layer(b, k, ys_r, layer_i + 1, (sr, vr, tr)),
262        );
263        #[cfg(not(feature = "multi-thread"))]
264        {
265            self.full_eval_layer(b, k, ys_l, layer_i + 1, (sl, vl, tl));
266            self.full_eval_layer(b, k, ys_r, layer_i + 1, (sr, vr, tr));
267        }
268    }
269
270    pub fn eval_point<G>(&self, b: bool, k: &Share<OUT_BLEN, G>, x: &[u8; IN_BLEN], y: &mut G)
271    where
272        G: Group<OUT_BLEN>,
273    {
274        let n = k.cws.len();
275        assert_eq!(n, self.filter_bitn);
276        let v = y;
277
278        let mut s_prev = k.s0s[0];
279        let mut t_prev = b;
280        *v = G::zero();
281        for i in 0..n {
282            let cw = &k.cws[i];
283            // `*_hat` before in-place XOR.
284            let [([mut sl, vl_hat], mut tl), ([mut sr, vr_hat], mut tr)] = self.prg.gen(&s_prev);
285            xor_inplace(&mut sl, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
286            xor_inplace(&mut sr, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
287            tl ^= t_prev & cw.tl;
288            tr ^= t_prev & cw.tr;
289            if x.view_bits::<Msb0>()[i] {
290                *v += (G::from(vr_hat) + if t_prev { cw.v.clone() } else { G::zero() }).neg_if(b);
291                s_prev = sr;
292                t_prev = tr;
293            } else {
294                *v += (G::from(vl_hat) + if t_prev { cw.v.clone() } else { G::zero() }).neg_if(b);
295                s_prev = sl;
296                t_prev = tl;
297            }
298        }
299        *v += (G::from(s_prev) + if t_prev { k.cw_np1.clone() } else { G::zero() }).neg_if(b);
300    }
301}
302
303pub enum BoundState {
304    /// `$f(x) = \beta$` iff. `$x < \alpha$`, otherwise `$f(x) = 0$`.
305    ///
306    /// This is the choice of the paper.
307    LtAlpha,
308    /// `$f(x) = \beta$` iff. `$x > \alpha$`, otherwise `$f(x) = 0$`.
309    GtAlpha,
310}
311
312#[cfg(all(test, feature = "prg"))]
313mod tests {
314    use std::iter;
315
316    use arbtest::arbtest;
317
318    use super::*;
319    use crate::group::byte::ByteGroup;
320    use crate::prg::Aes128MatyasMeyerOseasPrg;
321
322    type GroupImpl = ByteGroup<16>;
323    type PrgImpl = Aes128MatyasMeyerOseasPrg<16, 2, 4>;
324    type DcfImplImpl = DcfImpl<2, 16, PrgImpl>;
325
326    #[test]
327    fn test_correctness() {
328        arbtest(|u| {
329            let keys: [[u8; 16]; 4] = u.arbitrary()?;
330            let prg = PrgImpl::new(&std::array::from_fn(|i| &keys[i]));
331            let filter_bitn = u.arbitrary::<usize>()? % 15 + 2; // 2..=16
332            let dcf = DcfImplImpl::new_with_filter(prg, filter_bitn);
333            let s0s: [[u8; 16]; 2] = u.arbitrary()?;
334            let alpha_i: u16 = u.arbitrary::<u16>()? >> (16 - filter_bitn);
335            let alpha: [u8; 2] = (alpha_i << (16 - filter_bitn)).to_be_bytes();
336            let beta: [u8; 16] = u.arbitrary()?;
337            let bound_is_gt: bool = u.arbitrary()?;
338            let f = CmpFn {
339                alpha,
340                beta: beta.into(),
341                bound: if bound_is_gt {
342                    BoundState::GtAlpha
343                } else {
344                    BoundState::LtAlpha
345                },
346            };
347            let k = dcf.gen(&f, [&s0s[0], &s0s[1]]);
348            let mut k0 = k.clone();
349            k0.s0s = vec![k0.s0s[0]];
350            let mut k1 = k;
351            k1.s0s = vec![k1.s0s[1]];
352
353            let xs: Vec<_> = (0u16..=u16::MAX >> (16 - filter_bitn))
354                .map(|i| (i << (16 - filter_bitn)).to_be_bytes())
355                .collect();
356            assert_eq!(xs.len(), 1 << filter_bitn);
357            let xs_lt_num = alpha_i;
358            let xs_gt_num = (u16::MAX >> (16 - filter_bitn)) - alpha_i;
359            let ys_expected: Vec<_> = iter::repeat(if bound_is_gt {
360                GroupImpl::zero()
361            } else {
362                beta.into()
363            })
364            .take(xs_lt_num as usize)
365            .chain([GroupImpl::zero()])
366            .chain(
367                iter::repeat({
368                    if bound_is_gt {
369                        beta.into()
370                    } else {
371                        GroupImpl::zero()
372                    }
373                })
374                .take(xs_gt_num as usize),
375            )
376            .collect();
377
378            let mut ys0 = vec![GroupImpl::zero(); xs.len()];
379            let mut ys1 = vec![GroupImpl::zero(); xs.len()];
380            dcf.eval(
381                false,
382                &k0,
383                &xs.iter().collect::<Vec<_>>(),
384                &mut ys0.iter_mut().collect::<Vec<_>>(),
385            );
386            ys0.iter().for_each(|y0| {
387                assert_ne!(*y0, GroupImpl::zero());
388                assert_ne!(*y0, [0xff; 16].into());
389            });
390            dcf.eval(
391                true,
392                &k1,
393                &xs.iter().collect::<Vec<_>>(),
394                &mut ys1.iter_mut().collect::<Vec<_>>(),
395            );
396            ys1.iter().for_each(|y1| {
397                assert_ne!(*y1, GroupImpl::zero());
398                assert_ne!(*y1, [0xff; 16].into());
399            });
400            let ys: Vec<_> = ys0
401                .iter()
402                .zip(ys1.iter())
403                .map(|(y0, y1)| y0.clone() + y1.clone())
404                .collect();
405            assert_ys_eq(&ys, &ys_expected, &xs, &alpha);
406
407            let mut ys0_full_eval = vec![ByteGroup::zero(); 1 << filter_bitn];
408            dcf.full_eval(
409                false,
410                &k0,
411                &mut ys0_full_eval.iter_mut().collect::<Vec<_>>(),
412            );
413            assert_ys_eq(&ys0_full_eval, &ys0, &xs, &alpha);
414            let mut ys1_full_eval = vec![ByteGroup::zero(); 1 << filter_bitn];
415            dcf.full_eval(true, &k1, &mut ys1_full_eval.iter_mut().collect::<Vec<_>>());
416            assert_ys_eq(&ys1_full_eval, &ys1, &xs, &alpha);
417
418            Ok(())
419        });
420    }
421
422    fn assert_ys_eq(ys: &[GroupImpl], ys_expected: &[GroupImpl], xs: &[[u8; 2]], alpha: &[u8; 2]) {
423        let alpha_int = u16::from_be_bytes(*alpha);
424        for (i, (x, (y, y_expected))) in
425            xs.iter().zip(ys.iter().zip(ys_expected.iter())).enumerate()
426        {
427            let x_int = u16::from_be_bytes(*x);
428            let cmp = if x_int < alpha_int {
429                "<"
430            } else if x_int > alpha_int {
431                ">"
432            } else {
433                "="
434            };
435            assert_eq!(y, y_expected, "where i={}, x={:?}, x{}alpha", i, *x, cmp);
436        }
437    }
438}