use ndarray::{Array1, Array2, ArrayBase, Data, DataMut, Ix1, Ix2};
use super::zp;
use super::zp::{ucomp, ufield, Modulus};
#[derive(Debug, Eq, PartialEq)]
pub enum LinearSolverError {
Underdetermined { min_rank: usize, max_rank: usize },
Inconsistent,
}
pub fn solve_subsystem<S1: DataMut<Elem = ufield>, M: Modulus<ucomp, ufield>>(
m: &mut ArrayBase<S1, Ix2>,
max_col: usize,
p: M,
) -> Result<usize, LinearSolverError> {
let neqs = m.shape()[0];
let ncols = m.shape()[1];
if neqs < max_col {
return Err(LinearSolverError::Underdetermined {
min_rank: 0,
max_rank: neqs,
});
}
let mut i = 0;
for j in 0..max_col {
if m[(i, j)] == 0 {
for k in i + 1..neqs {
if m[(k, j)] != 0 {
for l in j..ncols {
m.swap((i, l), (k, l));
}
break;
}
}
if m[(i, j)] == 0 {
return Err(LinearSolverError::Underdetermined {
min_rank: i,
max_rank: max_col - 1,
});
}
}
let x = m[(i, j)];
let inv_x = zp::inv(x, p);
for k in i + 1..neqs {
if m[(k, j)] != 0 {
let s = zp::mul(m[(k, j)], inv_x, p);
m[(k, j)] = 0;
for l in j + 1..ncols {
m[(k, l)] = zp::sub(m[(k, l)], zp::mul(m[(i, l)], s, p), p);
}
}
}
i += 1;
if i >= neqs {
break;
}
}
Ok(i)
}
pub fn solve<S1: Data<Elem = ufield>, S2: Data<Elem = ufield>, M: Modulus<ucomp, ufield>>(
a: &ArrayBase<S1, Ix2>,
b: &ArrayBase<S2, Ix1>,
p: M,
) -> Result<Array1<ufield>, LinearSolverError> {
assert!(a.shape()[0] == b.shape()[0]);
let neqs = a.shape()[0];
let nvars = a.shape()[1];
if neqs < nvars {
return Err(LinearSolverError::Underdetermined {
min_rank: 0,
max_rank: neqs,
});
}
let mut m = unsafe { Array2::<ufield>::uninitialized((neqs, nvars + 1)) };
for (i, e) in a.indexed_iter() {
m[i] = *e;
}
for (i, e) in b.indexed_iter() {
m[(i, nvars)] = *e;
}
let mut i = match solve_subsystem(&mut m, nvars, p) {
Ok(i) => i,
Err(x) => {
return Err(x);
}
};
let rank = i;
for k in rank..neqs {
if m[(k, nvars)] != 0 {
return Err(LinearSolverError::Inconsistent);
}
}
if rank < nvars {
return Err(LinearSolverError::Underdetermined {
min_rank: rank,
max_rank: rank,
});
}
assert_eq!(rank, nvars);
i -= 1;
for j in (0..nvars).rev() {
if m[(i, j)] != 1 {
let x = m[(i, j)];
let inv_x = zp::inv(x, p);
#[cfg(debug_assertions)]
{
m[(i, j)] = 1;
}
m[(i, nvars)] = zp::mul(m[(i, nvars)], inv_x, p);
}
for k in 0..i {
if m[(k, j)] != 0 {
m[(k, nvars)] = zp::sub(m[(k, nvars)], zp::mul(m[(i, nvars)], m[(k, j)], p), p);
#[cfg(debug_assertions)]
{
m[(k, j)] = 0;
}
}
}
if i == 0 {
break;
}
i -= 1;
}
Ok(Array1::<ufield>::from_iter(
(0..nvars).map(|i| m[(i, nvars)]),
))
}
#[test]
fn test_solve() {
use ndarray::{arr1, arr2};
let a: Array2<ufield> = arr2(&[[1, 1, 2], [3, 4, 3], [16, 5, 5]]);
let b: Array1<ufield> = arr1(&[3, 15, 8]);
let p: ufield = 17;
let r = solve(&a, &b, p).unwrap();
assert_eq!(r, arr1(&[2, 3, 16]));
}
#[test]
#[should_panic]
fn test_solve_bad_shape() {
use ndarray::{arr1, arr2};
let a: Array2<ufield> = arr2(&[[1, 1, 2], [3, 4, 3], [16, 5, 5]]);
let b: Array1<ufield> = arr1(&[3, 15, 8, 1]);
let p: ufield = 17;
let _ = solve(&a, &b, p);
}
#[test]
fn test_solve_underdetermined1() {
use ndarray::{arr1, arr2};
let a: Array2<ufield> = arr2(&[[1, 1, 2], [3, 4, 3]]);
let b: Array1<ufield> = arr1(&[3, 15]);
let p: ufield = 17;
let r = solve(&a, &b, p);
assert_eq!(
r,
Err(LinearSolverError::Underdetermined {
min_rank: 0,
max_rank: 2
})
);
}
#[test]
fn test_solve_underdetermined2() {
use ndarray::{arr1, arr2};
let a: Array2<ufield> = arr2(&[[1, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 2]]);
let b: Array1<ufield> = arr1(&[1, 15, 1, 2]);
let p: ufield = 17;
let r = solve(&a, &b, p);
assert_eq!(
r,
Err(LinearSolverError::Underdetermined {
min_rank: 1,
max_rank: 3
})
);
}
#[test]
fn test_solve_underdetermined3() {
use ndarray::{arr1, arr2};
let a: Array2<ufield> = arr2(&[[1, 1, 2], [3, 4, 3], [10, 7, 12]]);
let b: Array1<ufield> = arr1(&[3, 15, 12]);
let p: ufield = 17;
let r = solve(&a, &b, p);
assert_eq!(
r,
Err(LinearSolverError::Underdetermined {
min_rank: 2,
max_rank: 2
})
);
}
#[test]
fn test_solve_overdetermined() {
use ndarray::{arr1, arr2};
let a: Array2<ufield> = arr2(&[[1, 1, 2], [3, 4, 3], [9, 0, 11], [1, 1, 7], [2, 3, 8]]);
let b: Array1<ufield> = arr1(&[3, 15, 7, 6, 6]);
let p: ufield = 17;
let r = solve(&a, &b, p).unwrap();
assert_eq!(r, arr1(&[11, 1, 4]));
}
#[test]
fn test_solve_inconsistent() {
use ndarray::{arr1, arr2};
let a: Array2<ufield> = arr2(&[[1, 1, 2], [3, 4, 3], [16, 5, 5], [14, 2, 4]]);
let b: Array1<ufield> = arr1(&[3, 15, 8, 3]);
let p: ufield = 17;
let r = solve(&a, &b, p);
assert_eq!(r, Err(LinearSolverError::Inconsistent));
}