use crate::edm::core::{EdmError, EmotionDataModel, EmotionDataModelTrainer, TrainingDataset, TrainingResult};
use crate::edm::roguelite::core::RogueliteEdm;
use crate::utils::math::{mse, concordance_correlation_coefficient};
use candle_core::{Device, Tensor, DType, Var};
use candle_nn::{AdamW, Optimizer, ParamsAdamW};
use std::collections::HashMap;
use std::path::Path;
const LEARNING_RATE: f64 = 1e-3;
const BATCH_SIZE: usize = 32;
const EPOCHS: usize = 100;
const HISTORY_STEPS: usize = 5;
const CONV_FILTERS: usize = 32;
const LSTM_HIDDEN: usize = 64;
const MLP_HIDDEN: usize = 64;
const SE_REDUCTION: usize = 8;
fn relu(x: &Tensor) -> candle_core::Result<Tensor> {
x.maximum(&x.zeros_like()?)
}
fn sigmoid(x: &Tensor) -> candle_core::Result<Tensor> {
(x.neg()?.exp()? + 1.0)?.recip()
}
fn softmax(x: &Tensor, dim: usize) -> candle_core::Result<Tensor> {
let max = x.max_keepdim(dim)?;
let exp = (x - &max)?.exp()?;
let sum = exp.sum_keepdim(dim)?;
exp.div(&sum)
}
pub struct RogueliteEdmTrainer {
device: Device,
conv1_weight: Var,
conv1_bias: Var,
se_fc1_weight: Var,
se_fc1_bias: Var,
se_fc2_weight: Var,
se_fc2_bias: Var,
lstm_weight_ih: Var,
lstm_weight_hh: Var,
lstm_bias: Var,
attn_query_weight: Var,
attn_key_weight: Var,
attn_value_weight: Var,
fc1_weight: Var,
fc1_bias: Var,
fc2_weight: Var,
fc2_bias: Var,
}
impl RogueliteEdmTrainer {
pub fn new(device: Device) -> Result<Self, EdmError> {
let conv1_weight = Var::zeros((CONV_FILTERS, 15, 3), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let conv1_bias = Var::zeros(CONV_FILTERS, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let se_fc1_weight = Var::zeros((CONV_FILTERS / SE_REDUCTION, CONV_FILTERS), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let se_fc1_bias = Var::zeros(CONV_FILTERS / SE_REDUCTION, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let se_fc2_weight = Var::zeros((CONV_FILTERS, CONV_FILTERS / SE_REDUCTION), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let se_fc2_bias = Var::zeros(CONV_FILTERS, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let lstm_weight_ih = Var::zeros((4 * LSTM_HIDDEN, CONV_FILTERS), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let lstm_weight_hh = Var::zeros((4 * LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let lstm_bias = Var::zeros(4 * LSTM_HIDDEN, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let attn_query_weight = Var::zeros((LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let attn_key_weight = Var::zeros((LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let attn_value_weight = Var::zeros((LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let fc1_weight = Var::zeros((MLP_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let fc1_bias = Var::zeros(MLP_HIDDEN, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let fc2_weight = Var::zeros((3, MLP_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let fc2_bias = Var::zeros(3, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
Ok(Self {
device,
conv1_weight,
conv1_bias,
se_fc1_weight,
se_fc1_bias,
se_fc2_weight,
se_fc2_bias,
lstm_weight_ih,
lstm_weight_hh,
lstm_bias,
attn_query_weight,
attn_key_weight,
attn_value_weight,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias,
})
}
pub fn compute_loss(
predicted: &Tensor,
target: &Tensor,
) -> candle_core::Result<Tensor> {
let pred_flat = predicted.flatten_all()?;
let target_flat = target.flatten_all()?;
let pred_vals = pred_flat.to_vec1::<f32>()?;
let target_vals = target_flat.to_vec1::<f32>()?;
let valence_mse = (pred_flat.narrow(0, 0, pred_vals.len() / 3)? - target_flat.narrow(0, 0, target_vals.len() / 3)?)?
.sqr()?.mean_all()?;
let arousal_mse = (pred_flat.narrow(0, pred_vals.len() / 3, pred_vals.len() / 3)? - target_flat.narrow(0, target_vals.len() / 3, target_vals.len() / 3)?)?
.sqr()?.mean_all()?;
let dominance_mse = (pred_flat.narrow(0, 2 * pred_vals.len() / 3, pred_vals.len() / 3)? - target_flat.narrow(0, 2 * target_vals.len() / 3, target_vals.len() / 3)?)?
.sqr()?.mean_all()?;
let ccc = concordance_correlation_coefficient(&pred_vals, &target_vals);
let ccc_tensor = Tensor::new(1.0 - ccc, predicted.device())?;
(valence_mse * 0.4)? + (arousal_mse * 0.4)? + (dominance_mse * 0.2)? + (ccc_tensor * 0.2)?
}
fn conv1d_forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
let (batch, in_ch, seq_len) = input.dims3()?;
let kernel_size = 3;
let input_padded = input.pad_with_zeros(2, 1, 1)?;
let mut unfolded = Vec::new();
for i in 0..seq_len {
let window = input_padded.narrow(2, i, kernel_size)?;
unfolded.push(window);
}
let unfolded = Tensor::stack(&unfolded, 2)?;
let unfolded = unfolded.reshape((batch * seq_len, in_ch * kernel_size))?;
let weight = self.conv1_weight.reshape((CONV_FILTERS, in_ch * kernel_size))?;
let weight_t = weight.t()?;
let output = unfolded.matmul(&weight_t)?;
let output = output.reshape((batch, seq_len, CONV_FILTERS))?;
let output = output.permute((0, 2, 1))?;
let bias = self.conv1_bias.reshape((1, CONV_FILTERS, 1))?;
output.broadcast_add(&bias)
}
fn se_attention(&self, input: &Tensor) -> candle_core::Result<Tensor> {
let (batch, channels, _seq_len) = input.dims3()?;
let squeeze = input.mean(2)?;
let excitation = squeeze.matmul(&self.se_fc1_weight.t()?)?;
let excitation = excitation.broadcast_add(&self.se_fc1_bias)?;
let excitation = relu(&excitation)?;
let excitation = excitation.matmul(&self.se_fc2_weight.t()?)?;
let excitation = excitation.broadcast_add(&self.se_fc2_bias)?;
let scale = sigmoid(&excitation)?;
let scale = scale.reshape((batch, channels, 1))?;
input.broadcast_mul(&scale)
}
fn lstm_forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
let (seq_len, batch, _) = input.dims3()?;
let mut h = Tensor::zeros((batch, LSTM_HIDDEN), DType::F32, &self.device)?;
let mut c = Tensor::zeros((batch, LSTM_HIDDEN), DType::F32, &self.device)?;
let mut outputs = Vec::new();
for t in 0..seq_len {
let x_t = input.narrow(0, t, 1)?.squeeze(0)?;
let gates = x_t.matmul(&self.lstm_weight_ih.t()?)?
.broadcast_add(&h.matmul(&self.lstm_weight_hh.t()?)?)?
.broadcast_add(&self.lstm_bias.reshape((1, 4 * LSTM_HIDDEN))?)?;
let gates = gates.chunk(4, 1)?;
let i = sigmoid(&gates[0])?;
let f = sigmoid(&gates[1])?;
let g = gates[2].tanh()?;
let o = sigmoid(&gates[3])?;
c = ((&f * &c)? + (&i * &g)?)?;
h = (&o * &c.tanh()?)?;
outputs.push(h.clone());
}
Tensor::stack(&outputs, 0)
}
fn self_attention(&self, input: &Tensor) -> candle_core::Result<Tensor> {
let (seq_len, batch, hidden) = input.dims3()?;
let input_flat = input.reshape((batch * seq_len, hidden))?;
let q = input_flat.matmul(&self.attn_query_weight)?;
let k = input_flat.matmul(&self.attn_key_weight)?;
let v = input_flat.matmul(&self.attn_value_weight)?;
let q = q.reshape((batch, seq_len, hidden))?;
let k = k.reshape((batch, seq_len, hidden))?;
let v = v.reshape((batch, seq_len, hidden))?;
let scale = 1.0 / (hidden as f64).sqrt();
let scores = q.matmul(&k.transpose(1, 2)?)?;
let scores = (scores * scale)?;
let attn_weights = softmax(&scores, 2)?;
let output = attn_weights.matmul(&v)?;
output.reshape((seq_len, batch, hidden))
}
fn forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
let x = self.conv1d_forward(input)?;
let x = relu(&x)?;
let x = self.se_attention(&x)?;
let x = x.permute((2, 0, 1))?;
let lstm_out = self.lstm_forward(&x)?;
let attn_out = self.self_attention(&lstm_out)?;
let hidden = attn_out.mean(0)?;
let x = hidden.matmul(&self.fc1_weight.t()?)?;
let x = x.broadcast_add(&self.fc1_bias)?;
let x = relu(&x)?;
let x = x.matmul(&self.fc2_weight.t()?)?;
let x = x.broadcast_add(&self.fc2_bias)?;
sigmoid(&x)
}
pub fn train_epoch(&mut self, dataset: &TrainingDataset) -> Result<f32, EdmError> {
let params = ParamsAdamW {
lr: LEARNING_RATE,
..Default::default()
};
let mut optimizer = AdamW::new(
vec![
self.conv1_weight.clone(), self.conv1_bias.clone(),
self.se_fc1_weight.clone(), self.se_fc1_bias.clone(),
self.se_fc2_weight.clone(), self.se_fc2_bias.clone(),
self.lstm_weight_ih.clone(), self.lstm_weight_hh.clone(), self.lstm_bias.clone(),
self.attn_query_weight.clone(), self.attn_key_weight.clone(), self.attn_value_weight.clone(),
self.fc1_weight.clone(), self.fc1_bias.clone(), self.fc2_weight.clone(), self.fc2_bias.clone(),
],
params,
).map_err(|e| EdmError::ModelError(e.to_string()))?;
let samples = dataset.samples();
let mut total_loss = 0.0f32;
let mut num_batches = 0;
for chunk in samples.chunks(BATCH_SIZE) {
let mut batch_inputs = Vec::new();
let mut batch_targets = Vec::new();
for sample in chunk {
let input_vec: Vec<f32> = sample.features.values()
.copied()
.chain(std::iter::repeat(0.0).take(15))
.take(15)
.collect();
batch_inputs.push(input_vec);
batch_targets.push(vec![
sample.emotion.valence,
sample.emotion.arousal,
sample.emotion.dominance,
]);
}
let input = Tensor::from_vec(
batch_inputs.concat(),
(chunk.len(), 15, 1),
&self.device,
).map_err(|e| EdmError::ModelError(e.to_string()))?;
let target = Tensor::from_vec(
batch_targets.concat(),
(chunk.len(), 3),
&self.device,
).map_err(|e| EdmError::ModelError(e.to_string()))?;
let output = self.forward(&input)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let target_flat = target.flatten_all()
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let loss = Self::compute_loss(&output, &target_flat)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let loss_val = loss.to_scalar::<f32>()
.map_err(|e| EdmError::ModelError(e.to_string()))?;
total_loss += loss_val;
num_batches += 1;
optimizer.backward_step(&loss)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
}
Ok(total_loss / num_batches.max(1) as f32)
}
pub fn to_model(&self) -> RogueliteEdm {
RogueliteEdm {
device: self.device.clone(),
conv1_weight: self.conv1_weight.as_tensor().clone(),
conv1_bias: self.conv1_bias.as_tensor().clone(),
se_fc1_weight: self.se_fc1_weight.as_tensor().clone(),
se_fc1_bias: self.se_fc1_bias.as_tensor().clone(),
se_fc2_weight: self.se_fc2_weight.as_tensor().clone(),
se_fc2_bias: self.se_fc2_bias.as_tensor().clone(),
lstm_weight_ih: self.lstm_weight_ih.as_tensor().clone(),
lstm_weight_hh: self.lstm_weight_hh.as_tensor().clone(),
lstm_bias: self.lstm_bias.as_tensor().clone(),
attn_query_weight: self.attn_query_weight.as_tensor().clone(),
attn_key_weight: self.attn_key_weight.as_tensor().clone(),
attn_value_weight: self.attn_value_weight.as_tensor().clone(),
fc1_weight: self.fc1_weight.as_tensor().clone(),
fc1_bias: self.fc1_bias.as_tensor().clone(),
fc2_weight: self.fc2_weight.as_tensor().clone(),
fc2_bias: self.fc2_bias.as_tensor().clone(),
}
}
pub fn cross_validate(dataset: &TrainingDataset, k: usize) -> Result<Vec<f32>, EdmError> {
let device = Device::Cpu;
let samples = dataset.samples();
let fold_size = samples.len() / k;
let mut fold_losses = Vec::new();
for fold in 0..k {
let val_start = fold * fold_size;
let val_end = if fold == k - 1 { samples.len() } else { (fold + 1) * fold_size };
let train_samples: Vec<_> = samples.iter()
.enumerate()
.filter(|(i, _)| *i < val_start || *i >= val_end)
.map(|(_, s)| s.clone())
.collect();
let train_dataset = TrainingDataset::new(train_samples);
let mut trainer = Self::new(device.clone())?;
for _ in 0..EPOCHS {
trainer.train_epoch(&train_dataset)?;
}
let val_samples: Vec<_> = samples[val_start..val_end].to_vec();
let mut val_loss = 0.0;
for sample in &val_samples {
let model = trainer.to_model();
let pred = model.infer(&sample.features)?;
let pred_vec = vec![pred.valence, pred.arousal, pred.dominance];
let target_vec = vec![sample.emotion.valence, sample.emotion.arousal, sample.emotion.dominance];
val_loss += Self::compute_loss_scalar(&pred_vec, &target_vec);
}
fold_losses.push(val_loss / val_samples.len() as f32);
}
Ok(fold_losses)
}
fn compute_loss_scalar(predicted: &[f32], target: &[f32]) -> f32 {
let valence_loss = mse(&predicted[0..1], &target[0..1]);
let arousal_loss = mse(&predicted[1..2], &target[1..2]);
let dominance_loss = mse(&predicted[2..3], &target[2..3]);
let ccc = concordance_correlation_coefficient(predicted, target);
0.4 * valence_loss + 0.4 * arousal_loss + 0.2 * dominance_loss + 0.2 * (1.0 - ccc)
}
}
impl EmotionDataModelTrainer for RogueliteEdmTrainer {
fn train(&mut self, dataset: &TrainingDataset) -> Result<TrainingResult, EdmError> {
let mut best_loss = f32::MAX;
for epoch in 0..EPOCHS {
let loss = self.train_epoch(dataset)?;
if loss < best_loss {
best_loss = loss;
}
if epoch % 10 == 0 {
eprintln!("Epoch {}: loss = {:.4}", epoch, loss);
}
}
Ok(TrainingResult {
final_loss: best_loss,
epochs: EPOCHS,
..TrainingResult::default()
})
}
fn save(&self, path: &Path) -> Result<(), EdmError> {
let weights = HashMap::from([
("conv1_weight", self.conv1_weight.as_tensor().clone()),
("conv1_bias", self.conv1_bias.as_tensor().clone()),
("se_fc1_weight", self.se_fc1_weight.as_tensor().clone()),
("se_fc1_bias", self.se_fc1_bias.as_tensor().clone()),
("se_fc2_weight", self.se_fc2_weight.as_tensor().clone()),
("se_fc2_bias", self.se_fc2_bias.as_tensor().clone()),
("lstm_weight_ih", self.lstm_weight_ih.as_tensor().clone()),
("lstm_weight_hh", self.lstm_weight_hh.as_tensor().clone()),
("lstm_bias", self.lstm_bias.as_tensor().clone()),
("attn_query_weight", self.attn_query_weight.as_tensor().clone()),
("attn_key_weight", self.attn_key_weight.as_tensor().clone()),
("attn_value_weight", self.attn_value_weight.as_tensor().clone()),
("fc1_weight", self.fc1_weight.as_tensor().clone()),
("fc1_bias", self.fc1_bias.as_tensor().clone()),
("fc2_weight", self.fc2_weight.as_tensor().clone()),
("fc2_bias", self.fc2_bias.as_tensor().clone()),
]);
candle_core::safetensors::save(&weights, path)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
Ok(())
}
fn load(&mut self, path: &Path) -> Result<(), EdmError> {
let weights = candle_core::safetensors::load(path, &self.device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
self.conv1_weight = Var::from_tensor(
weights.get("conv1_weight")
.ok_or_else(|| EdmError::ModelError("conv1_weight not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.conv1_bias = Var::from_tensor(
weights.get("conv1_bias")
.ok_or_else(|| EdmError::ModelError("conv1_bias not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.se_fc1_weight = Var::from_tensor(
weights.get("se_fc1_weight")
.ok_or_else(|| EdmError::ModelError("se_fc1_weight not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.se_fc1_bias = Var::from_tensor(
weights.get("se_fc1_bias")
.ok_or_else(|| EdmError::ModelError("se_fc1_bias not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.se_fc2_weight = Var::from_tensor(
weights.get("se_fc2_weight")
.ok_or_else(|| EdmError::ModelError("se_fc2_weight not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.se_fc2_bias = Var::from_tensor(
weights.get("se_fc2_bias")
.ok_or_else(|| EdmError::ModelError("se_fc2_bias not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.lstm_weight_ih = Var::from_tensor(
weights.get("lstm_weight_ih")
.ok_or_else(|| EdmError::ModelError("lstm_weight_ih not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.lstm_weight_hh = Var::from_tensor(
weights.get("lstm_weight_hh")
.ok_or_else(|| EdmError::ModelError("lstm_weight_hh not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.lstm_bias = Var::from_tensor(
weights.get("lstm_bias")
.ok_or_else(|| EdmError::ModelError("lstm_bias not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.attn_query_weight = Var::from_tensor(
weights.get("attn_query_weight")
.ok_or_else(|| EdmError::ModelError("attn_query_weight not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.attn_key_weight = Var::from_tensor(
weights.get("attn_key_weight")
.ok_or_else(|| EdmError::ModelError("attn_key_weight not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.attn_value_weight = Var::from_tensor(
weights.get("attn_value_weight")
.ok_or_else(|| EdmError::ModelError("attn_value_weight not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.fc1_weight = Var::from_tensor(
weights.get("fc1_weight")
.ok_or_else(|| EdmError::ModelError("fc1_weight not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.fc1_bias = Var::from_tensor(
weights.get("fc1_bias")
.ok_or_else(|| EdmError::ModelError("fc1_bias not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.fc2_weight = Var::from_tensor(
weights.get("fc2_weight")
.ok_or_else(|| EdmError::ModelError("fc2_weight not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
self.fc2_bias = Var::from_tensor(
weights.get("fc2_bias")
.ok_or_else(|| EdmError::ModelError("fc2_bias not found".into()))?
).map_err(|e| EdmError::ModelError(e.to_string()))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trainer_new() {
let device = Device::Cpu;
let trainer = RogueliteEdmTrainer::new(device);
assert!(trainer.is_ok());
}
#[test]
fn test_compute_loss_scalar() {
let predicted = vec![0.5, 0.5, 0.5];
let target = vec![0.5, 0.5, 0.5];
let loss = RogueliteEdmTrainer::compute_loss_scalar(&predicted, &target);
assert!(loss >= 0.0);
}
#[test]
fn test_to_model() {
let device = Device::Cpu;
let trainer = RogueliteEdmTrainer::new(device).unwrap();
let model = trainer.to_model();
assert!(model.infer(&HashMap::new()).is_ok());
}
}