use aip::edm::core::{EmotionDataModel, EmotionDataModelTrainer, EmotionState, TrainingDataset, TrainingSample};
use aip::edm::roguelite::core::RogueliteEdm;
use aip::edm::roguelite::training::RogueliteEdmTrainer;
use aip::director::core::{InteractionStrategy, InteractionStrategyTrainer, InteractionState, EmotionStats, Trajectory, TrajectoryStep, InteractionParams};
use aip::director::roguelite::core::RogueliteDirector;
use aip::director::roguelite::training::RogueliteDirectorTrainer;
use candle_core::Device;
use std::collections::HashMap;
fn get_model_dir() -> std::path::PathBuf {
let mut path = std::env::current_dir().unwrap();
path.push("models");
if !path.exists() {
std::fs::create_dir_all(&path).unwrap();
}
path
}
fn generate_mock_edm_samples(count: usize) -> Vec<TrainingSample> {
let mut samples = Vec::with_capacity(count);
for i in 0..count {
let mut features = HashMap::new();
let base_val = (i as f32) / (count as f32);
features.insert(0, 0.3 + base_val * 0.4); features.insert(1, 0.2 + base_val * 0.5); features.insert(2, 0.1 + (i % 5) as f32 * 0.1); features.insert(3, 0.2 + base_val * 0.3); features.insert(4, 0.1 + (i % 3) as f32 * 0.1); features.insert(5, 0.05 + base_val * 0.2); features.insert(6, 0.1 + base_val * 0.3); features.insert(7, 0.5 + base_val * 0.3); features.insert(8, 0.05 + (i % 4) as f32 * 0.05); features.insert(9, 0.6 + base_val * 0.3); features.insert(10, 0.3 + base_val * 0.4); features.insert(11, 0.1 + base_val * 0.2); features.insert(12, 0.7 + base_val * 0.2); features.insert(13, 0.2 + base_val * 0.3); features.insert(14, 0.05 + base_val * 0.15);
let valence = 0.3 + features[&0] * 0.4 + features[&7] * 0.2;
let arousal = 0.2 + features[&1] * 0.5 + features[&0] * 0.3;
let dominance = 0.3 + features[&12] * 0.3 + features[&7] * 0.2;
let emotion = EmotionState::new(valence, arousal, dominance);
samples.push(TrainingSample { features, emotion });
}
samples
}
fn generate_mock_trajectories(count: usize, steps_per_trajectory: usize) -> Vec<Trajectory> {
let mut trajectories = Vec::with_capacity(count);
for t in 0..count {
let mut steps = Vec::with_capacity(steps_per_trajectory);
for s in 0..steps_per_trajectory {
let mut user_traits: HashMap<u32, f32> = HashMap::new();
for i in 0u32..8 {
user_traits.insert(i, 0.3 + ((t + s + i as usize) as f32 % 100.0) / 100.0 * 0.4);
}
let mut env_state: HashMap<u32, f32> = HashMap::new();
for i in 0u32..6 {
env_state.insert(i, 0.2 + ((t * s + i as usize) as f32 % 100.0) / 100.0 * 0.6);
}
let emotion = EmotionState::new(
0.3 + ((t + s) as f32 % 100.0) / 100.0 * 0.4,
0.2 + ((t * 2 + s) as f32 % 100.0) / 100.0 * 0.5,
0.3 + ((t + s * 2) as f32 % 100.0) / 100.0 * 0.4,
);
let progress = (s as f32 + 1.0) / steps_per_trajectory as f32;
let emotion_improvement = if s > 0 { 0.1 } else { 0.0 };
let retention = s < steps_per_trajectory - 1;
let reward = aip::director::roguelite::training::RogueliteDirectorTrainer::compute_reward(
progress,
emotion_improvement,
retention,
);
let action = InteractionParams::default();
steps.push(TrajectoryStep {
state: InteractionState {
user_traits,
env_state,
emotion,
emotion_stats: EmotionStats::default(),
},
action,
reward,
});
}
trajectories.push(Trajectory { steps });
}
trajectories
}
#[test]
fn test_edm_full_pipeline() {
let device = Device::Cpu;
println!("\n=== EDM Full Pipeline Test ===\n");
println!("1. Generating mock training data...");
let samples = generate_mock_edm_samples(1000);
let dataset = TrainingDataset::new(samples);
println!(" Generated {} training samples (simulating ~100 players, 10 sessions each)", dataset.samples().len());
println!("\n2. Creating EDM trainer...");
let mut trainer = RogueliteEdmTrainer::new(device.clone()).unwrap();
println!("\n3. Training model (20 epochs)...");
let mut best_loss = f32::MAX;
for epoch in 0..20 {
let loss = trainer.train_epoch(&dataset).unwrap();
if loss < best_loss {
best_loss = loss;
}
if epoch % 5 == 0 {
println!(" Epoch {}: loss = {:.4}", epoch, loss);
}
}
println!(" Best loss: {:.4}", best_loss);
println!("\n4. Saving model...");
let model_dir = get_model_dir();
let model_path = model_dir.join("test_edm_model.safetensors");
trainer.save(&model_path).unwrap();
println!(" Model saved to {:?}", model_path);
println!("\n5. Loading model into new instance...");
let mut loaded_model = RogueliteEdm::new(device).unwrap();
EmotionDataModel::load(&mut loaded_model, &model_path).unwrap();
println!(" Model loaded successfully");
println!("\n6. Testing inference...");
let test_features = generate_mock_edm_samples(5);
for (i, sample) in test_features.iter().enumerate() {
let predicted = loaded_model.infer(&sample.features).unwrap();
let target = &sample.emotion;
println!(" Sample {}: predicted=({:.3}, {:.3}, {:.3}), target=({:.3}, {:.3}, {:.3})",
i,
predicted.valence, predicted.arousal, predicted.dominance,
target.valence, target.arousal, target.dominance
);
}
println!("\n=== EDM Pipeline Test Complete ===\n");
}
#[test]
fn test_director_full_pipeline() {
let device = Device::Cpu;
println!("\n=== Director Full Pipeline Test ===\n");
println!("1. Generating mock trajectories...");
let trajectories = generate_mock_trajectories(100, 50);
println!(" Generated {} trajectories with {} steps each (simulating 100 players, 50 rooms per run)",
trajectories.len(), trajectories.first().map(|t| t.steps.len()).unwrap_or(0));
println!("\n2. Creating Director trainer...");
let mut trainer = RogueliteDirectorTrainer::new(device.clone()).unwrap();
println!("\n3. Training model (20 epochs)...");
let mut best_loss = f32::MAX;
for epoch in 0..20 {
let loss = trainer.train_epoch(&trajectories).unwrap();
if loss < best_loss {
best_loss = loss;
}
if epoch % 5 == 0 {
println!(" Epoch {}: loss = {:.4}", epoch, loss);
}
}
println!(" Best loss: {:.4}", best_loss);
println!("\n4. Saving model...");
let model_dir = get_model_dir();
let model_path = model_dir.join("test_director_model.safetensors");
trainer.save(&model_path).unwrap();
println!(" Model saved to {:?}", model_path);
println!("\n5. Loading model into new instance...");
let mut loaded_model = RogueliteDirector::new(device).unwrap();
InteractionStrategy::load(&mut loaded_model, &model_path).unwrap();
println!(" Model loaded successfully");
println!("\n6. Testing decision making...");
let test_state = InteractionState {
user_traits: HashMap::from([
(0, 0.5), (1, 0.6), (2, 0.4), (3, 0.5),
(4, 0.5), (5, 0.5), (6, 0.5), (7, 0.5),
]),
env_state: HashMap::from([
(0, 0.5), (1, 0.5), (2, 0.5),
(3, 0.5), (4, 0.5), (5, 0.5),
]),
emotion: EmotionState::new(0.5, 0.5, 0.5),
emotion_stats: EmotionStats::default(),
};
let params = loaded_model.decide(&test_state).unwrap();
println!(" Decision parameters:");
println!(" intensity_factor: {:.3}", params.intensity_factor);
println!(" feedback_intensity: {:.3}", params.feedback_intensity);
println!(" pace_speed: {:.3}", params.pace_speed);
println!(" reward_scarcity: {:.3}", params.reward_scarcity);
println!(" env_arousal: {:.3}", params.env_arousal);
println!(" rhythm_modulation: {:.3}", params.rhythm_modulation);
println!(" challenge_curve: {:.3}", params.challenge_curve);
println!("\n=== Director Pipeline Test Complete ===\n");
}
#[test]
fn test_end_to_end_pipeline() {
let device = Device::Cpu;
println!("\n=== End-to-End Pipeline Test ===\n");
println!("1. Training EDM...");
let samples = generate_mock_edm_samples(500);
let dataset = TrainingDataset::new(samples);
let mut edm_trainer = RogueliteEdmTrainer::new(device.clone()).unwrap();
for _ in 0..10 {
edm_trainer.train_epoch(&dataset).unwrap();
}
let edm = edm_trainer.to_model();
println!(" EDM trained with 500 samples");
println!("\n2. Training Director...");
let trajectories = generate_mock_trajectories(50, 30);
let mut director_trainer = RogueliteDirectorTrainer::new(device.clone()).unwrap();
for _ in 0..10 {
director_trainer.train_epoch(&trajectories).unwrap();
}
let director = director_trainer.to_model();
println!(" Director trained with 50 trajectories x 30 steps");
println!("\n3. Simulating game loop (10 iterations)...");
let mut features = HashMap::new();
for i in 0..15 {
features.insert(i as u32, 0.5);
}
for i in 0..10 {
let emotion = edm.infer(&features).unwrap();
let state = InteractionState {
user_traits: HashMap::from([
(0, 0.5), (1, 0.5), (2, 0.5), (3, 0.5),
(4, 0.5), (5, 0.5), (6, 0.5), (7, 0.5),
]),
env_state: HashMap::from([
(0, 0.5), (1, 0.5), (2, 0.5),
(3, 0.5), (4, 0.5), (5, 0.5),
]),
emotion,
emotion_stats: EmotionStats::default(),
};
let params = director.decide(&state).unwrap();
if i % 2 == 0 {
println!(" Iteration {}: emotion=({:.2}, {:.2}, {:.2}), intensity={:.2}, pace={:.2}",
i,
emotion.valence, emotion.arousal, emotion.dominance,
params.intensity_factor, params.pace_speed
);
}
features.insert(0, params.intensity_factor / 2.0);
features.insert(1, params.pace_speed / 2.0);
}
println!("\n=== End-to-End Pipeline Test Complete ===\n");
}
#[test]
fn test_edm_model_save_load() {
let device = Device::Cpu;
let model_dir = get_model_dir();
let model_path = model_dir.join("test_edm_save_load.safetensors");
println!("\n=== EDM Model Save/Load Test ===\n");
println!("1. Creating and training original model...");
let mut trainer = RogueliteEdmTrainer::new(device.clone()).unwrap();
let samples = generate_mock_edm_samples(500);
let dataset = TrainingDataset::new(samples);
for _ in 0..10 {
trainer.train_epoch(&dataset).unwrap();
}
let original_model = trainer.to_model();
println!(" Original model trained with 500 samples");
println!("\n2. Testing original model inference...");
let test_features = generate_mock_edm_samples(3);
let original_results: Vec<EmotionState> = test_features.iter()
.map(|s| original_model.infer(&s.features).unwrap())
.collect();
for (i, result) in original_results.iter().enumerate() {
println!(" Sample {}: ({:.3}, {:.3}, {:.3})", i, result.valence, result.arousal, result.dominance);
}
println!("\n3. Saving model to disk...");
original_model.save(&model_path).unwrap();
let metadata = std::fs::metadata(&model_path).unwrap();
println!(" Model saved, file size: {} bytes", metadata.len());
println!("\n4. Loading model into new instance...");
let mut loaded_model = RogueliteEdm::new(device).unwrap();
EmotionDataModel::load(&mut loaded_model, &model_path).unwrap();
println!(" Model loaded successfully");
println!("\n5. Verifying loaded model produces same results...");
for (i, sample) in test_features.iter().enumerate() {
let loaded_result = loaded_model.infer(&sample.features).unwrap();
let original = &original_results[i];
let valence_diff = (loaded_result.valence - original.valence).abs();
let arousal_diff = (loaded_result.arousal - original.arousal).abs();
let dominance_diff = (loaded_result.dominance - original.dominance).abs();
println!(" Sample {}: diffs=({:.6}, {:.6}, {:.6})", i, valence_diff, arousal_diff, dominance_diff);
assert!(valence_diff < 1e-5, "Valence mismatch");
assert!(arousal_diff < 1e-5, "Arousal mismatch");
assert!(dominance_diff < 1e-5, "Dominance mismatch");
}
let _ = std::fs::remove_file(&model_path);
println!("\n=== EDM Model Save/Load Test Complete ===\n");
}
#[test]
fn test_director_model_save_load() {
let device = Device::Cpu;
let model_dir = get_model_dir();
let model_path = model_dir.join("test_director_save_load.safetensors");
println!("\n=== Director Model Save/Load Test ===\n");
println!("1. Creating and training original model...");
let mut trainer = RogueliteDirectorTrainer::new(device.clone()).unwrap();
let trajectories = generate_mock_trajectories(50, 30);
for _ in 0..10 {
trainer.train_epoch(&trajectories).unwrap();
}
let original_model = trainer.to_model();
println!(" Original model trained with 50 trajectories x 30 steps");
println!("\n2. Testing original model decision...");
let test_state = InteractionState {
user_traits: HashMap::from([
(0, 0.6), (1, 0.7), (2, 0.4), (3, 0.5),
(4, 0.5), (5, 0.5), (6, 0.5), (7, 0.5),
]),
env_state: HashMap::from([
(0, 0.5), (1, 0.6), (2, 0.4),
(3, 0.5), (4, 0.5), (5, 0.5),
]),
emotion: EmotionState::new(0.6, 0.7, 0.5),
emotion_stats: EmotionStats::default(),
};
let original_result = original_model.decide(&test_state).unwrap();
println!(" Original decision: intensity={:.3}, pace={:.3}",
original_result.intensity_factor, original_result.pace_speed);
println!("\n3. Saving model to disk...");
original_model.save(&model_path).unwrap();
let metadata = std::fs::metadata(&model_path).unwrap();
println!(" Model saved, file size: {} bytes", metadata.len());
println!("\n4. Loading model into new instance...");
let mut loaded_model = RogueliteDirector::new(device).unwrap();
InteractionStrategy::load(&mut loaded_model, &model_path).unwrap();
println!(" Model loaded successfully");
println!("\n5. Verifying loaded model produces same results...");
let loaded_result = loaded_model.decide(&test_state).unwrap();
let intensity_diff = (loaded_result.intensity_factor - original_result.intensity_factor).abs();
let feedback_diff = (loaded_result.feedback_intensity - original_result.feedback_intensity).abs();
let pace_diff = (loaded_result.pace_speed - original_result.pace_speed).abs();
let reward_diff = (loaded_result.reward_scarcity - original_result.reward_scarcity).abs();
let arousal_diff = (loaded_result.env_arousal - original_result.env_arousal).abs();
let rhythm_diff = (loaded_result.rhythm_modulation - original_result.rhythm_modulation).abs();
let challenge_diff = (loaded_result.challenge_curve - original_result.challenge_curve).abs();
println!(" Parameter differences:");
println!(" intensity_factor: {:.6}", intensity_diff);
println!(" feedback_intensity: {:.6}", feedback_diff);
println!(" pace_speed: {:.6}", pace_diff);
println!(" reward_scarcity: {:.6}", reward_diff);
println!(" env_arousal: {:.6}", arousal_diff);
println!(" rhythm_modulation: {:.6}", rhythm_diff);
println!(" challenge_curve: {:.6}", challenge_diff);
assert!(intensity_diff < 1e-5, "Intensity mismatch");
assert!(feedback_diff < 1e-5, "Feedback mismatch");
assert!(pace_diff < 1e-5, "Pace mismatch");
assert!(reward_diff < 1e-5, "Reward mismatch");
assert!(arousal_diff < 1e-5, "Arousal mismatch");
assert!(rhythm_diff < 1e-5, "Rhythm mismatch");
assert!(challenge_diff < 1e-5, "Challenge mismatch");
let _ = std::fs::remove_file(&model_path);
println!("\n=== Director Model Save/Load Test Complete ===\n");
}
#[test]
fn test_model_persistence_across_restarts() {
let device = Device::Cpu;
let model_dir = get_model_dir();
let edm_path = model_dir.join("test_persistence_edm.safetensors");
let director_path = model_dir.join("test_persistence_director.safetensors");
println!("\n=== Model Persistence Test ===\n");
println!("Phase 1: Train and save models...");
let mut edm_trainer = RogueliteEdmTrainer::new(device.clone()).unwrap();
let samples = generate_mock_edm_samples(300);
let dataset = TrainingDataset::new(samples);
for _ in 0..5 {
edm_trainer.train_epoch(&dataset).unwrap();
}
let edm = edm_trainer.to_model();
edm.save(&edm_path).unwrap();
println!(" EDM saved (300 samples)");
let mut director_trainer = RogueliteDirectorTrainer::new(device.clone()).unwrap();
let trajectories = generate_mock_trajectories(30, 20);
for _ in 0..5 {
director_trainer.train_epoch(&trajectories).unwrap();
}
let director = director_trainer.to_model();
director.save(&director_path).unwrap();
println!(" Director saved (30 trajectories x 20 steps)");
println!("\nPhase 2: Simulate restart - load models into fresh instances...");
let mut edm_loaded = RogueliteEdm::new(device.clone()).unwrap();
EmotionDataModel::load(&mut edm_loaded, &edm_path).unwrap();
println!(" EDM loaded");
let mut director_loaded = RogueliteDirector::new(device).unwrap();
InteractionStrategy::load(&mut director_loaded, &director_path).unwrap();
println!(" Director loaded");
println!("\nPhase 3: Verify inference consistency...");
let test_features = generate_mock_edm_samples(5);
for (i, sample) in test_features.iter().enumerate() {
let emotion = edm_loaded.infer(&sample.features).unwrap();
let state = InteractionState {
user_traits: HashMap::from([
(0, 0.5), (1, 0.5), (2, 0.5), (3, 0.5),
(4, 0.5), (5, 0.5), (6, 0.5), (7, 0.5),
]),
env_state: HashMap::from([
(0, 0.5), (1, 0.5), (2, 0.5),
(3, 0.5), (4, 0.5), (5, 0.5),
]),
emotion,
emotion_stats: EmotionStats::default(),
};
let params = director_loaded.decide(&state).unwrap();
assert!(params.intensity_factor >= 0.5 && params.intensity_factor <= 2.0);
assert!(params.pace_speed >= 0.6 && params.pace_speed <= 1.8);
println!(" Sample {}: emotion=({:.2},{:.2},{:.2}), params=({:.2},{:.2})",
i, emotion.valence, emotion.arousal, emotion.dominance,
params.intensity_factor, params.pace_speed);
}
let _ = std::fs::remove_file(&edm_path);
let _ = std::fs::remove_file(&director_path);
println!("\n=== Model Persistence Test Complete ===\n");
}