use std::collections::HashMap;
use serde::{Deserialize, Serialize};
type Tuple3Map<V> = HashMap<(String, String, String), V>;
type Tuple4Map<V> = HashMap<(String, String, String, String), V>;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LearnStats {
pub episode_transitions: EpisodeTransitions,
pub ngram_stats: NgramStats,
pub selection_performance: SelectionPerformance,
#[serde(
serialize_with = "serialize_tuple2_map",
deserialize_with = "deserialize_tuple2_map"
)]
pub contextual_stats: HashMap<(String, String), ContextualActionStats>,
}
impl LearnStats {
pub fn load_prior(&mut self, snapshot: &crate::learn::LearningSnapshot) {
self.episode_transitions = snapshot.episode_transitions.clone();
self.ngram_stats = snapshot.ngram_stats.clone();
self.selection_performance = snapshot.selection_performance.clone();
for ((prev, action), stats) in &snapshot.contextual_stats {
let contextual = ContextualActionStats {
visits: stats.visits,
successes: stats.successes,
failures: stats.failures,
};
self.contextual_stats
.insert((prev.clone(), action.clone()), contextual);
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ContextualActionStats {
pub visits: u32,
pub successes: u32,
pub failures: u32,
}
impl ContextualActionStats {
pub fn success_rate(&self) -> f64 {
if self.visits == 0 {
0.5
} else {
self.successes as f64 / self.visits as f64
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EpisodeTransitions {
#[serde(
serialize_with = "serialize_tuple2_map",
deserialize_with = "deserialize_tuple2_map"
)]
pub success_transitions: HashMap<(String, String), u32>,
#[serde(
serialize_with = "serialize_tuple2_map",
deserialize_with = "deserialize_tuple2_map"
)]
pub failure_transitions: HashMap<(String, String), u32>,
pub success_episodes: u32,
pub failure_episodes: u32,
}
impl EpisodeTransitions {
pub fn success_transition_rate(&self, from: &str, to: &str) -> f64 {
let key = (from.to_string(), to.to_string());
let success_count = self.success_transitions.get(&key).copied().unwrap_or(0);
let failure_count = self.failure_transitions.get(&key).copied().unwrap_or(0);
let total = success_count + failure_count;
if total == 0 {
0.5
} else {
success_count as f64 / total as f64
}
}
pub fn transition_value(&self, from: &str, to: &str) -> f64 {
let key = (from.to_string(), to.to_string());
let success_count = self.success_transitions.get(&key).copied().unwrap_or(0) as f64;
let failure_count = self.failure_transitions.get(&key).copied().unwrap_or(0) as f64;
let total_success = self.success_transitions.values().sum::<u32>() as f64;
let total_failure = self.failure_transitions.values().sum::<u32>() as f64;
let success_rate = if total_success > 0.0 {
success_count / total_success
} else {
0.0
};
let failure_rate = if total_failure > 0.0 {
failure_count / total_failure
} else {
0.0
};
success_rate - failure_rate
}
pub fn recommended_next_actions(&self, from: &str) -> Vec<(String, f64)> {
let mut candidates: Vec<_> = self
.success_transitions
.iter()
.filter(|((f, _), _)| f == from)
.map(|((_, to), _)| {
let value = self.transition_value(from, to);
(to.clone(), value)
})
.collect();
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
candidates
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct NgramStats {
#[serde(
serialize_with = "serialize_tuple3_map",
deserialize_with = "deserialize_tuple3_map"
)]
pub trigrams: HashMap<(String, String, String), (u32, u32)>,
#[serde(
serialize_with = "serialize_tuple4_map",
deserialize_with = "deserialize_tuple4_map"
)]
pub quadgrams: HashMap<(String, String, String, String), (u32, u32)>,
}
impl NgramStats {
pub fn trigram_success_rate(&self, a1: &str, a2: &str, a3: &str) -> f64 {
let key = (a1.to_string(), a2.to_string(), a3.to_string());
match self.trigrams.get(&key) {
Some(&(success, failure)) => {
let total = success + failure;
if total == 0 {
0.5
} else {
success as f64 / total as f64
}
}
None => 0.5,
}
}
pub fn quadgram_success_rate(&self, a1: &str, a2: &str, a3: &str, a4: &str) -> f64 {
let key = (
a1.to_string(),
a2.to_string(),
a3.to_string(),
a4.to_string(),
);
match self.quadgrams.get(&key) {
Some(&(success, failure)) => {
let total = success + failure;
if total == 0 {
0.5
} else {
success as f64 / total as f64
}
}
None => 0.5,
}
}
pub fn trigram_value(&self, a1: &str, a2: &str, a3: &str) -> f64 {
let key = (a1.to_string(), a2.to_string(), a3.to_string());
match self.trigrams.get(&key) {
Some(&(success, failure)) => {
let total = success + failure;
if total == 0 {
0.0
} else {
(success as f64 / total as f64) * 2.0 - 1.0
}
}
None => 0.0,
}
}
pub fn recommended_after(&self, a1: &str, a2: &str) -> Vec<(String, f64)> {
let mut candidates: Vec<_> = self
.trigrams
.iter()
.filter(|((x1, x2, _), _)| x1 == a1 && x2 == a2)
.map(|((_, _, a3), &(success, failure))| {
let total = success + failure;
let score = if total == 0 {
0.0
} else {
(success as f64 / total as f64) * 2.0 - 1.0
};
(a3.clone(), score)
})
.collect();
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
candidates
}
pub fn recommended_after_three(&self, a1: &str, a2: &str, a3: &str) -> Vec<(String, f64)> {
let mut candidates: Vec<_> = self
.quadgrams
.iter()
.filter(|((x1, x2, x3, _), _)| x1 == a1 && x2 == a2 && x3 == a3)
.map(|((_, _, _, a4), &(success, failure))| {
let total = success + failure;
let score = if total == 0 {
0.0
} else {
(success as f64 / total as f64) * 2.0 - 1.0
};
(a4.clone(), score)
})
.collect();
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
candidates
}
pub fn trigram_count(&self) -> usize {
self.trigrams.len()
}
pub fn quadgram_count(&self) -> usize {
self.quadgrams.len()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SelectionPerformance {
pub strategy_stats: HashMap<String, StrategyStats>,
pub switch_history: Vec<StrategySwitchEvent>,
pub current_strategy: Option<String>,
pub strategy_start_visits: u32,
pub strategy_start_success_rate: f64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct StrategyStats {
pub visits: u32,
pub successes: u32,
pub failures: u32,
pub usage_count: u32,
pub episodes_success: u32,
pub episodes_failure: u32,
}
impl StrategyStats {
pub fn success_rate(&self) -> f64 {
if self.visits == 0 {
0.5
} else {
self.successes as f64 / self.visits as f64
}
}
pub fn episode_success_rate(&self) -> f64 {
let total = self.episodes_success + self.episodes_failure;
if total == 0 {
0.5
} else {
self.episodes_success as f64 / total as f64
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StrategySwitchEvent {
pub from: String,
pub to: String,
pub visits_at_switch: u32,
pub success_rate_at_switch: f64,
pub from_strategy_success_rate: f64,
}
impl SelectionPerformance {
pub fn start_strategy(
&mut self,
strategy: &str,
current_visits: u32,
current_success_rate: f64,
) {
if let Some(ref current) = self.current_strategy {
if current != strategy {
let from_stats = self
.strategy_stats
.get(current)
.cloned()
.unwrap_or_default();
self.switch_history.push(StrategySwitchEvent {
from: current.clone(),
to: strategy.to_string(),
visits_at_switch: current_visits,
success_rate_at_switch: current_success_rate,
from_strategy_success_rate: from_stats.success_rate(),
});
}
}
self.current_strategy = Some(strategy.to_string());
self.strategy_start_visits = current_visits;
self.strategy_start_success_rate = current_success_rate;
self.strategy_stats
.entry(strategy.to_string())
.or_default()
.usage_count += 1;
}
pub fn record_action(&mut self, success: bool) {
if let Some(ref strategy) = self.current_strategy {
let stats = self.strategy_stats.entry(strategy.clone()).or_default();
stats.visits += 1;
if success {
stats.successes += 1;
} else {
stats.failures += 1;
}
}
}
pub fn record_episode_end(&mut self, success: bool) {
if let Some(ref strategy) = self.current_strategy {
let stats = self.strategy_stats.entry(strategy.clone()).or_default();
if success {
stats.episodes_success += 1;
} else {
stats.episodes_failure += 1;
}
}
}
pub fn strategy_effectiveness(&self, strategy: &str) -> Option<f64> {
self.strategy_stats.get(strategy).map(|s| s.success_rate())
}
pub fn best_strategy(&self) -> Option<(&str, f64)> {
self.strategy_stats
.iter()
.filter(|(_, stats)| stats.visits >= 10)
.max_by(|(_, a), (_, b)| {
a.success_rate()
.partial_cmp(&b.success_rate())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(name, stats)| (name.as_str(), stats.success_rate()))
}
pub fn recommended_strategy(&self, failure_rate: f64, visits: u32) -> &str {
let ucb1_score = self.strategy_score_for_context("UCB1", failure_rate, visits);
let greedy_score = self.strategy_score_for_context("Greedy", failure_rate, visits);
let thompson_score = self.strategy_score_for_context("Thompson", failure_rate, visits);
if ucb1_score >= greedy_score && ucb1_score >= thompson_score {
"UCB1"
} else if greedy_score >= thompson_score {
"Greedy"
} else {
"Thompson"
}
}
fn strategy_score_for_context(&self, strategy: &str, failure_rate: f64, _visits: u32) -> f64 {
let base_score = self
.strategy_stats
.get(strategy)
.map(|s| s.success_rate())
.unwrap_or(0.5);
match strategy {
"UCB1" => base_score + failure_rate * 0.2,
"Greedy" => base_score + (1.0 - failure_rate) * 0.2,
"Thompson" => {
let distance_from_middle = (failure_rate - 0.5).abs();
base_score + (0.5 - distance_from_middle) * 0.2
}
_ => base_score,
}
}
}
fn serialize_tuple2_map<V, S>(
map: &HashMap<(String, String), V>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
V: Serialize,
S: serde::Serializer,
{
use serde::ser::SerializeMap;
let mut ser_map = serializer.serialize_map(Some(map.len()))?;
for ((a, b), v) in map {
ser_map.serialize_entry(&format!("{}:{}", a, b), v)?;
}
ser_map.end()
}
fn deserialize_tuple2_map<'de, V, D>(
deserializer: D,
) -> Result<HashMap<(String, String), V>, D::Error>
where
V: Deserialize<'de>,
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
let mut result = HashMap::new();
for (k, v) in string_map {
let parts: Vec<&str> = k.splitn(2, ':').collect();
if parts.len() != 2 {
return Err(D::Error::custom(format!("invalid tuple2 key: {}", k)));
}
result.insert((parts[0].to_string(), parts[1].to_string()), v);
}
Ok(result)
}
fn serialize_tuple3_map<V, S>(
map: &HashMap<(String, String, String), V>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
V: Serialize,
S: serde::Serializer,
{
use serde::ser::SerializeMap;
let mut ser_map = serializer.serialize_map(Some(map.len()))?;
for ((a, b, c), v) in map {
ser_map.serialize_entry(&format!("{}:{}:{}", a, b, c), v)?;
}
ser_map.end()
}
fn deserialize_tuple3_map<'de, V, D>(deserializer: D) -> Result<Tuple3Map<V>, D::Error>
where
V: Deserialize<'de>,
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
let mut result = HashMap::new();
for (k, v) in string_map {
let parts: Vec<&str> = k.splitn(3, ':').collect();
if parts.len() != 3 {
return Err(D::Error::custom(format!("invalid tuple3 key: {}", k)));
}
result.insert(
(
parts[0].to_string(),
parts[1].to_string(),
parts[2].to_string(),
),
v,
);
}
Ok(result)
}
fn serialize_tuple4_map<V, S>(
map: &HashMap<(String, String, String, String), V>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
V: Serialize,
S: serde::Serializer,
{
use serde::ser::SerializeMap;
let mut ser_map = serializer.serialize_map(Some(map.len()))?;
for ((a, b, c, d), v) in map {
ser_map.serialize_entry(&format!("{}:{}:{}:{}", a, b, c, d), v)?;
}
ser_map.end()
}
fn deserialize_tuple4_map<'de, V, D>(deserializer: D) -> Result<Tuple4Map<V>, D::Error>
where
V: Deserialize<'de>,
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
let mut result = HashMap::new();
for (k, v) in string_map {
let parts: Vec<&str> = k.splitn(4, ':').collect();
if parts.len() != 4 {
return Err(D::Error::custom(format!("invalid tuple4 key: {}", k)));
}
result.insert(
(
parts[0].to_string(),
parts[1].to_string(),
parts[2].to_string(),
parts[3].to_string(),
),
v,
);
}
Ok(result)
}