use std::collections::HashMap;
use std::path::Path;
use thiserror::Error;
#[cfg(feature = "edm")]
use crate::edm::core::EmotionState;
#[derive(Debug, Clone)]
pub struct InteractionParams {
pub intensity_factor: f32,
pub feedback_intensity: f32,
pub pace_speed: f32,
pub reward_scarcity: f32,
pub env_arousal: f32,
pub rhythm_modulation: f32,
pub challenge_curve: f32,
}
impl Default for InteractionParams {
fn default() -> Self {
Self {
intensity_factor: 1.0,
feedback_intensity: 1.0,
pace_speed: 1.0,
reward_scarcity: 0.5,
env_arousal: 0.5,
rhythm_modulation: 1.0,
challenge_curve: 0.0,
}
}
}
impl InteractionParams {
pub fn clamp(&mut self) {
self.intensity_factor = self.intensity_factor.clamp(0.5, 2.0);
self.feedback_intensity = self.feedback_intensity.clamp(0.3, 1.5);
self.pace_speed = self.pace_speed.clamp(0.6, 1.8);
self.reward_scarcity = self.reward_scarcity.clamp(0.0, 1.0);
self.env_arousal = self.env_arousal.clamp(0.3, 1.0);
self.rhythm_modulation = self.rhythm_modulation.clamp(0.8, 1.5);
self.challenge_curve = self.challenge_curve.clamp(-1.0, 1.0);
}
pub fn to_vec(&self) -> Vec<f32> {
vec![
self.intensity_factor,
self.feedback_intensity,
self.pace_speed,
self.reward_scarcity,
self.env_arousal,
self.rhythm_modulation,
self.challenge_curve,
]
}
}
#[derive(Debug, Clone)]
pub struct EmotionStats {
pub valence_mean: f32,
pub valence_std: f32,
pub arousal_mean: f32,
pub arousal_std: f32,
pub dominance_mean: f32,
pub dominance_std: f32,
}
impl Default for EmotionStats {
fn default() -> Self {
Self {
valence_mean: 0.5,
valence_std: 0.1,
arousal_mean: 0.5,
arousal_std: 0.1,
dominance_mean: 0.5,
dominance_std: 0.1,
}
}
}
#[derive(Debug, Clone)]
pub struct InteractionState {
pub user_traits: HashMap<u32, f32>,
pub env_state: HashMap<u32, f32>,
#[cfg(feature = "edm")]
pub emotion: EmotionState,
pub emotion_stats: EmotionStats,
}
#[derive(Error, Debug)]
pub enum DirectorError {
#[error("Invalid state: {0}")]
InvalidState(String),
#[error("Model error: {0}")]
ModelError(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
}
pub trait InteractionStrategy: Send + Sync {
fn decide(&self, state: &InteractionState) -> Result<InteractionParams, DirectorError>;
fn load(&mut self, path: &Path) -> Result<(), DirectorError>;
}
#[derive(Debug, Clone)]
pub struct TrajectoryStep {
pub state: InteractionState,
pub action: InteractionParams,
pub reward: f32,
}
#[derive(Debug, Clone)]
pub struct Trajectory {
pub steps: Vec<TrajectoryStep>,
}
#[derive(Debug, Clone)]
pub struct TrainingResult {
pub mean_reward: f32,
pub episodes: usize,
}
pub trait InteractionStrategyTrainer: Send + Sync {
fn train(&mut self, trajectories: &[Trajectory]) -> Result<TrainingResult, DirectorError>;
fn save(&self, path: &Path) -> Result<(), DirectorError>;
fn load(&mut self, path: &Path) -> Result<(), DirectorError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interaction_params_default() {
let params = InteractionParams::default();
assert!((params.intensity_factor - 1.0).abs() < 1e-6);
assert!((params.feedback_intensity - 1.0).abs() < 1e-6);
assert!((params.pace_speed - 1.0).abs() < 1e-6);
assert!((params.reward_scarcity - 0.5).abs() < 1e-6);
assert!((params.env_arousal - 0.5).abs() < 1e-6);
assert!((params.rhythm_modulation - 1.0).abs() < 1e-6);
assert!((params.challenge_curve - 0.0).abs() < 1e-6);
}
#[test]
fn test_interaction_params_clamp() {
let mut params = InteractionParams {
intensity_factor: 3.0,
feedback_intensity: 2.0,
pace_speed: 0.0,
reward_scarcity: -0.5,
env_arousal: 1.5,
rhythm_modulation: 0.0,
challenge_curve: 2.0,
};
params.clamp();
assert!((params.intensity_factor - 2.0).abs() < 1e-6);
assert!((params.feedback_intensity - 1.5).abs() < 1e-6);
assert!((params.pace_speed - 0.6).abs() < 1e-6);
assert!((params.reward_scarcity - 0.0).abs() < 1e-6);
assert!((params.env_arousal - 1.0).abs() < 1e-6);
assert!((params.rhythm_modulation - 0.8).abs() < 1e-6);
assert!((params.challenge_curve - 1.0).abs() < 1e-6);
}
#[test]
fn test_interaction_params_to_vec() {
let params = InteractionParams::default();
let vec = params.to_vec();
assert_eq!(vec.len(), 7);
}
#[test]
fn test_emotion_stats_default() {
let stats = EmotionStats::default();
assert!((stats.valence_mean - 0.5).abs() < 1e-6);
assert!((stats.valence_std - 0.1).abs() < 1e-6);
}
#[test]
fn test_trajectory() {
let step = TrajectoryStep {
state: InteractionState {
user_traits: HashMap::new(),
env_state: HashMap::new(),
#[cfg(feature = "edm")]
emotion: crate::edm::core::EmotionState::default(),
emotion_stats: EmotionStats::default(),
},
action: InteractionParams::default(),
reward: 1.0,
};
let trajectory = Trajectory {
steps: vec![step],
};
assert_eq!(trajectory.steps.len(), 1);
}
}