use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
use crate::solve::solve_multiple;
use crate::validation::validate_squarematrix;
fn norm2_1d<F: Float + Sum>(v: &Array1<F>) -> F {
v.iter().map(|&x| x * x).fold(F::zero(), |a, b| a + b).sqrt()
}
fn norm1_1d<F: Float + Sum>(v: &Array1<F>) -> F {
v.iter().map(|&x| x.abs()).fold(F::zero(), |a, b| a + b)
}
fn left_multiply_row<F>(pi: &Array1<F>, p: &ArrayView2<F>) -> Array1<F>
where
F: Float + NumAssign + ScalarOperand,
{
let n = p.ncols();
let mut result = Array1::zeros(n);
for j in 0..n {
let mut sum = F::zero();
for i in 0..n {
sum += pi[i] * p[[i, j]];
}
result[j] = sum;
}
result
}
pub fn is_stochastic<F>(p: &ArrayView2<F>, tol: F) -> bool
where
F: Float + Sum + NumAssign + ScalarOperand,
{
if p.nrows() != p.ncols() {
return false;
}
let n = p.nrows();
for i in 0..n {
for j in 0..n {
if p[[i, j]] < -tol {
return false;
}
}
let row_sum: F = (0..n).map(|j| p[[i, j]]).fold(F::zero(), |a, b| a + b);
if (row_sum - F::one()).abs() > tol {
return false;
}
}
true
}
pub fn is_stochastic_f64(p: &ArrayView2<f64>, tol: f64) -> bool {
is_stochastic(p, tol)
}
pub fn stationary_distribution<F>(
transition_matrix: &ArrayView2<F>,
tol: F,
max_iter: usize,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static + std::fmt::Display,
{
validate_squarematrix(transition_matrix, "Stationary distribution")?;
let n = transition_matrix.nrows();
if !is_stochastic(transition_matrix, F::from(1e-6).unwrap_or(F::epsilon() * F::from(1000.0).unwrap_or(F::one()))) {
return Err(LinalgError::InvalidInputError(
"Transition matrix must be row-stochastic (rows sum to 1, all entries non-negative)".to_string(),
));
}
if max_iter == 0 {
return Err(LinalgError::InvalidInputError(
"max_iter must be positive".to_string(),
));
}
let inv_n = F::one() / F::from(n).ok_or_else(|| {
LinalgError::ComputationError("Cannot convert n to float".to_string())
})?;
let mut pi: Array1<F> = Array1::from_elem(n, inv_n);
for _ in 0..max_iter {
let pi_new = left_multiply_row(&pi, transition_matrix);
let diff: Array1<F> = Array1::from_iter(
pi_new.iter().zip(pi.iter()).map(|(&a, &b)| a - b)
);
let change = norm2_1d(&diff);
pi = pi_new;
let sum: F = pi.iter().copied().fold(F::zero(), |a, b| a + b);
if sum > F::epsilon() {
pi.mapv_inplace(|v| v / sum);
}
if change < tol {
return Ok(pi);
}
}
Err(LinalgError::ConvergenceError(format!(
"Stationary distribution did not converge in {max_iter} iterations"
)))
}
pub fn fundamental_matrix<F>(
transition_matrix: &ArrayView2<F>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static + std::fmt::Display,
{
validate_squarematrix(transition_matrix, "Fundamental matrix")?;
let n = transition_matrix.nrows();
let pi = stationary_distribution(transition_matrix, F::from(1e-12).unwrap_or(F::epsilon()), 50_000)?;
let mut a = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
let delta_ij = if i == j { F::one() } else { F::zero() };
a[[i, j]] = delta_ij - transition_matrix[[i, j]] + pi[j];
}
}
let identity = Array2::eye(n);
let z = solve_multiple(&a.view(), &identity.view(), None)?;
Ok(z)
}
pub fn mean_first_passage_time<F>(
transition_matrix: &ArrayView2<F>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static + std::fmt::Display,
{
validate_squarematrix(transition_matrix, "Mean first passage time")?;
let n = transition_matrix.nrows();
let pi = stationary_distribution(transition_matrix, F::from(1e-12).unwrap_or(F::epsilon()), 50_000)?;
let z = fundamental_matrix(transition_matrix)?;
let mut m = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
if pi[j] <= F::epsilon() {
m[[i, j]] = F::infinity();
} else {
m[[i, j]] = (z[[j, j]] - z[[i, j]]) / pi[j];
}
}
}
for j in 0..n {
if pi[j] > F::epsilon() {
m[[j, j]] = F::one() / pi[j];
} else {
m[[j, j]] = F::infinity();
}
}
Ok(m)
}
pub fn mixing_time<F>(
transition_matrix: &ArrayView2<F>,
epsilon: F,
) -> LinalgResult<usize>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static + std::fmt::Display,
{
validate_squarematrix(transition_matrix, "Mixing time")?;
if epsilon <= F::zero() || epsilon >= F::one() {
return Err(LinalgError::InvalidInputError(
"epsilon must be in (0, 1)".to_string(),
));
}
let n = transition_matrix.nrows();
let pi_res = stationary_distribution(transition_matrix, F::from(1e-12).unwrap_or(F::epsilon()), 50_000);
let pi = pi_res.unwrap_or_else(|_| {
Array1::from_elem(n, F::one() / F::from(n).unwrap_or(F::one()))
});
let mut p_centered = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
p_centered[[i, j]] = transition_matrix[[i, j]] - pi[j];
}
}
use crate::decomposition::svd;
let (_, s, _) = svd(&p_centered.view(), false, None)?;
let lambda2 = if s.len() > 1 { s[0] } else { F::zero() };
let gap = F::one() - lambda2;
if gap <= F::epsilon() {
return Err(LinalgError::ComputationError(
"Spectral gap is zero or negative; chain may not be ergodic".to_string(),
));
}
let ln_inv_eps = (-epsilon.ln()).max(F::zero());
let t_f = ln_inv_eps / gap;
let t = t_f.ceil();
let t_usize = t.to_usize().unwrap_or(usize::MAX).max(1);
Ok(t_usize)
}
#[derive(Debug, Clone)]
pub struct AbsorbingMarkovChain {
pub absorbing_states: Vec<usize>,
pub transient_states: Vec<usize>,
}
pub fn analyze_absorbing_chain<F>(
transition_matrix: &ArrayView2<F>,
) -> LinalgResult<AbsorbingMarkovChain>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static + std::fmt::Display,
{
validate_squarematrix(transition_matrix, "Absorbing chain analysis")?;
let n = transition_matrix.nrows();
let tol_stoch = F::from(1e-6).unwrap_or(F::epsilon());
for i in 0..n {
let row_sum: F = (0..n).map(|j| transition_matrix[[i, j]]).fold(F::zero(), |a, b| a + b);
if (row_sum - F::one()).abs() > tol_stoch {
return Err(LinalgError::InvalidInputError(format!(
"Row {i} of transition matrix does not sum to 1 (sum = {row_sum})"
)));
}
}
let tol = F::from(1e-10).unwrap_or(F::epsilon());
let mut absorbing = Vec::new();
let mut transient = Vec::new();
for i in 0..n {
if (transition_matrix[[i, i]] - F::one()).abs() <= tol {
let off_diag_sum: F = (0..n)
.filter(|&j| j != i)
.map(|j| transition_matrix[[i, j]].abs())
.fold(F::zero(), |a, b| a + b);
if off_diag_sum <= tol {
absorbing.push(i);
continue;
}
}
transient.push(i);
}
Ok(AbsorbingMarkovChain {
absorbing_states: absorbing,
transient_states: transient,
})
}
pub fn absorption_probabilities<F>(
transition_matrix: &ArrayView2<F>,
absorbing_chain: &AbsorbingMarkovChain,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static + std::fmt::Display,
{
validate_squarematrix(transition_matrix, "Absorption probabilities")?;
let t_states = &absorbing_chain.transient_states;
let a_states = &absorbing_chain.absorbing_states;
if t_states.is_empty() {
return Err(LinalgError::InvalidInputError(
"No transient states found in the chain".to_string(),
));
}
if a_states.is_empty() {
return Err(LinalgError::InvalidInputError(
"No absorbing states found in the chain".to_string(),
));
}
let nt = t_states.len();
let na = a_states.len();
let mut q = Array2::zeros((nt, nt));
for (i, &ti) in t_states.iter().enumerate() {
for (j, &tj) in t_states.iter().enumerate() {
q[[i, j]] = transition_matrix[[ti, tj]];
}
}
let mut r = Array2::zeros((nt, na));
for (i, &ti) in t_states.iter().enumerate() {
for (j, &aj) in a_states.iter().enumerate() {
r[[i, j]] = transition_matrix[[ti, aj]];
}
}
let mut i_minus_q: Array2<F> = Array2::eye(nt);
for i in 0..nt {
for j in 0..nt {
i_minus_q[[i, j]] -= q[[i, j]];
}
}
let b = solve_multiple(&i_minus_q.view(), &r.view(), None)?;
Ok(b)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
fn two_state_chain() -> Array2<f64> {
array![[0.7_f64, 0.3], [0.4, 0.6]]
}
fn three_state_ergodic() -> Array2<f64> {
array![
[0.5_f64, 0.3, 0.2],
[0.2, 0.6, 0.2],
[0.3, 0.2, 0.5]
]
}
fn absorbing_chain_3state() -> Array2<f64> {
array![
[1.0_f64, 0.0, 0.0],
[0.3, 0.4, 0.3],
[0.0, 0.0, 1.0],
]
}
#[test]
fn test_is_stochastic_valid() {
let p = two_state_chain();
assert!(is_stochastic(&p.view(), 1e-10));
}
#[test]
fn test_is_stochastic_invalid_row_sum() {
let p = array![[0.7_f64, 0.4], [0.4, 0.6]]; assert!(!is_stochastic(&p.view(), 1e-10));
}
#[test]
fn test_is_stochastic_negative_entry() {
let p = array![[0.7_f64, 0.3], [-0.1, 1.1]];
assert!(!is_stochastic(&p.view(), 1e-10));
}
#[test]
fn test_is_stochastic_non_square() {
let p = array![[0.5_f64, 0.5, 0.0], [0.3, 0.4, 0.3]];
assert!(!is_stochastic(&p.view(), 1e-10));
}
#[test]
fn test_stationary_distribution_two_state() {
let p = two_state_chain();
let pi = stationary_distribution(&p.view(), 1e-12, 10_000)
.expect("Must converge");
assert_relative_eq!(pi[0], 4.0 / 7.0, epsilon = 1e-8);
assert_relative_eq!(pi[1], 3.0 / 7.0, epsilon = 1e-8);
}
#[test]
fn test_stationary_distribution_three_state() {
let p = three_state_ergodic();
let pi = stationary_distribution(&p.view(), 1e-12, 10_000)
.expect("Must converge");
let sum: f64 = pi.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
for &v in pi.iter() {
assert!(v > 0.0);
}
let pi_p: Array1<f64> = {
let n = 3;
let mut r = Array1::zeros(n);
for j in 0..n {
r[j] = (0..n).map(|i| pi[i] * p[[i, j]]).sum::<f64>();
}
r
};
for i in 0..3 {
assert_relative_eq!(pi[i], pi_p[i], epsilon = 1e-8);
}
}
#[test]
fn test_stationary_distribution_uniform() {
let p = array![[0.5_f64, 0.5], [0.5, 0.5]];
let pi = stationary_distribution(&p.view(), 1e-12, 10_000)
.expect("Must converge");
assert_relative_eq!(pi[0], 0.5, epsilon = 1e-8);
assert_relative_eq!(pi[1], 0.5, epsilon = 1e-8);
}
#[test]
fn test_fundamental_matrix_two_state() {
let p = two_state_chain();
let z = fundamental_matrix(&p.view()).expect("Must succeed");
assert_eq!(z.nrows(), 2);
assert_eq!(z.ncols(), 2);
}
#[test]
fn test_fundamental_matrix_identity_row() {
let p = two_state_chain();
let z = fundamental_matrix(&p.view()).expect("Must succeed");
let pi = stationary_distribution(&p.view(), 1e-12, 10_000).expect("failed to create pi");
let n = 2;
let mut a = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
a[[i, j]] = (if i == j { 1.0 } else { 0.0 }) - p[[i, j]] + pi[j];
}
}
let mut za = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
za[[i, j]] += z[[i, k]] * a[[k, j]];
}
}
}
for i in 0..n {
for j in 0..n {
let expected = if i == j { 1.0 } else { 0.0 };
assert_relative_eq!(za[[i, j]], expected, epsilon = 1e-8);
}
}
}
#[test]
fn test_mfpt_diagonal_is_recurrence_time() {
let p = two_state_chain();
let mfpt = mean_first_passage_time(&p.view()).expect("Must succeed");
let pi = stationary_distribution(&p.view(), 1e-12, 10_000).expect("failed to create pi");
assert_relative_eq!(mfpt[[0, 0]], 1.0 / pi[0], epsilon = 1e-6);
assert_relative_eq!(mfpt[[1, 1]], 1.0 / pi[1], epsilon = 1e-6);
}
#[test]
fn test_mfpt_positive_entries() {
let p = two_state_chain();
let mfpt = mean_first_passage_time(&p.view()).expect("Must succeed");
for i in 0..2 {
for j in 0..2 {
assert!(mfpt[[i, j]] > 0.0, "MFPT entry [{i},{j}] must be positive");
}
}
}
#[test]
fn test_mixing_time_positive() {
let p = two_state_chain();
let t = mixing_time(&p.view(), 0.01).expect("Must succeed");
assert!(t >= 1);
}
#[test]
fn test_mixing_time_faster_for_large_gap() {
let p_fast = array![[0.5_f64, 0.5], [0.5, 0.5]];
let p_slow = array![[0.99_f64, 0.01], [0.01, 0.99]];
let t_fast = mixing_time(&p_fast.view(), 0.05).expect("Must succeed");
let t_slow = mixing_time(&p_slow.view(), 0.05).expect("Must succeed");
assert!(t_slow > t_fast, "Slower chain should need more steps: {t_slow} vs {t_fast}");
}
#[test]
fn test_mixing_time_invalid_epsilon() {
let p = two_state_chain();
assert!(mixing_time(&p.view(), 0.0).is_err());
assert!(mixing_time(&p.view(), 1.0).is_err());
assert!(mixing_time(&p.view(), -0.1).is_err());
}
#[test]
fn test_analyze_absorbing_chain_basic() {
let p = absorbing_chain_3state();
let chain = analyze_absorbing_chain(&p.view()).expect("Must succeed");
assert_eq!(chain.absorbing_states, vec![0, 2]);
assert_eq!(chain.transient_states, vec![1]);
}
#[test]
fn test_analyze_absorbing_chain_no_absorbing() {
let p = two_state_chain();
let chain = analyze_absorbing_chain(&p.view()).expect("Must succeed");
assert!(chain.absorbing_states.is_empty());
assert_eq!(chain.transient_states.len(), 2);
}
#[test]
fn test_analyze_absorbing_chain_all_absorbing() {
let p: Array2<f64> = Array2::eye(3);
let chain = analyze_absorbing_chain(&p.view()).expect("Must succeed");
assert_eq!(chain.absorbing_states.len(), 3);
assert!(chain.transient_states.is_empty());
}
#[test]
fn test_absorption_probs_symmetric() {
let p = absorbing_chain_3state();
let chain = analyze_absorbing_chain(&p.view()).expect("Must succeed");
let b = absorption_probabilities(&p.view(), &chain).expect("Must succeed");
assert_relative_eq!(b[[0, 0]], 0.5, epsilon = 1e-8);
assert_relative_eq!(b[[0, 1]], 0.5, epsilon = 1e-8);
}
#[test]
fn test_absorption_probs_row_sum_to_one() {
let p = absorbing_chain_3state();
let chain = analyze_absorbing_chain(&p.view()).expect("Must succeed");
let b = absorption_probabilities(&p.view(), &chain).expect("Must succeed");
for i in 0..b.nrows() {
let row_sum: f64 = b.row(i).iter().sum();
assert_relative_eq!(row_sum, 1.0, epsilon = 1e-8);
}
}
#[test]
fn test_absorption_probs_shape() {
let p = absorbing_chain_3state();
let chain = analyze_absorbing_chain(&p.view()).expect("Must succeed");
let b = absorption_probabilities(&p.view(), &chain).expect("Must succeed");
assert_eq!(b.nrows(), chain.transient_states.len());
assert_eq!(b.ncols(), chain.absorbing_states.len());
}
#[test]
fn test_absorption_probs_asymmetric() {
let p = array![
[1.0_f64, 0.0, 0.0],
[0.7, 0.0, 0.3],
[0.0, 0.0, 1.0],
];
let chain = analyze_absorbing_chain(&p.view()).expect("Must succeed");
let b = absorption_probabilities(&p.view(), &chain).expect("Must succeed");
assert_relative_eq!(b[[0, 0]], 0.7, epsilon = 1e-8);
assert_relative_eq!(b[[0, 1]], 0.3, epsilon = 1e-8);
}
}