1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
//! Defines the Pooled Gauss solving algorithms by Esser, Kübler and May
use crate::{
    oracle::{LpnOracle, StorageBlock},
    random::{lpn_thread_rng, ThreadRng},
};
use m4ri_rust::friendly::solve_left;
use m4ri_rust::friendly::BinMatrix;
use m4ri_rust::friendly::BinVector;
use rand::prelude::SliceRandom;
use rayon::prelude::*;

use std::{
    cell::RefCell,
    sync::{Arc, Mutex},
};

/// Solves an LPN problem using Pooled Gauss
#[allow(clippy::many_single_char_names, clippy::needless_pass_by_value)]
pub fn pooled_gauss_solve(oracle: LpnOracle) -> BinVector {
    let mut rng = lpn_thread_rng();

    let k = oracle.get_k();
    let alpha = 0.5f64.powi(k as i32);
    let tau = (1.0 - oracle.delta) / 2.0;
    let beta = ((1f64 - tau) / 2f64).powi(k as i32);
    let m: f64 = (((1.5 * (1.0 / alpha).ln()).sqrt() + (1.0 / beta).ln().sqrt()) / (0.5 - tau))
        .powi(2)
        .floor();
    let c = (tau * m + (3.0 * (0.5 - tau) * (1.0 / alpha).ln() * m).sqrt().floor()) as u32;
    let m = m as usize;

    log::info!(
        "Attempting Pooled Gauss solving method, k={}, tau={}",
        k,
        tau
    );
    log::trace!("Target secret weight <= {}", c);
    log::trace!("Building (Am, b) with length {}", m);
    let (am, bm) = sample_matrix(m, &oracle, &mut rng);
    debug_assert_eq!(am.ncols(), k);
    debug_assert_eq!(am.nrows(), m);
    debug_assert_eq!(bm.nrows(), m);
    debug_assert_eq!(bm.ncols(), 1);

    let secret = &oracle.secret.as_binvector(k);

    let test = |s_prime: &BinMatrix| {
        debug_assert_eq!(s_prime.nrows(), k);
        debug_assert_eq!(s_prime.ncols(), 1);

        let mut testproduct = &am * s_prime;
        testproduct += &bm;
        let result = testproduct.count_ones() <= c;
        debug_assert_eq!(
            result,
            &s_prime.as_vector() == secret,
            "Test will reject or accept an (in)correct secret with weight {} <= {}",
            testproduct.count_ones(),
            c
        );
        result
    };

    log::debug!("Starting random sampling of invertible (A, b)");

    let s_prime_finder = move |(sender, rng): &mut (Arc<Mutex<Option<BinMatrix>>>, _), _| {
        for _ in 0..10000 {
            // find k-rank matrix
            let (a, mut b) = loop {
                let (a_try, b_try) = sample_matrix(k as usize, &oracle, rng);
                // TODO is this check necessary?
                // TODO avoid allocate?
                if a_try.clone().echelonize() == k as usize {
                    break (a_try, b_try);
                }
            };
            // A*s = b
            if !solve_left(a, &mut b) {
                log::warn!("Somehow, solving failed....");
                continue;
            }
            let result = { test(&b) };
            if result {
                println!("Found {:?}!", b.as_vector());
                let mut sender = sender.lock().unwrap();
                sender.replace(b);
                break;
            }
        }

        if sender.lock().unwrap().is_none() {
            Some(())
        } else {
            None
        }
    };

    let sender_parent = Arc::new(Mutex::new(None));
    let sender = sender_parent.clone();

    rayon::iter::repeat(())
        .try_for_each_init(|| (sender.clone(), lpn_thread_rng()), s_prime_finder);

    let sender = sender_parent.lock().unwrap();
    let s_prime = sender.as_ref().unwrap();

    s_prime.as_vector()
}

/// Randomly sample ``k`` queries from the oracle as a ``(A, s)``.
fn sample_matrix<'a>(k: usize, oracle: &LpnOracle, rng: &mut ThreadRng) -> (BinMatrix, BinMatrix) {
    thread_local!(static TLS: RefCell<(Vec<&'static [StorageBlock]>, BinVector)> = RefCell::new((Vec::new(), BinVector::new())));

    TLS.with(|stor| {
        let mut stor = stor.borrow_mut();
        let (slices, b_bits) = &mut (*stor);
        let samples = oracle.samples.choose_multiple(rng, k);
        slices.extend(samples.map(|q| {
            b_bits.push(q.get_product());
            // we cheat the lifetime but this is okay, because we clear out `samples` at the end.
            unsafe {
                std::mem::transmute::<&'_ [StorageBlock], &'static [StorageBlock]>(q.get_sample())
            }
        }));
        // replace by matrix directly?
        let mat = BinMatrix::from_slices(slices, oracle.get_k());
        let ret_b = b_bits.as_column_matrix();
        slices.clear();
        b_bits.clear();
        (mat, ret_b)
    })
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn run_gauss() {
        let mut oracle: LpnOracle = LpnOracle::new(32, 1.0 / 4.0);
        oracle.get_samples(4000555);
        let secret = oracle.secret.clone();
        let solution = pooled_gauss_solve(oracle);
        assert_eq!(solution, secret.as_binvector(32));
    }
}