arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use ff::Field;
use num_traits::Zero;
use std::ops::{Add, AddAssign, Index, IndexMut, Mul, Sub, SubAssign};

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Matrix<T: Copy> {
    data: Vec<T>,
    pub nrows: usize,
    pub ncols: usize,
}

impl<T: Copy> Matrix<T> {
    /// Builds a Matrix.
    /// * size should be nrows then ncols
    /// * item is what will fill the matrix
    pub fn new(size: (usize, usize), item: T) -> Self {
        Matrix {
            data: vec![item; size.0 * size.1],
            nrows: size.0,
            ncols: size.1,
        }
    }
    /// Builds a Matrix.
    /// * size should be nrows then ncols
    /// * iterator is what will fill the matrix. It will be wholly consumed, and will fail if size
    ///   is wrong.
    pub fn new_from_iter<U: Iterator<Item = T>>(size: (usize, usize), iterator: U) -> Self {
        let data: Vec<T> = iterator.collect();
        assert_eq!(
            data.len(),
            size.0 * size.1,
            "iterator of size {} for matrix of size {}x{}",
            data.len(),
            size.0,
            size.1
        );
        Matrix {
            data,
            nrows: size.0,
            ncols: size.1,
        }
    }

    pub fn new_from_column_major_iter<U: Iterator<Item = T>>(
        size: (usize, usize),
        iterator: U,
    ) -> Self {
        let column_major_data = iterator.collect::<Vec<T>>();
        assert_eq!(
            column_major_data.len(),
            size.0 * size.1,
            "iterator of size {} for matrix of size {}x{}",
            column_major_data.len(),
            size.0,
            size.1
        );
        let row_major_data = (0..size.0)
            .flat_map(|i| {
                (0..size.1)
                    .map(|j| column_major_data[i + j * size.0])
                    .collect::<Vec<T>>()
            })
            .collect();

        Matrix {
            data: row_major_data,
            nrows: size.0,
            ncols: size.1,
        }
    }

    fn index(&self, x: usize, y: usize) -> usize {
        x * self.ncols + y
    }
    pub fn get(&self, location: (usize, usize)) -> Option<&T> {
        let (x, y) = location;
        let index = self.index(x, y);
        self.data.get(index)
    }
    pub fn get_mut(&mut self, location: (usize, usize)) -> Option<&mut T> {
        let (x, y) = location;
        let index = self.index(x, y);
        self.data.get_mut(index)
    }
    pub fn col(&self, index: usize) -> Self {
        if index >= self.ncols {
            panic!(
                "index for column extraction must be less than {} (found {})",
                self.ncols, index
            );
        }
        let data = (index..self.nrows * self.ncols)
            .step_by(self.ncols)
            .map(|i| self.data[i])
            .collect();
        Self {
            data,
            nrows: self.nrows,
            ncols: 1,
        }
    }
    /// Applies a function to all items in self, in-place.
    pub fn map_mut(&mut self, mut f: impl FnMut(T) -> T) {
        for x in 0..self.nrows {
            for y in 0..self.ncols {
                self[(x, y)] = f(self[(x, y)]);
            }
        }
    }
    /// Matrix Multiplication
    pub fn mat_mul<U: Copy + Add<Output = U> + Mul<T, Output = U> + Zero>(
        &self,
        rhs: &Matrix<U>,
    ) -> Matrix<U> {
        assert_eq!(self.ncols, rhs.nrows);
        let mut mat: Matrix<U> = Matrix::new((self.nrows, rhs.ncols), U::zero());
        for i in 0..self.nrows {
            for j in 0..rhs.ncols {
                let acc = mat.get_mut((i, j)).unwrap();
                for k in 0..self.ncols {
                    *acc = *acc + rhs[(k, j)] * self[(i, k)];
                }
            }
        }
        mat
    }

    /*
       return the determinant of a square matrix.
       Panics if matrix is not square or of size 0x0.

       The determinant is computed via gaus elimination
    */
    pub fn det(&self) -> T
    where
        T: Field,
    {
        // non-empty square matrix
        assert!(self.nrows == self.ncols && self.ncols != 0);
        let n = self.ncols;

        let mut det = T::ONE;
        let mut rows;
        // we start with the complete matrix, each round will reduce the matrix dimension by one.
        rows = self
            .data
            .chunks(n)
            .map(|c| c.to_vec())
            .collect::<Vec<_>>()
            .clone();

        // Each recursion step removes the pivot's elements row and column and multiplies the pivot
        // onto the determinant.
        for _ in 0..n {
            // we partition into rows that have a leading zero and rows that don't
            let (lz_rows_vec, nlz_rows_vec): (Vec<_>, Vec<_>) =
                rows.iter().partition(|row| row.starts_with(&[T::ZERO]));

            let (lz_rows, mut nlz_rows) = (lz_rows_vec.iter(), nlz_rows_vec.iter());
            // take pivot element
            let Some(pivot) = nlz_rows.next() else {
                // no pivot row implies the rank is less than n i.e. the determinant is zero
                return T::ZERO;
            };

            // multiply pivot onto the determinant
            det *= pivot[0];

            // subtract all leading non zero values with the pivot element (forward elimination).

            let pivot_inverse = pivot[0].invert().unwrap();
            // precomputing pivot row such that the leading value is one. This reduces the number of
            // multiplications in the forward elimination multiplications by 50%
            let normalized_pivot: Vec<_> = pivot.iter().map(|f| *f * pivot_inverse).collect();
            // forward elimination with normalized pivot row
            let processed_nlz_rows = nlz_rows.map(|row| {
                let lead = row[0];
                let row: Vec<_> = row
                    .iter()
                    .zip(&normalized_pivot)
                    .map(move |(f, p)| *f - lead * p)
                    .collect();
                row
            });

            // collect the reamining rows (without pivot row) and remove the pivot column (all first
            // elements (i.e. zeros) from the remaining rows).
            rows = processed_nlz_rows
                .chain(lz_rows.map(|c| c.to_vec()))
                .map(|mut v| v.drain(1..).collect::<Vec<_>>())
                .collect::<Vec<_>>();
        }
        det
    }
    pub fn convert<U: From<T> + Copy>(&self) -> Matrix<U> {
        Matrix::new_from_iter(
            (self.nrows, self.ncols),
            self.into_iter().map(|c| U::from(c)),
        )
    }
}

