use std::any::Any;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::base::{Model, ModelMetadata, ModelType, ModelVersion};
use crate::learn::offline::{OfflineModel, RecommendedPath, StrategyConfig};
use crate::util::epoch_millis;
pub trait Parametric: Model {
fn get_param(&self, key: &str) -> Option<ParamValue>;
fn all_params(&self) -> HashMap<String, ParamValue>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ParamValue {
Float(f64),
Int(i64),
Bool(bool),
String(String),
Array(Vec<ParamValue>),
}
impl ParamValue {
pub fn as_f64(&self) -> Option<f64> {
match self {
Self::Float(v) => Some(*v),
Self::Int(v) => Some(*v as f64),
_ => None,
}
}
pub fn as_i64(&self) -> Option<i64> {
match self {
Self::Int(v) => Some(*v),
Self::Float(v) => Some(*v as i64),
_ => None,
}
}
pub fn as_bool(&self) -> Option<bool> {
match self {
Self::Bool(v) => Some(*v),
_ => None,
}
}
pub fn as_str(&self) -> Option<&str> {
match self {
Self::String(v) => Some(v),
_ => None,
}
}
}
impl From<f64> for ParamValue {
fn from(v: f64) -> Self {
Self::Float(v)
}
}
impl From<i64> for ParamValue {
fn from(v: i64) -> Self {
Self::Int(v)
}
}
impl From<bool> for ParamValue {
fn from(v: bool) -> Self {
Self::Bool(v)
}
}
impl From<String> for ParamValue {
fn from(v: String) -> Self {
Self::String(v)
}
}
impl From<&str> for ParamValue {
fn from(v: &str) -> Self {
Self::String(v.to_string())
}
}
pub mod param_keys {
pub const UCB1_C: &str = "ucb1_c";
pub const LEARNING_WEIGHT: &str = "learning_weight";
pub const NGRAM_WEIGHT: &str = "ngram_weight";
pub const MATURITY_THRESHOLD: &str = "maturity_threshold";
pub const ERROR_RATE_THRESHOLD: &str = "error_rate_threshold";
pub const INITIAL_STRATEGY: &str = "initial_strategy";
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimalParamsModel {
version: ModelVersion,
metadata: ModelMetadata,
created_at: u64,
params: HashMap<String, ParamValue>,
pub strategy_config: StrategyConfig,
pub recommended_paths: Vec<RecommendedPath>,
pub analyzed_sessions: usize,
}
impl Default for OptimalParamsModel {
fn default() -> Self {
let mut params = HashMap::new();
params.insert(
param_keys::UCB1_C.to_string(),
ParamValue::Float(std::f64::consts::SQRT_2),
);
params.insert(
param_keys::LEARNING_WEIGHT.to_string(),
ParamValue::Float(0.3),
);
params.insert(param_keys::NGRAM_WEIGHT.to_string(), ParamValue::Float(1.0));
Self {
version: ModelVersion::new(1, 0),
metadata: ModelMetadata::default(),
created_at: epoch_millis(),
params,
strategy_config: StrategyConfig::default(),
recommended_paths: Vec::new(),
analyzed_sessions: 0,
}
}
}
impl OptimalParamsModel {
pub fn new() -> Self {
Self::default()
}
pub fn set_param(&mut self, key: &str, value: impl Into<ParamValue>) {
self.params.insert(key.to_string(), value.into());
}
pub fn ucb1_c(&self) -> f64 {
self.get_param(param_keys::UCB1_C)
.and_then(|v| v.as_f64())
.unwrap_or(std::f64::consts::SQRT_2)
}
pub fn learning_weight(&self) -> f64 {
self.get_param(param_keys::LEARNING_WEIGHT)
.and_then(|v| v.as_f64())
.unwrap_or(0.3)
}
pub fn ngram_weight(&self) -> f64 {
self.get_param(param_keys::NGRAM_WEIGHT)
.and_then(|v| v.as_f64())
.unwrap_or(1.0)
}
}
impl Model for OptimalParamsModel {
fn model_type(&self) -> ModelType {
ModelType::OptimalParams
}
fn version(&self) -> &ModelVersion {
&self.version
}
fn created_at(&self) -> u64 {
self.created_at
}
fn metadata(&self) -> &ModelMetadata {
&self.metadata
}
fn as_any(&self) -> &dyn Any {
self
}
}
impl Parametric for OptimalParamsModel {
fn get_param(&self, key: &str) -> Option<ParamValue> {
self.params.get(key).cloned()
}
fn all_params(&self) -> HashMap<String, ParamValue> {
self.params.clone()
}
}
impl From<OfflineModel> for OptimalParamsModel {
fn from(old: OfflineModel) -> Self {
let mut params = HashMap::new();
params.insert(
param_keys::UCB1_C.to_string(),
ParamValue::Float(old.parameters.ucb1_c),
);
params.insert(
param_keys::LEARNING_WEIGHT.to_string(),
ParamValue::Float(old.parameters.learning_weight),
);
params.insert(
param_keys::NGRAM_WEIGHT.to_string(),
ParamValue::Float(old.parameters.ngram_weight),
);
params.insert(
param_keys::MATURITY_THRESHOLD.to_string(),
ParamValue::Int(old.strategy_config.maturity_threshold as i64),
);
params.insert(
param_keys::ERROR_RATE_THRESHOLD.to_string(),
ParamValue::Float(old.strategy_config.error_rate_threshold),
);
params.insert(
param_keys::INITIAL_STRATEGY.to_string(),
ParamValue::String(old.strategy_config.initial_strategy.clone()),
);
Self {
version: ModelVersion::new(old.version, 0),
metadata: ModelMetadata::default(),
created_at: old.updated_at,
params,
strategy_config: old.strategy_config,
recommended_paths: old.recommended_paths,
analyzed_sessions: old.analyzed_sessions,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_param_value_conversions() {
let f = ParamValue::Float(1.5);
assert_eq!(f.as_f64(), Some(1.5));
assert_eq!(f.as_i64(), Some(1));
let i = ParamValue::Int(42);
assert_eq!(i.as_i64(), Some(42));
assert_eq!(i.as_f64(), Some(42.0));
let b = ParamValue::Bool(true);
assert_eq!(b.as_bool(), Some(true));
let s = ParamValue::String("test".to_string());
assert_eq!(s.as_str(), Some("test"));
}
#[test]
fn test_optimal_params_model_default() {
let model = OptimalParamsModel::new();
assert!((model.ucb1_c() - std::f64::consts::SQRT_2).abs() < 1e-10);
assert!((model.learning_weight() - 0.3).abs() < 1e-10);
assert!((model.ngram_weight() - 1.0).abs() < 1e-10);
}
#[test]
fn test_optimal_params_model_set_param() {
let mut model = OptimalParamsModel::new();
model.set_param(param_keys::UCB1_C, 2.0);
assert!((model.ucb1_c() - 2.0).abs() < 1e-10);
}
#[test]
fn test_parametric_trait() {
let model = OptimalParamsModel::new();
let value = model.get_param(param_keys::UCB1_C);
assert!(value.is_some());
let all = model.all_params();
assert!(all.contains_key(param_keys::UCB1_C));
assert!(all.contains_key(param_keys::LEARNING_WEIGHT));
}
#[test]
fn test_from_offline_model() {
use crate::learn::offline::{OfflineModel, OptimalParameters, StrategyConfig};
let old = OfflineModel {
version: 2,
parameters: OptimalParameters {
ucb1_c: 1.5,
learning_weight: 0.4,
ngram_weight: 1.2,
},
recommended_paths: vec![],
strategy_config: StrategyConfig::default(),
analyzed_sessions: 5,
updated_at: 12345,
action_order: None,
};
let model: OptimalParamsModel = old.into();
assert!((model.ucb1_c() - 1.5).abs() < 1e-10);
assert!((model.learning_weight() - 0.4).abs() < 1e-10);
assert!((model.ngram_weight() - 1.2).abs() < 1e-10);
assert_eq!(model.analyzed_sessions, 5);
}
}