use std::any::Any;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::base::{Model, ModelMetadata, ModelType, ModelVersion};
use crate::learn::stats::LearnStats;
use crate::util::epoch_millis;
pub trait Scorable: Model {
fn score(&self, query: &ScoreQuery) -> Option<f64>;
fn score_batch(&self, queries: &[ScoreQuery]) -> Vec<Option<f64>> {
queries.iter().map(|q| self.score(q)).collect()
}
}
#[derive(Debug, Clone)]
pub enum ScoreQuery {
Transition {
prev_action: String,
action: String,
target: Option<String>,
},
Contextual {
prev_action: String,
action: String,
target: Option<String>,
},
Ngram {
actions: Vec<String>,
target: Option<String>,
},
Confidence {
action: String,
target: Option<String>,
context: ScoreContext,
},
}
impl ScoreQuery {
pub fn transition(prev: &str, action: &str, target: Option<&str>) -> Self {
Self::Transition {
prev_action: prev.to_string(),
action: action.to_string(),
target: target.map(String::from),
}
}
pub fn contextual(prev: &str, action: &str, target: Option<&str>) -> Self {
Self::Contextual {
prev_action: prev.to_string(),
action: action.to_string(),
target: target.map(String::from),
}
}
pub fn ngram(actions: Vec<String>, target: Option<&str>) -> Self {
Self::Ngram {
actions,
target: target.map(String::from),
}
}
pub fn confidence(action: &str, target: Option<&str>, context: ScoreContext) -> Self {
Self::Confidence {
action: action.to_string(),
target: target.map(String::from),
context,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ScoreContext {
pub prev_action: Option<String>,
pub prev_prev_action: Option<String>,
pub additional: HashMap<String, String>,
}
impl ScoreContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_prev(mut self, prev: &str) -> Self {
self.prev_action = Some(prev.to_string());
self
}
pub fn with_prev_prev(mut self, prev_prev: &str) -> Self {
self.prev_prev_action = Some(prev_prev.to_string());
self
}
pub fn with_additional(mut self, key: &str, value: &str) -> Self {
self.additional.insert(key.to_string(), value.to_string());
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ScoreModel {
version: ModelVersion,
metadata: ModelMetadata,
created_at: u64,
pub action_scores: HashMap<String, f64>,
pub transition_scores: HashMap<String, f64>,
pub contextual_scores: HashMap<String, f64>,
pub ngram_scores: HashMap<String, f64>,
}
impl ScoreModel {
pub fn new() -> Self {
Self {
created_at: epoch_millis(),
..Default::default()
}
}
pub fn from_stats(stats: &LearnStats) -> Self {
let mut model = Self::new();
model.compute_from_stats(stats);
model
}
pub fn with_version(mut self, version: ModelVersion) -> Self {
self.version = version;
self
}
pub fn with_metadata(mut self, metadata: ModelMetadata) -> Self {
self.metadata = metadata;
self
}
pub fn compute_from_stats(&mut self, stats: &LearnStats) {
self.compute_transition_scores(stats);
self.compute_contextual_scores(stats);
self.compute_ngram_scores(stats);
self.compute_action_scores();
}
fn compute_transition_scores(&mut self, stats: &LearnStats) {
let transitions = &stats.episode_transitions;
let mut all_keys: std::collections::HashSet<(String, String)> =
std::collections::HashSet::new();
for key in transitions.success_transitions.keys() {
all_keys.insert(key.clone());
}
for key in transitions.failure_transitions.keys() {
all_keys.insert(key.clone());
}
for (prev, action) in all_keys {
let score = transitions.transition_value(&prev, &action);
self.transition_scores
.insert(Self::key2(&prev, &action), score);
}
}
fn compute_contextual_scores(&mut self, stats: &LearnStats) {
for ((prev, action), ctx_stats) in &stats.contextual_stats {
if ctx_stats.visits > 0 {
let score = ctx_stats.success_rate() - 0.5;
self.contextual_scores
.insert(Self::key2(prev, action), score);
}
}
}
fn compute_ngram_scores(&mut self, stats: &LearnStats) {
for ((a1, a2, a3), &(success, failure)) in &stats.ngram_stats.trigrams {
let total = success + failure;
if total >= 2 {
let rate = success as f64 / total as f64;
let score = (rate - 0.5) * 2.0;
self.ngram_scores.insert(Self::key3(a1, a2, a3), score);
}
}
}
fn compute_action_scores(&mut self) {
let mut actions: std::collections::HashSet<String> = std::collections::HashSet::new();
for key in self.transition_scores.keys() {
if let Some(action) = Self::action_from_key2(key) {
actions.insert(action.to_string());
}
}
for key in self.contextual_scores.keys() {
if let Some(action) = Self::action_from_key2(key) {
actions.insert(action.to_string());
}
}
for key in self.ngram_scores.keys() {
if let Some(action) = Self::action_from_key3(key) {
actions.insert(action.to_string());
}
}
for action in actions {
let score = self.compute_action_aggregate_score(&action);
self.action_scores.insert(action, score);
}
}
fn compute_action_aggregate_score(&self, action: &str) -> f64 {
let mut score = 0.0;
let mut count = 0;
let transition_scores: Vec<f64> = self
.transition_scores
.iter()
.filter(|(key, _)| Self::action_from_key2(key) == Some(action))
.map(|(_, &s)| s)
.collect();
if !transition_scores.is_empty() {
let avg = transition_scores.iter().sum::<f64>() / transition_scores.len() as f64;
score += avg * 0.4;
count += 1;
}
let contextual_scores: Vec<f64> = self
.contextual_scores
.iter()
.filter(|(key, _)| Self::action_from_key2(key) == Some(action))
.map(|(_, &s)| s)
.collect();
if !contextual_scores.is_empty() {
let avg = contextual_scores.iter().sum::<f64>() / contextual_scores.len() as f64;
score += avg * 0.4;
count += 1;
}
let ngram_scores: Vec<f64> = self
.ngram_scores
.iter()
.filter(|(key, _)| Self::action_from_key3(key) == Some(action))
.map(|(_, &s)| s)
.collect();
if !ngram_scores.is_empty() {
let avg = ngram_scores.iter().sum::<f64>() / ngram_scores.len() as f64;
score += avg * 0.2;
count += 1;
}
if count > 0 {
score
} else {
0.0
}
}
pub fn compute_confidence(&self, action: &str, context: &ScoreContext) -> Option<f64> {
if let Some(ref prev) = context.prev_action {
let mut score = 0.0;
let mut has_data = false;
if let Some(transition_score) = self.get_transition(prev, action, None) {
score += transition_score * 0.4;
has_data = true;
}
if let Some(contextual_score) = self.get_contextual(prev, action, None) {
score += contextual_score * 0.3;
has_data = true;
}
if let Some(ref prev_prev) = context.prev_prev_action {
let key = Self::key3(prev_prev, prev, action);
if let Some(ngram_score) = self.ngram_scores.get(&key) {
score += ngram_score * 0.3;
has_data = true;
}
}
if has_data {
Some(score)
} else {
None
}
} else {
self.action_scores.get(action).copied()
}
}
fn action_key(action: &str, target: Option<&str>) -> String {
match target {
Some(t) => format!("{}@{}", action, t),
None => action.to_string(),
}
}
fn key2(prev: &str, action: &str) -> String {
format!("{}->{}", prev, action)
}
fn key3(prev_prev: &str, prev: &str, action: &str) -> String {
format!("{}->{}->{}", prev_prev, prev, action)
}
fn action_from_key2(key: &str) -> Option<&str> {
key.split("->").nth(1)
}
fn action_from_key3(key: &str) -> Option<&str> {
key.split("->").nth(2)
}
fn get_transition(&self, prev: &str, action: &str, target: Option<&str>) -> Option<f64> {
if let Some(t) = target {
let key = Self::key2(prev, &Self::action_key(action, Some(t)));
if let Some(&score) = self.transition_scores.get(&key) {
return Some(score);
}
}
let key = Self::key2(prev, action);
self.transition_scores.get(&key).copied()
}
fn get_contextual(&self, prev: &str, action: &str, target: Option<&str>) -> Option<f64> {
if let Some(t) = target {
let key = Self::key2(prev, &Self::action_key(action, Some(t)));
if let Some(&score) = self.contextual_scores.get(&key) {
return Some(score);
}
}
let key = Self::key2(prev, action);
self.contextual_scores.get(&key).copied()
}
pub fn is_empty(&self) -> bool {
self.action_scores.is_empty()
&& self.transition_scores.is_empty()
&& self.contextual_scores.is_empty()
&& self.ngram_scores.is_empty()
}
pub fn confidence(
&self,
action: &str,
_target: Option<&str>,
prev: Option<&str>,
prev_prev: Option<&str>,
) -> Option<f64> {
let context = ScoreContext {
prev_action: prev.map(String::from),
prev_prev_action: prev_prev.map(String::from),
additional: HashMap::new(),
};
self.compute_confidence(action, &context)
}
pub fn transition(&self, prev: &str, action: &str, target: Option<&str>) -> Option<f64> {
self.get_transition(prev, action, target)
}
pub fn contextual(&self, prev: &str, action: &str, target: Option<&str>) -> Option<f64> {
self.get_contextual(prev, action, target)
}
pub fn ngram(
&self,
prev_prev: &str,
prev: &str,
action: &str,
target: Option<&str>,
) -> Option<f64> {
if let Some(t) = target {
let key = Self::key3(prev_prev, prev, &Self::action_key(action, Some(t)));
if let Some(&score) = self.ngram_scores.get(&key) {
return Some(score);
}
}
let key = Self::key3(prev_prev, prev, action);
self.ngram_scores.get(&key).copied()
}
}
impl Model for ScoreModel {
fn model_type(&self) -> ModelType {
ModelType::ActionScore
}
fn version(&self) -> &ModelVersion {
&self.version
}
fn created_at(&self) -> u64 {
self.created_at
}
fn metadata(&self) -> &ModelMetadata {
&self.metadata
}
fn as_any(&self) -> &dyn Any {
self
}
}
impl Scorable for ScoreModel {
fn score(&self, query: &ScoreQuery) -> Option<f64> {
match query {
ScoreQuery::Transition {
prev_action,
action,
target,
} => self.get_transition(prev_action, action, target.as_deref()),
ScoreQuery::Contextual {
prev_action,
action,
target,
} => self.get_contextual(prev_action, action, target.as_deref()),
ScoreQuery::Ngram { actions, .. } => {
if actions.len() == 3 {
let key = Self::key3(&actions[0], &actions[1], &actions[2]);
self.ngram_scores.get(&key).copied()
} else {
None
}
}
ScoreQuery::Confidence {
action, context, ..
} => self.compute_confidence(action, context),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learn::stats::{ContextualActionStats, LearnStats};
fn create_test_stats() -> LearnStats {
let mut stats = LearnStats::default();
stats
.episode_transitions
.success_transitions
.insert(("A".to_string(), "B".to_string()), 10);
stats
.episode_transitions
.failure_transitions
.insert(("A".to_string(), "B".to_string()), 2);
stats.contextual_stats.insert(
("A".to_string(), "B".to_string()),
ContextualActionStats {
visits: 12,
successes: 10,
failures: 2,
},
);
stats
.ngram_stats
.trigrams
.insert(("A".to_string(), "B".to_string(), "C".to_string()), (9, 1));
stats
}
#[test]
fn test_score_model_from_stats() {
let stats = create_test_stats();
let model = ScoreModel::from_stats(&stats);
assert!(!model.is_empty());
assert!(model.transition_scores.contains_key("A->B"));
}
#[test]
fn test_scorable_trait() {
let stats = create_test_stats();
let model = ScoreModel::from_stats(&stats);
let score = model.score(&ScoreQuery::transition("A", "B", None));
assert!(score.is_some());
let score = model.score(&ScoreQuery::contextual("A", "B", None));
assert!(score.is_some());
let score = model.score(&ScoreQuery::ngram(
vec!["A".to_string(), "B".to_string(), "C".to_string()],
None,
));
assert!(score.is_some());
let score = model.score(&ScoreQuery::ngram(
vec!["A".to_string(), "B".to_string()],
None,
));
assert!(score.is_none());
let ctx = ScoreContext::new().with_prev("A").with_prev_prev("X");
let score = model.score(&ScoreQuery::confidence("B", None, ctx));
assert!(score.is_some());
}
#[test]
fn test_provider_compat_methods() {
let stats = create_test_stats();
let model = ScoreModel::from_stats(&stats);
let score = model.transition("A", "B", None);
assert!(score.is_some());
let score = model.contextual("A", "B", None);
assert!(score.is_some());
let score = model.ngram("A", "B", "C", None);
assert!(score.is_some());
let score = model.confidence("B", None, Some("A"), None);
assert!(score.is_some());
}
#[test]
fn test_score_context_builder() {
let ctx = ScoreContext::new()
.with_prev("A")
.with_prev_prev("B")
.with_additional("key", "value");
assert_eq!(ctx.prev_action.as_deref(), Some("A"));
assert_eq!(ctx.prev_prev_action.as_deref(), Some("B"));
assert_eq!(ctx.additional.get("key").map(|s| s.as_str()), Some("value"));
}
}