use crate::director::core::{
DirectorError, InteractionParams, InteractionState, InteractionStrategy,
InteractionStrategyTrainer, Trajectory, TrainingResult,
};
use candle_core::{Device, Tensor, DType};
use std::path::Path;
use std::collections::HashMap;
const STATE_DIM: usize = 23;
const ACTION_DIM: usize = 7;
const HIDDEN_DIM: usize = 128;
fn relu(x: &Tensor) -> candle_core::Result<Tensor> {
x.maximum(&x.zeros_like()?)
}
fn sigmoid(x: &Tensor) -> candle_core::Result<Tensor> {
(x.neg()?.exp()? + 1.0)?.recip()
}
pub struct RogueliteDirector {
pub(crate) device: Device,
pub(crate) actor_fc1_weight: Tensor,
pub(crate) actor_fc1_bias: Tensor,
pub(crate) actor_fc2_weight: Tensor,
pub(crate) actor_fc2_bias: Tensor,
pub(crate) actor_fc3_weight: Tensor,
pub(crate) actor_fc3_bias: Tensor,
pub(crate) critic_fc1_weight: Tensor,
pub(crate) critic_fc1_bias: Tensor,
pub(crate) critic_fc2_weight: Tensor,
pub(crate) critic_fc2_bias: Tensor,
pub(crate) critic_fc3_weight: Tensor,
pub(crate) critic_fc3_bias: Tensor,
}
impl RogueliteDirector {
pub fn new(device: Device) -> Result<Self, DirectorError> {
let actor_fc1_weight = Tensor::zeros((HIDDEN_DIM, STATE_DIM), DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let actor_fc1_bias = Tensor::zeros(HIDDEN_DIM, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let actor_fc2_weight = Tensor::zeros((HIDDEN_DIM, HIDDEN_DIM), DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let actor_fc2_bias = Tensor::zeros(HIDDEN_DIM, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let actor_fc3_weight = Tensor::zeros((ACTION_DIM, HIDDEN_DIM), DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let actor_fc3_bias = Tensor::zeros(ACTION_DIM, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc1_weight = Tensor::zeros((HIDDEN_DIM, STATE_DIM), DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc1_bias = Tensor::zeros(HIDDEN_DIM, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc2_weight = Tensor::zeros((HIDDEN_DIM, HIDDEN_DIM), DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc2_bias = Tensor::zeros(HIDDEN_DIM, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc3_weight = Tensor::zeros((1, HIDDEN_DIM), DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc3_bias = Tensor::zeros(1, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
Ok(Self {
device,
actor_fc1_weight,
actor_fc1_bias,
actor_fc2_weight,
actor_fc2_bias,
actor_fc3_weight,
actor_fc3_bias,
critic_fc1_weight,
critic_fc1_bias,
critic_fc2_weight,
critic_fc2_bias,
critic_fc3_weight,
critic_fc3_bias,
})
}
fn encode_state(&self, state: &InteractionState) -> candle_core::Result<Tensor> {
let mut state_vec = Vec::with_capacity(STATE_DIM);
for i in 0..8 {
state_vec.push(*state.user_traits.get(&(i as u32)).unwrap_or(&0.5));
}
for i in 0..6 {
state_vec.push(*state.env_state.get(&(i as u32)).unwrap_or(&0.5));
}
#[cfg(feature = "edm")]
{
state_vec.push(state.emotion.valence);
state_vec.push(state.emotion.arousal);
state_vec.push(state.emotion.dominance);
}
#[cfg(not(feature = "edm"))]
{
state_vec.push(0.5);
state_vec.push(0.5);
state_vec.push(0.5);
}
state_vec.push(state.emotion_stats.valence_mean);
state_vec.push(state.emotion_stats.valence_std);
state_vec.push(state.emotion_stats.arousal_mean);
state_vec.push(state.emotion_stats.arousal_std);
state_vec.push(state.emotion_stats.dominance_mean);
state_vec.push(state.emotion_stats.dominance_std);
Tensor::from_vec(state_vec, (1, STATE_DIM), &self.device)
}
fn actor_forward(&self, state_tensor: &Tensor) -> candle_core::Result<Tensor> {
let x = state_tensor.matmul(&self.actor_fc1_weight.t()?)?;
let x = x.broadcast_add(&self.actor_fc1_bias)?;
let x = relu(&x)?;
let x = x.matmul(&self.actor_fc2_weight.t()?)?;
let x = x.broadcast_add(&self.actor_fc2_bias)?;
let x = relu(&x)?;
let x = x.matmul(&self.actor_fc3_weight.t()?)?;
let x = x.broadcast_add(&self.actor_fc3_bias)?;
sigmoid(&x)
}
fn critic_forward(&self, state_tensor: &Tensor) -> candle_core::Result<Tensor> {
let x = state_tensor.matmul(&self.critic_fc1_weight.t()?)?;
let x = x.broadcast_add(&self.critic_fc1_bias)?;
let x = relu(&x)?;
let x = x.matmul(&self.critic_fc2_weight.t()?)?;
let x = x.broadcast_add(&self.critic_fc2_bias)?;
let x = relu(&x)?;
let x = x.matmul(&self.critic_fc3_weight.t()?)?;
x.broadcast_add(&self.critic_fc3_bias)
}
fn tensor_to_params(&self, tensor: &Tensor) -> candle_core::Result<InteractionParams> {
let tensor = tensor.squeeze(0)?;
let values: Vec<f32> = tensor.to_vec1()?;
Ok(InteractionParams {
intensity_factor: 0.5 + values[0] * 1.5,
feedback_intensity: 0.3 + values[1] * 1.2,
pace_speed: 0.6 + values[2] * 1.2,
reward_scarcity: values[3],
env_arousal: 0.3 + values[4] * 0.7,
rhythm_modulation: 0.8 + values[5] * 0.7,
challenge_curve: values[6] * 2.0 - 1.0,
})
}
}
impl InteractionStrategy for RogueliteDirector {
fn decide(&self, state: &InteractionState) -> Result<InteractionParams, DirectorError> {
let state_tensor = self.encode_state(state)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let action_tensor = self.actor_forward(&state_tensor)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let mut params = self.tensor_to_params(&action_tensor)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
params.clamp();
Ok(params)
}
fn load(&mut self, path: &Path) -> Result<(), DirectorError> {
let weights = candle_core::safetensors::load(path, &self.device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.actor_fc1_weight = weights.get("actor_fc1_weight")
.ok_or_else(|| DirectorError::ModelError("actor_fc1_weight not found".into()))?
.clone();
self.actor_fc1_bias = weights.get("actor_fc1_bias")
.ok_or_else(|| DirectorError::ModelError("actor_fc1_bias not found".into()))?
.clone();
self.actor_fc2_weight = weights.get("actor_fc2_weight")
.ok_or_else(|| DirectorError::ModelError("actor_fc2_weight not found".into()))?
.clone();
self.actor_fc2_bias = weights.get("actor_fc2_bias")
.ok_or_else(|| DirectorError::ModelError("actor_fc2_bias not found".into()))?
.clone();
self.actor_fc3_weight = weights.get("actor_fc3_weight")
.ok_or_else(|| DirectorError::ModelError("actor_fc3_weight not found".into()))?
.clone();
self.actor_fc3_bias = weights.get("actor_fc3_bias")
.ok_or_else(|| DirectorError::ModelError("actor_fc3_bias not found".into()))?
.clone();
self.critic_fc1_weight = weights.get("critic_fc1_weight")
.ok_or_else(|| DirectorError::ModelError("critic_fc1_weight not found".into()))?
.clone();
self.critic_fc1_bias = weights.get("critic_fc1_bias")
.ok_or_else(|| DirectorError::ModelError("critic_fc1_bias not found".into()))?
.clone();
self.critic_fc2_weight = weights.get("critic_fc2_weight")
.ok_or_else(|| DirectorError::ModelError("critic_fc2_weight not found".into()))?
.clone();
self.critic_fc2_bias = weights.get("critic_fc2_bias")
.ok_or_else(|| DirectorError::ModelError("critic_fc2_bias not found".into()))?
.clone();
self.critic_fc3_weight = weights.get("critic_fc3_weight")
.ok_or_else(|| DirectorError::ModelError("critic_fc3_weight not found".into()))?
.clone();
self.critic_fc3_bias = weights.get("critic_fc3_bias")
.ok_or_else(|| DirectorError::ModelError("critic_fc3_bias not found".into()))?
.clone();
Ok(())
}
}
impl InteractionStrategyTrainer for RogueliteDirector {
fn train(&mut self, _trajectories: &[Trajectory]) -> Result<TrainingResult, DirectorError> {
todo!("PPO training implementation")
}
fn save(&self, path: &Path) -> Result<(), DirectorError> {
let weights = HashMap::from([
("actor_fc1_weight", self.actor_fc1_weight.clone()),
("actor_fc1_bias", self.actor_fc1_bias.clone()),
("actor_fc2_weight", self.actor_fc2_weight.clone()),
("actor_fc2_bias", self.actor_fc2_bias.clone()),
("actor_fc3_weight", self.actor_fc3_weight.clone()),
("actor_fc3_bias", self.actor_fc3_bias.clone()),
("critic_fc1_weight", self.critic_fc1_weight.clone()),
("critic_fc1_bias", self.critic_fc1_bias.clone()),
("critic_fc2_weight", self.critic_fc2_weight.clone()),
("critic_fc2_bias", self.critic_fc2_bias.clone()),
("critic_fc3_weight", self.critic_fc3_weight.clone()),
("critic_fc3_bias", self.critic_fc3_bias.clone()),
]);
candle_core::safetensors::save(&weights, path)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
Ok(())
}
fn load(&mut self, path: &Path) -> Result<(), DirectorError> {
InteractionStrategy::load(self, path)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap as StdHashMap;
#[test]
fn test_director_new() {
let device = Device::Cpu;
let model = RogueliteDirector::new(device);
assert!(model.is_ok());
}
#[test]
fn test_director_decide() {
let device = Device::Cpu;
let model = RogueliteDirector::new(device).unwrap();
let state = InteractionState {
user_traits: StdHashMap::from([
(0, 0.5), (1, 0.5), (2, 0.5), (3, 0.5),
(4, 0.5), (5, 0.5), (6, 0.5), (7, 0.5),
]),
env_state: StdHashMap::from([
(0, 0.5), (1, 0.5), (2, 0.5),
(3, 0.5), (4, 0.5), (5, 0.5),
]),
#[cfg(feature = "edm")]
emotion: crate::edm::core::EmotionState::new(0.5, 0.5, 0.5),
emotion_stats: crate::director::core::EmotionStats::default(),
};
let result = model.decide(&state);
if let Err(ref e) = result {
eprintln!("Error: {:?}", e);
}
assert!(result.is_ok());
let params = result.unwrap();
assert!(params.intensity_factor >= 0.5 && params.intensity_factor <= 2.0);
assert!(params.feedback_intensity >= 0.3 && params.feedback_intensity <= 1.5);
assert!(params.pace_speed >= 0.6 && params.pace_speed <= 1.8);
assert!(params.reward_scarcity >= 0.0 && params.reward_scarcity <= 1.0);
assert!(params.env_arousal >= 0.3 && params.env_arousal <= 1.0);
assert!(params.rhythm_modulation >= 0.8 && params.rhythm_modulation <= 1.5);
assert!(params.challenge_curve >= -1.0 && params.challenge_curve <= 1.0);
}
#[test]
fn test_director_save_load() {
let device = Device::Cpu;
let model = RogueliteDirector::new(device.clone()).unwrap();
let mut model_dir = std::env::current_dir().unwrap();
model_dir.push("models");
if !model_dir.exists() {
std::fs::create_dir_all(&model_dir).unwrap();
}
let model_path = model_dir.join("test_director_model.safetensors");
let save_result = model.save(&model_path);
assert!(save_result.is_ok());
let mut model2 = RogueliteDirector::new(device).unwrap();
let load_result = InteractionStrategy::load(&mut model2, &model_path);
assert!(load_result.is_ok());
}
}