use cblas;
use ndarray::{
Data,
ShapeBuilder,
};
use ndarray::prelude::*;
use std::slice;
use crate::{
Error,
GramSchmidt,
Result,
utils::{
as_slice_with_layout,
get_layout,
}
};
#[derive(Clone, Debug)]
pub struct Reorthogonalized {
q: Array2<f64>,
r: Array2<f64>,
work_vector: Array1<f64>,
memory_layout: cblas::Layout,
}
impl GramSchmidt for Reorthogonalized {
fn from_shape<T>(shape: T) -> Result<Self>
where T: ShapeBuilder<Dim = Ix2>,
{
let shape = shape.into_shape();
let q = Array2::zeros(shape);
let memory_layout = match get_layout(&q) {
Some(layout) => layout,
None => Err(Error::NonContiguous)?,
};
let r = q.clone();
let work_vector = Array1::zeros(q.dim().0);
Ok(Self {
q,
r,
work_vector,
memory_layout,
})
}
fn compute<S>(&mut self, a: &ArrayBase<S, Ix2>) -> Result<()>
where S: Data<Elem = f64>,
{
use cblas::Layout::*;
use Error::*;
assert_eq!(a.shape(), self.q.shape());
let (n_rows, n_cols) = self.q.dim();
let a_slice = match (self.memory_layout, as_slice_with_layout(&a)) {
(a, Some((_, b))) if a != b => Err(IncompatibleLayouts)?,
(_, Some((a_slice, _))) => a_slice,
(_, None) => Err(NonContiguous)?,
};
let (leading_dim, next_elem, next_col) = match self.memory_layout {
ColumnMajor => (n_rows as i32, 1, n_rows),
RowMajor => (n_cols as i32, n_cols as i32, 1),
};
for i in 0..n_cols {
self.q.column_mut(i).assign(&a.column(i));
let len = self.q.len();
let q_ptr = self.q.as_mut_ptr();
let q_matrix = unsafe {
slice::from_raw_parts(q_ptr, len)
};
let q_column = match self.memory_layout {
ColumnMajor => {
let offset = n_rows * i;
unsafe {
slice::from_raw_parts_mut(q_ptr.offset(offset as isize), len - offset)
}
},
RowMajor => {
let offset = i as isize;
unsafe {
slice::from_raw_parts_mut(q_ptr.offset(offset), len - i)
}
},
};
if i > 0 {
let a_column = &a_slice[next_col * i..];
let r_slice = self.r.as_slice_memory_order_mut().unwrap();
let r_column = &mut r_slice[next_col * i..];
let work_slice = self.work_vector.as_slice_memory_order_mut().unwrap();
unsafe {
cblas::dgemv(
self.memory_layout,
cblas::Transpose::Ordinary,
n_rows as i32,
i as i32,
1.0,
q_matrix,
leading_dim,
a_column,
next_elem,
0.0,
r_column,
next_elem
);
cblas::dgemv(
self.memory_layout,
cblas::Transpose::None,
n_rows as i32,
i as i32,
-1.0,
q_matrix,
leading_dim,
r_column,
next_elem,
1.0,
q_column,
next_elem,
);
cblas::dgemv(
self.memory_layout,
cblas::Transpose::Ordinary,
n_rows as i32,
i as i32,
1.0,
q_matrix,
leading_dim,
q_column,
next_elem,
0.0,
work_slice,
1
);
cblas::dgemv(
self.memory_layout,
cblas::Transpose::None,
n_rows as i32,
i as i32,
-1.0,
q_matrix,
leading_dim,
work_slice,
1,
1.0,
q_column,
next_elem,
);
cblas::daxpy(
n_rows as i32,
1.0,
work_slice,
1,
r_column,
next_elem,
);
}
};
let norm = unsafe {
cblas::dnrm2(n_rows as i32, q_column, next_elem)
};
let mut v = self.q.column_mut(i);
v /= norm;
self.r[(i,i)] = a.column(i).dot(&v);
}
Ok(())
}
fn q(&self) -> &Array2<f64> {
&self.q
}
fn r(&self) -> &Array2<f64> {
&self.r
}
}
pub fn cgs2<S>(a: &ArrayBase<S, Ix2>) -> Result<(Array<f64, Ix2>, Array<f64, Ix2>)>
where S: Data<Elem=f64>
{
let mut cgs2 = Reorthogonalized::from_matrix(a)?;
cgs2.compute(a)?;
Ok((cgs2.q().clone(), cgs2.r().clone()))
}
#[cfg(test)]
generate_tests!(Reorthogonalized, 1e-13);