use pounce_common::types::{Index, Number};
pub trait SchurData {
fn nrows(&self) -> Index;
fn is_initialized(&self) -> bool;
fn set_from_flags(&mut self, flags: &[Index], v: Number) -> Result<(), SchurDataError>;
fn set_from_list(&mut self, cols: &[Index], v: Number) -> Result<(), SchurDataError>;
fn multiplying_row(&self, i: Index) -> Result<(Vec<Index>, Vec<Number>), SchurDataError>;
fn multiply(&self, v: &[Number], u: &mut [Number]) -> Result<(), SchurDataError>;
fn trans_multiply(&self, u: &[Number], v: &mut [Number]) -> Result<(), SchurDataError>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchurDataError {
NotInitialized,
AlreadyInitialized,
RowOutOfRange,
DimensionMismatch,
ZeroSign,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct IndexSchurData {
idx: Vec<Index>,
val: Vec<Index>,
initialized: bool,
}
impl IndexSchurData {
pub fn new() -> Self {
Self::default()
}
pub fn from_parts(idx: Vec<Index>, val: Vec<Index>) -> Result<Self, SchurDataError> {
if idx.len() != val.len() {
return Err(SchurDataError::DimensionMismatch);
}
if val.iter().any(|&v| v != 1 && v != -1) {
return Err(SchurDataError::ZeroSign);
}
Ok(Self {
idx,
val,
initialized: true,
})
}
pub fn col_indices(&self) -> &[Index] {
&self.idx
}
pub fn signs(&self) -> &[Index] {
&self.val
}
}
impl SchurData for IndexSchurData {
fn nrows(&self) -> Index {
self.val.len() as Index
}
fn is_initialized(&self) -> bool {
self.initialized
}
fn set_from_flags(&mut self, flags: &[Index], v: Number) -> Result<(), SchurDataError> {
if self.initialized {
return Err(SchurDataError::AlreadyInitialized);
}
if v == 0.0 {
return Err(SchurDataError::ZeroSign);
}
let w: Index = if v > 0.0 { 1 } else { -1 };
for (i, &f) in flags.iter().enumerate() {
match f {
0 => {}
1 => {
self.idx.push(i as Index);
self.val.push(w);
}
_ => return Err(SchurDataError::AlreadyInitialized), }
}
self.initialized = true;
Ok(())
}
fn set_from_list(&mut self, cols: &[Index], v: Number) -> Result<(), SchurDataError> {
if self.initialized {
return Err(SchurDataError::AlreadyInitialized);
}
if v == 0.0 {
return Err(SchurDataError::ZeroSign);
}
let w: Index = if v > 0.0 { 1 } else { -1 };
self.idx.extend_from_slice(cols);
self.val.resize(cols.len(), w);
self.initialized = true;
Ok(())
}
fn multiplying_row(&self, i: Index) -> Result<(Vec<Index>, Vec<Number>), SchurDataError> {
if !self.initialized {
return Err(SchurDataError::NotInitialized);
}
let i_us = i as usize;
if i_us >= self.idx.len() {
return Err(SchurDataError::RowOutOfRange);
}
Ok((vec![self.idx[i_us]], vec![self.val[i_us] as Number]))
}
fn multiply(&self, v: &[Number], u: &mut [Number]) -> Result<(), SchurDataError> {
if !self.initialized {
return Err(SchurDataError::NotInitialized);
}
if u.len() != self.idx.len() {
return Err(SchurDataError::DimensionMismatch);
}
for (i, slot) in u.iter_mut().enumerate() {
let col = self.idx[i] as usize;
if col >= v.len() {
return Err(SchurDataError::DimensionMismatch);
}
*slot = (self.val[i] as Number) * v[col];
}
Ok(())
}
fn trans_multiply(&self, u: &[Number], v: &mut [Number]) -> Result<(), SchurDataError> {
if !self.initialized {
return Err(SchurDataError::NotInitialized);
}
if u.len() != self.idx.len() {
return Err(SchurDataError::DimensionMismatch);
}
for slot in v.iter_mut() {
*slot = 0.0;
}
for (i, &row_u) in u.iter().enumerate() {
let col = self.idx[i] as usize;
if col >= v.len() {
return Err(SchurDataError::DimensionMismatch);
}
v[col] += (self.val[i] as Number) * row_u;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn set_from_flags_round_trip() {
let mut s = IndexSchurData::new();
let flags = [0, 1, 0, 1];
s.set_from_flags(&flags, 1.0).expect("init");
assert_eq!(s.nrows(), 2);
assert_eq!(s.col_indices(), &[1, 3]);
assert_eq!(s.signs(), &[1, 1]);
assert!(s.is_initialized());
}
#[test]
fn set_from_flags_negative_sign_records_minus_one() {
let mut s = IndexSchurData::new();
s.set_from_flags(&[1, 0, 1], -2.5).expect("init");
assert_eq!(s.signs(), &[-1, -1]);
}
#[test]
fn set_from_flags_rejects_double_init() {
let mut s = IndexSchurData::new();
s.set_from_flags(&[1, 0], 1.0).expect("first init");
assert_eq!(
s.set_from_flags(&[0, 1], 1.0),
Err(SchurDataError::AlreadyInitialized),
);
}
#[test]
fn set_from_flags_rejects_zero_sign() {
let mut s = IndexSchurData::new();
assert_eq!(
s.set_from_flags(&[1, 0, 1], 0.0),
Err(SchurDataError::ZeroSign),
);
}
#[test]
fn set_from_list_records_each_column_once() {
let mut s = IndexSchurData::new();
s.set_from_list(&[2, 0, 4], 1.0).expect("init");
assert_eq!(s.nrows(), 3);
assert_eq!(s.col_indices(), &[2, 0, 4]);
assert_eq!(s.signs(), &[1, 1, 1]);
}
#[test]
fn from_parts_validates_signs() {
let ok = IndexSchurData::from_parts(vec![0, 2], vec![1, -1]).expect("ok");
assert_eq!(ok.signs(), &[1, -1]);
assert_eq!(
IndexSchurData::from_parts(vec![0, 2], vec![1]),
Err(SchurDataError::DimensionMismatch),
);
assert_eq!(
IndexSchurData::from_parts(vec![0], vec![2]),
Err(SchurDataError::ZeroSign),
);
}
#[test]
fn multiply_picks_selected_columns_with_signs() {
let s = IndexSchurData::from_parts(vec![1, 3], vec![1, -1]).expect("ok");
let v = [10.0, 20.0, 30.0, 40.0];
let mut u = [0.0; 2];
s.multiply(&v, &mut u).expect("ok");
assert_eq!(u, [20.0, -40.0]);
}
#[test]
fn trans_multiply_scatters_with_signs() {
let s = IndexSchurData::from_parts(vec![1, 3], vec![1, -1]).expect("ok");
let u = [3.0, 5.0];
let mut v = [0.0; 4];
s.trans_multiply(&u, &mut v).expect("ok");
assert_eq!(v, [0.0, 3.0, 0.0, -5.0]);
}
#[test]
fn trans_multiply_overwrites_caller_buffer() {
let s = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).expect("ok");
let u = [1.0, 2.0];
let mut v = [99.0, 99.0, 99.0, 99.0];
s.trans_multiply(&u, &mut v).expect("ok");
assert_eq!(v, [1.0, 0.0, 2.0, 0.0]);
}
#[test]
fn multiply_rejects_uninitialized() {
let s = IndexSchurData::new();
let v = [0.0];
let mut u = [0.0];
assert_eq!(s.multiply(&v, &mut u), Err(SchurDataError::NotInitialized),);
}
#[test]
fn multiplying_row_out_of_range() {
let s = IndexSchurData::from_parts(vec![0], vec![1]).expect("ok");
assert_eq!(s.multiplying_row(2), Err(SchurDataError::RowOutOfRange),);
}
#[test]
fn multiplying_row_returns_single_entry_for_index_schur_data() {
let s = IndexSchurData::from_parts(vec![5, 7], vec![1, -1]).expect("ok");
let (idxs, facs) = s.multiplying_row(0).expect("ok");
assert_eq!(idxs, &[5]);
assert_eq!(facs, &[1.0]);
let (idxs, facs) = s.multiplying_row(1).expect("ok");
assert_eq!(idxs, &[7]);
assert_eq!(facs, &[-1.0]);
}
}