use ndarray::{Array1, ArrayView1};
use std::ops::Deref;
use std::sync::Arc;
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct SignedWeightsView<'a>(ArrayView1<'a, f64>);
impl<'a> SignedWeightsView<'a> {
#[inline]
pub fn new(view: ArrayView1<'a, f64>) -> Self {
Self(view)
}
#[inline]
pub fn from_array(array: &'a Array1<f64>) -> Self {
Self(array.view())
}
#[inline]
pub fn from_slice(slice: &'a [f64]) -> Self {
Self(ArrayView1::from(slice))
}
#[inline]
pub fn view(&self) -> ArrayView1<'a, f64> {
self.0
}
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[inline]
pub fn as_slice(&self) -> Option<&[f64]> {
self.0.as_slice()
}
#[inline]
pub fn as_psd(self) -> Option<PsdWeightsView<'a>> {
if self.0.iter().all(|&w| w >= 0.0) {
Some(PsdWeightsView(self.0))
} else {
None
}
}
}
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct PsdWeightsView<'a>(ArrayView1<'a, f64>);
impl<'a> PsdWeightsView<'a> {
#[inline]
pub fn try_new(view: ArrayView1<'a, f64>) -> Result<Self, String> {
if view.iter().all(|&w| w >= 0.0) {
Ok(Self(view))
} else {
Err("PsdWeights::try_new: weights must be nonneg (use SignedWeightsView for observed-Hessian assembly)".to_string())
}
}
#[inline]
pub fn try_from_array(array: &'a Array1<f64>) -> Result<Self, String> {
Self::try_new(array.view())
}
#[inline]
pub fn from_view_unchecked(view: ArrayView1<'a, f64>) -> Self {
Self(view)
}
#[inline]
pub fn as_signed(self) -> SignedWeightsView<'a> {
SignedWeightsView(self.0)
}
#[inline]
pub fn view(&self) -> ArrayView1<'a, f64> {
self.0
}
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[inline]
pub fn as_slice(&self) -> Option<&[f64]> {
self.0.as_slice()
}
}
#[derive(Clone)]
#[repr(transparent)]
pub struct SignedWeightsArc(Arc<Array1<f64>>);
impl SignedWeightsArc {
#[inline]
pub fn from_arc(arc: Arc<Array1<f64>>) -> Self {
Self(arc)
}
#[inline]
pub fn from_array(array: Array1<f64>) -> Self {
Self(Arc::new(array))
}
#[inline]
pub fn view_signed(&self) -> SignedWeightsView<'_> {
SignedWeightsView::from_array(self.0.as_ref())
}
#[inline]
pub fn as_arc(&self) -> &Arc<Array1<f64>> {
&self.0
}
}
impl Deref for SignedWeightsArc {
type Target = Array1<f64>;
#[inline]
fn deref(&self) -> &Array1<f64> {
self.0.as_ref()
}
}
impl AsRef<Array1<f64>> for SignedWeightsArc {
#[inline]
fn as_ref(&self) -> &Array1<f64> {
self.0.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn signed_view_from_slice_len_and_values() {
let s = [1.0_f64, -2.0, 3.0];
let w = SignedWeightsView::from_slice(&s);
assert_eq!(w.len(), 3);
assert!(!w.is_empty());
assert_eq!(w.as_slice().unwrap(), &s);
}
#[test]
fn signed_view_from_array_round_trips() {
let a = array![5.0_f64, -1.0];
let w = SignedWeightsView::from_array(&a);
assert_eq!(w.len(), 2);
assert_eq!(w.view()[0], 5.0);
assert_eq!(w.view()[1], -1.0);
}
#[test]
fn signed_view_empty_is_empty() {
let s: [f64; 0] = [];
let w = SignedWeightsView::from_slice(&s);
assert_eq!(w.len(), 0);
assert!(w.is_empty());
}
#[test]
fn signed_view_as_psd_succeeds_when_all_nonneg() {
let a = array![0.0_f64, 1.0, 2.0];
let w = SignedWeightsView::from_array(&a);
assert!(w.as_psd().is_some());
}
#[test]
fn signed_view_as_psd_fails_on_negative_entry() {
let a = array![1.0_f64, -0.001, 2.0];
let w = SignedWeightsView::from_array(&a);
assert!(w.as_psd().is_none());
}
#[test]
fn psd_try_new_ok_for_all_nonneg() {
let a = array![0.0_f64, 1.0, 2.0];
assert!(PsdWeightsView::try_new(a.view()).is_ok());
}
#[test]
fn psd_try_new_ok_for_all_zeros() {
let a = array![0.0_f64, 0.0];
assert!(PsdWeightsView::try_new(a.view()).is_ok());
}
#[test]
fn psd_try_new_err_for_negative_entry() {
let a = array![1.0_f64, -1e-10, 2.0];
assert!(PsdWeightsView::try_new(a.view()).is_err());
}
#[test]
fn psd_try_from_array_round_trips() {
let a = array![3.0_f64, 4.0];
let psd = PsdWeightsView::try_from_array(&a).unwrap();
assert_eq!(psd.len(), 2);
assert_eq!(psd.view()[0], 3.0);
}
#[test]
fn psd_as_signed_preserves_values() {
let a = array![7.0_f64, 8.0];
let psd = PsdWeightsView::try_from_array(&a).unwrap();
let signed = psd.as_signed();
assert_eq!(signed.len(), 2);
assert_eq!(signed.view()[1], 8.0);
}
#[test]
fn signed_weights_arc_from_array_view_signed_len() {
let w = SignedWeightsArc::from_array(array![1.0, 2.0, 3.0]);
assert_eq!(w.view_signed().len(), 3);
}
#[test]
fn signed_weights_arc_deref_gives_array() {
let w = SignedWeightsArc::from_array(array![10.0_f64, 20.0]);
assert_eq!((*w)[0], 10.0);
assert_eq!((*w)[1], 20.0);
}
}