fss_rs/
dpf.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (C) 2023 Yulong Ming (myl7)
3
4//! See [`Dpf`].
5
6use bitvec::prelude::*;
7#[cfg(feature = "multi-thread")]
8use rayon::prelude::*;
9
10use crate::group::Group;
11use crate::utils::{xor, xor_inplace};
12pub use crate::PointFn;
13use crate::{Cw, Prg, Share};
14
15/// API of distributed point functions (DPFs).
16///
17/// `PointFn` used here means `$f(x) = \beta$` iff. `$x = \alpha$`, otherwise `$f(x) = 0$`.
18///
19/// - See [`PointFn`] for `IN_BLEN` and `OUT_BLEN`.
20/// - See [`DpfImpl`] for the implementation.
21pub trait Dpf<const IN_BLEN: usize, const OUT_BLEN: usize, G>
22where
23    G: Group<OUT_BLEN>,
24{
25    /// `s0s` is `$s^{(0)}_0$` and `$s^{(0)}_1$` which should be randomly sampled.
26    fn gen(
27        &self,
28        f: &PointFn<IN_BLEN, OUT_BLEN, G>,
29        s0s: [&[u8; OUT_BLEN]; 2],
30    ) -> Share<OUT_BLEN, G>;
31
32    /// `b` is the party. `false` is 0 and `true` is 1.
33    fn eval(&self, b: bool, k: &Share<OUT_BLEN, G>, xs: &[&[u8; IN_BLEN]], ys: &mut [&mut G]);
34
35    /// Full domain eval.
36    /// See [`Dpf::eval`] for `b`.
37    /// The corresponding `xs` to `ys` is the big endian representation of `0..=u*::MAX`.
38    fn full_eval(&self, b: bool, k: &Share<OUT_BLEN, G>, ys: &mut [&mut G]);
39}
40
41/// Implementation of [`Dpf`].
42pub struct DpfImpl<const IN_BLEN: usize, const OUT_BLEN: usize, P>
43where
44    P: Prg<OUT_BLEN, 1>,
45{
46    prg: P,
47    filter_bitn: usize,
48}
49
50impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DpfImpl<IN_BLEN, OUT_BLEN, P>
51where
52    P: Prg<OUT_BLEN, 1>,
53{
54    pub fn new(prg: P) -> Self {
55        Self {
56            prg,
57            filter_bitn: 8 * IN_BLEN,
58        }
59    }
60
61    pub fn new_with_filter(prg: P, filter_bitn: usize) -> Self {
62        assert!(filter_bitn <= 8 * IN_BLEN && filter_bitn > 1);
63        Self { prg, filter_bitn }
64    }
65}
66
67const IDX_L: usize = 0;
68const IDX_R: usize = 1;
69
70impl<const IN_BLEN: usize, const OUT_BLEN: usize, P, G> Dpf<IN_BLEN, OUT_BLEN, G>
71    for DpfImpl<IN_BLEN, OUT_BLEN, P>
72where
73    P: Prg<OUT_BLEN, 1>,
74    G: Group<OUT_BLEN>,
75{
76    fn gen(
77        &self,
78        f: &PointFn<IN_BLEN, OUT_BLEN, G>,
79        s0s: [&[u8; OUT_BLEN]; 2],
80    ) -> Share<OUT_BLEN, G> {
81        // The bit size of `$\alpha$`.
82        let n = self.filter_bitn;
83        // Set `$s^{(1)}_0$` and `$s^{(1)}_1$`.
84        let mut ss_prev = [s0s[0].to_owned(), s0s[1].to_owned()];
85        // Set `$t^{(0)}_0$` and `$t^{(0)}_1$`.
86        let mut ts_prev = [false, true];
87        let mut cws = Vec::<Cw<OUT_BLEN, G>>::with_capacity(n);
88        for i in 0..n {
89            // MSB is required since we index from high to low in arrays.
90            let alpha_i = f.alpha.view_bits::<Msb0>()[i];
91            let [([s0l], t0l), ([s0r], t0r)] = self.prg.gen(&ss_prev[0]);
92            let [([s1l], t1l), ([s1r], t1r)] = self.prg.gen(&ss_prev[1]);
93            let (keep, lose) = if alpha_i {
94                (IDX_R, IDX_L)
95            } else {
96                (IDX_L, IDX_R)
97            };
98            let s_cw = xor(&[[&s0l, &s0r][lose], [&s1l, &s1r][lose]]);
99            let tl_cw = t0l ^ t1l ^ alpha_i ^ true;
100            let tr_cw = t0r ^ t1r ^ alpha_i;
101            let cw = Cw {
102                s: s_cw,
103                v: G::zero(),
104                tl: tl_cw,
105                tr: tr_cw,
106            };
107            cws.push(cw);
108            ss_prev = [
109                xor(&[
110                    [&s0l, &s0r][keep],
111                    if ts_prev[0] { &s_cw } else { &[0; OUT_BLEN] },
112                ]),
113                xor(&[
114                    [&s1l, &s1r][keep],
115                    if ts_prev[1] { &s_cw } else { &[0; OUT_BLEN] },
116                ]),
117            ];
118            ts_prev = [
119                [t0l, t0r][keep] ^ (ts_prev[0] & [tl_cw, tr_cw][keep]),
120                [t1l, t1r][keep] ^ (ts_prev[1] & [tl_cw, tr_cw][keep]),
121            ];
122        }
123        let cw_np1 =
124            (f.beta.clone() + -Into::<G>::into(ss_prev[0]) + ss_prev[1].into()).neg_if(ts_prev[1]);
125        Share {
126            s0s: vec![s0s[0].to_owned(), s0s[1].to_owned()],
127            cws,
128            cw_np1,
129        }
130    }
131
132    fn eval(&self, b: bool, k: &Share<OUT_BLEN, G>, xs: &[&[u8; IN_BLEN]], ys: &mut [&mut G]) {
133        #[cfg(feature = "multi-thread")]
134        self.eval_mt(b, k, xs, ys);
135        #[cfg(not(feature = "multi-thread"))]
136        self.eval_st(b, k, xs, ys);
137    }
138
139    fn full_eval(&self, b: bool, k: &Share<OUT_BLEN, G>, ys: &mut [&mut G]) {
140        let n = k.cws.len();
141        assert_eq!(n, self.filter_bitn);
142
143        let s = k.s0s[0];
144        let t = b;
145        self.full_eval_layer(b, k, ys, 0, (s, t));
146    }
147}
148
149impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DpfImpl<IN_BLEN, OUT_BLEN, P>
150where
151    P: Prg<OUT_BLEN, 1>,
152{
153    /// Eval with single-threading.
154    /// See [`Dpf::eval`].
155    pub fn eval_st<G>(
156        &self,
157        b: bool,
158        k: &Share<OUT_BLEN, G>,
159        xs: &[&[u8; IN_BLEN]],
160        ys: &mut [&mut G],
161    ) where
162        G: Group<OUT_BLEN>,
163    {
164        xs.iter()
165            .zip(ys.iter_mut())
166            .for_each(|(x, y)| self.eval_point(b, k, x, y));
167    }
168
169    #[cfg(feature = "multi-thread")]
170    /// Eval with multi-threading.
171    /// See [`Dpf::eval`].
172    pub fn eval_mt<G>(
173        &self,
174        b: bool,
175        k: &Share<OUT_BLEN, G>,
176        xs: &[&[u8; IN_BLEN]],
177        ys: &mut [&mut G],
178    ) where
179        G: Group<OUT_BLEN>,
180    {
181        xs.par_iter()
182            .zip(ys.par_iter_mut())
183            .for_each(|(x, y)| self.eval_point(b, k, x, y));
184    }
185
186    fn full_eval_layer<G>(
187        &self,
188        b: bool,
189        k: &Share<OUT_BLEN, G>,
190        ys: &mut [&mut G],
191        layer_i: usize,
192        (s, t): ([u8; OUT_BLEN], bool),
193    ) where
194        G: Group<OUT_BLEN>,
195    {
196        assert_eq!(ys.len(), 1 << (self.filter_bitn - layer_i));
197        if ys.len() == 1 {
198            *ys[0] = (Into::<G>::into(s) + if t { k.cw_np1.clone() } else { G::zero() }).neg_if(b);
199            return;
200        }
201
202        let cw = &k.cws[layer_i];
203        let [([mut sl], mut tl), ([mut sr], mut tr)] = self.prg.gen(&s);
204        xor_inplace(&mut sl, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
205        xor_inplace(&mut sr, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
206        tl ^= t & cw.tl;
207        tr ^= t & cw.tr;
208
209        let (ys_l, ys_r) = ys.split_at_mut(ys.len() / 2);
210        #[cfg(feature = "multi-thread")]
211        rayon::join(
212            || self.full_eval_layer(b, k, ys_l, layer_i + 1, (sl, tl)),
213            || self.full_eval_layer(b, k, ys_r, layer_i + 1, (sr, tr)),
214        );
215        #[cfg(not(feature = "multi-thread"))]
216        {
217            self.full_eval_layer(b, k, ys_l, layer_i + 1, (sl, tl));
218            self.full_eval_layer(b, k, ys_r, layer_i + 1, (sr, tr));
219        }
220    }
221
222    pub fn eval_point<G>(&self, b: bool, k: &Share<OUT_BLEN, G>, x: &[u8; IN_BLEN], y: &mut G)
223    where
224        G: Group<OUT_BLEN>,
225    {
226        let n = k.cws.len();
227        assert_eq!(n, self.filter_bitn);
228        let v = y;
229
230        let mut s_prev = k.s0s[0];
231        let mut t_prev = b;
232        for i in 0..n {
233            let cw = &k.cws[i];
234            let [([mut sl], mut tl), ([mut sr], mut tr)] = self.prg.gen(&s_prev);
235            xor_inplace(&mut sl, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
236            xor_inplace(&mut sr, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
237            tl ^= t_prev & cw.tl;
238            tr ^= t_prev & cw.tr;
239            if x.view_bits::<Msb0>()[i] {
240                s_prev = sr;
241                t_prev = tr;
242            } else {
243                s_prev = sl;
244                t_prev = tl;
245            }
246        }
247        *v =
248            (Into::<G>::into(s_prev) + if t_prev { k.cw_np1.clone() } else { G::zero() }).neg_if(b);
249    }
250}
251
252#[cfg(all(test, feature = "prg"))]
253mod tests {
254    use arbtest::arbtest;
255    use std::iter;
256
257    use super::*;
258    use crate::group::byte::ByteGroup;
259    use crate::prg::Aes128MatyasMeyerOseasPrg;
260
261    type GroupImpl = ByteGroup<16>;
262    type PrgImpl = Aes128MatyasMeyerOseasPrg<16, 1, 2>;
263    type DpfImplImpl = DpfImpl<2, 16, PrgImpl>;
264
265    #[test]
266    fn test_correctness() {
267        arbtest(|u| {
268            let keys: [[u8; 16]; 2] = u.arbitrary()?;
269            let prg = PrgImpl::new(&std::array::from_fn(|i| &keys[i]));
270            let filter_bitn = u.arbitrary::<usize>()? % 15 + 2; // 2..=16
271            let dpf = DpfImplImpl::new_with_filter(prg, filter_bitn);
272            let s0s: [[u8; 16]; 2] = u.arbitrary()?;
273            let alpha_i: u16 = u.arbitrary::<u16>()? >> (16 - filter_bitn);
274            let alpha: [u8; 2] = (alpha_i << (16 - filter_bitn)).to_be_bytes();
275            let beta: [u8; 16] = u.arbitrary()?;
276            let f = PointFn {
277                alpha,
278                beta: beta.into(),
279            };
280            let k = dpf.gen(&f, [&s0s[0], &s0s[1]]);
281            let mut k0 = k.clone();
282            k0.s0s = vec![k0.s0s[0]];
283            let mut k1 = k;
284            k1.s0s = vec![k1.s0s[1]];
285
286            let xs: Vec<_> = (0u16..=u16::MAX >> (16 - filter_bitn))
287                .map(|i| (i << (16 - filter_bitn)).to_be_bytes())
288                .collect();
289            assert_eq!(xs.len(), 1 << filter_bitn);
290            let xs_lt_num = alpha_i;
291            let xs_gt_num = (u16::MAX >> (16 - filter_bitn)) - alpha_i;
292            let ys_expected: Vec<_> = iter::repeat(beta.into())
293                .take(xs_lt_num as usize)
294                .chain([beta.into()])
295                .chain(iter::repeat(GroupImpl::zero()).take(xs_gt_num as usize))
296                .collect();
297
298            let mut ys0 = vec![GroupImpl::zero(); xs.len()];
299            let mut ys1 = vec![GroupImpl::zero(); xs.len()];
300            dpf.eval(
301                false,
302                &k0,
303                &xs.iter().collect::<Vec<_>>(),
304                &mut ys0.iter_mut().collect::<Vec<_>>(),
305            );
306            ys0.iter().for_each(|y0| {
307                assert_ne!(*y0, GroupImpl::zero());
308                assert_ne!(*y0, [0xff; 16].into());
309            });
310            dpf.eval(
311                true,
312                &k1,
313                &xs.iter().collect::<Vec<_>>(),
314                &mut ys1.iter_mut().collect::<Vec<_>>(),
315            );
316            ys1.iter().for_each(|y1| {
317                assert_ne!(*y1, GroupImpl::zero());
318                assert_ne!(*y1, [0xff; 16].into());
319            });
320            let ys: Vec<_> = ys0
321                .iter()
322                .zip(ys1.iter())
323                .map(|(y0, y1)| y0.clone() + y1.clone())
324                .collect();
325            assert_ys_eq(&ys, &ys_expected, &xs, &alpha);
326
327            let mut ys0_full_eval = vec![ByteGroup::zero(); 1 << filter_bitn];
328            dpf.full_eval(
329                false,
330                &k0,
331                &mut ys0_full_eval.iter_mut().collect::<Vec<_>>(),
332            );
333            assert_ys_eq(&ys0_full_eval, &ys0, &xs, &alpha);
334            let mut ys1_full_eval = vec![ByteGroup::zero(); 1 << filter_bitn];
335            dpf.full_eval(true, &k1, &mut ys1_full_eval.iter_mut().collect::<Vec<_>>());
336            assert_ys_eq(&ys1_full_eval, &ys1, &xs, &alpha);
337
338            Ok(())
339        });
340    }
341
342    fn assert_ys_eq(ys: &[GroupImpl], ys_expected: &[GroupImpl], xs: &[[u8; 2]], alpha: &[u8; 2]) {
343        let alpha_int = u16::from_be_bytes(*alpha);
344        for (i, (x, (y, y_expected))) in
345            xs.iter().zip(ys.iter().zip(ys_expected.iter())).enumerate()
346        {
347            let x_int = u16::from_be_bytes(*x);
348            let cmp = if x_int < alpha_int {
349                "<"
350            } else if x_int > alpha_int {
351                ">"
352            } else {
353                "="
354            };
355            assert_eq!(y, y_expected, "where i={}, x={:?}, x{}alpha", i, *x, cmp);
356        }
357    }
358}