okvs 0.2.0

WIP implementation of Oblivious Key-Value Stores
Documentation
use std::collections::{BTreeSet, VecDeque};

use rand::{thread_rng, Rng};

use crate::{
    bits::{bit_at_is_set, Bits},
    elimination::SparseBoolMatrix,
    graph::UndirectedGraph,
    hashable::Hashable,
};

use super::Okvs;

/// A garbled cuckoo table with two hash functions proposed by Pinkas, Rosulek, Trieu, and Yanai. See https://eprint.iacr.org/2020/193.
#[derive(Debug, PartialEq, Clone)]
pub struct Paxos<V: Bits> {
    left: Vec<V>,
    right: Vec<V>,
    first_seed: u64,
    second_seed: u64,
    third_seed: u64,
}

const LIMIT: usize = 8;
const LAMBDA: usize = 24;
const TOTAL: usize = LIMIT + LAMBDA;
const TOTAL_WORDS: usize = (TOTAL + 63) / 64;

#[derive(Debug, Clone, Copy)]
struct EdgeData<V: Bits> {
    target: V,
    randomness: [u64; TOTAL_WORDS],
}

impl<V: Bits> Okvs<V> for Paxos<V> {
    // TODO: Consider making lambda a const generic
    fn try_encode<K: Hashable>(key_value_pairs: &[(K, V)], _lambda: usize) -> Option<Self> {
        // From Figure 4 in https://eprint.iacr.org/2021/883.pdf
        // TODO: Consider making it a power of two, so we can do a cheap modular reduction
        let bin_count = (2.4 * key_value_pairs.len() as f64).ceil() as usize;

        let first_seed = thread_rng().gen();
        let second_seed = thread_rng().gen();
        let third_seed = thread_rng().gen();

        // Create the graph representing the Cuckoo table
        let mut graph = UndirectedGraph::new(bin_count);
        for (key, value) in key_value_pairs {
            let first_hash = key.hash_to_index(first_seed, bin_count);
            let second_hash = key.hash_to_index(second_seed, bin_count);
            let randomness = key.hash_to_bytes::<TOTAL_WORDS>(third_seed);

            if !graph.add_edge(
                first_hash,
                second_hash,
                EdgeData {
                    target: *value,
                    randomness,
                },
            ) {
                return None;
            }
        }

        // Peel the graph, leaving the first graph with its core and creating a second directed acyclic graph for unpeeling
        let mut peeled_edges = VecDeque::new();
        for vertex in 0..bin_count {
            let mut current_vertex = vertex;
            loop {
                if let Some(edge) = graph.pop_only_edge(current_vertex) {
                    let next_vertex = edge.to_vertex;
                    peeled_edges.push_front((current_vertex, edge));
                    current_vertex = next_vertex;
                } else {
                    break;
                }
            }
        }

        assert_eq!(peeled_edges.len() + graph.edge_count, key_value_pairs.len());

        // If we have too many elements remaining, we failed
        if graph.edge_count > LIMIT {
            return None;
        }

        // Construct the random matrix and the target vector
        let mut matrix = SparseBoolMatrix::new(bin_count + TOTAL);
        let mut targets = vec![];
        for vertex in 0..bin_count {
            if !graph.has_edges(vertex) {
                continue;
            }

            loop {
                let edge = graph.pop_edge(vertex);
                if edge.is_none() {
                    break;
                }

                let edge = edge.unwrap();

                // There are still unpeeled elements, we add them to the matrix
                let mut new_row = if vertex == edge.to_vertex {
                    // If they are equal, the ones are XORed to 0 so we can ignore them
                    BTreeSet::new()
                } else {
                    BTreeSet::from([vertex, edge.to_vertex])
                };

                for i in 0..TOTAL {
                    if bit_at_is_set(&edge.value.randomness, i) {
                        new_row.insert(bin_count + i);
                    }
                }

                matrix.push_row(new_row);
                targets.push(edge.value.target);
            }
        }

        // Solve the matrix problem
        let mut assignment = matrix.solve(targets);

        // Split off `right`, and fill in randomness for any `None`s.
        let right: Vec<V> = assignment
            .drain(bin_count..)
            .map(|x| match x {
                Some(value) => value,
                None => V::random(),
            })
            .collect();

        // Perform unpeeling by following the paths through the acyclic graph, starting with the vertices that we assigned by solving the matrix problem.
        for (from, edge) in peeled_edges {
            let mut value_for_from = edge.value.target;

            // Adjust by XORing the result of `right`
            for i in 0..TOTAL {
                if bit_at_is_set(&edge.value.randomness, i) {
                    value_for_from ^= right[i];
                }
            }

            // If to is assigned a value, use that, otherwise use randomness
            if assignment[edge.to_vertex].is_none() {
                assignment[edge.to_vertex] = Some(V::random());
            }

            value_for_from ^= assignment[edge.to_vertex].unwrap();
            assignment[from] = Some(value_for_from);
        }

        // Fill in randomness for any unassigned vertices in `left`
        let left = assignment
            .into_iter()
            .map(|x| match x {
                Some(value) => value,
                None => V::random(),
            })
            .collect();

        Some(Self {
            left,
            right,
            first_seed,
            second_seed,
            third_seed,
        })
    }

