okvs 0.2.0

WIP implementation of Oblivious Key-Value Stores
Documentation
use std::collections::BTreeSet;

use crate::bits::Bits;

#[derive(Debug)]
pub struct SparseBoolMatrix {
    ones_by_row: Vec<BTreeSet<usize>>,
    column_count: usize,
}

impl SparseBoolMatrix {
    pub fn new(column_count: usize) -> Self {
        Self {
            ones_by_row: vec![],
            column_count,
        }
    }

    pub fn push_row(&mut self, one_indices: BTreeSet<usize>) {
        self.ones_by_row.push(one_indices);
    }

    pub fn identity(row_count: usize) -> Self {
        Self {
            ones_by_row: (0..row_count).map(|i| BTreeSet::from([i])).collect(),
            column_count: row_count,
        }
    }

    /// Solves the linear problem, returning `None` for unconstrained columns.
    pub fn solve<V: Bits>(mut self, mut targets: Vec<V>) -> Vec<Option<V>> {
        let row_count = self.ones_by_row.len();

        assert_eq!(row_count, targets.len());

        let mut pivot_row = 0;
        let mut pivot_column = 0;
        while pivot_row <= row_count && pivot_column <= self.column_count {
            // Find a pivot, which is any row below the previous pivot row that has true in the required pivot column
            let mut next_pivot_row = None;
            for i in pivot_row..row_count {
                if self.ones_by_row[i].contains(&pivot_column) {
                    next_pivot_row = Some(i);
                    break;
                }
            }

            // If we found a pivot, swap with the previous pivot row. Otherwise, continue with the next column.
            match next_pivot_row {
                Some(i) => {
                    self.ones_by_row.swap(pivot_row, i);
                    targets.swap(pivot_row, i);
                }
                None => {
                    pivot_column += 1;
                    continue;
                }
            }

            // For all rows below the current pivot
            for i in (pivot_row + 1)..row_count {
                // Fill the lower part of the pivot column with false.
                self.ones_by_row[i].remove(&pivot_column);

                // If the value in the ith row and the pivot column is false, there is nothing to do.
                if !self.ones_by_row[i].contains(&pivot_column) {
                    continue;
                }

                // For the right half of the current row, xor with the same column in the pivot row
                for j in (pivot_column + 1)..self.column_count {
                    if self.ones_by_row[i].contains(&j) {
                        self.ones_by_row[i].remove(&j);
                    } else {
                        self.ones_by_row[i].insert(j);
                    }
                }
            }

            // Increase the pivot row and column
            pivot_row += 1;
            pivot_column += 1;
        }

        // Backwards substitution
        let mut assignment: Vec<Option<V>> = (0..self.column_count).map(|_| None).collect();

        for (k, ones) in self.ones_by_row.into_iter().enumerate().rev() {
            let mut last = None;
            for i in ones {
                match assignment[i] {
                    Some(value) => targets[k] ^= value,
                    None => {
                        last = Some(i);
                        let value = V::random();
                        targets[k] ^= value;
                        assignment[i] = Some(value);
                    }
                }
            }

            assignment[last.unwrap()] = Some(assignment[last.unwrap()].unwrap() ^ targets[k]);
        }

        assignment
    }
}

#[cfg(test)]
mod tests {
    use super::SparseBoolMatrix;

    #[test]
    fn test_identity_5x5() {
        let identity = SparseBoolMatrix::identity(5);
        let targets = vec![10, 20, 30, 40, 50];
        let solution = identity.solve(targets.clone());
        assert_eq!(
            solution,
            vec![Some(10), Some(20), Some(30), Some(40), Some(50)]
        );
    }

    #[test]
    fn test_identity_4x5() {
        let mut identity = SparseBoolMatrix::identity(5);
        identity.ones_by_row.pop();

        let solution = identity.solve(vec![10, 20, 30, 40]);
        assert_eq!(solution[0..4], [Some(10), Some(20), Some(30), Some(40)]);
        assert!(solution[4].is_none());
    }

    #[test]
    fn test_identity_4x5_shifted() {
        let mut identity = SparseBoolMatrix::identity(5);
        identity.ones_by_row.remove(0);

        let solution = identity.solve(vec![10, 20, 30, 40]);
        assert!(solution[0].is_none());
        assert_eq!(solution[1..5], [Some(10), Some(20), Some(30), Some(40)]);
    }
}