use std::collections::VecDeque;
use std::hash::Hash;
use serde::{Serialize, Deserialize};
use anyhow::{Result, anyhow};
use nalgebra::{DMatrix, DVector};
use rand::Rng;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Regime {
Exploration,
Exploitation,
Crisis,
Transition,
}
impl Regime {
pub fn index(&self) -> usize {
match self {
Regime::Exploration => 0,
Regime::Exploitation => 1,
Regime::Crisis => 2,
Regime::Transition => 3,
}
}
pub fn from_index(idx: usize) -> Result<Self> {
match idx {
0 => Ok(Regime::Exploration),
1 => Ok(Regime::Exploitation),
2 => Ok(Regime::Crisis),
3 => Ok(Regime::Transition),
_ => Err(anyhow!("Invalid regime index: {}", idx)),
}
}
pub fn characteristics(&self) -> RegimeCharacteristics {
match self {
Regime::Exploration => RegimeCharacteristics {
risk_tolerance: 0.8,
cost_sensitivity: 0.3,
quality_focus: 0.6,
speed_priority: 0.4,
innovation_bias: 0.9,
},
Regime::Exploitation => RegimeCharacteristics {
risk_tolerance: 0.2,
cost_sensitivity: 0.7,
quality_focus: 0.9,
speed_priority: 0.8,
innovation_bias: 0.2,
},
Regime::Crisis => RegimeCharacteristics {
risk_tolerance: 0.1,
cost_sensitivity: 0.5,
quality_focus: 0.5,
speed_priority: 0.9,
innovation_bias: 0.1,
},
Regime::Transition => RegimeCharacteristics {
risk_tolerance: 0.5,
cost_sensitivity: 0.5,
quality_focus: 0.7,
speed_priority: 0.6,
innovation_bias: 0.5,
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegimeCharacteristics {
pub risk_tolerance: f64,
pub cost_sensitivity: f64,
pub quality_focus: f64,
pub speed_priority: f64,
pub innovation_bias: f64,
}
pub struct RegimeDetector {
transition_matrix: DMatrix<f64>,
emission_params: Vec<EmissionParams>,
belief_state: DVector<f64>,
observation_history: VecDeque<Vec<f64>>,
max_history: usize,
transition_threshold: f64,
}
#[derive(Debug, Clone)]
struct EmissionParams {
mean: DVector<f64>,
covariance: DMatrix<f64>,
}
impl RegimeDetector {
pub fn new(num_regimes: usize, transition_threshold: f64) -> Result<Self> {
let transition_matrix = Self::init_transition_matrix(num_regimes);
let emission_params = Self::init_emission_params(num_regimes);
let mut belief_state = DVector::from_element(num_regimes, 1.0 / num_regimes as f64);
belief_state[0] = 0.7;
belief_state[1] = 0.2;
belief_state[2] = 0.05;
belief_state[3] = 0.05;
Ok(Self {
transition_matrix,
emission_params,
belief_state,
observation_history: VecDeque::with_capacity(100),
max_history: 100,
transition_threshold,
})
}
fn init_transition_matrix(n: usize) -> DMatrix<f64> {
let mut matrix = DMatrix::zeros(n, n);
matrix[(0, 0)] = 0.7; matrix[(0, 1)] = 0.2; matrix[(0, 2)] = 0.05; matrix[(0, 3)] = 0.05;
matrix[(1, 0)] = 0.1; matrix[(1, 1)] = 0.8; matrix[(1, 2)] = 0.05; matrix[(1, 3)] = 0.05;
matrix[(2, 0)] = 0.05; matrix[(2, 1)] = 0.05; matrix[(2, 2)] = 0.7; matrix[(2, 3)] = 0.2;
matrix[(3, 0)] = 0.25; matrix[(3, 1)] = 0.25; matrix[(3, 2)] = 0.1; matrix[(3, 3)] = 0.4;
matrix
}
fn init_emission_params(n: usize) -> Vec<EmissionParams> {
let mut params = Vec::new();
for i in 0..n {
let regime = Regime::from_index(i).unwrap();
let characteristics = regime.characteristics();
let mean = DVector::from_vec(vec![
characteristics.risk_tolerance,
characteristics.cost_sensitivity,
characteristics.quality_focus,
characteristics.speed_priority,
characteristics.innovation_bias,
]);
let variance = match regime {
Regime::Exploration => 0.1, Regime::Exploitation => 0.02, Regime::Crisis => 0.05, Regime::Transition => 0.08, };
let covariance = DMatrix::from_diagonal(&DVector::from_element(5, variance));
params.push(EmissionParams { mean, covariance });
}
params
}
pub fn detect_regime(&self, observations: &[f64]) -> Result<Regime> {
if observations.len() < 5 {
return Err(anyhow!("Need at least 5 observations"));
}
let obs_vector = DVector::from_vec(observations.to_vec());
let mut likelihoods = Vec::new();
for (i, params) in self.emission_params.iter().enumerate() {
let likelihood = self.calculate_likelihood(&obs_vector, params)?;
likelihoods.push(likelihood * self.belief_state[i]);
}
let (max_idx, _) = likelihoods
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap();
Regime::from_index(max_idx)
}
fn calculate_likelihood(&self, obs: &DVector<f64>, params: &EmissionParams) -> Result<f64> {
let diff = obs - ¶ms.mean;
let inv_cov = params.covariance.clone().try_inverse()
.ok_or_else(|| anyhow!("Covariance matrix not invertible"))?;
let exponent = -0.5 * diff.transpose() * &inv_cov * &diff;
let det = params.covariance.determinant();
let norm = (2.0 * std::f64::consts::PI).powi(obs.len() as i32) * det.abs();
Ok((exponent[(0, 0)] / norm.sqrt()).exp())
}
pub fn update(&mut self, observations: Vec<f64>) -> Result<()> {
self.observation_history.push_back(observations.clone());
if self.observation_history.len() > self.max_history {
self.observation_history.pop_front();
}
let obs_vector = DVector::from_vec(observations);
let mut new_belief = DVector::zeros(self.belief_state.len());
for i in 0..self.belief_state.len() {
let mut prediction = 0.0;
for j in 0..self.belief_state.len() {
prediction += self.transition_matrix[(i, j)] * self.belief_state[j];
}
let likelihood = self.calculate_likelihood(&obs_vector, &self.emission_params[i])?;
new_belief[i] = likelihood * prediction;
}
let sum: f64 = new_belief.iter().sum();
if sum > 0.0 {
self.belief_state = new_belief / sum;
}
Ok(())
}
pub fn get_current_state(&self) -> Result<RegimeState> {
let current_regime = self.detect_regime(&self.get_average_observations())?;
Ok(RegimeState {
current_regime,
belief_distribution: self.belief_state.as_slice().to_vec(),
transition_matrix: self.transition_matrix.clone().data.as_vec()
.chunks(self.transition_matrix.ncols())
.map(|row| row.to_vec())
.collect(),
observation_count: self.observation_history.len(),
})
}
fn get_average_observations(&self) -> Vec<f64> {
if self.observation_history.is_empty() {
return vec![0.5; 5]; }
let n = self.observation_history.len() as f64;
let mut avg = vec![0.0; 5];
for obs in &self.observation_history {
for (i, &val) in obs.iter().enumerate() {
if i < 5 {
avg[i] += val / n;
}
}
}
avg
}
pub fn is_transition_likely(&self) -> bool {
let max_belief = self.belief_state.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
max_belief < &(1.0 - self.transition_threshold)
}
}
pub struct MarkovChain {
states: Vec<Regime>,
transition_matrix: DMatrix<f64>,
current_state: usize,
}
impl MarkovChain {
pub fn new(states: Vec<Regime>, transition_matrix: DMatrix<f64>) -> Result<Self> {
if states.len() != transition_matrix.nrows() || states.len() != transition_matrix.ncols() {
return Err(anyhow!("States and transition matrix dimensions must match"));
}
Ok(Self {
states,
transition_matrix,
current_state: 0,
})
}
pub fn next_state(&mut self) -> Regime {
let mut rng = rand::thread_rng();
let row = self.transition_matrix.row(self.current_state);
let r: f64 = rng.gen();
let mut cumsum = 0.0;
for (i, &prob) in row.iter().enumerate() {
cumsum += prob;
if r < cumsum {
self.current_state = i;
return self.states[i];
}
}
self.states[self.current_state]
}
pub fn stationary_distribution(&self) -> Result<DVector<f64>> {
let n = self.transition_matrix.nrows();
let mut a = self.transition_matrix.transpose() - DMatrix::identity(n, n);
for j in 0..n {
a[(n-1, j)] = 1.0;
}
let mut b = DVector::zeros(n);
b[n-1] = 1.0;
let decomp = a.lu();
decomp.solve(&b).ok_or_else(|| anyhow!("Cannot solve for stationary distribution"))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegimeState {
pub current_regime: Regime,
pub belief_distribution: Vec<f64>,
pub transition_matrix: Vec<Vec<f64>>,
pub observation_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_regime_detection() {
let detector = RegimeDetector::new(4, 0.15).unwrap();
let obs = vec![0.8, 0.3, 0.6, 0.4, 0.9];
let regime = detector.detect_regime(&obs).unwrap();
assert_eq!(regime, Regime::Exploration);
}
#[test]
fn test_markov_chain() {
let states = vec![
Regime::Exploration,
Regime::Exploitation,
Regime::Crisis,
Regime::Transition,
];
let matrix = RegimeDetector::init_transition_matrix(4);
let mut chain = MarkovChain::new(states, matrix).unwrap();
let next = chain.next_state();
assert!(matches!(next, Regime::Exploration | Regime::Exploitation | Regime::Crisis | Regime::Transition));
}
}