    fn decode<K: Hashable>(&self, key: &K) -> V {
        // XOR the two elements selected in L
        let mut result = self.left[key.hash_to_index(self.first_seed, self.left.len())];
        result ^= self.left[key.hash_to_index(self.second_seed, self.left.len())];

        // XOR the randomness generated by R
        let randomness = key.hash_to_bytes::<TOTAL_WORDS>(self.third_seed);

        for i in 0..TOTAL {
            if bit_at_is_set(&randomness, i) {
                result ^= self.right[i];
            }
        }

        result
    }

    fn to_bytes(self) -> Vec<u8> {
        let mut output = self.first_seed.to_bytes();
        output.append(&mut self.second_seed.to_bytes());
        output.append(&mut self.third_seed.to_bytes());
        output.append(&mut self.left.into_iter().flat_map(|x| x.to_bytes()).collect());
        output.append(&mut self.right.into_iter().flat_map(|x| x.to_bytes()).collect());

        output
    }

    fn from_bytes(bytes: &[u8]) -> Self {
        let first_seed = u64::from_bytes(&bytes[..8]);
        let second_seed = u64::from_bytes(&bytes[8..16]);
        let third_seed = u64::from_bytes(&bytes[16..24]);

        let left_count = bytes.len() - 24 - TOTAL * V::BYTES;
        let left = bytes[24..(24 + left_count)]
            .chunks_exact(V::BYTES)
            .map(|chunk| V::from_bytes(chunk))
            .collect();
        let right = bytes[(24 + left_count)..]
            .chunks_exact(V::BYTES)
            .map(|chunk| V::from_bytes(chunk))
            .collect();

        Self {
            left,
            right,
            first_seed,
            second_seed,
            third_seed,
        }
    }
}

#[cfg(test)]
mod tests {
    //use test::Bencher;

    use crate::{
        bits::Bits,
        schemes::{paxos::Paxos, Okvs},
    };

    #[test]
    fn encode_decode_u64() {
        let r1 = u64::random();
        let r2 = u64::random();
        let r3 = u64::random();

        let okvs = Paxos::encode(&[(1u64, r1), (1000u64, r2), (123u64, r3)], 40);

        assert_eq!(okvs.decode(&1u64), r1);
        assert_eq!(okvs.decode(&1000u64), r2);
        assert_eq!(okvs.decode(&123u64), r3);

        assert_ne!(okvs.decode(&0u64), r1);
        assert_ne!(okvs.decode(&0u64), r2);
        assert_ne!(okvs.decode(&0u64), r3);

        assert_ne!(okvs.decode(&2u64), r1);
        assert_ne!(okvs.decode(&2u64), r2);
        assert_ne!(okvs.decode(&2u64), r3);

        assert_ne!(okvs.decode(&u64::MAX), r1);
        assert_ne!(okvs.decode(&u64::MAX), r2);
        assert_ne!(okvs.decode(&u64::MAX), r3);
    }

    #[test]
    fn encode_many_u64() {
        Paxos::encode(
            &(0..10_000).map(|i| (i, u64::random())).collect::<Vec<_>>(),
            40,
        );
    }

    #[test]
    fn encode_decode_u64_serialization() {
        let r1 = u64::random();
        let r2 = u64::random();
        let r3 = u64::random();

        let okvs = Paxos::encode(&[(1u64, r1), (1000u64, r2), (123u64, r3)], 40);

        let bytes = okvs.clone().to_bytes();
        let okvs_retrieved = Paxos::from_bytes(&bytes);

        assert_eq!(okvs, okvs_retrieved);
    }

    // #[bench]
    // fn bench_encode_many_u64(b: &mut Bencher) {
    //     let key_value_pairs = (0..10_000).map(|i| (i, u64::random())).collect::<Vec<_>>();
    //     b.iter(|| Paxos::encode(&key_value_pairs, 40));
    // }

    // Self-reference BTreeMap: 9,924,100 ns/iter (+/- 4,215,162)
    // Self-reference HashMap: 10,031,772 ns/iter (+/- 2,780,928)
    // Self-reference HashMap + Copy: 7,761,418 ns/iter (+/- 2,183,830)
    // Self-reference HashMap + Copy + xxh3: 5,549,983 ns/iter (+/- 1,348,701)
    // Self-reference HashMap + Copy + both xxh3: 4,544,847 ns/iter (+/- 1,309,729)
}