use aip::edm::core::{EmotionDataModel, EmotionState};
use aip::edm::roguelite::core::RogueliteEdm;
use aip::director::core::{InteractionStrategy, InteractionState, EmotionStats};
use aip::director::roguelite::core::RogueliteDirector;
use candle_core::Device;
use std::collections::HashMap;
fn create_test_features() -> HashMap<u32, f32> {
let mut features = HashMap::new();
features.insert(0, 0.6); features.insert(1, 0.7); features.insert(2, 0.3); features.insert(3, 0.4); features.insert(4, 0.2); features.insert(5, 0.1); features.insert(6, 0.3); features.insert(7, 0.8); features.insert(8, 0.1); features.insert(9, 0.9); features.insert(10, 0.5); features.insert(11, 0.2); features.insert(12, 0.85); features.insert(13, 0.4); features.insert(14, 0.15); features
}
fn create_test_state(emotion: EmotionState) -> InteractionState {
let mut user_traits = HashMap::new();
for i in 0..8 {
user_traits.insert(i, 0.5);
}
let mut env_state = HashMap::new();
for i in 0..6 {
env_state.insert(i, 0.5);
}
InteractionState {
user_traits,
env_state,
emotion,
emotion_stats: EmotionStats::default(),
}
}
#[test]
fn test_edm_director_pipeline() {
let device = Device::Cpu;
let edm = RogueliteEdm::new(device.clone()).unwrap();
let director = RogueliteDirector::new(device).unwrap();
let features = create_test_features();
let emotion = edm.infer(&features).unwrap();
assert!(emotion.valence >= 0.0 && emotion.valence <= 1.0);
assert!(emotion.arousal >= 0.0 && emotion.arousal <= 1.0);
assert!(emotion.dominance >= 0.0 && emotion.dominance <= 1.0);
let state = create_test_state(emotion);
let params = director.decide(&state).unwrap();
assert!(params.intensity_factor >= 0.5 && params.intensity_factor <= 2.0);
assert!(params.feedback_intensity >= 0.3 && params.feedback_intensity <= 1.5);
assert!(params.pace_speed >= 0.6 && params.pace_speed <= 1.8);
}
#[test]
fn test_emotion_state_propagation() {
let device = Device::Cpu;
let edm = RogueliteEdm::new(device.clone()).unwrap();
let director = RogueliteDirector::new(device).unwrap();
let mut high_arousal_features = create_test_features();
high_arousal_features.insert(0, 0.9);
high_arousal_features.insert(1, 0.95);
let emotion_high = edm.infer(&high_arousal_features).unwrap();
assert!(emotion_high.valence >= 0.0 && emotion_high.valence <= 1.0);
assert!(emotion_high.arousal >= 0.0 && emotion_high.arousal <= 1.0);
assert!(emotion_high.dominance >= 0.0 && emotion_high.dominance <= 1.0);
let state_high = create_test_state(emotion_high);
let params_high = director.decide(&state_high).unwrap();
assert!(params_high.intensity_factor >= 0.5 && params_high.intensity_factor <= 2.0);
assert!(params_high.feedback_intensity >= 0.3 && params_high.feedback_intensity <= 1.5);
assert!(params_high.pace_speed >= 0.6 && params_high.pace_speed <= 1.8);
let mut low_arousal_features = create_test_features();
low_arousal_features.insert(0, 0.1);
low_arousal_features.insert(1, 0.1);
let emotion_low = edm.infer(&low_arousal_features).unwrap();
assert!(emotion_low.valence >= 0.0 && emotion_low.valence <= 1.0);
assert!(emotion_low.arousal >= 0.0 && emotion_low.arousal <= 1.0);
assert!(emotion_low.dominance >= 0.0 && emotion_low.dominance <= 1.0);
let state_low = create_test_state(emotion_low);
let params_low = director.decide(&state_low).unwrap();
assert!(params_low.intensity_factor >= 0.5 && params_low.intensity_factor <= 2.0);
assert!(params_low.feedback_intensity >= 0.3 && params_low.feedback_intensity <= 1.5);
assert!(params_low.pace_speed >= 0.6 && params_low.pace_speed <= 1.8);
}
#[test]
fn test_inference_latency() {
use std::time::Instant;
let device = Device::Cpu;
let edm = RogueliteEdm::new(device.clone()).unwrap();
let director = RogueliteDirector::new(device).unwrap();
let features = create_test_features();
let emotion = edm.infer(&features).unwrap();
let state = create_test_state(emotion);
let iterations = 100;
let mut total_time_ms = 0.0;
for _ in 0..iterations {
let start = Instant::now();
let _ = edm.infer(&features).unwrap();
let _ = director.decide(&state).unwrap();
total_time_ms += start.elapsed().as_secs_f64() * 1000.0;
}
let avg_latency = total_time_ms / iterations as f64;
assert!(avg_latency < 10.0, "Average latency {}ms exceeds 10ms target", avg_latency);
}