use std::hash::Hash;
use crate::composition::{compose, materialize};
use crate::semiring::{NumericalWeight, PowerWeight};
use crate::wfst::{MutableWfst, StateId, VectorWfst, Wfst};
use super::push::{push_weights, PushConfig};
use super::sample::{sample_path, SampleConfig, SampleError, SampledPath};
#[derive(Clone, Debug)]
pub struct RrwmConfig {
pub eta: f64,
pub learning_rate: f64,
pub max_rounds: usize,
pub track_statistics: bool,
pub seed: Option<u64>,
}
impl Default for RrwmConfig {
fn default() -> Self {
Self {
eta: 1.0,
learning_rate: 1.0,
max_rounds: 100_000,
track_statistics: false,
seed: None,
}
}
}
impl RrwmConfig {
pub fn new() -> Self {
Self::default()
}
pub fn eta(mut self, eta: f64) -> Self {
self.eta = eta;
self
}
pub fn learning_rate(mut self, rate: f64) -> Self {
self.learning_rate = rate;
self
}
pub fn max_rounds(mut self, rounds: usize) -> Self {
self.max_rounds = rounds;
self
}
pub fn with_statistics(mut self) -> Self {
self.track_statistics = true;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
}
#[derive(Clone, Debug, Default)]
pub struct RrwmStatistics {
pub total_loss: f64,
pub rounds: usize,
pub average_loss: f64,
pub cumulative_states: usize,
pub loss_history: Vec<f64>,
}
impl RrwmStatistics {
fn update(&mut self, loss: f64, num_states: usize) {
self.total_loss += loss;
self.rounds += 1;
self.average_loss = self.total_loss / self.rounds as f64;
self.cumulative_states = num_states;
}
}
#[derive(Clone, Debug)]
pub enum RrwmError {
MaxRoundsExceeded,
PushFailed(String),
SampleFailed(SampleError),
EmptyComposition,
ConfigError(String),
}
impl std::fmt::Display for RrwmError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MaxRoundsExceeded => write!(f, "Maximum rounds exceeded"),
Self::PushFailed(e) => write!(f, "Weight pushing failed: {}", e),
Self::SampleFailed(e) => write!(f, "Sampling failed: {}", e),
Self::EmptyComposition => write!(f, "Composition produced empty result"),
Self::ConfigError(e) => write!(f, "Configuration error: {}", e),
}
}
}
impl std::error::Error for RrwmError {}
impl From<SampleError> for RrwmError {
fn from(e: SampleError) -> Self {
Self::SampleFailed(e)
}
}
pub struct Rrwm<L>
where
L: Clone + Eq + Hash + Send + Sync,
{
config: RrwmConfig,
cumulative: VectorWfst<L, PowerWeight>,
round: usize,
statistics: Option<RrwmStatistics>,
}
impl<L> Rrwm<L>
where
L: Clone + Eq + Hash + Send + Sync + 'static,
{
pub fn new(config: RrwmConfig) -> Self {
let mut cumulative = VectorWfst::new();
let start = cumulative.add_state();
cumulative.set_start(start);
cumulative.set_final(start, PowerWeight::one_with_eta(config.eta));
let statistics = if config.track_statistics {
Some(RrwmStatistics::default())
} else {
None
};
Self {
config,
cumulative,
round: 0,
statistics,
}
}
pub fn round(&self) -> usize {
self.round
}
pub fn eta(&self) -> f64 {
self.config.eta
}
pub fn cumulative_weights(&self) -> &VectorWfst<L, PowerWeight> {
&self.cumulative
}
pub fn statistics(&self) -> Option<&RrwmStatistics> {
self.statistics.as_ref()
}
pub fn observe<T>(&mut self, loss_transducer: T) -> Result<f64, RrwmError>
where
T: Wfst<L, PowerWeight>,
{
if self.round >= self.config.max_rounds {
return Err(RrwmError::MaxRoundsExceeded);
}
let composed = compose(self.cumulative.clone(), loss_transducer);
let mut materialized: VectorWfst<L, PowerWeight> = materialize(composed);
if materialized.is_empty() {
return Err(RrwmError::EmptyComposition);
}
push_weights(&mut materialized, PushConfig::backward())
.map_err(|e| RrwmError::PushFailed(format!("{:?}", e)))?;
let loss = self.extract_loss(&materialized);
self.cumulative = materialized;
self.round += 1;
if let Some(ref mut stats) = self.statistics {
stats.update(loss, self.cumulative.num_states());
if self.config.track_statistics {
stats.loss_history.push(loss);
}
}
Ok(loss)
}
pub fn predict(&self) -> Result<SampledPath<L, PowerWeight>, RrwmError> {
let sample_config = SampleConfig::default().seed(
self.config
.seed
.map(|s| s.wrapping_add(self.round as u64))
.unwrap_or(self.round as u64),
);
sample_path(&self.cumulative, sample_config).map_err(RrwmError::from)
}
pub fn regret_bound(&self, max_loss: f64, num_experts: usize) -> f64 {
if self.round == 0 || num_experts == 0 {
return 0.0;
}
2.0 * max_loss * ((self.round as f64) * (num_experts as f64).ln()).sqrt()
}
pub fn reset(&mut self) {
self.cumulative = VectorWfst::new();
let start = self.cumulative.add_state();
self.cumulative.set_start(start);
self.cumulative
.set_final(start, PowerWeight::one_with_eta(self.config.eta));
self.round = 0;
if let Some(ref mut stats) = self.statistics {
*stats = RrwmStatistics::default();
}
}
fn extract_loss(&self, wfst: &VectorWfst<L, PowerWeight>) -> f64 {
let mut total = 0.0;
for state in 0..wfst.num_states() as StateId {
if wfst.is_final(state) {
total += wfst.final_weight(state).numerical_value();
}
}
total
}
}
pub struct RrwmBuilder<L>
where
L: Clone + Eq + Hash + Send + Sync,
{
config: RrwmConfig,
initial_experts: Vec<VectorWfst<L, PowerWeight>>,
}
impl<L> RrwmBuilder<L>
where
L: Clone + Eq + Hash + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
config: RrwmConfig::default(),
initial_experts: Vec::new(),
}
}
pub fn config(mut self, config: RrwmConfig) -> Self {
self.config = config;
self
}
pub fn eta(mut self, eta: f64) -> Self {
self.config.eta = eta;
self
}
pub fn add_expert(mut self, expert: VectorWfst<L, PowerWeight>) -> Self {
self.initial_experts.push(expert);
self
}
pub fn build(self) -> Rrwm<L> {
let mut rrwm = Rrwm::new(self.config);
for expert in self.initial_experts {
if rrwm.cumulative.num_states() == 1 {
rrwm.cumulative = expert;
}
}
rrwm
}
}
impl<L> Default for RrwmBuilder<L>
where
L: Clone + Eq + Hash + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wfst::MutableWfst;
fn make_simple_expert() -> VectorWfst<char, PowerWeight> {
let mut wfst = VectorWfst::new();
let s0 = wfst.add_state();
let s1 = wfst.add_state();
wfst.set_start(s0);
wfst.set_final(s1, PowerWeight::one_with_eta(1.0));
wfst.add_arc(
s0,
Some('a'),
Some('a'),
s1,
PowerWeight::from_probability(0.8, 1.0),
);
wfst
}
#[test]
fn test_rrwm_creation() {
let rrwm = Rrwm::<char>::new(RrwmConfig::default());
assert_eq!(rrwm.round(), 0);
assert_eq!(rrwm.eta(), 1.0);
assert_eq!(rrwm.cumulative_weights().num_states(), 1);
}
#[test]
fn test_rrwm_config() {
let config = RrwmConfig::default()
.eta(2.0)
.learning_rate(0.5)
.max_rounds(1000)
.with_statistics()
.seed(42);
assert_eq!(config.eta, 2.0);
assert_eq!(config.learning_rate, 0.5);
assert_eq!(config.max_rounds, 1000);
assert!(config.track_statistics);
assert_eq!(config.seed, Some(42));
}
#[test]
fn test_rrwm_builder() {
let expert = make_simple_expert();
let rrwm = RrwmBuilder::<char>::new()
.eta(2.0)
.add_expert(expert)
.build();
assert_eq!(rrwm.eta(), 2.0);
}
#[test]
fn test_regret_bound() {
let mut rrwm = Rrwm::<char>::new(RrwmConfig::default());
assert_eq!(rrwm.regret_bound(1.0, 10), 0.0);
rrwm.round = 100;
let bound = rrwm.regret_bound(1.0, 10);
assert!(bound > 0.0);
assert!(bound < 35.0); }
#[test]
fn test_rrwm_reset() {
let mut rrwm = Rrwm::<char>::new(RrwmConfig::default().with_statistics());
rrwm.round = 10;
rrwm.reset();
assert_eq!(rrwm.round(), 0);
assert_eq!(rrwm.cumulative_weights().num_states(), 1);
assert_eq!(
rrwm.statistics()
.expect("algorithms/rrwm.rs: required value was None/Err")
.rounds,
0
);
}
#[test]
fn test_rrwm_statistics() {
let rrwm = Rrwm::<char>::new(RrwmConfig::default().with_statistics());
let stats = rrwm.statistics().expect("Statistics should be enabled");
assert_eq!(stats.rounds, 0);
assert_eq!(stats.total_loss, 0.0);
}
#[test]
fn test_rrwm_predict_initial() {
let rrwm = Rrwm::<char>::new(RrwmConfig::default().seed(42));
let result = rrwm.predict();
assert!(result.is_ok());
}
}