use crate::director::core::{
DirectorError, InteractionStrategyTrainer, Trajectory, TrainingResult,
};
use crate::director::roguelite::core::RogueliteDirector;
use candle_core::{Device, Tensor, DType, Var};
use candle_nn::{AdamW, Optimizer, ParamsAdamW};
use std::collections::HashMap;
use std::path::Path;
const PPO_CLIP_EPS: f32 = 0.2;
const GAE_LAMBDA: f32 = 0.95;
const ENTROPY_COEF: f32 = 0.01;
const LEARNING_RATE: f64 = 1e-4;
const EPOCHS: usize = 10;
const STATE_DIM: usize = 23;
const ACTION_DIM: usize = 7;
const HIDDEN_DIM: usize = 128;
fn relu(x: &Tensor) -> candle_core::Result<Tensor> {
x.maximum(&x.zeros_like()?)
}
fn sigmoid(x: &Tensor) -> candle_core::Result<Tensor> {
let clamped = x.clamp(-10.0, 10.0)?;
(clamped.neg()?.exp()? + 1.0)?.recip()
}
fn xavier_init(rows: usize, cols: usize, device: &Device) -> candle_core::Result<Var> {
let scale = (2.0 / (rows + cols) as f64).sqrt() as f32;
let data: Vec<f32> = (0..rows * cols)
.map(|i| ((i as f32 * 0.01) % 1.0 - 0.5) * 2.0 * scale)
.collect();
let tensor = Tensor::from_vec(data, (rows, cols), device)?;
Var::from_tensor(&tensor)
}
pub struct RogueliteDirectorTrainer {
device: Device,
actor_fc1_weight: Var,
actor_fc1_bias: Var,
actor_fc2_weight: Var,
actor_fc2_bias: Var,
actor_fc3_weight: Var,
actor_fc3_bias: Var,
critic_fc1_weight: Var,
critic_fc1_bias: Var,
critic_fc2_weight: Var,
critic_fc2_bias: Var,
critic_fc3_weight: Var,
critic_fc3_bias: Var,
}
impl RogueliteDirectorTrainer {
pub fn new(device: Device) -> Result<Self, DirectorError> {
let actor_fc1_weight = xavier_init(HIDDEN_DIM, STATE_DIM, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let actor_fc1_bias = Var::zeros(HIDDEN_DIM, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let actor_fc2_weight = xavier_init(HIDDEN_DIM, HIDDEN_DIM, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let actor_fc2_bias = Var::zeros(HIDDEN_DIM, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let actor_fc3_weight = xavier_init(ACTION_DIM, HIDDEN_DIM, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let actor_fc3_bias = Var::zeros(ACTION_DIM, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc1_weight = xavier_init(HIDDEN_DIM, STATE_DIM, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc1_bias = Var::zeros(HIDDEN_DIM, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc2_weight = xavier_init(HIDDEN_DIM, HIDDEN_DIM, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc2_bias = Var::zeros(HIDDEN_DIM, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc3_weight = xavier_init(1, HIDDEN_DIM, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let critic_fc3_bias = Var::zeros(1, DType::F32, &device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
Ok(Self {
device,
actor_fc1_weight,
actor_fc1_bias,
actor_fc2_weight,
actor_fc2_bias,
actor_fc3_weight,
actor_fc3_bias,
critic_fc1_weight,
critic_fc1_bias,
critic_fc2_weight,
critic_fc2_bias,
critic_fc3_weight,
critic_fc3_bias,
})
}
pub fn compute_gae(rewards: &[f32], values: &[f32], gamma: f32) -> Vec<f32> {
let n = rewards.len();
if n == 0 {
return Vec::new();
}
let mut advantages = Vec::with_capacity(n);
let mut gae = 0.0;
for t in (0..n).rev() {
let next_value = if t + 1 < values.len() { values[t + 1] } else { 0.0 };
let delta = rewards[t] + gamma * next_value - values[t];
gae = delta + gamma * GAE_LAMBDA * gae;
advantages.insert(0, gae);
}
advantages
}
pub fn compute_reward(
progress: f32,
emotion_improvement: f32,
retention: bool,
) -> f32 {
let retention_reward = if retention { 1.0 } else { 0.0 };
0.4 * progress + 0.4 * emotion_improvement + 0.2 * retention_reward
}
pub fn ppo_clip(ratio: f32, advantage: f32) -> f32 {
let clipped = ratio.clamp(1.0 - PPO_CLIP_EPS, 1.0 + PPO_CLIP_EPS);
-(ratio * advantage).min(clipped * advantage)
}
fn actor_forward(&self, state_tensor: &Tensor) -> candle_core::Result<Tensor> {
let x = state_tensor.matmul(&self.actor_fc1_weight.t()?)?;
let x = x.broadcast_add(&self.actor_fc1_bias)?;
let x = relu(&x)?;
let x = x.matmul(&self.actor_fc2_weight.t()?)?;
let x = x.broadcast_add(&self.actor_fc2_bias)?;
let x = relu(&x)?;
let x = x.matmul(&self.actor_fc3_weight.t()?)?;
let x = x.broadcast_add(&self.actor_fc3_bias)?;
sigmoid(&x)
}
fn critic_forward(&self, state_tensor: &Tensor) -> candle_core::Result<Tensor> {
let x = state_tensor.matmul(&self.critic_fc1_weight.t()?)?;
let x = x.broadcast_add(&self.critic_fc1_bias)?;
let x = relu(&x)?;
let x = x.matmul(&self.critic_fc2_weight.t()?)?;
let x = x.broadcast_add(&self.critic_fc2_bias)?;
let x = relu(&x)?;
let x = x.matmul(&self.critic_fc3_weight.t()?)?;
x.broadcast_add(&self.critic_fc3_bias)
}
fn compute_entropy(action_probs: &Tensor) -> candle_core::Result<Tensor> {
let log_probs = action_probs.log()?;
let entropy = (action_probs * log_probs)?.neg()?.sum_all()?;
Ok(entropy)
}
pub fn train_epoch(&mut self, trajectories: &[Trajectory]) -> Result<f32, DirectorError> {
let params = ParamsAdamW {
lr: LEARNING_RATE,
..Default::default()
};
let mut optimizer = AdamW::new(
vec![
self.actor_fc1_weight.clone(), self.actor_fc1_bias.clone(),
self.actor_fc2_weight.clone(), self.actor_fc2_bias.clone(),
self.actor_fc3_weight.clone(), self.actor_fc3_bias.clone(),
self.critic_fc1_weight.clone(), self.critic_fc1_bias.clone(),
self.critic_fc2_weight.clone(), self.critic_fc2_bias.clone(),
self.critic_fc3_weight.clone(), self.critic_fc3_bias.clone(),
],
params,
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
let mut total_loss = 0.0f32;
let mut num_samples = 0;
for trajectory in trajectories {
if trajectory.steps.is_empty() {
continue;
}
let states: Vec<f32> = trajectory.steps.iter()
.flat_map(|step| {
let mut v = Vec::new();
for i in 0..8 {
v.push(*step.state.user_traits.get(&(i as u32)).unwrap_or(&0.5));
}
for i in 0..6 {
v.push(*step.state.env_state.get(&(i as u32)).unwrap_or(&0.5));
}
v.push(step.state.emotion.valence);
v.push(step.state.emotion.arousal);
v.push(step.state.emotion.dominance);
v
})
.collect();
if states.is_empty() {
continue;
}
let state_tensor = Tensor::from_vec(
states.clone(),
(trajectory.steps.len(), STATE_DIM),
&self.device,
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
let values = self.critic_forward(&state_tensor)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let values_flat = values.flatten_all()
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let values_vec: Vec<f32> = values_flat.to_vec1()
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let rewards: Vec<f32> = trajectory.steps.iter()
.map(|s| s.reward)
.collect();
let advantages = Self::compute_gae(&rewards, &values_vec, 0.99);
if advantages.is_empty() {
continue;
}
let action_probs = self.actor_forward(&state_tensor)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let old_action_probs = action_probs.clone();
let ratio = (&action_probs / &old_action_probs)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let ratio_flat = ratio.flatten_all()
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let ratio_vals: Vec<f32> = ratio_flat.to_vec1()
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let mut clip_loss_sum = 0.0f32;
for (i, &r) in ratio_vals.iter().enumerate() {
let adv_idx = i / ACTION_DIM;
let adv = if adv_idx < advantages.len() { advantages[adv_idx] } else { 0.0 };
let clipped = Self::ppo_clip(r, adv);
clip_loss_sum += clipped;
}
let clip_loss = Tensor::new(clip_loss_sum / ratio_vals.len() as f32, &self.device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let entropy = Self::compute_entropy(&action_probs)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let value_targets: Vec<f32> = advantages.iter()
.zip(values_vec.iter())
.map(|(a, v)| a + v)
.collect();
let value_targets_tensor = Tensor::from_vec(
value_targets.clone(),
value_targets.len(),
&self.device,
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
let values_flat = values.flatten_all()
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let value_diff = (values_flat - value_targets_tensor)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let value_loss = value_diff
.sqr().map_err(|e| DirectorError::ModelError(e.to_string()))?
.mean_all().map_err(|e| DirectorError::ModelError(e.to_string()))?;
let entropy_coef_tensor = Tensor::new(ENTROPY_COEF, &self.device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let entropy_scaled = (&entropy * entropy_coef_tensor)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let loss = (clip_loss + value_loss - entropy_scaled)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
let loss_val = loss.to_scalar::<f32>()
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
if loss_val.is_nan() || loss_val.is_infinite() {
continue;
}
total_loss += loss_val;
num_samples += 1;
optimizer.backward_step(&loss)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
}
Ok(total_loss / num_samples.max(1) as f32)
}
pub fn to_model(&self) -> RogueliteDirector {
RogueliteDirector {
device: self.device.clone(),
actor_fc1_weight: self.actor_fc1_weight.as_tensor().clone(),
actor_fc1_bias: self.actor_fc1_bias.as_tensor().clone(),
actor_fc2_weight: self.actor_fc2_weight.as_tensor().clone(),
actor_fc2_bias: self.actor_fc2_bias.as_tensor().clone(),
actor_fc3_weight: self.actor_fc3_weight.as_tensor().clone(),
actor_fc3_bias: self.actor_fc3_bias.as_tensor().clone(),
critic_fc1_weight: self.critic_fc1_weight.as_tensor().clone(),
critic_fc1_bias: self.critic_fc1_bias.as_tensor().clone(),
critic_fc2_weight: self.critic_fc2_weight.as_tensor().clone(),
critic_fc2_bias: self.critic_fc2_bias.as_tensor().clone(),
critic_fc3_weight: self.critic_fc3_weight.as_tensor().clone(),
critic_fc3_bias: self.critic_fc3_bias.as_tensor().clone(),
}
}
}
impl InteractionStrategyTrainer for RogueliteDirectorTrainer {
fn train(&mut self, trajectories: &[Trajectory]) -> Result<TrainingResult, DirectorError> {
let mut best_loss = f32::MAX;
for epoch in 0..EPOCHS {
let loss = self.train_epoch(trajectories)?;
if loss < best_loss {
best_loss = loss;
}
if epoch % 5 == 0 {
eprintln!("Epoch {}: loss = {:.4}", epoch, loss);
}
}
Ok(TrainingResult {
mean_reward: -best_loss,
episodes: trajectories.len(),
})
}
fn save(&self, path: &Path) -> Result<(), DirectorError> {
let weights = HashMap::from([
("actor_fc1_weight", self.actor_fc1_weight.as_tensor().clone()),
("actor_fc1_bias", self.actor_fc1_bias.as_tensor().clone()),
("actor_fc2_weight", self.actor_fc2_weight.as_tensor().clone()),
("actor_fc2_bias", self.actor_fc2_bias.as_tensor().clone()),
("actor_fc3_weight", self.actor_fc3_weight.as_tensor().clone()),
("actor_fc3_bias", self.actor_fc3_bias.as_tensor().clone()),
("critic_fc1_weight", self.critic_fc1_weight.as_tensor().clone()),
("critic_fc1_bias", self.critic_fc1_bias.as_tensor().clone()),
("critic_fc2_weight", self.critic_fc2_weight.as_tensor().clone()),
("critic_fc2_bias", self.critic_fc2_bias.as_tensor().clone()),
("critic_fc3_weight", self.critic_fc3_weight.as_tensor().clone()),
("critic_fc3_bias", self.critic_fc3_bias.as_tensor().clone()),
]);
candle_core::safetensors::save(&weights, path)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
Ok(())
}
fn load(&mut self, path: &Path) -> Result<(), DirectorError> {
let weights = candle_core::safetensors::load(path, &self.device)
.map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.actor_fc1_weight = Var::from_tensor(
weights.get("actor_fc1_weight")
.ok_or_else(|| DirectorError::ModelError("actor_fc1_weight not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.actor_fc1_bias = Var::from_tensor(
weights.get("actor_fc1_bias")
.ok_or_else(|| DirectorError::ModelError("actor_fc1_bias not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.actor_fc2_weight = Var::from_tensor(
weights.get("actor_fc2_weight")
.ok_or_else(|| DirectorError::ModelError("actor_fc2_weight not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.actor_fc2_bias = Var::from_tensor(
weights.get("actor_fc2_bias")
.ok_or_else(|| DirectorError::ModelError("actor_fc2_bias not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.actor_fc3_weight = Var::from_tensor(
weights.get("actor_fc3_weight")
.ok_or_else(|| DirectorError::ModelError("actor_fc3_weight not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.actor_fc3_bias = Var::from_tensor(
weights.get("actor_fc3_bias")
.ok_or_else(|| DirectorError::ModelError("actor_fc3_bias not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.critic_fc1_weight = Var::from_tensor(
weights.get("critic_fc1_weight")
.ok_or_else(|| DirectorError::ModelError("critic_fc1_weight not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.critic_fc1_bias = Var::from_tensor(
weights.get("critic_fc1_bias")
.ok_or_else(|| DirectorError::ModelError("critic_fc1_bias not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.critic_fc2_weight = Var::from_tensor(
weights.get("critic_fc2_weight")
.ok_or_else(|| DirectorError::ModelError("critic_fc2_weight not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.critic_fc2_bias = Var::from_tensor(
weights.get("critic_fc2_bias")
.ok_or_else(|| DirectorError::ModelError("critic_fc2_bias not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.critic_fc3_weight = Var::from_tensor(
weights.get("critic_fc3_weight")
.ok_or_else(|| DirectorError::ModelError("critic_fc3_weight not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
self.critic_fc3_bias = Var::from_tensor(
weights.get("critic_fc3_bias")
.ok_or_else(|| DirectorError::ModelError("critic_fc3_bias not found".into()))?
).map_err(|e| DirectorError::ModelError(e.to_string()))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trainer_new() {
let device = Device::Cpu;
let trainer = RogueliteDirectorTrainer::new(device);
assert!(trainer.is_ok());
}
#[test]
fn test_compute_gae() {
let rewards = vec![1.0, 0.5, 0.3, 0.8];
let values = vec![0.5, 0.4, 0.3, 0.6];
let advantages = RogueliteDirectorTrainer::compute_gae(&rewards, &values, 0.99);
assert!(!advantages.is_empty());
}
#[test]
fn test_compute_reward() {
let reward = RogueliteDirectorTrainer::compute_reward(0.5, 0.3, true);
assert!(reward > 0.0);
let reward2 = RogueliteDirectorTrainer::compute_reward(0.5, 0.3, false);
assert!(reward > reward2);
}
#[test]
fn test_ppo_clip() {
let loss = RogueliteDirectorTrainer::ppo_clip(1.0, 1.0);
assert!(loss <= 0.0);
let loss2 = RogueliteDirectorTrainer::ppo_clip(1.5, 1.0);
assert!(loss2 <= 0.0);
}
}