use scirs2_core::RngExt;
use tensorlogic_quantrs_hooks::{BaumWelchLearner, SimpleHMM};
fn main() -> anyhow::Result<()> {
println!("=== HMM Parameter Learning Example ===\n");
println!("Creating true weather model...");
let true_hmm = create_true_weather_model();
println!("True Model Parameters:");
println!("---------------------");
print_hmm_parameters(&true_hmm);
println!();
println!("Generating observation sequences from true model...");
let num_sequences = 50;
let sequence_length = 20;
let observation_sequences = generate_observations(&true_hmm, num_sequences, sequence_length);
println!(
"Generated {} sequences of length {}",
num_sequences, sequence_length
);
println!("Sample sequence: {:?}", observation_sequences[0]);
println!();
println!("Initializing random HMM for learning...");
let mut learned_hmm = SimpleHMM::new_random(2, 3);
println!("Initial (Random) Parameters:");
println!("---------------------------");
print_hmm_parameters(&learned_hmm);
println!();
println!("=== Learning Parameters with Baum-Welch ===\n");
let learner = BaumWelchLearner::with_verbose(100, 1e-4);
let final_log_likelihood = learner.learn(&mut learned_hmm, &observation_sequences)?;
println!("\nFinal log-likelihood: {:.4}", final_log_likelihood);
println!();
println!("Learned Model Parameters:");
println!("------------------------");
print_hmm_parameters(&learned_hmm);
println!();
println!("=== Parameter Comparison ===\n");
compare_parameters(&true_hmm, &learned_hmm);
println!("\n=== Testing Learned Model ===\n");
test_model_predictions(&true_hmm, &learned_hmm);
println!("\n✓ Parameter learning completed successfully!");
Ok(())
}
fn create_true_weather_model() -> SimpleHMM {
use scirs2_core::ndarray::{Array1, Array2};
let mut hmm = SimpleHMM::new(2, 3);
hmm.initial_distribution = Array1::from_vec(vec![0.6, 0.4]);
hmm.transition_probabilities = Array2::from_shape_vec(
(2, 2),
vec![
0.8, 0.2, 0.4, 0.6, ],
)
.expect("Failed to create transition_probabilities array");
hmm.emission_probabilities = Array2::from_shape_vec(
(2, 3),
vec![
0.6, 0.3, 0.1, 0.1, 0.2, 0.7, ],
)
.expect("Failed to create emission_probabilities array");
hmm
}
fn generate_observations(hmm: &SimpleHMM, num_sequences: usize, length: usize) -> Vec<Vec<usize>> {
use scirs2_core::random::thread_rng;
let mut rng = thread_rng();
let mut sequences = Vec::new();
for _ in 0..num_sequences {
let mut sequence = Vec::new();
let mut state = sample_discrete(&hmm.initial_distribution.to_vec(), &mut rng);
for _ in 0..length {
let emission_probs = hmm.emission_probabilities.row(state).to_vec();
let observation = sample_discrete(&emission_probs, &mut rng);
sequence.push(observation);
let transition_probs = hmm.transition_probabilities.row(state).to_vec();
state = sample_discrete(&transition_probs, &mut rng);
}
sequences.push(sequence);
}
sequences
}
fn sample_discrete(probs: &[f64], rng: &mut impl scirs2_core::Rng) -> usize {
let u: f64 = rng.random();
let mut cumsum = 0.0;
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if u < cumsum {
return i;
}
}
probs.len() - 1
}
fn print_hmm_parameters(hmm: &SimpleHMM) {
println!("Initial Distribution:");
println!(" P(Sunny) = {:.3}", hmm.initial_distribution[0]);
println!(" P(Rainy) = {:.3}", hmm.initial_distribution[1]);
println!("\nTransition Probabilities:");
println!(
" Sunny -> Sunny: {:.3}",
hmm.transition_probabilities[[0, 0]]
);
println!(
" Sunny -> Rainy: {:.3}",
hmm.transition_probabilities[[0, 1]]
);
println!(
" Rainy -> Sunny: {:.3}",
hmm.transition_probabilities[[1, 0]]
);
println!(
" Rainy -> Rainy: {:.3}",
hmm.transition_probabilities[[1, 1]]
);
println!("\nEmission Probabilities:");
println!(
" Sunny -> Walk: {:.3}",
hmm.emission_probabilities[[0, 0]]
);
println!(
" Sunny -> Shop: {:.3}",
hmm.emission_probabilities[[0, 1]]
);
println!(
" Sunny -> Clean: {:.3}",
hmm.emission_probabilities[[0, 2]]
);
println!(
" Rainy -> Walk: {:.3}",
hmm.emission_probabilities[[1, 0]]
);
println!(
" Rainy -> Shop: {:.3}",
hmm.emission_probabilities[[1, 1]]
);
println!(
" Rainy -> Clean: {:.3}",
hmm.emission_probabilities[[1, 2]]
);
}
fn compare_parameters(true_hmm: &SimpleHMM, learned_hmm: &SimpleHMM) {
println!("Parameter Errors (Absolute Difference):\n");
let init_error_0 =
(true_hmm.initial_distribution[0] - learned_hmm.initial_distribution[0]).abs();
let init_error_1 =
(true_hmm.initial_distribution[1] - learned_hmm.initial_distribution[1]).abs();
println!("Initial Distribution:");
println!(
" P(Sunny): true={:.3}, learned={:.3}, error={:.3}",
true_hmm.initial_distribution[0], learned_hmm.initial_distribution[0], init_error_0
);
println!(
" P(Rainy): true={:.3}, learned={:.3}, error={:.3}",
true_hmm.initial_distribution[1], learned_hmm.initial_distribution[1], init_error_1
);
println!("\nTransition Probabilities:");
for i in 0..2 {
for j in 0..2 {
let state_names = ["Sunny", "Rainy"];
let true_val = true_hmm.transition_probabilities[[i, j]];
let learned_val = learned_hmm.transition_probabilities[[i, j]];
let error = (true_val - learned_val).abs();
println!(
" {} -> {}: true={:.3}, learned={:.3}, error={:.3}",
state_names[i], state_names[j], true_val, learned_val, error
);
}
}
println!("\nEmission Probabilities:");
let states = ["Sunny", "Rainy"];
let observations = ["Walk", "Shop", "Clean"];
for (i, state) in states.iter().enumerate() {
for (j, obs) in observations.iter().enumerate() {
let true_val = true_hmm.emission_probabilities[[i, j]];
let learned_val = learned_hmm.emission_probabilities[[i, j]];
let error = (true_val - learned_val).abs();
println!(
" {} -> {}: true={:.3}, learned={:.3}, error={:.3}",
state, obs, true_val, learned_val, error
);
}
}
let mut total_error = 0.0;
let mut count = 0;
for i in 0..2 {
total_error +=
(true_hmm.initial_distribution[i] - learned_hmm.initial_distribution[i]).abs();
count += 1;
}
for i in 0..2 {
for j in 0..2 {
total_error += (true_hmm.transition_probabilities[[i, j]]
- learned_hmm.transition_probabilities[[i, j]])
.abs();
count += 1;
}
}
for i in 0..2 {
for j in 0..3 {
total_error += (true_hmm.emission_probabilities[[i, j]]
- learned_hmm.emission_probabilities[[i, j]])
.abs();
count += 1;
}
}
let avg_error = total_error / count as f64;
println!("\nAverage absolute error: {:.4}", avg_error);
}
fn test_model_predictions(true_hmm: &SimpleHMM, learned_hmm: &SimpleHMM) {
let test_sequence = vec![0, 0, 1, 2, 2, 0];
println!("Test sequence: {:?}", test_sequence);
println!("(0=Walk, 1=Shop, 2=Clean)\n");
println!("Comparing likelihood under both models:");
let true_likelihood = compute_sequence_likelihood(true_hmm, &test_sequence);
let learned_likelihood = compute_sequence_likelihood(learned_hmm, &test_sequence);
println!(
" True model log-likelihood: {:.4}",
true_likelihood.ln()
);
println!(
" Learned model log-likelihood: {:.4}",
learned_likelihood.ln()
);
println!(
" Difference: {:.4}",
(true_likelihood - learned_likelihood).ln()
);
}
fn compute_sequence_likelihood(hmm: &SimpleHMM, sequence: &[usize]) -> f64 {
let num_states = hmm.num_states;
let mut alpha = vec![0.0; num_states];
for (s, alpha_val) in alpha.iter_mut().enumerate().take(num_states) {
*alpha_val = hmm.initial_distribution[s] * hmm.emission_probabilities[[s, sequence[0]]];
}
for &obs in sequence.iter().skip(1) {
let mut new_alpha = vec![0.0; num_states];
for (s2, new_val) in new_alpha.iter_mut().enumerate().take(num_states) {
let mut sum = 0.0;
for (s1, &alpha_val) in alpha.iter().enumerate().take(num_states) {
sum += alpha_val * hmm.transition_probabilities[[s1, s2]];
}
*new_val = sum * hmm.emission_probabilities[[s2, obs]];
}
alpha = new_alpha;
}
alpha.iter().sum()
}