dpf_fss/
lib.rs

1// Copyright (C) myl7
2// SPDX-License-Identifier: Apache-2.0
3
4//! See [`Dpf`]
5
6extern crate group_math as group;
7
8#[cfg(feature = "prg")]
9pub mod prg;
10
11use bitvec::prelude::*;
12pub use fss_types::PointFn;
13use fss_types::{decl_prg_trait, Cw, Share};
14use group::byte::utils::{xor, xor_inplace};
15pub use group::Group;
16#[cfg(feature = "multithread")]
17use rayon::prelude::*;
18
19/// API of Distributed point function.
20///
21/// `PointFn` used here means `$f(x) = \beta$` iff. `$x = \alpha$`, otherwise `$f(x) = 0$`.
22///
23/// See [`PointFn`] for `N` and `LAMBDA`.
24pub trait Dpf<const N: usize, const LAMBDA: usize, G>
25where
26    G: Group<LAMBDA>,
27{
28    /// `s0s` is `$s^{(0)}_0$` and `$s^{(0)}_1$` which should be randomly sampled
29    fn gen(&self, f: &PointFn<N, LAMBDA, G>, s0s: [&[u8; LAMBDA]; 2]) -> Share<LAMBDA, G>;
30
31    /// `b` is the party. `false` is 0 and `true` is 1.
32    fn eval(&self, b: bool, k: &Share<LAMBDA, G>, xs: &[&[u8; N]], ys: &mut [&mut G]);
33}
34
35decl_prg_trait!(([u8; LAMBDA], bool));
36
37/// Implementation of [`Dpf`].
38///
39/// `$\alpha$` itself is not included, which means `$f(\alpha)$ = 0`.
40pub struct DpfImpl<const N: usize, const LAMBDA: usize, PrgT>
41where
42    PrgT: Prg<LAMBDA>,
43{
44    prg: PrgT,
45}
46
47impl<const N: usize, const LAMBDA: usize, PrgT> DpfImpl<N, LAMBDA, PrgT>
48where
49    PrgT: Prg<LAMBDA>,
50{
51    pub fn new(prg: PrgT) -> Self {
52        Self { prg }
53    }
54}
55
56const IDX_L: usize = 0;
57const IDX_R: usize = 1;
58
59impl<const N: usize, const LAMBDA: usize, PrgT, G> Dpf<N, LAMBDA, G> for DpfImpl<N, LAMBDA, PrgT>
60where
61    PrgT: Prg<LAMBDA>,
62    G: Group<LAMBDA>,
63{
64    fn gen(&self, f: &PointFn<N, LAMBDA, G>, s0s: [&[u8; LAMBDA]; 2]) -> Share<LAMBDA, G> {
65        // The bit size of `$\alpha$`
66        let n = 8 * N;
67        // let mut v_alpha = G::zero();
68        let mut ss = Vec::<[[u8; LAMBDA]; 2]>::with_capacity(n + 1);
69        // Set `$s^{(1)}_0$` and `$s^{(1)}_1$`
70        ss.push([s0s[0].to_owned(), s0s[1].to_owned()]);
71        let mut ts = Vec::<[bool; 2]>::with_capacity(n + 1);
72        // Set `$t^{(0)}_0$` and `$t^{(0)}_1$`
73        ts.push([false, true]);
74        let mut cws = Vec::<Cw<LAMBDA, G>>::with_capacity(n);
75        for i in 1..n + 1 {
76            let [(s0l, t0l), (s0r, t0r)] = self.prg.gen(&ss[i - 1][0]);
77            let [(s1l, t1l), (s1r, t1r)] = self.prg.gen(&ss[i - 1][1]);
78            // MSB is required since we index from high to low in arrays
79            let alpha_i = f.alpha.view_bits::<Msb0>()[i - 1];
80            let (keep, lose) = if alpha_i {
81                (IDX_R, IDX_L)
82            } else {
83                (IDX_L, IDX_R)
84            };
85            let s_cw = xor(&[[&s0l, &s0r][lose], [&s1l, &s1r][lose]]);
86            let tl_cw = t0l ^ t1l ^ alpha_i ^ true;
87            let tr_cw = t0r ^ t1r ^ alpha_i;
88            let cw = Cw {
89                s: s_cw,
90                v: G::zero(),
91                tl: tl_cw,
92                tr: tr_cw,
93            };
94            cws.push(cw);
95            ss.push([
96                xor(&[
97                    [&s0l, &s0r][keep],
98                    if ts[i - 1][0] { &s_cw } else { &[0; LAMBDA] },
99                ]),
100                xor(&[
101                    [&s1l, &s1r][keep],
102                    if ts[i - 1][1] { &s_cw } else { &[0; LAMBDA] },
103                ]),
104            ]);
105            ts.push([
106                [t0l, t0r][keep] ^ (ts[i - 1][0] & [tl_cw, tr_cw][keep]),
107                [t1l, t1r][keep] ^ (ts[i - 1][1] & [tl_cw, tr_cw][keep]),
108            ]);
109        }
110        assert_eq!((ss.len(), ts.len(), cws.len()), (n + 1, n + 1, n));
111        let cw_np1 = (f.beta.clone() + Into::<G>::into(ss[n][0]).add_inverse() + ss[n][1].into())
112            .add_inverse_if(ts[n][1]);
113        Share {
114            s0s: vec![s0s[0].to_owned(), s0s[1].to_owned()],
115            cws,
116            cw_np1,
117        }
118    }
119
120    fn eval(&self, b: bool, k: &Share<LAMBDA, G>, xs: &[&[u8; N]], ys: &mut [&mut G]) {
121        let n = k.cws.len();
122        assert_eq!(n, N * 8);
123        let f = |x: &[u8; N], v: &mut G| {
124            let mut ss = Vec::<[u8; LAMBDA]>::with_capacity(n + 1);
125            ss.push(k.s0s[0].to_owned());
126            let mut ts = Vec::<bool>::with_capacity(n + 1);
127            ts.push(b);
128            for i in 1..n + 1 {
129                let cw = &k.cws[i - 1];
130                let [(mut sl, mut tl), (mut sr, mut tr)] = self.prg.gen(&ss[i - 1]);
131                xor_inplace(&mut sl, &[if ts[i - 1] { &cw.s } else { &[0; LAMBDA] }]);
132                xor_inplace(&mut sr, &[if ts[i - 1] { &cw.s } else { &[0; LAMBDA] }]);
133                tl ^= ts[i - 1] & cw.tl;
134                tr ^= ts[i - 1] & cw.tr;
135                if x.view_bits::<Msb0>()[i - 1] {
136                    ss.push(sr);
137                    ts.push(tr);
138                } else {
139                    ss.push(sl);
140                    ts.push(tl);
141                }
142            }
143            assert_eq!((ss.len(), ts.len()), (n + 1, n + 1));
144            *v = (Into::<G>::into(ss[n]) + if ts[n] { k.cw_np1.clone() } else { G::zero() })
145                .add_inverse_if(b);
146        };
147        #[cfg(feature = "multithread")]
148        {
149            xs.par_iter()
150                .zip(ys.par_iter_mut())
151                .for_each(|(x, y)| f(x, y));
152        }
153        #[cfg(not(feature = "multithread"))]
154        {
155            xs.iter().zip(ys.iter_mut()).for_each(|(x, y)| f(x, y));
156        }
157    }
158}
159
160#[cfg(all(test, feature = "prg"))]
161mod tests {
162    use super::*;
163
164    use group::byte::ByteGroup;
165    use rand::{thread_rng, Rng};
166
167    use crate::prg::Aes256HirosePrg;
168
169    const KEYS: [&[u8; 32]; 2] = [
170        b"j9\x1b_\xb3X\xf33\xacW\x15\x1b\x0812K\xb3I\xb9\x90r\x1cN\xb5\xee9W\xd3\xbb@\xc6d",
171        b"\x9b\x15\xc8\x0f\xb7\xbc!q\x9e\x89\xb8\xf7\x0e\xa0S\x9dN\xfa\x0c;\x16\xe4\x98\x82b\xfcdy\xb5\x8c{\xc2",
172    ];
173    const ALPHAS: &[&[u8; 16]] = &[
174        b"K\xa9W\xf5\xdd\x05\xe9\xfc?\x04\xf6\xfbUo\xa8C",
175        b"\xc2GK\xda\xc6\xbb\x99\x98Fq\"f\xb7\x8csU",
176        b"\xc2GK\xda\xc6\xbb\x99\x98Fq\"f\xb7\x8csV",
177        b"\xc2GK\xda\xc6\xbb\x99\x98Fq\"f\xb7\x8csW",
178        b"\xef\x96\x97\xd7\x8f\x8a\xa4AP\n\xb35\xb5k\xff\x97",
179    ];
180    const BETA: &[u8; 16] = b"\x03\x11\x97\x12C\x8a\xe9#\x81\xa8\xde\xa8\x8f \xc0\xbb";
181
182    #[test]
183    fn test_dpf_gen_then_eval_ok() {
184        let prg = Aes256HirosePrg::new(KEYS);
185        let dpf = DpfImpl::<16, 16, _>::new(prg);
186        let s0s: [[u8; 16]; 2] = thread_rng().gen();
187        let f = PointFn {
188            alpha: ALPHAS[2].to_owned(),
189            beta: BETA.clone().into(),
190        };
191        let k = dpf.gen(&f, [&s0s[0], &s0s[1]]);
192        let mut k0 = k.clone();
193        k0.s0s = vec![k0.s0s[0]];
194        let mut k1 = k.clone();
195        k1.s0s = vec![k1.s0s[1]];
196        let mut ys0 = vec![ByteGroup::zero(); ALPHAS.len()];
197        let mut ys1 = vec![ByteGroup::zero(); ALPHAS.len()];
198        dpf.eval(false, &k0, ALPHAS, &mut ys0.iter_mut().collect::<Vec<_>>());
199        dpf.eval(true, &k1, ALPHAS, &mut ys1.iter_mut().collect::<Vec<_>>());
200        ys0.iter_mut()
201            .zip(ys1.iter())
202            .for_each(|(y0, y1)| *y0 += y1.clone());
203        ys1 = vec![
204            ByteGroup::zero(),
205            ByteGroup::zero(),
206            BETA.clone().into(),
207            ByteGroup::zero(),
208            ByteGroup::zero(),
209        ];
210        assert_eq!(ys0, ys1);
211    }
212
213    #[test]
214    fn test_dpf_gen_then_eval_not_zeros() {
215        let prg = Aes256HirosePrg::new(KEYS);
216        let dpf = DpfImpl::<16, 16, _>::new(prg);
217        let s0s: [[u8; 16]; 2] = thread_rng().gen();
218        let f = PointFn {
219            alpha: ALPHAS[2].to_owned(),
220            beta: BETA.clone().into(),
221        };
222        let k = dpf.gen(&f, [&s0s[0], &s0s[1]]);
223        let mut k0 = k.clone();
224        k0.s0s = vec![k0.s0s[0]];
225        let mut k1 = k.clone();
226        k1.s0s = vec![k1.s0s[1]];
227        let mut ys0 = vec![ByteGroup::zero(); ALPHAS.len()];
228        let mut ys1 = vec![ByteGroup::zero(); ALPHAS.len()];
229        dpf.eval(false, &k0, ALPHAS, &mut ys0.iter_mut().collect::<Vec<_>>());
230        dpf.eval(true, &k1, ALPHAS, &mut ys1.iter_mut().collect::<Vec<_>>());
231        assert_ne!(ys0[2], ByteGroup::zero());
232        assert_ne!(ys1[2], ByteGroup::zero());
233    }
234}