use super::learned_component::{LearnedDepGraph, LearnedExploration, LearnedStrategy};
use super::offline::{
ActionOrderSource, LearnedActionOrder, OfflineModel, OptimalParameters, StrategyConfig,
};
use super::scenario_profile::ScenarioProfile;
pub fn profile_to_offline_model(profile: &ScenarioProfile) -> OfflineModel {
let mut model = OfflineModel::default();
if let Some(exploration) = &profile.exploration {
model.parameters = OptimalParameters {
ucb1_c: exploration.ucb1_c,
learning_weight: exploration.learning_weight,
ngram_weight: exploration.ngram_weight,
};
}
if let Some(strategy) = &profile.strategy {
model.strategy_config = StrategyConfig {
initial_strategy: strategy.initial_strategy.clone(),
maturity_threshold: strategy.maturity_threshold as u32,
error_rate_threshold: strategy.error_rate_threshold,
};
}
if let Some(dep_graph) = &profile.dep_graph {
let has_new_format =
!dep_graph.discover_order.is_empty() || !dep_graph.not_discover_order.is_empty();
if has_new_format || !dep_graph.action_order.is_empty() {
let (discover, not_discover) = if has_new_format {
(
dep_graph.discover_order.clone(),
dep_graph.not_discover_order.clone(),
)
} else {
(dep_graph.action_order.clone(), Vec::new())
};
let mut all_actions = discover.clone();
all_actions.extend(not_discover.clone());
let action_set_hash = LearnedActionOrder::compute_hash(&all_actions);
model.action_order = Some(LearnedActionOrder {
discover,
not_discover,
action_set_hash,
source: ActionOrderSource::Manual, lora: None,
validated_accuracy: None,
});
}
model.recommended_paths = dep_graph.recommended_paths.clone();
}
model.updated_at = profile.updated_at;
model.analyzed_sessions = profile
.dep_graph
.as_ref()
.map(|d| d.learned_from.len())
.unwrap_or(0);
model
}
pub fn offline_model_to_components(
model: &OfflineModel,
) -> (
Option<LearnedDepGraph>,
Option<LearnedExploration>,
Option<LearnedStrategy>,
) {
let exploration = Some(LearnedExploration {
ucb1_c: model.parameters.ucb1_c,
learning_weight: model.parameters.learning_weight,
ngram_weight: model.parameters.ngram_weight,
confidence: 0.8, session_count: model.analyzed_sessions,
updated_at: model.updated_at,
});
let strategy = Some(LearnedStrategy {
initial_strategy: model.strategy_config.initial_strategy.clone(),
maturity_threshold: model.strategy_config.maturity_threshold as usize,
error_rate_threshold: model.strategy_config.error_rate_threshold,
confidence: 0.8,
session_count: model.analyzed_sessions,
updated_at: model.updated_at,
});
let dep_graph = model.action_order.as_ref().map(|order| {
use crate::exploration::DependencyGraph;
let mut all_actions = order.discover.clone();
all_actions.extend(order.not_discover.clone());
LearnedDepGraph {
graph: DependencyGraph::new(), action_order: all_actions,
discover_order: order.discover.clone(),
not_discover_order: order.not_discover.clone(),
recommended_paths: model.recommended_paths.clone(),
confidence: 0.8, learned_from: Vec::new(),
updated_at: model.updated_at,
}
});
(dep_graph, exploration, strategy)
}
pub fn migrate_offline_model_to_profile(
profile_id: impl Into<String>,
scenario_path: impl Into<std::path::PathBuf>,
model: &OfflineModel,
) -> ScenarioProfile {
use super::scenario_profile::{ProfileState, ScenarioSource};
let (dep_graph, exploration, strategy) = offline_model_to_components(model);
let mut profile =
ScenarioProfile::new(profile_id, ScenarioSource::from_path(scenario_path.into()));
profile.dep_graph = dep_graph;
profile.exploration = exploration;
profile.strategy = strategy;
profile.state = ProfileState::Active; profile.updated_at = model.updated_at;
profile
}
pub trait ProfileToOfflineModel {
fn to_offline_model(&self) -> OfflineModel;
}
impl ProfileToOfflineModel for ScenarioProfile {
fn to_offline_model(&self) -> OfflineModel {
profile_to_offline_model(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::exploration::DependencyGraph;
use crate::learn::scenario_profile::ScenarioSource;
#[test]
fn test_profile_to_offline_model_empty() {
let profile = ScenarioProfile::new("test", ScenarioSource::from_path("/test.toml"));
let model = profile_to_offline_model(&profile);
assert!(model.parameters.ucb1_c > 0.0);
}
#[test]
fn test_profile_to_offline_model_with_components() {
let mut profile = ScenarioProfile::new("test", ScenarioSource::from_path("/test.toml"));
profile.exploration = Some(LearnedExploration {
ucb1_c: 2.5,
learning_weight: 0.4,
ngram_weight: 1.2,
confidence: 0.9,
session_count: 10,
updated_at: 12345,
});
profile.strategy = Some(LearnedStrategy {
initial_strategy: "greedy".to_string(),
maturity_threshold: 10,
error_rate_threshold: 0.3,
confidence: 0.85,
session_count: 10,
updated_at: 12345,
});
profile.dep_graph = Some(
LearnedDepGraph::new(
DependencyGraph::new(),
vec!["A".to_string(), "B".to_string()],
)
.with_confidence(0.95),
);
let model = profile_to_offline_model(&profile);
assert_eq!(model.parameters.ucb1_c, 2.5);
assert_eq!(model.parameters.learning_weight, 0.4);
assert_eq!(model.strategy_config.initial_strategy, "greedy");
assert_eq!(model.strategy_config.maturity_threshold, 10);
assert!(model.action_order.is_some());
assert_eq!(model.action_order.as_ref().unwrap().discover.len(), 2);
}
#[test]
fn test_offline_model_to_components() {
use super::ActionOrderSource;
let mut model = OfflineModel::default();
model.parameters.ucb1_c = 1.8;
model.strategy_config.initial_strategy = "ucb1".to_string();
model.action_order = Some(LearnedActionOrder {
discover: vec!["X".to_string(), "Y".to_string()],
not_discover: vec![],
action_set_hash: 12345,
source: ActionOrderSource::Manual,
lora: None,
validated_accuracy: None,
});
model.analyzed_sessions = 5;
let (dep_graph, exploration, strategy) = offline_model_to_components(&model);
assert!(dep_graph.is_some());
assert!(exploration.is_some());
assert!(strategy.is_some());
let exploration = exploration.unwrap();
assert_eq!(exploration.ucb1_c, 1.8);
let strategy = strategy.unwrap();
assert_eq!(strategy.initial_strategy, "ucb1");
let dep_graph = dep_graph.unwrap();
assert_eq!(dep_graph.action_order.len(), 2);
}
#[test]
fn test_migrate_offline_model_to_profile() {
let mut model = OfflineModel::default();
model.parameters.ucb1_c = 2.0;
model.analyzed_sessions = 10;
let profile =
migrate_offline_model_to_profile("test-profile", "/path/to/scenario.toml", &model);
assert_eq!(profile.id.0, "test-profile");
assert!(profile.exploration.is_some());
assert_eq!(profile.exploration.as_ref().unwrap().ucb1_c, 2.0);
}
#[test]
fn test_profile_to_offline_model_trait() {
let profile = ScenarioProfile::new("test", ScenarioSource::from_path("/test.toml"));
let model = profile.to_offline_model();
assert!(model.parameters.ucb1_c > 0.0);
}
}