use std::collections::HashMap;
use std::path::Path;
use thiserror::Error;
#[derive(Debug, Clone, Copy, Default)]
pub struct EmotionState {
pub valence: f32,
pub arousal: f32,
pub dominance: f32,
}
impl EmotionState {
pub fn new(valence: f32, arousal: f32, dominance: f32) -> Self {
Self {
valence: valence.clamp(0.0, 1.0),
arousal: arousal.clamp(0.0, 1.0),
dominance: dominance.clamp(0.0, 1.0),
}
}
pub fn to_vec(&self) -> Vec<f32> {
vec![self.valence, self.arousal, self.dominance]
}
}
#[derive(Error, Debug)]
pub enum EdmError {
#[error("Feature missing: {0}")]
FeatureMissing(u32),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Model error: {0}")]
ModelError(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
}
pub trait EmotionDataModel: Send + Sync {
fn infer(&self, features: &HashMap<u32, f32>) -> Result<EmotionState, EdmError>;
fn load(&mut self, path: &Path) -> Result<(), EdmError>;
}
#[derive(Debug, Clone)]
pub struct TrainingSample {
pub features: HashMap<u32, f32>,
pub emotion: EmotionState,
}
#[derive(Debug, Clone)]
pub struct TrainingDataset {
pub samples: Vec<TrainingSample>,
}
impl TrainingDataset {
pub fn new(samples: Vec<TrainingSample>) -> Self {
Self { samples }
}
pub fn samples(&self) -> &[TrainingSample] {
&self.samples
}
}
#[derive(Debug, Clone)]
pub struct TrainingResult {
pub final_loss: f32,
pub epochs: usize,
pub valence_r: f32,
pub arousal_r: f32,
pub dominance_r: f32,
pub valence_rmse: f32,
pub arousal_rmse: f32,
pub dominance_rmse: f32,
}
impl Default for TrainingResult {
fn default() -> Self {
Self {
final_loss: 0.0,
epochs: 0,
valence_r: 0.0,
arousal_r: 0.0,
dominance_r: 0.0,
valence_rmse: 0.0,
arousal_rmse: 0.0,
dominance_rmse: 0.0,
}
}
}
pub trait EmotionDataModelTrainer: Send + Sync {
fn train(&mut self, dataset: &TrainingDataset) -> Result<TrainingResult, EdmError>;
fn save(&self, path: &Path) -> Result<(), EdmError>;
fn load(&mut self, path: &Path) -> Result<(), EdmError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_emotion_state_new() {
let state = EmotionState::new(0.5, 0.7, 0.3);
assert!((state.valence - 0.5).abs() < 1e-6);
assert!((state.arousal - 0.7).abs() < 1e-6);
assert!((state.dominance - 0.3).abs() < 1e-6);
}
#[test]
fn test_emotion_state_clamp() {
let state = EmotionState::new(-0.5, 1.5, 2.0);
assert!((state.valence - 0.0).abs() < 1e-6);
assert!((state.arousal - 1.0).abs() < 1e-6);
assert!((state.dominance - 1.0).abs() < 1e-6);
}
#[test]
fn test_emotion_state_to_vec() {
let state = EmotionState::new(0.5, 0.7, 0.3);
let vec = state.to_vec();
assert_eq!(vec.len(), 3);
assert!((vec[0] - 0.5).abs() < 1e-6);
assert!((vec[1] - 0.7).abs() < 1e-6);
assert!((vec[2] - 0.3).abs() < 1e-6);
}
#[test]
fn test_training_dataset() {
let mut features = HashMap::new();
features.insert(0, 0.5);
let sample = TrainingSample {
features: features.clone(),
emotion: EmotionState::new(0.5, 0.5, 0.5),
};
let dataset = TrainingDataset {
samples: vec![sample],
};
assert_eq!(dataset.samples.len(), 1);
}
}