use std::collections::HashMap;
use std::sync::Arc;
use super::stats::LearnStats;
use super::stats_model::ScoreModel;
#[derive(Debug, Clone)]
pub enum LearningQuery<'a> {
Transition {
prev: &'a str,
action: &'a str,
target: Option<&'a str>,
},
Contextual {
prev: &'a str,
action: &'a str,
target: Option<&'a str>,
},
Ngram {
prev_prev: &'a str,
prev: &'a str,
action: &'a str,
target: Option<&'a str>,
},
Confidence {
action: &'a str,
target: Option<&'a str>,
prev: Option<&'a str>,
prev_prev: Option<&'a str>,
},
}
impl<'a> LearningQuery<'a> {
pub fn transition(prev: &'a str, action: &'a str, target: Option<&'a str>) -> Self {
Self::Transition {
prev,
action,
target,
}
}
pub fn contextual(prev: &'a str, action: &'a str, target: Option<&'a str>) -> Self {
Self::Contextual {
prev,
action,
target,
}
}
pub fn ngram(
prev_prev: &'a str,
prev: &'a str,
action: &'a str,
target: Option<&'a str>,
) -> Self {
Self::Ngram {
prev_prev,
prev,
action,
target,
}
}
pub fn confidence(action: &'a str, target: Option<&'a str>) -> Self {
Self::Confidence {
action,
target,
prev: None,
prev_prev: None,
}
}
pub fn confidence_with_context(
action: &'a str,
target: Option<&'a str>,
prev: Option<&'a str>,
prev_prev: Option<&'a str>,
) -> Self {
Self::Confidence {
action,
target,
prev,
prev_prev,
}
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub enum LearningResult {
Score(f64),
#[default]
NotAvailable,
}
impl LearningResult {
pub fn score_or(&self, default: f64) -> f64 {
match self {
Self::Score(v) => *v,
Self::NotAvailable => default,
}
}
pub fn score(&self) -> f64 {
self.score_or(0.0)
}
pub fn is_available(&self) -> bool {
matches!(self, Self::Score(_))
}
}
pub trait LearnedProvider: Send + Sync {
fn query(&self, q: LearningQuery<'_>) -> LearningResult;
fn stats(&self) -> Option<&LearnStats> {
None
}
fn model(&self) -> Option<&ScoreModel> {
None
}
}
pub type SharedLearnedProvider = Arc<dyn LearnedProvider>;
pub struct ScoreModelProvider {
model: ScoreModel,
stats: Option<LearnStats>,
}
impl ScoreModelProvider {
pub fn new(model: ScoreModel) -> Self {
Self { model, stats: None }
}
pub fn from_stats(stats: LearnStats) -> Self {
let model = ScoreModel::from_stats(&stats);
Self {
model,
stats: Some(stats),
}
}
pub fn inner(&self) -> &ScoreModel {
&self.model
}
pub fn update_model(&mut self, model: ScoreModel) {
self.model = model;
}
}
impl LearnedProvider for ScoreModelProvider {
fn query(&self, q: LearningQuery<'_>) -> LearningResult {
match q {
LearningQuery::Transition {
prev,
action,
target,
} => match self.model.transition(prev, action, target) {
Some(score) => LearningResult::Score(score),
None => LearningResult::NotAvailable,
},
LearningQuery::Contextual {
prev,
action,
target,
} => match self.model.contextual(prev, action, target) {
Some(score) => LearningResult::Score(score),
None => LearningResult::NotAvailable,
},
LearningQuery::Ngram {
prev_prev,
prev,
action,
target,
} => match self.model.ngram(prev_prev, prev, action, target) {
Some(score) => LearningResult::Score(score),
None => LearningResult::NotAvailable,
},
LearningQuery::Confidence {
action,
target,
prev,
prev_prev,
} => match self.model.confidence(action, target, prev, prev_prev) {
Some(score) => LearningResult::Score(score),
None => LearningResult::NotAvailable,
},
}
}
fn stats(&self) -> Option<&LearnStats> {
self.stats.as_ref()
}
fn model(&self) -> Option<&ScoreModel> {
Some(&self.model)
}
}
#[derive(Debug, Clone, Default)]
pub struct NullProvider;
impl LearnedProvider for NullProvider {
fn query(&self, _q: LearningQuery<'_>) -> LearningResult {
LearningResult::NotAvailable
}
}
#[derive(Debug, Clone, Default)]
pub struct ConfidenceMapProvider {
confidence: HashMap<String, f64>,
}
impl ConfidenceMapProvider {
pub fn new(confidence: HashMap<String, f64>) -> Self {
Self { confidence }
}
pub fn get(&self, action: &str) -> Option<f64> {
self.confidence.get(action).copied()
}
}
impl LearnedProvider for ConfidenceMapProvider {
fn query(&self, q: LearningQuery<'_>) -> LearningResult {
match q {
LearningQuery::Confidence { action, .. } => {
match self.get(action) {
Some(c) => {
LearningResult::Score(c - 0.5)
}
None => LearningResult::NotAvailable,
}
}
_ => LearningResult::NotAvailable,
}
}
}
pub struct LearnStatsProvider {
stats: LearnStats,
model: ScoreModel,
}
impl LearnStatsProvider {
pub fn new(stats: LearnStats) -> Self {
let model = ScoreModel::from_stats(&stats);
Self { stats, model }
}
pub fn stats(&self) -> &LearnStats {
&self.stats
}
pub fn model(&self) -> &ScoreModel {
&self.model
}
pub fn update_stats<F>(&mut self, f: F)
where
F: FnOnce(&mut LearnStats),
{
f(&mut self.stats);
self.model = ScoreModel::from_stats(&self.stats);
}
pub fn replace_stats(&mut self, stats: LearnStats) {
self.stats = stats;
self.model = ScoreModel::from_stats(&self.stats);
}
}
impl LearnedProvider for LearnStatsProvider {
fn query(&self, q: LearningQuery<'_>) -> LearningResult {
match q {
LearningQuery::Transition {
prev,
action,
target,
} => match self.model.transition(prev, action, target) {
Some(score) => LearningResult::Score(score),
None => LearningResult::NotAvailable,
},
LearningQuery::Contextual {
prev,
action,
target,
} => match self.model.contextual(prev, action, target) {
Some(score) => LearningResult::Score(score),
None => LearningResult::NotAvailable,
},
LearningQuery::Ngram {
prev_prev,
prev,
action,
target,
} => match self.model.ngram(prev_prev, prev, action, target) {
Some(score) => LearningResult::Score(score),
None => LearningResult::NotAvailable,
},
LearningQuery::Confidence {
action,
target,
prev,
prev_prev,
} => match self.model.confidence(action, target, prev, prev_prev) {
Some(score) => LearningResult::Score(score),
None => LearningResult::NotAvailable,
},
}
}
fn stats(&self) -> Option<&LearnStats> {
Some(&self.stats)
}
fn model(&self) -> Option<&ScoreModel> {
Some(&self.model)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_learning_result_score_or() {
assert_eq!(LearningResult::Score(0.5).score_or(0.0), 0.5);
assert_eq!(LearningResult::NotAvailable.score_or(0.0), 0.0);
assert_eq!(LearningResult::NotAvailable.score_or(-1.0), -1.0);
}
#[test]
fn test_learning_result_is_available() {
assert!(LearningResult::Score(0.5).is_available());
assert!(!LearningResult::NotAvailable.is_available());
}
#[test]
fn test_null_provider() {
let provider = NullProvider;
assert_eq!(
provider.query(LearningQuery::transition("A", "B", None)),
LearningResult::NotAvailable
);
assert_eq!(
provider.query(LearningQuery::contextual("A", "B", Some("svc1"))),
LearningResult::NotAvailable
);
assert_eq!(
provider.query(LearningQuery::ngram("A", "B", "C", None)),
LearningResult::NotAvailable
);
}
#[test]
fn test_confidence_map_provider() {
let mut map = HashMap::new();
map.insert("grep".to_string(), 0.8);
map.insert("restart".to_string(), 0.3);
let provider = ConfidenceMapProvider::new(map);
let result = provider.query(LearningQuery::confidence("grep", None));
let score = result.score();
assert!((score - 0.3).abs() < 1e-10, "expected ~0.3, got {}", score);
let result = provider.query(LearningQuery::confidence("restart", None));
let score = result.score();
assert!(
(score - (-0.2)).abs() < 1e-10,
"expected ~-0.2, got {}",
score
);
let result = provider.query(LearningQuery::confidence("unknown", None));
assert_eq!(result, LearningResult::NotAvailable);
let result = provider.query(LearningQuery::transition("A", "B", None));
assert_eq!(result, LearningResult::NotAvailable);
}
#[test]
fn test_learning_query_constructors() {
let q = LearningQuery::transition("A", "B", Some("svc1"));
assert!(matches!(
q,
LearningQuery::Transition {
prev: "A",
action: "B",
target: Some("svc1")
}
));
let q = LearningQuery::ngram("A", "B", "C", None);
assert!(matches!(
q,
LearningQuery::Ngram {
prev_prev: "A",
prev: "B",
action: "C",
target: None
}
));
let q =
LearningQuery::confidence_with_context("action", None, Some("prev"), Some("prev_prev"));
assert!(matches!(
q,
LearningQuery::Confidence {
action: "action",
prev: Some("prev"),
prev_prev: Some("prev_prev"),
..
}
));
}
#[test]
fn test_score_model_provider() {
use crate::learn::stats::{ContextualActionStats, LearnStats};
let mut stats = LearnStats::default();
stats
.episode_transitions
.success_transitions
.insert(("A".to_string(), "B".to_string()), 10);
stats
.episode_transitions
.failure_transitions
.insert(("A".to_string(), "B".to_string()), 2);
stats.contextual_stats.insert(
("A".to_string(), "B".to_string()),
ContextualActionStats {
visits: 12,
successes: 10,
failures: 2,
},
);
stats
.ngram_stats
.trigrams
.insert(("X".to_string(), "A".to_string(), "B".to_string()), (9, 1));
let provider = ScoreModelProvider::from_stats(stats);
let result = provider.query(LearningQuery::transition("A", "B", None));
assert!(result.is_available());
let result = provider.query(LearningQuery::contextual("A", "B", None));
assert!(result.is_available());
assert!(result.score() > 0.0, "成功率が高いので正のスコア");
let result = provider.query(LearningQuery::ngram("X", "A", "B", None));
assert!(result.is_available());
let result = provider.query(LearningQuery::confidence_with_context(
"B",
None,
Some("A"),
Some("X"),
));
assert!(result.is_available());
}
}