use approx::assert_relative_eq;
use ndarray::array;
use rhmm::algorithms::{backward_algorithm, forward_algorithm, viterbi_algorithm};
use rhmm::base::HiddenMarkovModel;
use rhmm::models::GaussianHMM;
use rhmm::utils::{normalize_vector, validate_probability_vector};
#[test]
fn test_gaussian_hmm_workflow() {
let mut hmm = GaussianHMM::new(2);
let observations = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
assert!(hmm.fit(&observations, None).is_ok());
assert!(hmm.is_fitted());
let test_obs = array![[1.5, 2.5], [3.5, 4.5]];
let states = hmm.predict(&test_obs);
assert!(states.is_ok());
let score = hmm.score(&test_obs);
assert!(score.is_ok());
}
#[test]
fn test_forward_backward_consistency() {
let start_prob = array![0.6, 0.4];
let transition_matrix = array![[0.7, 0.3], [0.4, 0.6]];
let emission_probs = array![[0.9, 0.1], [0.8, 0.2], [0.7, 0.3]];
let alpha = forward_algorithm(&start_prob, &transition_matrix, &emission_probs).unwrap();
let beta = backward_algorithm(&transition_matrix, &emission_probs).unwrap();
assert_eq!(alpha.shape(), beta.shape());
assert_eq!(alpha.nrows(), 3);
assert_eq!(alpha.ncols(), 2);
}
#[test]
fn test_viterbi_with_known_path() {
let start_prob = array![1.0, 0.0];
let transition_matrix = array![[0.9, 0.1], [0.1, 0.9]];
let emission_probs = array![[0.9, 0.1], [0.9, 0.1], [0.1, 0.9]];
let (_log_prob, path) =
viterbi_algorithm(&start_prob, &transition_matrix, &emission_probs).unwrap();
assert_eq!(path[0], 0);
assert_eq!(path[1], 0);
}
#[test]
fn test_normalization_and_validation() {
let vec = array![1.0, 2.0, 3.0];
let normalized = normalize_vector(vec);
assert_relative_eq!(normalized.sum(), 1.0, epsilon = 1e-10);
assert!(validate_probability_vector(&normalized, "test").is_ok());
}
#[test]
fn test_multiple_sequences() {
let mut hmm = GaussianHMM::new(2);
let observations = array![[1.0, 2.0], [2.0, 3.0], [5.0, 6.0], [6.0, 7.0]];
let lengths = vec![2, 2];
assert!(hmm.fit(&observations, Some(&lengths)).is_ok());
}
#[test]
fn test_error_handling() {
let mut hmm = GaussianHMM::new(2);
let empty_obs = ndarray::Array2::<f64>::zeros((0, 2));
assert!(hmm.fit(&empty_obs, None).is_err());
let empty_cols = array![[]];
assert!(hmm.fit(&empty_cols, None).is_err());
let obs = array![[1.0, 2.0]];
assert!(hmm.predict(&obs).is_err());
}