use crate::action::Action;
use crate::observation::Observation;
use crate::types::{Confidence, Timestamp};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct StateId(String);
impl StateId {
pub fn from_observation(obs: &Observation) -> Self {
let state_str = format!("{:?}_{}", obs.obs_type, obs.value.as_string());
Self(state_str)
}
pub fn from_string(s: String) -> Self {
Self(s)
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ActionId(String);
impl ActionId {
pub fn from_action(action: &Action) -> Self {
Self(format!("{:?}", action.action_type))
}
pub fn from_string(s: String) -> Self {
Self(s)
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct StateActionPair {
pub state: StateId,
pub action: ActionId,
}
impl StateActionPair {
pub fn new(state: StateId, action: ActionId) -> Self {
Self { state, action }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QValue {
pub mean: f64,
pub variance: f64,
pub update_count: u64,
pub last_updated: Timestamp,
}
impl QValue {
pub fn new(initial_value: f64) -> Self {
Self {
mean: initial_value,
variance: 0.0,
update_count: 0,
last_updated: Timestamp::now(),
}
}
pub fn update(&mut self, new_value: f64, learning_rate: f64) {
let old_mean = self.mean;
self.mean += learning_rate * (new_value - self.mean);
if self.update_count > 0 {
let delta = new_value - old_mean;
let delta2 = new_value - self.mean;
self.variance =
self.variance + (delta * delta2 - self.variance) / (self.update_count as f64);
}
self.update_count += 1;
self.last_updated = Timestamp::now();
}
pub fn confidence(&self) -> Confidence {
if self.update_count == 0 {
return Confidence::new(0.0);
}
let count_factor = (self.update_count as f32).min(100.0) / 100.0;
let variance_factor = 1.0 / (1.0 + self.variance as f32);
Confidence::new(count_factor * variance_factor)
}
}
impl Default for QValue {
fn default() -> Self {
Self::new(0.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum LearningAlgorithm {
#[default]
QLearning,
SARSA,
ExpectedSARSA,
TemporalDifference,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Experience {
pub state: StateId,
pub action: ActionId,
pub reward: f64,
pub next_state: StateId,
pub next_action: Option<ActionId>,
pub done: bool,
pub timestamp: Timestamp,
}
impl Experience {
pub fn new(
state: StateId,
action: ActionId,
reward: f64,
next_state: StateId,
done: bool,
) -> Self {
Self {
state,
action,
reward,
next_state,
next_action: None,
done,
timestamp: Timestamp::now(),
}
}
pub fn with_next_action(mut self, next_action: ActionId) -> Self {
self.next_action = Some(next_action);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningConfig {
pub learning_rate: f64,
pub discount_factor: f64,
pub algorithm: LearningAlgorithm,
pub initial_q_value: f64,
pub replay_buffer_size: usize,
pub min_replay_size: usize,
pub epsilon: f64,
pub epsilon_decay: f64,
pub epsilon_min: f64,
}
impl Default for LearningConfig {
fn default() -> Self {
Self {
learning_rate: 0.1,
discount_factor: 0.99,
algorithm: LearningAlgorithm::QLearning,
initial_q_value: 0.0,
replay_buffer_size: 10000,
min_replay_size: 100,
epsilon: 0.1,
epsilon_decay: 0.995,
epsilon_min: 0.01,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningEngine {
q_values: HashMap<StateActionPair, QValue>,
config: LearningConfig,
replay_buffer: VecDeque<Experience>,
total_updates: u64,
total_episodes: u64,
}
impl LearningEngine {
pub fn new(config: LearningConfig) -> Self {
Self {
q_values: HashMap::new(),
config,
replay_buffer: VecDeque::new(),
total_updates: 0,
total_episodes: 0,
}
}
pub fn default_config() -> Self {
Self::new(LearningConfig::default())
}
pub fn get_q_value(&self, state: &StateId, action: &ActionId) -> f64 {
let pair = StateActionPair::new(state.clone(), action.clone());
self.q_values
.get(&pair)
.map(|qv| qv.mean)
.unwrap_or(self.config.initial_q_value)
}
pub fn get_q_value_stats(&self, state: &StateId, action: &ActionId) -> QValue {
let pair = StateActionPair::new(state.clone(), action.clone());
self.q_values
.get(&pair)
.cloned()
.unwrap_or_else(|| QValue::new(self.config.initial_q_value))
}
fn set_q_value(&mut self, state: &StateId, action: &ActionId, new_q: f64) {
let pair = StateActionPair::new(state.clone(), action.clone());
let qvalue = self
.q_values
.entry(pair)
.or_insert_with(|| QValue::new(self.config.initial_q_value));
qvalue.update(new_q, self.config.learning_rate);
self.total_updates += 1;
}
fn get_max_q_value(&self, state: &StateId, available_actions: &[ActionId]) -> f64 {
if available_actions.is_empty() {
return self.config.initial_q_value;
}
available_actions
.iter()
.map(|action| self.get_q_value(state, action))
.fold(f64::NEG_INFINITY, f64::max)
}
fn get_avg_q_value(&self, state: &StateId, available_actions: &[ActionId]) -> f64 {
if available_actions.is_empty() {
return self.config.initial_q_value;
}
let sum: f64 = available_actions
.iter()
.map(|action| self.get_q_value(state, action))
.sum();
sum / (available_actions.len() as f64)
}
pub fn update_q_learning(
&mut self,
state: &StateId,
action: &ActionId,
reward: f64,
next_state: &StateId,
available_actions: &[ActionId],
) {
let current_q = self.get_q_value(state, action);
let max_next_q = self.get_max_q_value(next_state, available_actions);
let td_target = reward + self.config.discount_factor * max_next_q;
let new_q = current_q + self.config.learning_rate * (td_target - current_q);
self.set_q_value(state, action, new_q);
}
pub fn update_sarsa(
&mut self,
state: &StateId,
action: &ActionId,
reward: f64,
next_state: &StateId,
next_action: &ActionId,
) {
let current_q = self.get_q_value(state, action);
let next_q = self.get_q_value(next_state, next_action);
let td_target = reward + self.config.discount_factor * next_q;
let new_q = current_q + self.config.learning_rate * (td_target - current_q);
self.set_q_value(state, action, new_q);
}
pub fn update_expected_sarsa(
&mut self,
state: &StateId,
action: &ActionId,
reward: f64,
next_state: &StateId,
available_actions: &[ActionId],
) {
let current_q = self.get_q_value(state, action);
let expected_next_q = self.get_avg_q_value(next_state, available_actions);
let td_target = reward + self.config.discount_factor * expected_next_q;
let new_q = current_q + self.config.learning_rate * (td_target - current_q);
self.set_q_value(state, action, new_q);
}
pub fn update_td(
&mut self,
state: &StateId,
action: &ActionId,
reward: f64,
next_state: &StateId,
available_actions: &[ActionId],
) {
let current_q = self.get_q_value(state, action);
let next_value = self.get_avg_q_value(next_state, available_actions);
let td_target = reward + self.config.discount_factor * next_value;
let new_q = current_q + self.config.learning_rate * (td_target - current_q);
self.set_q_value(state, action, new_q);
}
pub fn update(
&mut self,
state: &StateId,
action: &ActionId,
reward: f64,
next_state: &StateId,
next_action: Option<&ActionId>,
available_actions: &[ActionId],
) {
match self.config.algorithm {
LearningAlgorithm::QLearning => {
self.update_q_learning(state, action, reward, next_state, available_actions);
}
LearningAlgorithm::SARSA => {
if let Some(next_act) = next_action {
self.update_sarsa(state, action, reward, next_state, next_act);
} else {
self.update_q_learning(state, action, reward, next_state, available_actions);
}
}
LearningAlgorithm::ExpectedSARSA => {
self.update_expected_sarsa(state, action, reward, next_state, available_actions);
}
LearningAlgorithm::TemporalDifference => {
self.update_td(state, action, reward, next_state, available_actions);
}
}
}
pub fn get_best_action(
&self,
state: &StateId,
available_actions: &[ActionId],
) -> Option<ActionId> {
if available_actions.is_empty() {
return None;
}
available_actions
.iter()
.max_by(|a, b| {
let qa = self.get_q_value(state, a);
let qb = self.get_q_value(state, b);
qa.partial_cmp(&qb).unwrap_or(std::cmp::Ordering::Equal)
})
.cloned()
}
pub fn get_action_epsilon_greedy(
&self,
state: &StateId,
available_actions: &[ActionId],
) -> Option<ActionId> {
use rand::Rng;
if available_actions.is_empty() {
return None;
}
let mut rng = rand::rng();
if rng.random::<f64>() < self.config.epsilon {
let idx = rng.random_range(0..available_actions.len());
Some(available_actions[idx].clone())
} else {
self.get_best_action(state, available_actions)
}
}
pub fn decay_epsilon(&mut self) {
self.config.epsilon =
(self.config.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
}
pub fn add_experience(&mut self, experience: Experience) {
if self.replay_buffer.len() >= self.config.replay_buffer_size {
self.replay_buffer.pop_front();
}
self.replay_buffer.push_back(experience);
}
pub fn replay_batch(&mut self, batch_size: usize, available_actions: &[ActionId]) {
use rand::seq::IndexedRandom;
if self.replay_buffer.len() < self.config.min_replay_size {
return;
}
let mut rng = rand::rng();
let experiences: Vec<Experience> = self.replay_buffer.iter().cloned().collect();
let sample_size = batch_size.min(experiences.len());
let batch: Vec<&Experience> = experiences.choose_multiple(&mut rng, sample_size).collect();
for exp in batch {
self.update(
&exp.state,
&exp.action,
exp.reward,
&exp.next_state,
exp.next_action.as_ref(),
available_actions,
);
}
}
pub fn end_episode(&mut self) {
self.total_episodes += 1;
self.decay_epsilon();
}
pub fn total_updates(&self) -> u64 {
self.total_updates
}
pub fn total_episodes(&self) -> u64 {
self.total_episodes
}
pub fn state_action_count(&self) -> usize {
self.q_values.len()
}
pub fn epsilon(&self) -> f64 {
self.config.epsilon
}
pub fn reset(&mut self) {
self.q_values.clear();
self.replay_buffer.clear();
self.total_updates = 0;
self.total_episodes = 0;
}
pub fn config(&self) -> &LearningConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut LearningConfig {
&mut self.config
}
pub fn get_all_q_values(&self) -> &HashMap<StateActionPair, QValue> {
&self.q_values
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_state_id(name: &str) -> StateId {
StateId::from_string(name.to_string())
}
fn create_action_id(name: &str) -> ActionId {
ActionId::from_string(name.to_string())
}
#[test]
fn test_learning_engine_creation() {
let engine = LearningEngine::default_config();
assert_eq!(engine.total_updates(), 0);
assert_eq!(engine.state_action_count(), 0);
}
#[test]
fn test_q_value_update() {
let mut qvalue = QValue::new(0.0);
assert_eq!(qvalue.mean, 0.0);
assert_eq!(qvalue.update_count, 0);
qvalue.update(1.0, 0.1);
assert!(qvalue.mean > 0.0);
assert_eq!(qvalue.update_count, 1);
}
#[test]
fn test_q_learning_update() {
let mut engine = LearningEngine::default_config();
let s0 = create_state_id("state0");
let a0 = create_action_id("action0");
let s1 = create_state_id("state1");
let actions = vec![a0.clone()];
assert_eq!(engine.get_q_value(&s0, &a0), 0.0);
engine.update_q_learning(&s0, &a0, 1.0, &s1, &actions);
assert!(engine.get_q_value(&s0, &a0) > 0.0);
assert_eq!(engine.total_updates(), 1);
}
#[test]
fn test_sarsa_update() {
let mut engine = LearningEngine::default_config();
let s0 = create_state_id("state0");
let a0 = create_action_id("action0");
let s1 = create_state_id("state1");
let a1 = create_action_id("action1");
engine.update_sarsa(&s0, &a0, 1.0, &s1, &a1);
assert!(engine.get_q_value(&s0, &a0) > 0.0);
}
#[test]
fn test_get_best_action() {
let mut engine = LearningEngine::default_config();
let state = create_state_id("state");
let action1 = create_action_id("action1");
let action2 = create_action_id("action2");
let actions = vec![action1.clone(), action2.clone()];
engine.set_q_value(&state, &action1, 0.5);
engine.set_q_value(&state, &action2, 1.0);
let best = engine.get_best_action(&state, &actions);
assert_eq!(best.unwrap(), action2);
}
#[test]
fn test_experience_replay() {
let mut engine = LearningEngine::default_config();
let s0 = create_state_id("state0");
let a0 = create_action_id("action0");
let s1 = create_state_id("state1");
let actions = vec![a0.clone()];
for _ in 0..10 {
let exp = Experience::new(s0.clone(), a0.clone(), 1.0, s1.clone(), false);
engine.add_experience(exp);
}
assert_eq!(engine.replay_buffer.len(), 10);
engine.config.min_replay_size = 100;
engine.replay_batch(5, &actions);
assert_eq!(engine.total_updates(), 0);
engine.config.min_replay_size = 5;
engine.replay_batch(5, &actions);
assert!(engine.total_updates() > 0);
}
#[test]
fn test_epsilon_decay() {
let mut config = LearningConfig::default();
config.epsilon = 1.0;
config.epsilon_decay = 0.9;
config.epsilon_min = 0.1;
let mut engine = LearningEngine::new(config);
let initial_epsilon = engine.epsilon();
engine.decay_epsilon();
assert!(engine.epsilon() < initial_epsilon);
assert!(engine.epsilon() >= 0.1);
}
#[test]
fn test_episode_management() {
let mut engine = LearningEngine::default_config();
assert_eq!(engine.total_episodes(), 0);
engine.end_episode();
assert_eq!(engine.total_episodes(), 1);
}
#[test]
fn test_state_action_id_from_types() {
let obs = Observation::sensor("temperature", 25.0);
let state_id = StateId::from_observation(&obs);
assert!(state_id.as_str().contains("temperature"));
let action = Action::alert("test");
let action_id = ActionId::from_action(&action);
assert!(action_id.as_str().contains("Alert"));
}
}