use crate::{Factor, FactorGraph, PgmError, Result};
use scirs2_core::ndarray::{Array1, Array2};
pub trait FeatureFunction: Send + Sync {
fn compute(
&self,
prev_label: Option<usize>,
curr_label: usize,
input_sequence: &[usize],
position: usize,
) -> f64;
fn name(&self) -> &str;
}
pub struct LinearChainCRF {
num_states: usize,
features: Vec<(Box<dyn FeatureFunction>, f64)>,
transition_weights: Option<Array2<f64>>,
emission_weights: Option<Array2<f64>>,
}
impl LinearChainCRF {
pub fn new(num_states: usize) -> Self {
Self {
num_states,
features: Vec::new(),
transition_weights: None,
emission_weights: None,
}
}
pub fn add_feature(&mut self, feature: Box<dyn FeatureFunction>, weight: f64) {
self.features.push((feature, weight));
}
pub fn set_transition_weights(&mut self, weights: Array2<f64>) -> Result<()> {
if weights.shape() != [self.num_states, self.num_states] {
return Err(PgmError::DimensionMismatch {
expected: vec![self.num_states, self.num_states],
got: weights.shape().to_vec(),
});
}
self.transition_weights = Some(weights);
Ok(())
}
pub fn set_emission_weights(&mut self, weights: Array2<f64>) -> Result<()> {
if weights.shape()[0] != self.num_states {
return Err(PgmError::DimensionMismatch {
expected: vec![self.num_states, weights.shape()[1]],
got: weights.shape().to_vec(),
});
}
self.emission_weights = Some(weights);
Ok(())
}
fn compute_feature_scores(&self, input_sequence: &[usize], position: usize) -> Array2<f64> {
let mut scores = Array2::zeros((self.num_states, self.num_states));
for prev_state in 0..self.num_states {
for curr_state in 0..self.num_states {
let mut score = 0.0;
for (feature, weight) in &self.features {
let feat_val =
feature.compute(Some(prev_state), curr_state, input_sequence, position);
score += weight * feat_val;
}
scores[[prev_state, curr_state]] = score;
}
}
scores
}
fn compute_emission_scores(&self, input_sequence: &[usize], position: usize) -> Array1<f64> {
let mut scores = Array1::zeros(self.num_states);
for state in 0..self.num_states {
let mut score = 0.0;
for (feature, weight) in &self.features {
let feat_val = feature.compute(None, state, input_sequence, position);
score += weight * feat_val;
}
if let Some(ref emission_weights) = self.emission_weights {
if position < input_sequence.len() {
let obs = input_sequence[position];
if obs < emission_weights.shape()[1] {
score += emission_weights[[state, obs]];
}
}
}
scores[state] = score;
}
scores
}
pub fn viterbi(&self, input_sequence: &[usize]) -> Result<(Vec<usize>, f64)> {
if input_sequence.is_empty() {
return Err(PgmError::InvalidGraph("Empty input sequence".to_string()));
}
let seq_len = input_sequence.len();
let mut viterbi_table = Array2::zeros((seq_len, self.num_states));
let mut backpointers = Array2::zeros((seq_len, self.num_states));
let emission_scores = self.compute_emission_scores(input_sequence, 0);
for state in 0..self.num_states {
viterbi_table[[0, state]] = emission_scores[state];
}
for t in 1..seq_len {
let emission_scores = self.compute_emission_scores(input_sequence, t);
let transition_scores = if let Some(ref weights) = self.transition_weights {
weights.clone()
} else {
self.compute_feature_scores(input_sequence, t)
};
for curr_state in 0..self.num_states {
let mut max_score = f64::NEG_INFINITY;
let mut best_prev_state = 0;
for prev_state in 0..self.num_states {
let score = viterbi_table[[t - 1, prev_state]]
+ transition_scores[[prev_state, curr_state]]
+ emission_scores[curr_state];
if score > max_score {
max_score = score;
best_prev_state = prev_state;
}
}
viterbi_table[[t, curr_state]] = max_score;
backpointers[[t, curr_state]] = best_prev_state as f64;
}
}
let mut best_final_state = 0;
let mut best_final_score = f64::NEG_INFINITY;
for state in 0..self.num_states {
let score = viterbi_table[[seq_len - 1, state]];
if score > best_final_score {
best_final_score = score;
best_final_state = state;
}
}
let mut path = vec![0; seq_len];
path[seq_len - 1] = best_final_state;
for t in (1..seq_len).rev() {
path[t - 1] = backpointers[[t, path[t]]] as usize;
}
Ok((path, best_final_score))
}
pub fn forward(&self, input_sequence: &[usize]) -> Result<Array2<f64>> {
if input_sequence.is_empty() {
return Err(PgmError::InvalidGraph("Empty input sequence".to_string()));
}
let seq_len = input_sequence.len();
let mut alpha = Array2::zeros((seq_len, self.num_states));
let emission_scores = self.compute_emission_scores(input_sequence, 0);
for state in 0..self.num_states {
alpha[[0, state]] = emission_scores[state].exp();
}
let init_sum: f64 = alpha.row(0).sum();
if init_sum > 0.0 {
for state in 0..self.num_states {
alpha[[0, state]] /= init_sum;
}
}
for t in 1..seq_len {
let emission_scores = self.compute_emission_scores(input_sequence, t);
let transition_scores = if let Some(ref weights) = self.transition_weights {
weights.clone()
} else {
self.compute_feature_scores(input_sequence, t)
};
for curr_state in 0..self.num_states {
let mut sum = 0.0;
for prev_state in 0..self.num_states {
sum += alpha[[t - 1, prev_state]]
* (transition_scores[[prev_state, curr_state]]
+ emission_scores[curr_state])
.exp();
}
alpha[[t, curr_state]] = sum;
}
let row_sum: f64 = alpha.row(t).sum();
if row_sum > 0.0 {
for state in 0..self.num_states {
alpha[[t, state]] /= row_sum;
}
}
}
Ok(alpha)
}
pub fn backward(&self, input_sequence: &[usize]) -> Result<Array2<f64>> {
if input_sequence.is_empty() {
return Err(PgmError::InvalidGraph("Empty input sequence".to_string()));
}
let seq_len = input_sequence.len();
let mut beta = Array2::zeros((seq_len, self.num_states));
for state in 0..self.num_states {
beta[[seq_len - 1, state]] = 1.0;
}
for t in (0..seq_len - 1).rev() {
let emission_scores = self.compute_emission_scores(input_sequence, t + 1);
let transition_scores = if let Some(ref weights) = self.transition_weights {
weights.clone()
} else {
self.compute_feature_scores(input_sequence, t + 1)
};
for curr_state in 0..self.num_states {
let mut sum = 0.0;
for next_state in 0..self.num_states {
sum += beta[[t + 1, next_state]]
* (transition_scores[[curr_state, next_state]]
+ emission_scores[next_state])
.exp();
}
beta[[t, curr_state]] = sum;
}
let row_sum: f64 = beta.row(t).sum();
if row_sum > 0.0 {
for state in 0..self.num_states {
beta[[t, state]] /= row_sum;
}
}
}
Ok(beta)
}
pub fn marginals(&self, input_sequence: &[usize]) -> Result<Array2<f64>> {
let alpha = self.forward(input_sequence)?;
let beta = self.backward(input_sequence)?;
let seq_len = input_sequence.len();
let mut marginals = Array2::zeros((seq_len, self.num_states));
for t in 0..seq_len {
for state in 0..self.num_states {
marginals[[t, state]] = alpha[[t, state]] * beta[[t, state]];
}
let row_sum: f64 = marginals.row(t).sum();
if row_sum > 0.0 {
for state in 0..self.num_states {
marginals[[t, state]] /= row_sum;
}
}
}
Ok(marginals)
}
pub fn to_factor_graph(&self, input_sequence: &[usize]) -> Result<FactorGraph> {
let mut graph = FactorGraph::new();
let seq_len = input_sequence.len();
for t in 0..seq_len {
graph.add_variable_with_card(format!("y_{}", t), "Label".to_string(), self.num_states);
}
for t in 0..seq_len {
let emission_scores = self.compute_emission_scores(input_sequence, t);
let emission_potentials = emission_scores.mapv(|x| x.exp());
let factor = Factor::new(
format!("emission_{}", t),
vec![format!("y_{}", t)],
emission_potentials.into_dyn(),
)?;
graph.add_factor(factor)?;
}
for t in 1..seq_len {
let transition_scores = if let Some(ref weights) = self.transition_weights {
weights.clone()
} else {
self.compute_feature_scores(input_sequence, t)
};
let transition_potentials = transition_scores.mapv(|x| x.exp());
let factor = Factor::new(
format!("transition_{}", t),
vec![format!("y_{}", t - 1), format!("y_{}", t)],
transition_potentials.into_dyn(),
)?;
graph.add_factor(factor)?;
}
Ok(graph)
}
}
pub struct IdentityFeature {
name: String,
}
impl IdentityFeature {
pub fn new(name: String) -> Self {
Self { name }
}
}
impl FeatureFunction for IdentityFeature {
fn compute(
&self,
_prev_label: Option<usize>,
_curr_label: usize,
_input_sequence: &[usize],
_position: usize,
) -> f64 {
1.0
}
fn name(&self) -> &str {
&self.name
}
}
pub struct TransitionFeature {
from_state: usize,
to_state: usize,
name: String,
}
impl TransitionFeature {
pub fn new(from_state: usize, to_state: usize) -> Self {
Self {
from_state,
to_state,
name: format!("transition_{}_{}", from_state, to_state),
}
}
}
impl FeatureFunction for TransitionFeature {
fn compute(
&self,
prev_label: Option<usize>,
curr_label: usize,
_input_sequence: &[usize],
_position: usize,
) -> f64 {
if let Some(prev) = prev_label {
if prev == self.from_state && curr_label == self.to_state {
return 1.0;
}
}
0.0
}
fn name(&self) -> &str {
&self.name
}
}
pub struct EmissionFeature {
state: usize,
observation: usize,
name: String,
}
impl EmissionFeature {
pub fn new(state: usize, observation: usize) -> Self {
Self {
state,
observation,
name: format!("emission_{}_{}", state, observation),
}
}
}
impl FeatureFunction for EmissionFeature {
fn compute(
&self,
_prev_label: Option<usize>,
curr_label: usize,
input_sequence: &[usize],
position: usize,
) -> f64 {
if curr_label == self.state
&& position < input_sequence.len()
&& input_sequence[position] == self.observation
{
return 1.0;
}
0.0
}
fn name(&self) -> &str {
&self.name
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array;
#[test]
fn test_linear_chain_crf_creation() {
let crf = LinearChainCRF::new(3);
assert_eq!(crf.num_states, 3);
assert_eq!(crf.features.len(), 0);
}
#[test]
fn test_add_feature() {
let mut crf = LinearChainCRF::new(2);
let feature = Box::new(IdentityFeature::new("test".to_string()));
crf.add_feature(feature, 1.0);
assert_eq!(crf.features.len(), 1);
}
#[test]
fn test_viterbi_simple() {
let mut crf = LinearChainCRF::new(2);
let transition_weights = Array::from_shape_vec(
vec![2, 2],
vec![1.0, -1.0, -1.0, 1.0], )
.expect("unwrap")
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.expect("unwrap");
crf.set_transition_weights(transition_weights)
.expect("unwrap");
let input_sequence = vec![0, 0, 0];
let (path, _score) = crf.viterbi(&input_sequence).expect("unwrap");
assert_eq!(path.len(), 3);
}
#[test]
fn test_forward_backward() {
let mut crf = LinearChainCRF::new(2);
let transition_weights = Array::from_shape_vec(vec![2, 2], vec![0.0, 0.0, 0.0, 0.0])
.expect("unwrap")
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.expect("unwrap");
crf.set_transition_weights(transition_weights)
.expect("unwrap");
let input_sequence = vec![0, 1];
let alpha = crf.forward(&input_sequence).expect("unwrap");
assert_eq!(alpha.shape(), &[2, 2]);
let beta = crf.backward(&input_sequence).expect("unwrap");
assert_eq!(beta.shape(), &[2, 2]);
for t in 0..2 {
let sum: f64 = alpha.row(t).sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
}
}
#[test]
fn test_marginals() {
let mut crf = LinearChainCRF::new(2);
let transition_weights = Array::from_shape_vec(vec![2, 2], vec![0.0, 0.0, 0.0, 0.0])
.expect("unwrap")
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.expect("unwrap");
crf.set_transition_weights(transition_weights)
.expect("unwrap");
let input_sequence = vec![0, 1];
let marginals = crf.marginals(&input_sequence).expect("unwrap");
assert_eq!(marginals.shape(), &[2, 2]);
for t in 0..2 {
let sum: f64 = marginals.row(t).sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
}
}
#[test]
fn test_transition_feature() {
let feature = TransitionFeature::new(0, 1);
let val = feature.compute(Some(0), 1, &[0, 1], 1);
assert_abs_diff_eq!(val, 1.0, epsilon = 1e-10);
let val = feature.compute(Some(0), 0, &[0, 1], 1);
assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
}
#[test]
fn test_emission_feature() {
let feature = EmissionFeature::new(0, 5);
let val = feature.compute(None, 0, &[5, 3], 0);
assert_abs_diff_eq!(val, 1.0, epsilon = 1e-10);
let val = feature.compute(None, 0, &[3, 5], 0);
assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
let val = feature.compute(None, 1, &[5, 3], 0);
assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
}
#[test]
fn test_to_factor_graph() {
let mut crf = LinearChainCRF::new(2);
let transition_weights = Array::from_shape_vec(vec![2, 2], vec![0.5, 0.5, 0.5, 0.5])
.expect("unwrap")
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.expect("unwrap");
crf.set_transition_weights(transition_weights)
.expect("unwrap");
let input_sequence = vec![0, 1, 0];
let graph = crf.to_factor_graph(&input_sequence).expect("unwrap");
assert_eq!(graph.num_variables(), 3);
assert_eq!(graph.num_factors(), 5);
}
}