use std::{
cmp::{max, min},
fmt,
ops::{Add, Div, Index, IndexMut, Mul, Neg, Sub},
};
use anyhow::{bail, Result};
use matrixmultiply::CGemmOption;
use num_complex::Complex;
use peroxide_num::{ExpLogOps, PowOps, TrigOps};
use rand_distr::num_traits::{One, Zero};
use crate::{
complex::C64,
structure::matrix::Shape,
traits::fp::{FPMatrix, FPVector},
traits::general::Algorithm,
traits::math::{InnerProduct, LinearOp, MatrixProduct, Norm, Normed, Vector},
traits::matrix::{Form, LinearAlgebra, MatrixTrait, SolveKind, PQLU, QR, SVD, WAZD},
traits::mutable::MutMatrix,
util::low_level::{copy_vec_ptr, swap_vec_ptr},
util::non_macro::ConcatenateError,
util::useful::{nearly_eq, tab},
};
#[derive(Debug, Clone, Default)]
pub struct ComplexMatrix {
pub data: Vec<C64>,
pub row: usize,
pub col: usize,
pub shape: Shape,
}
pub fn cmatrix<T>(v: Vec<T>, r: usize, c: usize, shape: Shape) -> ComplexMatrix
where
T: Into<C64>,
{
ComplexMatrix {
data: v.into_iter().map(|t| t.into()).collect::<Vec<C64>>(),
row: r,
col: c,
shape,
}
}
pub fn r_cmatrix<T>(v: Vec<T>, r: usize, c: usize, shape: Shape) -> ComplexMatrix
where
T: Into<C64>,
{
cmatrix(v, r, c, shape)
}
pub fn py_cmatrix<T>(v: Vec<Vec<T>>) -> ComplexMatrix
where
T: Into<C64> + Copy,
{
let r = v.len();
let c = v[0].len();
let data: Vec<T> = v.into_iter().flatten().collect();
cmatrix(data, r, c, Shape::Row)
}
pub fn ml_cmatrix(s: &str) -> ComplexMatrix {
let str_row = s.split(";").collect::<Vec<&str>>();
let r = str_row.len();
let str_data = str_row
.iter()
.map(|x| x.trim().split(" ").collect::<Vec<&str>>())
.collect::<Vec<Vec<&str>>>();
let c = str_data[0].len();
let data = str_data
.iter()
.flat_map(|t| {
t.iter()
.map(|x| x.parse::<C64>().unwrap())
.collect::<Vec<C64>>()
})
.collect::<Vec<C64>>();
cmatrix(data, r, c, Shape::Row)
}
impl fmt::Display for ComplexMatrix {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.spread())
}
}
impl PartialEq for ComplexMatrix {
fn eq(&self, other: &ComplexMatrix) -> bool {
if self.shape == other.shape {
self.data
.clone()
.into_iter()
.zip(other.data.clone())
.all(|(x, y)| nearly_eq(x.re, y.re) && nearly_eq(x.im, y.im))
&& self.row == other.row
} else {
self.eq(&other.change_shape())
}
}
}
impl MatrixTrait for ComplexMatrix {
type Scalar = C64;
fn ptr(&self) -> *const C64 {
&self.data[0] as *const C64
}
fn mut_ptr(&mut self) -> *mut C64 {
&mut self.data[0] as *mut C64
}
fn as_slice(&self) -> &[C64] {
&self.data[..]
}
fn as_mut_slice(&mut self) -> &mut [C64] {
&mut self.data[..]
}
fn change_shape(&self) -> Self {
let r = self.row;
let c = self.col;
assert_eq!(r * c, self.data.len());
let l = r * c - 1;
let mut data: Vec<C64> = self.data.clone();
let ref_data = &self.data;
match self.shape {
Shape::Row => {
for (i, slot) in data.iter_mut().enumerate().take(l) {
let s = (i * c) % l;
*slot = ref_data[s];
}
data[l] = ref_data[l];
cmatrix(data, r, c, Shape::Col)
}
Shape::Col => {
for (i, slot) in data.iter_mut().enumerate().take(l) {
let s = (i * r) % l;
*slot = ref_data[s];
}
data[l] = ref_data[l];
cmatrix(data, r, c, Shape::Row)
}
}
}
fn change_shape_mut(&mut self) {
let r = self.row;
let c = self.col;
assert_eq!(r * c, self.data.len());
let l = r * c - 1;
let ref_data = self.data.clone();
match self.shape {
Shape::Row => {
for i in 0..l {
let s = (i * c) % l;
self.data[i] = ref_data[s];
}
self.data[l] = ref_data[l];
self.shape = Shape::Col;
}
Shape::Col => {
for i in 0..l {
let s = (i * r) % l;
self.data[i] = ref_data[s];
}
self.data[l] = ref_data[l];
self.shape = Shape::Row;
}
}
}
fn spread(&self) -> String {
assert_eq!(self.row * self.col, self.data.len());
let r = self.row;
let c = self.col;
let mut key_row = 20usize;
let mut key_col = 20usize;
if r > 100 || c > 100 || (r > 20 && c > 20) {
let part = if r <= 10 {
key_row = r;
key_col = 100;
self.take_col(100)
} else if c <= 10 {
key_row = 100;
key_col = c;
self.take_row(100)
} else {
self.take_row(20).take_col(20)
};
return format!(
"Result is too large to print - {}x{}\n only print {}x{} parts:\n{}",
self.row,
self.col,
key_row,
key_col,
part.spread()
);
}
let sample = self.data.clone();
let mut space: usize = sample
.into_iter()
.map(
|x| min(format!("{:.4}", x).len(), x.to_string().len()), )
.fold(0, max)
+ 1;
if space < 5 {
space = 5;
}
let mut result = String::new();
result.push_str(&tab("", 5));
for i in 0..c {
result.push_str(&tab(&format!("c[{}]", i), space)); }
result.push('\n');
for i in 0..r {
result.push_str(&tab(&format!("r[{}]", i), 5));
for j in 0..c {
let st1 = format!("{:.4}", self[(i, j)]); let st2 = self[(i, j)].to_string(); let mut st = st2.clone();
if st1.len() < st2.len() {
st = st1;
}
result.push_str(&tab(&st, space));
}
if i == (r - 1) {
break;
}
result.push('\n');
}
result
}
fn col(&self, index: usize) -> Vec<C64> {
assert!(index < self.col);
let mut container: Vec<C64> = vec![Complex::zero(); self.row];
for i in 0..self.row {
container[i] = self[(i, index)];
}
container
}
fn row(&self, index: usize) -> Vec<C64> {
assert!(index < self.row);
let mut container: Vec<C64> = vec![Complex::zero(); self.col];
for i in 0..self.col {
container[i] = self[(index, i)];
}
container
}
fn diag(&self) -> Vec<C64> {
let mut container = vec![Complex::zero(); self.row];
let r = self.row;
let c = self.col;
assert_eq!(r, c);
let c2 = c + 1;
for (i, slot) in container.iter_mut().enumerate().take(r) {
*slot = self.data[i * c2];
}
container
}
fn transpose(&self) -> Self {
match self.shape {
Shape::Row => cmatrix(self.data.clone(), self.col, self.row, Shape::Col),
Shape::Col => cmatrix(self.data.clone(), self.col, self.row, Shape::Row),
}
}
#[inline]
fn subs_col(&mut self, idx: usize, v: &[C64]) {
for i in 0..self.row {
self[(i, idx)] = v[i];
}
}
#[inline]
fn subs_row(&mut self, idx: usize, v: &[C64]) {
for j in 0..self.col {
self[(idx, j)] = v[j];
}
}
fn from_index<F, G>(f: F, size: (usize, usize)) -> ComplexMatrix
where
F: Fn(usize, usize) -> G + Copy,
G: Into<C64>,
{
let row = size.0;
let col = size.1;
let mut mat = cmatrix(vec![Complex::zero(); row * col], row, col, Shape::Row);
for i in 0..row {
for j in 0..col {
mat[(i, j)] = f(i, j).into();
}
}
mat
}
fn to_vec(&self) -> Vec<Vec<C64>> {
let mut result = vec![vec![Complex::zero(); self.col]; self.row];
for (i, slot) in result.iter_mut().enumerate().take(self.row) {
*slot = self.row(i);
}
result
}
fn to_diag(&self) -> ComplexMatrix {
assert_eq!(self.row, self.col, "Should be square matrix");
let mut result = cmatrix(
vec![Complex::zero(); self.row * self.col],
self.row,
self.col,
Shape::Row,
);
let diag = self.diag();
for i in 0..self.row {
result[(i, i)] = diag[i];
}
result
}
fn submat(&self, start: (usize, usize), end: (usize, usize)) -> ComplexMatrix {
let row = end.0 - start.0 + 1;
let col = end.1 - start.1 + 1;
let mut result = cmatrix(vec![Complex::zero(); row * col], row, col, self.shape);
for i in 0..row {
for j in 0..col {
result[(i, j)] = self[(start.0 + i, start.1 + j)];
}
}
result
}
fn subs_mat(&mut self, start: (usize, usize), end: (usize, usize), m: &ComplexMatrix) {
let row = end.0 - start.0 + 1;
let col = end.1 - start.1 + 1;
for i in 0..row {
for j in 0..col {
self[(start.0 + i, start.1 + j)] = m[(i, j)];
}
}
}
}
impl Vector for ComplexMatrix {
type Scalar = C64;
fn add_vec(&self, other: &Self) -> Self {
assert_eq!(self.row, other.row);
assert_eq!(self.col, other.col);
let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
for i in 0..self.row {
for j in 0..self.col {
result[(i, j)] += other[(i, j)];
}
}
result
}
fn sub_vec(&self, other: &Self) -> Self {
assert_eq!(self.row, other.row);
assert_eq!(self.col, other.col);
let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
for i in 0..self.row {
for j in 0..self.col {
result[(i, j)] -= other[(i, j)];
}
}
result
}
fn mul_scalar(&self, other: Self::Scalar) -> Self {
let scalar = other;
self.fmap(|x| x * scalar)
}
}
impl Normed for ComplexMatrix {
type UnsignedScalar = f64;
fn norm(&self, kind: Norm) -> Self::UnsignedScalar {
match kind {
Norm::F => {
let mut s = Complex::zero();
for i in 0..self.data.len() {
s += self.data[i].powi(2);
}
s.sqrt().re
}
Norm::Lpq(p, q) => {
let mut s = Complex::zero();
for j in 0..self.col {
let mut s_row = Complex::zero();
for i in 0..self.row {
s_row += self[(i, j)].powi(p as i32);
}
s += s_row.powf(q / p);
}
s.powf(1f64 / q).re
}
Norm::L1 => {
let mut m = Complex::zero();
match self.shape {
Shape::Row => self.change_shape().norm(Norm::L1),
Shape::Col => {
for c in 0..self.col {
let s: C64 = self.col(c).iter().sum();
if s.re > m.re {
m = s;
}
}
m.re
}
}
}
Norm::LInf => {
let mut m = Complex::zero();
match self.shape {
Shape::Col => self.change_shape().norm(Norm::LInf),
Shape::Row => {
for r in 0..self.row {
let s: C64 = self.row(r).iter().sum();
if s.re > m.re {
m = s;
}
}
m.re
}
}
}
Norm::L2 => {
unimplemented!()
}
Norm::Lp(_) => unimplemented!(),
}
}
fn normalize(&self, _kind: Norm) -> Self
where
Self: Sized,
{
unimplemented!()
}
}
impl InnerProduct for ComplexMatrix {
fn dot(&self, rhs: &Self) -> C64 {
if self.shape == rhs.shape {
self.data.dot(&rhs.data)
} else {
self.data.dot(&rhs.change_shape().data)
}
}
}
#[allow(non_snake_case)]
impl LinearOp<Vec<C64>, Vec<C64>> for ComplexMatrix {
fn apply(&self, other: &Vec<C64>) -> Vec<C64> {
assert_eq!(self.col, other.len());
let mut c = vec![Complex::zero(); self.row];
cgemv(Complex::one(), self, other, Complex::zero(), &mut c);
c
}
}
pub fn complex_cbind(m1: ComplexMatrix, m2: ComplexMatrix) -> Result<ComplexMatrix> {
let mut temp = m1;
if temp.shape != Shape::Col {
temp = temp.change_shape();
}
let mut temp2 = m2;
if temp2.shape != Shape::Col {
temp2 = temp2.change_shape();
}
let mut v = temp.data;
let mut c = temp.col;
let r = temp.row;
if r != temp2.row {
bail!(ConcatenateError::DifferentLength);
}
v.extend_from_slice(&temp2.data[..]);
c += temp2.col;
Ok(cmatrix(v, r, c, Shape::Col))
}
pub fn complex_rbind(m1: ComplexMatrix, m2: ComplexMatrix) -> Result<ComplexMatrix> {
let mut temp = m1;
if temp.shape != Shape::Row {
temp = temp.change_shape();
}
let mut temp2 = m2;
if temp2.shape != Shape::Row {
temp2 = temp2.change_shape();
}
let mut v = temp.data;
let c = temp.col;
let mut r = temp.row;
if c != temp2.col {
bail!(ConcatenateError::DifferentLength);
}
v.extend_from_slice(&temp2.data[..]);
r += temp2.row;
Ok(cmatrix(v, r, c, Shape::Row))
}
impl MatrixProduct for ComplexMatrix {
fn kronecker(&self, other: &Self) -> Self {
let r1 = self.row;
let c1 = self.col;
let mut result = self[(0, 0)] * other;
for j in 1..c1 {
let n = self[(0, j)] * other;
result = complex_cbind(result, n).unwrap();
}
for i in 1..r1 {
let mut m = self[(i, 0)] * other;
for j in 1..c1 {
let n = self[(i, j)] * other;
m = complex_cbind(m, n).unwrap();
}
result = complex_rbind(result, m).unwrap();
}
result
}
fn hadamard(&self, other: &Self) -> Self {
assert_eq!(self.row, other.row);
assert_eq!(self.col, other.col);
let r = self.row;
let c = self.col;
let mut m = cmatrix(vec![Complex::zero(); r * c], r, c, self.shape);
for i in 0..r {
for j in 0..c {
m[(i, j)] = self[(i, j)] * other[(i, j)]
}
}
m
}
}
impl From<ComplexMatrix> for Vec<C64> {
fn from(val: ComplexMatrix) -> Self {
val.data
}
}
impl<'a> From<&'a ComplexMatrix> for &'a Vec<C64> {
fn from(val: &'a ComplexMatrix) -> Self {
&val.data
}
}
impl From<Vec<C64>> for ComplexMatrix {
fn from(val: Vec<C64>) -> Self {
let l = val.len();
cmatrix(val, l, 1, Shape::Col)
}
}
impl From<&Vec<C64>> for ComplexMatrix {
fn from(val: &Vec<C64>) -> Self {
let l = val.len();
cmatrix(val.clone(), l, 1, Shape::Col)
}
}
impl Add<ComplexMatrix> for ComplexMatrix {
type Output = Self;
fn add(self, other: Self) -> Self {
assert_eq!(&self.row, &other.row);
assert_eq!(&self.col, &other.col);
let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
for i in 0..self.row {
for j in 0..self.col {
result[(i, j)] += other[(i, j)];
}
}
result
}
}
impl<'b> Add<&'b ComplexMatrix> for &ComplexMatrix {
type Output = ComplexMatrix;
fn add(self, rhs: &'b ComplexMatrix) -> Self::Output {
self.add_vec(rhs)
}
}
impl<T> Add<T> for ComplexMatrix
where
T: Into<C64> + Copy,
{
type Output = Self;
fn add(self, other: T) -> Self {
self.fmap(|x| x + other.into())
}
}
impl<T> Add<T> for &ComplexMatrix
where
T: Into<C64> + Copy,
{
type Output = ComplexMatrix;
fn add(self, other: T) -> Self::Output {
self.fmap(|x| x + other.into())
}
}
impl Add<ComplexMatrix> for C64 {
type Output = ComplexMatrix;
fn add(self, other: ComplexMatrix) -> Self::Output {
other.add(self)
}
}
impl<'a> Add<&'a ComplexMatrix> for C64 {
type Output = ComplexMatrix;
fn add(self, other: &'a ComplexMatrix) -> Self::Output {
other.add(self)
}
}
impl Neg for ComplexMatrix {
type Output = Self;
fn neg(self) -> Self {
cmatrix(
self.data.into_iter().map(|x: C64| -x).collect::<Vec<C64>>(),
self.row,
self.col,
self.shape,
)
}
}
impl Neg for &ComplexMatrix {
type Output = ComplexMatrix;
fn neg(self) -> Self::Output {
cmatrix(
self.data
.clone()
.into_iter()
.map(|x: C64| -x)
.collect::<Vec<C64>>(),
self.row,
self.col,
self.shape,
)
}
}
impl Sub<ComplexMatrix> for ComplexMatrix {
type Output = Self;
fn sub(self, other: Self) -> Self::Output {
assert_eq!(&self.row, &other.row);
assert_eq!(&self.col, &other.col);
let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
for i in 0..self.row {
for j in 0..self.col {
result[(i, j)] -= other[(i, j)];
}
}
result
}
}
impl<'b> Sub<&'b ComplexMatrix> for &ComplexMatrix {
type Output = ComplexMatrix;
fn sub(self, rhs: &'b ComplexMatrix) -> Self::Output {
self.sub_vec(rhs)
}
}
impl<T> Sub<T> for ComplexMatrix
where
T: Into<C64> + Copy,
{
type Output = Self;
fn sub(self, other: T) -> Self::Output {
self.fmap(|x| x - other.into())
}
}
impl<T> Sub<T> for &ComplexMatrix
where
T: Into<C64> + Copy,
{
type Output = ComplexMatrix;
fn sub(self, other: T) -> Self::Output {
self.fmap(|x| x - other.into())
}
}
impl Sub<ComplexMatrix> for C64 {
type Output = ComplexMatrix;
fn sub(self, other: ComplexMatrix) -> Self::Output {
-other.sub(self)
}
}
impl<'a> Sub<&'a ComplexMatrix> for f64 {
type Output = ComplexMatrix;
fn sub(self, other: &'a ComplexMatrix) -> Self::Output {
-other.sub(self)
}
}
impl Mul<C64> for ComplexMatrix {
type Output = Self;
fn mul(self, other: C64) -> Self::Output {
self.fmap(|x| x * other)
}
}
impl Mul<ComplexMatrix> for C64 {
type Output = ComplexMatrix;
fn mul(self, other: ComplexMatrix) -> Self::Output {
other.mul(self)
}
}
impl<'a> Mul<&'a ComplexMatrix> for C64 {
type Output = ComplexMatrix;
fn mul(self, other: &'a ComplexMatrix) -> Self::Output {
other.mul_scalar(self)
}
}
impl Mul<ComplexMatrix> for ComplexMatrix {
type Output = Self;
fn mul(self, other: Self) -> Self::Output {
cmatmul(&self, &other)
}
}
impl<'b> Mul<&'b ComplexMatrix> for &ComplexMatrix {
type Output = ComplexMatrix;
fn mul(self, other: &'b ComplexMatrix) -> Self::Output {
cmatmul(self, other)
}
}
#[allow(non_snake_case)]
impl Mul<Vec<C64>> for ComplexMatrix {
type Output = Vec<C64>;
fn mul(self, other: Vec<C64>) -> Self::Output {
self.apply(&other)
}
}
#[allow(non_snake_case)]
impl<'b> Mul<&'b Vec<C64>> for &ComplexMatrix {
type Output = Vec<C64>;
fn mul(self, other: &'b Vec<C64>) -> Self::Output {
self.apply(other)
}
}
impl Mul<ComplexMatrix> for Vec<C64> {
type Output = Vec<C64>;
fn mul(self, other: ComplexMatrix) -> Self::Output {
assert_eq!(self.len(), other.row);
let mut c = vec![Complex::zero(); other.col];
complex_gevm(Complex::one(), &self, &other, Complex::zero(), &mut c);
c
}
}
impl<'b> Mul<&'b ComplexMatrix> for &Vec<C64> {
type Output = Vec<C64>;
fn mul(self, other: &'b ComplexMatrix) -> Self::Output {
assert_eq!(self.len(), other.row);
let mut c = vec![Complex::zero(); other.col];
complex_gevm(Complex::one(), self, other, Complex::zero(), &mut c);
c
}
}
impl Div<C64> for ComplexMatrix {
type Output = Self;
fn div(self, other: C64) -> Self::Output {
self.fmap(|x| x / other)
}
}
impl Div<C64> for &ComplexMatrix {
type Output = ComplexMatrix;
fn div(self, other: C64) -> Self::Output {
self.fmap(|x| x / other)
}
}
impl Index<(usize, usize)> for ComplexMatrix {
type Output = C64;
fn index(&self, pair: (usize, usize)) -> &C64 {
let p = self.ptr();
let i = pair.0;
let j = pair.1;
assert!(i < self.row && j < self.col, "Index out of range");
match self.shape {
Shape::Row => unsafe { &*p.add(i * self.col + j) },
Shape::Col => unsafe { &*p.add(i + j * self.row) },
}
}
}
impl IndexMut<(usize, usize)> for ComplexMatrix {
fn index_mut(&mut self, pair: (usize, usize)) -> &mut C64 {
let i = pair.0;
let j = pair.1;
let r = self.row;
let c = self.col;
assert!(i < self.row && j < self.col, "Index out of range");
let p = self.mut_ptr();
match self.shape {
Shape::Row => {
let idx = i * c + j;
unsafe { &mut *p.add(idx) }
}
Shape::Col => {
let idx = i + j * r;
unsafe { &mut *p.add(idx) }
}
}
}
}
impl FPMatrix for ComplexMatrix {
type Scalar = C64;
fn take_row(&self, n: usize) -> Self {
if n >= self.row {
return self.clone();
}
match self.shape {
Shape::Row => {
let new_data = self
.data
.clone()
.into_iter()
.take(n * self.col)
.collect::<Vec<C64>>();
cmatrix(new_data, n, self.col, Shape::Row)
}
Shape::Col => {
let mut temp_data: Vec<C64> = Vec::new();
for i in 0..n {
temp_data.extend(self.row(i));
}
cmatrix(temp_data, n, self.col, Shape::Row)
}
}
}
fn take_col(&self, n: usize) -> Self {
if n >= self.col {
return self.clone();
}
match self.shape {
Shape::Col => {
let new_data = self
.data
.clone()
.into_iter()
.take(n * self.row)
.collect::<Vec<C64>>();
cmatrix(new_data, self.row, n, Shape::Col)
}
Shape::Row => {
let mut temp_data: Vec<C64> = Vec::new();
for i in 0..n {
temp_data.extend(self.col(i));
}
cmatrix(temp_data, self.row, n, Shape::Col)
}
}
}
fn skip_row(&self, n: usize) -> Self {
assert!(n < self.row, "Skip range is larger than row of matrix");
let mut temp_data: Vec<C64> = Vec::new();
for i in n..self.row {
temp_data.extend(self.row(i));
}
cmatrix(temp_data, self.row - n, self.col, Shape::Row)
}
fn skip_col(&self, n: usize) -> Self {
assert!(n < self.col, "Skip range is larger than col of matrix");
let mut temp_data: Vec<C64> = Vec::new();
for i in n..self.col {
temp_data.extend(self.col(i));
}
cmatrix(temp_data, self.row, self.col - n, Shape::Col)
}
fn fmap<F>(&self, f: F) -> Self
where
F: Fn(C64) -> C64,
{
let result = self.data.iter().map(|x| f(*x)).collect::<Vec<C64>>();
cmatrix(result, self.row, self.col, self.shape)
}
fn col_map<F>(&self, f: F) -> ComplexMatrix
where
F: Fn(Vec<C64>) -> Vec<C64>,
{
let mut result = cmatrix(
vec![Complex::zero(); self.row * self.col],
self.row,
self.col,
Shape::Col,
);
for i in 0..self.col {
result.subs_col(i, &f(self.col(i)));
}
result
}
fn row_map<F>(&self, f: F) -> ComplexMatrix
where
F: Fn(Vec<C64>) -> Vec<C64>,
{
let mut result = cmatrix(
vec![Complex::zero(); self.row * self.col],
self.row,
self.col,
Shape::Row,
);
for i in 0..self.row {
result.subs_row(i, &f(self.row(i)));
}
result
}
fn col_mut_map<F>(&mut self, f: F)
where
F: Fn(Vec<C64>) -> Vec<C64>,
{
for i in 0..self.col {
unsafe {
let mut p = self.col_mut(i);
let fv = f(self.col(i));
for j in 0..p.len() {
*p[j] = fv[j];
}
}
}
}
fn row_mut_map<F>(&mut self, f: F)
where
F: Fn(Vec<C64>) -> Vec<C64>,
{
for i in 0..self.col {
unsafe {
let mut p = self.row_mut(i);
let fv = f(self.row(i));
for j in 0..p.len() {
*p[j] = fv[j];
}
}
}
}
fn reduce<F, T>(&self, init: T, f: F) -> C64
where
F: Fn(C64, C64) -> C64,
T: Into<C64>,
{
self.data.iter().fold(init.into(), |x, y| f(x, *y))
}
fn zip_with<F>(&self, f: F, other: &ComplexMatrix) -> Self
where
F: Fn(C64, C64) -> C64,
{
assert_eq!(self.data.len(), other.data.len());
let mut a = other.clone();
if self.shape != other.shape {
a = a.change_shape();
}
let result = self
.data
.iter()
.zip(a.data.iter())
.map(|(x, y)| f(*x, *y))
.collect::<Vec<C64>>();
cmatrix(result, self.row, self.col, self.shape)
}
fn col_reduce<F>(&self, f: F) -> Vec<C64>
where
F: Fn(Vec<C64>) -> C64,
{
let mut v = vec![Complex::zero(); self.col];
for (i, slot) in v.iter_mut().enumerate().take(self.col) {
*slot = f(self.col(i));
}
v
}
fn row_reduce<F>(&self, f: F) -> Vec<C64>
where
F: Fn(Vec<C64>) -> C64,
{
let mut v = vec![Complex::zero(); self.row];
for (i, slot) in v.iter_mut().enumerate().take(self.row) {
*slot = f(self.row(i));
}
v
}
}
pub fn cdiag(n: usize) -> ComplexMatrix {
let mut v: Vec<C64> = vec![Complex::zero(); n * n];
for i in 0..n {
let idx = i * (n + 1);
v[idx] = Complex::one();
}
cmatrix(v, n, n, Shape::Row)
}
impl PQLU<ComplexMatrix> {
pub fn extract(&self) -> (Vec<usize>, Vec<usize>, ComplexMatrix, ComplexMatrix) {
(
self.p.clone(),
self.q.clone(),
self.l.clone(),
self.u.clone(),
)
}
pub fn det(&self) -> C64 {
let mut sgn_p = 1f64;
let mut sgn_q = 1f64;
for (i, &j) in self.p.iter().enumerate() {
if i != j {
sgn_p *= -1f64;
}
}
for (i, &j) in self.q.iter().enumerate() {
if i != j {
sgn_q *= -1f64;
}
}
self.u.diag().reduce(Complex::one(), |x, y| x * y) * sgn_p * sgn_q
}
pub fn inv(&self) -> ComplexMatrix {
let (p, q, l, u) = self.extract();
let mut m = complex_inv_u(u) * complex_inv_l(l);
for (idx1, idx2) in q.into_iter().enumerate().rev() {
unsafe {
m.swap(idx1, idx2, Shape::Row);
}
}
for (idx1, idx2) in p.into_iter().enumerate().rev() {
unsafe {
m.swap(idx1, idx2, Shape::Col);
}
}
m
}
}
pub fn ceye(n: usize) -> ComplexMatrix {
let mut m = cmatrix(vec![Complex::zero(); n * n], n, n, Shape::Row);
for i in 0..n {
m[(i, i)] = Complex::one();
}
m
}
impl LinearAlgebra<ComplexMatrix> for ComplexMatrix {
fn back_subs(&self, b: &[C64]) -> Vec<C64> {
let n = self.col;
let mut y = vec![Complex::zero(); n];
y[n - 1] = b[n - 1] / self[(n - 1, n - 1)];
for i in (0..n - 1).rev() {
let mut s = Complex::zero();
for j in i + 1..n {
s += self[(i, j)] * y[j];
}
y[i] = 1f64 / self[(i, i)] * (b[i] - s);
}
y
}
fn forward_subs(&self, b: &[C64]) -> Vec<C64> {
let n = self.col;
let mut y = vec![Complex::zero(); n];
y[0] = b[0] / self[(0, 0)];
for i in 1..n {
let mut s = Complex::zero();
for j in 0..i {
s += self[(i, j)] * y[j];
}
y[i] = 1f64 / self[(i, i)] * (b[i] - s);
}
y
}
fn lu(&self) -> PQLU<ComplexMatrix> {
assert_eq!(self.col, self.row);
let n = self.row;
let len: usize = n * n;
let mut l = ceye(n);
let mut u = cmatrix(vec![Complex::zero(); len], n, n, self.shape);
let mut temp = self.clone();
let (p, q) = gecp(&mut temp);
for i in 0..n {
for j in 0..i {
l[(i, j)] = -temp[(i, j)];
}
for j in i..n {
u[(i, j)] = temp[(i, j)];
}
}
for i in 0..n - 1 {
unsafe {
let l_i = l.col_mut(i);
for j in i + 1..l.col - 1 {
let dst = p[j];
std::ptr::swap(l_i[j], l_i[dst]);
}
}
}
PQLU { p, q, l, u }
}
fn waz(&self, _d_form: Form) -> Option<WAZD<ComplexMatrix>> {
unimplemented!()
}
fn qr(&self) -> QR<ComplexMatrix> {
unimplemented!()
}
fn svd(&self) -> SVD<ComplexMatrix> {
unimplemented!()
}
#[cfg(feature = "O3")]
fn cholesky(&self, uplo: UPLO) -> ComplexMatrix {
unimplemented!()
}
fn rref(&self) -> ComplexMatrix {
unimplemented!()
}
fn det(&self) -> C64 {
assert_eq!(self.row, self.col);
self.lu().det()
}
fn block(&self) -> (Self, Self, Self, Self) {
let r = self.row;
let c = self.col;
let l_r = self.row / 2;
let l_c = self.col / 2;
let r_l = r - l_r;
let c_l = c - l_c;
let mut m1 = cmatrix(vec![Complex::zero(); l_r * l_c], l_r, l_c, self.shape);
let mut m2 = cmatrix(vec![Complex::zero(); l_r * c_l], l_r, c_l, self.shape);
let mut m3 = cmatrix(vec![Complex::zero(); r_l * l_c], r_l, l_c, self.shape);
let mut m4 = cmatrix(vec![Complex::zero(); r_l * c_l], r_l, c_l, self.shape);
for idx_row in 0..r {
for idx_col in 0..c {
match (idx_row, idx_col) {
(i, j) if (i < l_r) && (j < l_c) => {
m1[(i, j)] = self[(i, j)];
}
(i, j) if (i < l_r) && (j >= l_c) => {
m2[(i, j - l_c)] = self[(i, j)];
}
(i, j) if (i >= l_r) && (j < l_c) => {
m3[(i - l_r, j)] = self[(i, j)];
}
(i, j) if (i >= l_r) && (j >= l_c) => {
m4[(i - l_r, j - l_c)] = self[(i, j)];
}
_ => (),
}
}
}
(m1, m2, m3, m4)
}
fn inv(&self) -> Self {
self.lu().inv()
}
fn pseudo_inv(&self) -> ComplexMatrix {
unimplemented!()
}
fn solve(&self, b: &[C64], sk: SolveKind) -> Vec<C64> {
match sk {
SolveKind::LU => {
let lu = self.lu();
let (p, q, l, u) = lu.extract();
let mut v = b.to_vec();
v.swap_with_perm(&p.into_iter().enumerate().collect::<Vec<_>>());
let z = l.forward_subs(&v);
let mut y = u.back_subs(&z);
y.swap_with_perm(&q.into_iter().enumerate().rev().collect::<Vec<_>>());
y
}
SolveKind::WAZ => {
unimplemented!()
}
}
}
fn solve_mat(&self, m: &ComplexMatrix, sk: SolveKind) -> ComplexMatrix {
match sk {
SolveKind::LU => {
let lu = self.lu();
let (p, q, l, u) = lu.extract();
let mut x = cmatrix(
vec![Complex::zero(); self.col * m.col],
self.col,
m.col,
Shape::Col,
);
for i in 0..m.col {
let mut v = m.col(i).clone();
for (r, &s) in p.iter().enumerate() {
v.swap(r, s);
}
let z = l.forward_subs(&v);
let mut y = u.back_subs(&z);
for (r, &s) in q.iter().enumerate() {
y.swap(r, s);
}
unsafe {
let mut c = x.col_mut(i);
copy_vec_ptr(&mut c, &y);
}
}
x
}
SolveKind::WAZ => {
unimplemented!()
}
}
}
fn is_symmetric(&self) -> bool {
if self.row != self.col {
return false;
}
for i in 0..self.row {
for j in i..self.col {
if (!nearly_eq(self[(i, j)].re, self[(j, i)].re))
&& (!nearly_eq(self[(i, j)].im, self[(j, i)].im))
{
return false;
}
}
}
true
}
}
#[allow(non_snake_case)]
pub fn csolve(A: &ComplexMatrix, b: &ComplexMatrix, sk: SolveKind) -> ComplexMatrix {
A.solve_mat(b, sk)
}
impl MutMatrix for ComplexMatrix {
type Scalar = C64;
unsafe fn col_mut(&mut self, idx: usize) -> Vec<*mut C64> {
assert!(idx < self.col, "Index out of range");
match self.shape {
Shape::Col => {
let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.row];
let start_idx = idx * self.row;
let p = self.mut_ptr();
for (i, j) in (start_idx..start_idx + v.len()).enumerate() {
v[i] = p.add(j);
}
v
}
Shape::Row => {
let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.row];
let p = self.mut_ptr();
for (i, slot) in v.iter_mut().enumerate() {
*slot = p.add(idx + i * self.col);
}
v
}
}
}
unsafe fn row_mut(&mut self, idx: usize) -> Vec<*mut C64> {
assert!(idx < self.row, "Index out of range");
match self.shape {
Shape::Row => {
let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.col];
let start_idx = idx * self.col;
let p = self.mut_ptr();
for (i, j) in (start_idx..start_idx + v.len()).enumerate() {
v[i] = p.add(j);
}
v
}
Shape::Col => {
let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.col];
let p = self.mut_ptr();
for (i, slot) in v.iter_mut().enumerate() {
*slot = p.add(idx + i * self.row);
}
v
}
}
}
unsafe fn swap(&mut self, idx1: usize, idx2: usize, shape: Shape) {
match shape {
Shape::Col => swap_vec_ptr(&mut self.col_mut(idx1), &mut self.col_mut(idx2)),
Shape::Row => swap_vec_ptr(&mut self.row_mut(idx1), &mut self.row_mut(idx2)),
}
}
unsafe fn swap_with_perm(&mut self, p: &[(usize, usize)], shape: Shape) {
for (i, j) in p.iter() {
self.swap(*i, *j, shape)
}
}
}
impl ExpLogOps for ComplexMatrix {
type Float = C64;
fn exp(&self) -> Self {
self.fmap(|x| x.exp())
}
fn ln(&self) -> Self {
self.fmap(|x| x.ln())
}
fn log(&self, base: Self::Float) -> Self {
self.fmap(|x| x.ln() / base.ln()) }
fn log2(&self) -> Self {
self.fmap(|x| x.ln() / 2.0.ln()) }
fn log10(&self) -> Self {
self.fmap(|x| x.ln() / 10.0.ln()) }
}
impl PowOps for ComplexMatrix {
type Float = C64;
fn powi(&self, n: i32) -> Self {
self.fmap(|x| x.powi(n))
}
fn powf(&self, f: Self::Float) -> Self {
self.fmap(|x| x.powc(f))
}
fn pow(&self, _f: Self) -> Self {
unimplemented!()
}
fn sqrt(&self) -> Self {
self.fmap(|x| x.sqrt())
}
}
impl TrigOps for ComplexMatrix {
fn sin_cos(&self) -> (Self, Self) {
let (sin, cos) = self.data.iter().map(|x| (x.sin(), x.cos())).unzip();
(
cmatrix(sin, self.row, self.col, self.shape),
cmatrix(cos, self.row, self.col, self.shape),
)
}
fn sin(&self) -> Self {
self.fmap(|x| x.sin())
}
fn cos(&self) -> Self {
self.fmap(|x| x.cos())
}
fn tan(&self) -> Self {
self.fmap(|x| x.tan())
}
fn sinh(&self) -> Self {
self.fmap(|x| x.sinh())
}
fn cosh(&self) -> Self {
self.fmap(|x| x.cosh())
}
fn tanh(&self) -> Self {
self.fmap(|x| x.tanh())
}
fn asin(&self) -> Self {
self.fmap(|x| x.asin())
}
fn acos(&self) -> Self {
self.fmap(|x| x.acos())
}
fn atan(&self) -> Self {
self.fmap(|x| x.atan())
}
fn asinh(&self) -> Self {
self.fmap(|x| x.asinh())
}
fn acosh(&self) -> Self {
self.fmap(|x| x.acosh())
}
fn atanh(&self) -> Self {
self.fmap(|x| x.atanh())
}
}
pub fn complex_combine(
m1: ComplexMatrix,
m2: ComplexMatrix,
m3: ComplexMatrix,
m4: ComplexMatrix,
) -> ComplexMatrix {
let l_r = m1.row;
let l_c = m1.col;
let c_l = m2.col;
let r_l = m3.row;
let r = l_r + r_l;
let c = l_c + c_l;
let mut m = cmatrix(vec![Complex::zero(); r * c], r, c, m1.shape);
for idx_row in 0..r {
for idx_col in 0..c {
match (idx_row, idx_col) {
(i, j) if (i < l_r) && (j < l_c) => {
m[(i, j)] = m1[(i, j)];
}
(i, j) if (i < l_r) && (j >= l_c) => {
m[(i, j)] = m2[(i, j - l_c)];
}
(i, j) if (i >= l_r) && (j < l_c) => {
m[(i, j)] = m3[(i - l_r, j)];
}
(i, j) if (i >= l_r) && (j >= l_c) => {
m[(i, j)] = m4[(i - l_r, j - l_c)];
}
_ => (),
}
}
}
m
}
pub fn complex_inv_l(l: ComplexMatrix) -> ComplexMatrix {
let mut m = l.clone();
match l.row {
1 => l,
2 => {
m[(1, 0)] = -m[(1, 0)];
m
}
_ => {
let (l1, l2, l3, l4) = l.block();
let m1 = complex_inv_l(l1);
let m2 = l2;
let m4 = complex_inv_l(l4);
let m3 = -(&(&m4 * &l3) * &m1);
complex_combine(m1, m2, m3, m4)
}
}
}
pub fn complex_inv_u(u: ComplexMatrix) -> ComplexMatrix {
let mut w = u.clone();
match u.row {
1 => {
w[(0, 0)] = 1f64 / w[(0, 0)];
w
}
2 => {
let a = w[(0, 0)];
let b = w[(0, 1)];
let c = w[(1, 1)];
let d = a * c;
w[(0, 0)] = 1f64 / a;
w[(0, 1)] = -b / d;
w[(1, 1)] = 1f64 / c;
w
}
_ => {
let (u1, u2, u3, u4) = u.block();
let m1 = complex_inv_u(u1);
let m3 = u3;
let m4 = complex_inv_u(u4);
let m2 = -(m1.clone() * u2 * m4.clone());
complex_combine(m1, m2, m3, m4)
}
}
}
pub fn cmatmul(a: &ComplexMatrix, b: &ComplexMatrix) -> ComplexMatrix {
assert_eq!(a.col, b.row);
let mut c = cmatrix(vec![Complex::zero(); a.row * b.col], a.row, b.col, a.shape);
cgemm(Complex::one(), a, b, Complex::zero(), &mut c);
c
}
pub fn cgemm(alpha: C64, a: &ComplexMatrix, b: &ComplexMatrix, beta: C64, c: &mut ComplexMatrix) {
let m = a.row;
let k = a.col;
let n = b.col;
let (rsa, csa) = match a.shape {
Shape::Row => (a.col as isize, 1isize),
Shape::Col => (1isize, a.row as isize),
};
let (rsb, csb) = match b.shape {
Shape::Row => (b.col as isize, 1isize),
Shape::Col => (1isize, b.row as isize),
};
let (rsc, csc) = match c.shape {
Shape::Row => (c.col as isize, 1isize),
Shape::Col => (1isize, c.row as isize),
};
unsafe {
matrixmultiply::zgemm(
CGemmOption::Standard,
CGemmOption::Standard,
m,
k,
n,
[alpha.re, alpha.im],
a.ptr() as *const _,
rsa,
csa,
b.ptr() as *const _,
rsb,
csb,
[beta.re, beta.im],
c.mut_ptr() as *mut _,
rsc,
csc,
)
}
}
pub fn cgemv(alpha: C64, a: &ComplexMatrix, b: &[C64], beta: C64, c: &mut [C64]) {
let m = a.row;
let k = a.col;
let n = 1usize;
let (rsa, csa) = match a.shape {
Shape::Row => (a.col as isize, 1isize),
Shape::Col => (1isize, a.row as isize),
};
let (rsb, csb) = (1isize, 1isize);
let (rsc, csc) = (1isize, 1isize);
unsafe {
matrixmultiply::zgemm(
CGemmOption::Standard,
CGemmOption::Standard,
m,
k,
n,
[alpha.re, alpha.im],
a.ptr() as *const _,
rsa,
csa,
b.as_ptr() as *const _,
rsb,
csb,
[beta.re, beta.im],
c.as_mut_ptr() as *mut _,
rsc,
csc,
)
}
}
pub fn complex_gevm(alpha: C64, a: &[C64], b: &ComplexMatrix, beta: C64, c: &mut [C64]) {
let m = 1usize;
let k = a.len();
let n = b.col;
let (rsa, csa) = (1isize, 1isize);
let (rsb, csb) = match b.shape {
Shape::Row => (b.col as isize, 1isize),
Shape::Col => (1isize, b.row as isize),
};
let (rsc, csc) = (1isize, 1isize);
unsafe {
matrixmultiply::zgemm(
CGemmOption::Standard,
CGemmOption::Standard,
m,
k,
n,
[alpha.re, alpha.im],
a.as_ptr() as *const _,
rsa,
csa,
b.ptr() as *const _,
rsb,
csb,
[beta.re, beta.im],
c.as_mut_ptr() as *mut _,
rsc,
csc,
)
}
}
#[allow(dead_code)]
fn gepp(m: &mut ComplexMatrix) -> Vec<usize> {
let mut r = vec![0usize; m.col - 1];
for k in 0..(m.col - 1) {
let r_k = m
.col(k)
.into_iter()
.skip(k)
.enumerate()
.max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap())
.unwrap()
.0
+ k;
r[k] = r_k;
for j in k..m.col {
unsafe {
std::ptr::swap(&mut m[(k, j)], &mut m[(r_k, j)]);
println!("Swap! k:{}, r_k:{}", k, r_k);
}
}
for i in k + 1..m.col {
m[(i, k)] = -m[(i, k)] / m[(k, k)];
}
for i in k + 1..m.col {
for j in k + 1..m.col {
let local_m = m[(i, k)] * m[(k, j)];
m[(i, j)] += local_m;
}
}
}
r
}
fn gecp(m: &mut ComplexMatrix) -> (Vec<usize>, Vec<usize>) {
let n = m.col;
let mut r = vec![0usize; n - 1];
let mut s = vec![0usize; n - 1];
for k in 0..n - 1 {
let (r_k, s_k) = match m.shape {
Shape::Col => {
let mut row_ics = 0usize;
let mut col_ics = 0usize;
let mut max_val = 0f64;
for i in k..n {
let c = m
.col(i)
.into_iter()
.skip(k)
.enumerate()
.max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap())
.unwrap();
let c_ics = c.0 + k;
let c_val = c.1.norm();
if c_val > max_val {
row_ics = c_ics;
col_ics = i;
max_val = c_val;
}
}
(row_ics, col_ics)
}
Shape::Row => {
let mut row_ics = 0usize;
let mut col_ics = 0usize;
let mut max_val = 0f64;
for i in k..n {
let c = m
.row(i)
.into_iter()
.skip(k)
.enumerate()
.max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap())
.unwrap();
let c_ics = c.0 + k;
let c_val = c.1.norm();
if c_val > max_val {
col_ics = c_ics;
row_ics = i;
max_val = c_val;
}
}
(row_ics, col_ics)
}
};
r[k] = r_k;
s[k] = s_k;
for j in k..n {
unsafe {
std::ptr::swap(&mut m[(k, j)], &mut m[(r_k, j)]);
}
}
for i in 0..n {
unsafe {
std::ptr::swap(&mut m[(i, k)], &mut m[(i, s_k)]);
}
}
for i in k + 1..n {
m[(i, k)] = -m[(i, k)] / m[(k, k)];
for j in k + 1..n {
let local_m = m[(i, k)] * m[(k, j)];
m[(i, j)] += local_m;
}
}
}
(r, s)
}