impl<T: Copy> Index<(usize, usize)> for Matrix<T> {
    type Output = T;

    fn index(&self, index: (usize, usize)) -> &Self::Output {
        self.get(index).unwrap()
    }
}
impl<T: Copy> IndexMut<(usize, usize)> for Matrix<T> {
    fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
        self.get_mut(index).unwrap()
    }
}
impl<T: Copy> IntoIterator for Matrix<T> {
    type Item = T;
    type IntoIter = std::vec::IntoIter<Self::Item>;
    fn into_iter(self) -> Self::IntoIter {
        self.data.into_iter()
    }
}

impl<T: Copy> IntoIterator for &Matrix<T> {
    type Item = T;
    type IntoIter = std::vec::IntoIter<Self::Item>;
    fn into_iter(self) -> Self::IntoIter {
        self.data.clone().into_iter()
    }
}

impl<'a, T: Copy + Add<Output = T>> AddAssign<&'a Matrix<T>> for Matrix<T> {
    fn add_assign(&mut self, rhs: &'a Matrix<T>) {
        assert_eq!(self.nrows, rhs.nrows);
        assert_eq!(self.ncols, rhs.ncols);
        for i in 0..self.nrows {
            for j in 0..self.ncols {
                self[(i, j)] = self[(i, j)] + rhs[(i, j)];
            }
        }
    }
}

impl<'a, T: Copy + Sub<Output = T>> SubAssign<&'a Matrix<T>> for Matrix<T> {
    fn sub_assign(&mut self, rhs: &'a Matrix<T>) {
        assert_eq!(self.nrows, rhs.nrows);
        assert_eq!(self.ncols, rhs.ncols);
        for i in 0..self.nrows {
            for j in 0..self.ncols {
                self[(i, j)] = self[(i, j)] - rhs[(i, j)];
            }
        }
    }
}

impl<T: Copy + Add<Output = T>> Add for Matrix<T> {
    type Output = Matrix<T>;

    fn add(mut self, rhs: Self) -> Self::Output {
        self += &rhs;
        self
    }
}

impl<T: Copy + Sub<Output = T>> Sub for Matrix<T> {
    type Output = Matrix<T>;

    fn sub(mut self, rhs: Self) -> Self::Output {
        self -= &rhs;
        self
    }
}

impl<T: Copy> From<Vec<T>> for Matrix<T> {
    fn from(v: Vec<T>) -> Self {
        let nrows = v.len();
        Self::new_from_iter((nrows, 1), v.into_iter())
    }
}
impl<'a, T: Copy> From<&'a [T]> for Matrix<T> {
    fn from(v: &'a [T]) -> Self {
        let nrows = v.len();
        Self::new_from_iter((nrows, 1), v.iter().copied())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::utils::field::ScalarField;
    use ff::Field;

    type F = ScalarField;

    #[test]
    fn test_det_dim3() {
        // 4 2 4
        // 0 0 3
        // 5 7 7
        let data = vec![
            F::from(4),
            F::from(2),
            F::from(4),
            F::ZERO,
            F::ZERO,
            F::from(3),
            F::from(5),
            F::from(7),
            F::from(7),
        ];

        let mat = Matrix::new_from_iter((3, 3), data.into_iter());

        let det = mat.det();
        assert_eq!(F::from(54), det);
    }

    #[test]
    fn test_det_dim4() {
        let data = vec![
            F::from(6),
            F::from(4),
            F::from(7),
            F::from(8),
            F::from(9),
            F::from(3),
            F::from(9),
            F::from(8),
            F::from(8),
            F::from(3),
            F::from(4),
            F::from(9),
            F::from(5),
            F::from(4),
            F::from(1),
            F::from(3),
        ];

        let mat = Matrix::new_from_iter((4, 4), data.into_iter());
        let det = mat.det();
        assert_eq!(F::from(-476), det);
    }
}