use crate::edm::core::{EdmError, EmotionDataModel, EmotionState, EmotionDataModelTrainer, TrainingDataset, TrainingResult};
use crate::edm::features::{FeatureExtractor, FEATURE_COUNT};
use candle_core::{Device, Tensor, DType};
use std::collections::HashMap;
use std::path::Path;
const HISTORY_STEPS: usize = 5;
const CONV_FILTERS: usize = 32;
const LSTM_HIDDEN: usize = 64;
const MLP_HIDDEN: usize = 64;
const SE_REDUCTION: usize = 8;
fn relu(x: &Tensor) -> candle_core::Result<Tensor> {
x.maximum(&x.zeros_like()?)
}
fn sigmoid(x: &Tensor) -> candle_core::Result<Tensor> {
(x.neg()?.exp()? + 1.0)?.recip()
}
fn softmax(x: &Tensor, dim: usize) -> candle_core::Result<Tensor> {
let max = x.max_keepdim(dim)?;
let exp = (x - &max)?.exp()?;
let sum = exp.sum_keepdim(dim)?;
exp.div(&sum)
}
pub struct RogueliteEdm {
pub(crate) device: Device,
pub(crate) conv1_weight: Tensor,
pub(crate) conv1_bias: Tensor,
pub(crate) se_fc1_weight: Tensor,
pub(crate) se_fc1_bias: Tensor,
pub(crate) se_fc2_weight: Tensor,
pub(crate) se_fc2_bias: Tensor,
pub(crate) lstm_weight_ih: Tensor,
pub(crate) lstm_weight_hh: Tensor,
pub(crate) lstm_bias: Tensor,
pub(crate) attn_query_weight: Tensor,
pub(crate) attn_key_weight: Tensor,
pub(crate) attn_value_weight: Tensor,
pub(crate) fc1_weight: Tensor,
pub(crate) fc1_bias: Tensor,
pub(crate) fc2_weight: Tensor,
pub(crate) fc2_bias: Tensor,
}
impl RogueliteEdm {
pub fn new(device: Device) -> Result<Self, EdmError> {
let conv1_weight = Tensor::zeros((CONV_FILTERS, FEATURE_COUNT, 3), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let conv1_bias = Tensor::zeros(CONV_FILTERS, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let se_fc1_weight = Tensor::zeros((CONV_FILTERS / SE_REDUCTION, CONV_FILTERS), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let se_fc1_bias = Tensor::zeros(CONV_FILTERS / SE_REDUCTION, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let se_fc2_weight = Tensor::zeros((CONV_FILTERS, CONV_FILTERS / SE_REDUCTION), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let se_fc2_bias = Tensor::zeros(CONV_FILTERS, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let lstm_weight_ih = Tensor::zeros((4 * LSTM_HIDDEN, CONV_FILTERS), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let lstm_weight_hh = Tensor::zeros((4 * LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let lstm_bias = Tensor::zeros(4 * LSTM_HIDDEN, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let attn_query_weight = Tensor::zeros((LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let attn_key_weight = Tensor::zeros((LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let attn_value_weight = Tensor::zeros((LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let fc1_weight = Tensor::zeros((MLP_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let fc1_bias = Tensor::zeros(MLP_HIDDEN, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let fc2_weight = Tensor::zeros((3, MLP_HIDDEN), DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let fc2_bias = Tensor::zeros(3, DType::F32, &device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
Ok(Self {
device,
conv1_weight,
conv1_bias,
se_fc1_weight,
se_fc1_bias,
se_fc2_weight,
se_fc2_bias,
lstm_weight_ih,
lstm_weight_hh,
lstm_bias,
attn_query_weight,
attn_key_weight,
attn_value_weight,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias,
})
}
fn conv1d_forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
let (batch, in_ch, seq_len) = input.dims3()?;
let kernel_size = 3;
let input_padded = input.pad_with_zeros(2, 1, 1)?;
let mut unfolded = Vec::new();
for i in 0..seq_len {
let window = input_padded.narrow(2, i, kernel_size)?;
unfolded.push(window);
}
let unfolded = Tensor::stack(&unfolded, 2)?;
let unfolded = unfolded.reshape((batch * seq_len, in_ch * kernel_size))?;
let weight = self.conv1_weight.reshape((CONV_FILTERS, in_ch * kernel_size))?;
let weight_t = weight.t()?;
let output = unfolded.matmul(&weight_t)?;
let output = output.reshape((batch, seq_len, CONV_FILTERS))?;
let output = output.permute((0, 2, 1))?;
let bias = self.conv1_bias.reshape((1, CONV_FILTERS, 1))?;
output.broadcast_add(&bias)
}
fn se_attention(&self, input: &Tensor) -> candle_core::Result<Tensor> {
let (batch, channels, _seq_len) = input.dims3()?;
let squeeze = input.mean(2)?;
let excitation = squeeze.matmul(&self.se_fc1_weight.t()?)?;
let excitation = excitation.broadcast_add(&self.se_fc1_bias)?;
let excitation = relu(&excitation)?;
let excitation = excitation.matmul(&self.se_fc2_weight.t()?)?;
let excitation = excitation.broadcast_add(&self.se_fc2_bias)?;
let scale = sigmoid(&excitation)?;
let scale = scale.reshape((batch, channels, 1))?;
input.broadcast_mul(&scale)
}
fn lstm_forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
let (seq_len, batch, _) = input.dims3()?;
let mut h = Tensor::zeros((batch, LSTM_HIDDEN), DType::F32, &self.device)?;
let mut c = Tensor::zeros((batch, LSTM_HIDDEN), DType::F32, &self.device)?;
let mut outputs = Vec::new();
for t in 0..seq_len {
let x_t = input.narrow(0, t, 1)?.squeeze(0)?;
let gates = x_t.matmul(&self.lstm_weight_ih.t()?)?
.broadcast_add(&h.matmul(&self.lstm_weight_hh.t()?)?)?
.broadcast_add(&self.lstm_bias.reshape((1, 4 * LSTM_HIDDEN))?)?;
let gates = gates.chunk(4, 1)?;
let i = sigmoid(&gates[0])?;
let f = sigmoid(&gates[1])?;
let g = gates[2].tanh()?;
let o = sigmoid(&gates[3])?;
c = ((&f * &c)? + (&i * &g)?)?;
h = (&o * &c.tanh()?)?;
outputs.push(h.clone());
}
Tensor::stack(&outputs, 0)
}
fn self_attention(&self, input: &Tensor) -> candle_core::Result<Tensor> {
let (seq_len, batch, hidden) = input.dims3()?;
let input_flat = input.reshape((batch * seq_len, hidden))?;
let q = input_flat.matmul(&self.attn_query_weight)?;
let k = input_flat.matmul(&self.attn_key_weight)?;
let v = input_flat.matmul(&self.attn_value_weight)?;
let q = q.reshape((batch, seq_len, hidden))?;
let k = k.reshape((batch, seq_len, hidden))?;
let v = v.reshape((batch, seq_len, hidden))?;
let scale = 1.0 / (hidden as f64).sqrt();
let scores = q.matmul(&k.transpose(1, 2)?)?;
let scores = (scores * scale)?;
let attn_weights = softmax(&scores, 2)?;
let output = attn_weights.matmul(&v)?;
output.reshape((seq_len, batch, hidden))
}
fn forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
let x = self.conv1d_forward(input)?;
let x = relu(&x)?;
let x = self.se_attention(&x)?;
let x = x.permute((2, 0, 1))?;
let lstm_out = self.lstm_forward(&x)?;
let attn_out = self.self_attention(&lstm_out)?;
let hidden = attn_out.mean(0)?;
let x = hidden.matmul(&self.fc1_weight.t()?)?;
let x = x.broadcast_add(&self.fc1_bias)?;
let x = relu(&x)?;
let x = x.matmul(&self.fc2_weight.t()?)?;
let x = x.broadcast_add(&self.fc2_bias)?;
sigmoid(&x)
}
}
impl EmotionDataModel for RogueliteEdm {
fn infer(&self, features: &HashMap<u32, f32>) -> Result<EmotionState, EdmError> {
let feature_vec = FeatureExtractor::to_vector(features);
let input = Tensor::from_vec(feature_vec, (1, FEATURE_COUNT, 1), &self.device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let output = self.forward(&input)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let output = output.squeeze(0)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
let values: Vec<f32> = output.to_vec1()
.map_err(|e| EdmError::ModelError(e.to_string()))?;
Ok(EmotionState::new(values[0], values[1], values[2]))
}
fn load(&mut self, path: &Path) -> Result<(), EdmError> {
let weights = candle_core::safetensors::load(path, &self.device)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
self.conv1_weight = weights.get("conv1_weight")
.ok_or_else(|| EdmError::ModelError("conv1_weight not found".into()))?
.clone();
self.conv1_bias = weights.get("conv1_bias")
.ok_or_else(|| EdmError::ModelError("conv1_bias not found".into()))?
.clone();
self.se_fc1_weight = weights.get("se_fc1_weight")
.ok_or_else(|| EdmError::ModelError("se_fc1_weight not found".into()))?
.clone();
self.se_fc1_bias = weights.get("se_fc1_bias")
.ok_or_else(|| EdmError::ModelError("se_fc1_bias not found".into()))?
.clone();
self.se_fc2_weight = weights.get("se_fc2_weight")
.ok_or_else(|| EdmError::ModelError("se_fc2_weight not found".into()))?
.clone();
self.se_fc2_bias = weights.get("se_fc2_bias")
.ok_or_else(|| EdmError::ModelError("se_fc2_bias not found".into()))?
.clone();
self.lstm_weight_ih = weights.get("lstm_weight_ih")
.ok_or_else(|| EdmError::ModelError("lstm_weight_ih not found".into()))?
.clone();
self.lstm_weight_hh = weights.get("lstm_weight_hh")
.ok_or_else(|| EdmError::ModelError("lstm_weight_hh not found".into()))?
.clone();
self.lstm_bias = weights.get("lstm_bias")
.ok_or_else(|| EdmError::ModelError("lstm_bias not found".into()))?
.clone();
self.attn_query_weight = weights.get("attn_query_weight")
.ok_or_else(|| EdmError::ModelError("attn_query_weight not found".into()))?
.clone();
self.attn_key_weight = weights.get("attn_key_weight")
.ok_or_else(|| EdmError::ModelError("attn_key_weight not found".into()))?
.clone();
self.attn_value_weight = weights.get("attn_value_weight")
.ok_or_else(|| EdmError::ModelError("attn_value_weight not found".into()))?
.clone();
self.fc1_weight = weights.get("fc1_weight")
.ok_or_else(|| EdmError::ModelError("fc1_weight not found".into()))?
.clone();
self.fc1_bias = weights.get("fc1_bias")
.ok_or_else(|| EdmError::ModelError("fc1_bias not found".into()))?
.clone();
self.fc2_weight = weights.get("fc2_weight")
.ok_or_else(|| EdmError::ModelError("fc2_weight not found".into()))?
.clone();
self.fc2_bias = weights.get("fc2_bias")
.ok_or_else(|| EdmError::ModelError("fc2_bias not found".into()))?
.clone();
Ok(())
}
}
impl EmotionDataModelTrainer for RogueliteEdm {
fn train(&mut self, _dataset: &TrainingDataset) -> Result<TrainingResult, EdmError> {
todo!("Training implementation")
}
fn save(&self, path: &Path) -> Result<(), EdmError> {
let weights = HashMap::from([
("conv1_weight", self.conv1_weight.clone()),
("conv1_bias", self.conv1_bias.clone()),
("se_fc1_weight", self.se_fc1_weight.clone()),
("se_fc1_bias", self.se_fc1_bias.clone()),
("se_fc2_weight", self.se_fc2_weight.clone()),
("se_fc2_bias", self.se_fc2_bias.clone()),
("lstm_weight_ih", self.lstm_weight_ih.clone()),
("lstm_weight_hh", self.lstm_weight_hh.clone()),
("lstm_bias", self.lstm_bias.clone()),
("attn_query_weight", self.attn_query_weight.clone()),
("attn_key_weight", self.attn_key_weight.clone()),
("attn_value_weight", self.attn_value_weight.clone()),
("fc1_weight", self.fc1_weight.clone()),
("fc1_bias", self.fc1_bias.clone()),
("fc2_weight", self.fc2_weight.clone()),
("fc2_bias", self.fc2_bias.clone()),
]);
candle_core::safetensors::save(&weights, path)
.map_err(|e| EdmError::ModelError(e.to_string()))?;
Ok(())
}
fn load(&mut self, path: &Path) -> Result<(), EdmError> {
EmotionDataModel::load(self, path)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_edm_new() {
let device = Device::Cpu;
let model = RogueliteEdm::new(device);
assert!(model.is_ok());
}
#[test]
fn test_edm_infer() {
let device = Device::Cpu;
let model = RogueliteEdm::new(device).unwrap();
let mut features = HashMap::new();
for i in 0u32..15 {
features.insert(i, 0.5);
}
let result = model.infer(&features);
if let Err(ref e) = result {
eprintln!("Error: {:?}", e);
}
assert!(result.is_ok());
let state = result.unwrap();
assert!(state.valence >= 0.0 && state.valence <= 1.0);
assert!(state.arousal >= 0.0 && state.arousal <= 1.0);
assert!(state.dominance >= 0.0 && state.dominance <= 1.0);
}
#[test]
fn test_edm_save_load() {
let device = Device::Cpu;
let model = RogueliteEdm::new(device.clone()).unwrap();
let mut model_dir = std::env::current_dir().unwrap();
model_dir.push("models");
if !model_dir.exists() {
std::fs::create_dir_all(&model_dir).unwrap();
}
let model_path = model_dir.join("test_edm_model.safetensors");
let save_result = model.save(&model_path);
assert!(save_result.is_ok());
let mut model2 = RogueliteEdm::new(device).unwrap();
let load_result = EmotionDataModel::load(&mut model2, &model_path);
assert!(load_result.is_ok());
}
}