use crate::error::{Result, TransformError};
use scirs2_core::ndarray::{Array2, ArrayBase, Axis, Data, Ix2};
use scirs2_core::numeric::{Float, NumCast};
use scirs2_core::random::{Rng, RngExt};
const EPS: f64 = 1e-10;
#[derive(Debug, Clone)]
pub struct ArchetypalModel {
pub archetypes: Array2<f64>,
pub a: Array2<f64>,
pub b: Array2<f64>,
pub reconstruction_error: f64,
pub n_iter: usize,
}
#[derive(Debug, Clone)]
pub struct ArchetypalAnalysis {
pub n_archetypes: usize,
pub max_iter: usize,
pub tol: f64,
pub n_inner: usize,
pub seed: Option<u64>,
}
impl ArchetypalAnalysis {
pub fn new(n_archetypes: usize) -> Self {
Self {
n_archetypes,
max_iter: 300,
tol: 1e-5,
n_inner: 20,
seed: None,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn with_n_inner(mut self, n_inner: usize) -> Self {
self.n_inner = n_inner;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn fit<S>(&self, x_raw: &ArrayBase<S, Ix2>) -> Result<ArchetypalModel>
where
S: Data,
S::Elem: Float + NumCast,
{
let x = to_f64(x_raw)?;
let (n, p) = x.dim();
let k = self.n_archetypes;
if k == 0 {
return Err(TransformError::InvalidInput(
"n_archetypes must be ≥ 1".to_string(),
));
}
if k > n {
return Err(TransformError::InvalidInput(format!(
"n_archetypes={k} must be ≤ n_samples={n}"
)));
}
let mut b = uniform_simplex_cols(n, k, self.seed);
let mut a = uniform_simplex_cols(k, n, self.seed.map(|s| s + 1));
let xt = x.t().to_owned(); let z = xt.dot(&b).t().to_owned();
let mut prev_err = frob_error_za(&x, &z, &a);
let mut n_iter = 0;
for iter in 0..self.max_iter {
n_iter = iter + 1;
a = fw_update_a(&x, &z, &a, self.n_inner);
b = fw_update_b(&x, &z, &a, &b, self.n_inner);
let z_new = b.t().dot(&x); let err = frob_error_za(&x, &z_new, &a);
if iter > 0 && (prev_err - err).abs() / prev_err.max(EPS) < self.tol {
let z_final = b.t().dot(&x);
return Ok(ArchetypalModel {
archetypes: z_final,
a,
b,
reconstruction_error: err,
n_iter,
});
}
prev_err = err;
}
let z_final = b.t().dot(&x); let final_err = frob_error_za(&x, &z_final, &a);
Ok(ArchetypalModel {
archetypes: z_final,
a,
b,
reconstruction_error: final_err,
n_iter,
})
}
}
fn fw_update_a(x: &Array2<f64>, z: &Array2<f64>, a: &Array2<f64>, n_inner: usize) -> Array2<f64> {
let (n, p) = x.dim();
let k = z.nrows();
let mut a_new = a.clone();
let zt = z.t().to_owned();
for _step in 0..n_inner {
let za = zt.dot(&a_new); let residual = &za - &x.t().to_owned();
let grad = z.dot(&residual);
let mut a_vertex = Array2::<f64>::zeros((k, n));
for i in 0..n {
let mut best_j = 0;
let mut best_val = grad[[0, i]];
for j in 1..k {
if grad[[j, i]] < best_val {
best_val = grad[[j, i]];
best_j = j;
}
}
a_vertex[[best_j, i]] = 1.0;
}
let gamma = 2.0 / (_step as f64 + 2.0);
a_new = (1.0 - gamma) * &a_new + gamma * &a_vertex;
}
project_simplex_cols(&a_new)
}
fn fw_update_b(
x: &Array2<f64>,
_z: &Array2<f64>,
a: &Array2<f64>,
b: &Array2<f64>,
n_inner: usize,
) -> Array2<f64> {
let (n, p) = x.dim();
let k = b.ncols();
let mut b_new = b.clone();
for _step in 0..n_inner {
let z_cur = b_new.t().dot(x);
let at = a.t().to_owned(); let r = at.dot(&z_cur);
let residual = &r - x;
let a_res = a.dot(&residual); let grad_b = x.t().dot(&a_res.t()); let dl_dz = a.dot(&residual); let grad_b2 = x.dot(&dl_dz.t());
let mut b_vertex = Array2::<f64>::zeros((n, k));
for j in 0..k {
let mut best_i = 0;
let mut best_val = grad_b2[[0, j]];
for i in 1..n {
if grad_b2[[i, j]] < best_val {
best_val = grad_b2[[i, j]];
best_i = i;
}
}
b_vertex[[best_i, j]] = 1.0;
}
let gamma = 2.0 / (_step as f64 + 2.0);
b_new = (1.0 - gamma) * &b_new + gamma * &b_vertex;
}
project_simplex_cols(&b_new.t().to_owned()).t().to_owned()
}
pub fn archetypal_error(x: &Array2<f64>, model: &ArchetypalModel) -> f64 {
frob_error_za(x, &model.archetypes, &model.a)
}
pub fn archetypal_simplex(model: &ArchetypalModel) -> bool {
let tol = 1e-4;
for &v in model.a.iter() {
if v < -tol {
return false;
}
}
for &v in model.b.iter() {
if v < -tol {
return false;
}
}
let (k, n) = model.a.dim();
for i in 0..n {
let col_sum: f64 = (0..k).map(|j| model.a[[j, i]]).sum();
if (col_sum - 1.0).abs() > tol {
return false;
}
}
let (n_b, k_b) = model.b.dim();
for j in 0..k_b {
let col_sum: f64 = (0..n_b).map(|i| model.b[[i, j]]).sum();
if (col_sum - 1.0).abs() > tol {
return false;
}
}
true
}
fn frob_error_za(x: &Array2<f64>, z: &Array2<f64>, a: &Array2<f64>) -> f64 {
let at = a.t().to_owned(); let recon = at.dot(z); let diff = x - &recon;
diff.mapv(|v| v * v).sum().sqrt()
}
fn to_f64<S>(x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
Ok(x.mapv(|v| NumCast::from(v).unwrap_or(0.0)))
}
fn uniform_simplex_cols(nrows: usize, ncols: usize, seed: Option<u64>) -> Array2<f64> {
let mut rng = scirs2_core::random::rng();
let _ = seed;
let mut m = Array2::<f64>::zeros((nrows, ncols));
for j in 0..ncols {
let mut col_sum = 0.0;
for i in 0..nrows {
m[[i, j]] = rng.random::<f64>() + EPS;
col_sum += m[[i, j]];
}
for i in 0..nrows {
m[[i, j]] /= col_sum;
}
}
m
}
fn project_simplex_cols(m: &Array2<f64>) -> Array2<f64> {
let (nrows, ncols) = m.dim();
let mut out = m.clone();
for j in 0..ncols {
let mut col: Vec<f64> = (0..nrows).map(|i| m[[i, j]]).collect();
col.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut rho = 0usize;
for (idx, &val) in col.iter().enumerate() {
cumsum += val;
if val - (cumsum - 1.0) / (idx as f64 + 1.0) > 0.0 {
rho = idx;
}
}
let cumsum_rho: f64 = col.iter().take(rho + 1).sum();
let theta = (cumsum_rho - 1.0) / (rho as f64 + 1.0);
for i in 0..nrows {
out[[i, j]] = (m[[i, j]] - theta).max(0.0);
}
let col_sum: f64 = (0..nrows).map(|i| out[[i, j]]).sum::<f64>().max(EPS);
for i in 0..nrows {
out[[i, j]] /= col_sum;
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn simplex_data() -> Array2<f64> {
let data = vec![
1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 1.0, 0.0, 0.1, 0.8, 0.1, 0.0, 0.0, 1.0, 0.05, 0.05, 0.9, 0.33, 0.33, 0.34, ];
Array2::from_shape_vec((7, 3), data).expect("shape ok")
}
#[test]
fn test_archetypal_fit_shapes() {
let x = simplex_data();
let model = ArchetypalAnalysis::new(3)
.with_max_iter(100)
.fit(&x)
.expect("AA fit ok");
assert_eq!(model.archetypes.shape(), &[3, 3]);
assert_eq!(model.a.shape(), &[3, 7]);
assert_eq!(model.b.shape(), &[7, 3]);
}
#[test]
fn test_archetypal_simplex_constraints() {
let x = simplex_data();
let model = ArchetypalAnalysis::new(3)
.with_max_iter(200)
.fit(&x)
.expect("AA fit ok");
assert!(
archetypal_simplex(&model),
"convexity constraints should hold: {:?}",
model
);
}
#[test]
fn test_archetypal_reconstruction_reasonable() {
let x = simplex_data();
let model = ArchetypalAnalysis::new(3)
.with_max_iter(300)
.fit(&x)
.expect("AA fit ok");
let x_norm = x.mapv(|v| v * v).sum().sqrt();
let rel_err = model.reconstruction_error / x_norm.max(EPS);
assert!(
rel_err < 1.0,
"relative reconstruction error {rel_err} should be < 1.0"
);
}
#[test]
fn test_archetypal_error_function() {
let x = simplex_data();
let model = ArchetypalAnalysis::new(3)
.with_max_iter(50)
.fit(&x)
.expect("AA fit ok");
let err = archetypal_error(&x, &model);
let delta = (err - model.reconstruction_error).abs();
assert!(
delta < 1e-6,
"archetypal_error differs from stored error: {delta}"
);
}
#[test]
fn test_archetypal_too_many_archetypes() {
let x = simplex_data();
let result = ArchetypalAnalysis::new(100).fit(&x);
assert!(result.is_err(), "should reject k > n_samples");
}
#[test]
fn test_archetypal_zero_archetypes() {
let x = simplex_data();
let result = ArchetypalAnalysis::new(0).fit(&x);
assert!(result.is_err(), "should reject k=0");
}
#[test]
fn test_project_simplex_valid() {
let m = Array2::from_shape_vec((4, 3), vec![2.0, -1.0, 0.5, 1.0, 3.0, 2.0, -0.5, 1.0, 0.3, 0.0, 1.0, 0.2]).expect("valid shape");
let p = project_simplex_cols(&m);
for j in 0..3 {
let s: f64 = (0..4).map(|i| p[[i, j]]).sum();
assert!((s - 1.0).abs() < 1e-8, "col {j} sums to {s}");
for i in 0..4 {
assert!(p[[i, j]] >= -1e-10, "p[{i},{j}]={} is negative", p[[i,j]]);
}
}
}
}