use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Emotion {
Exuberant,
Dependent,
Relaxed,
Docile,
Hostile,
Disgust,
Anxious,
Bored,
Depressed,
Neutral,
}
impl Emotion {
#[must_use]
pub const fn all() -> [Emotion; 10] {
[
Emotion::Exuberant,
Emotion::Dependent,
Emotion::Relaxed,
Emotion::Docile,
Emotion::Hostile,
Emotion::Disgust,
Emotion::Anxious,
Emotion::Bored,
Emotion::Depressed,
Emotion::Neutral,
]
}
#[must_use]
pub const fn name(&self) -> &'static str {
match self {
Emotion::Exuberant => "Exuberant",
Emotion::Dependent => "Dependent",
Emotion::Relaxed => "Relaxed",
Emotion::Docile => "Docile",
Emotion::Hostile => "Hostile",
Emotion::Disgust => "Disgust",
Emotion::Anxious => "Anxious",
Emotion::Bored => "Bored",
Emotion::Depressed => "Depressed",
Emotion::Neutral => "Neutral",
}
}
#[must_use]
pub const fn is_positive(&self) -> bool {
matches!(
self,
Emotion::Exuberant | Emotion::Dependent | Emotion::Relaxed | Emotion::Docile
)
}
#[must_use]
pub const fn is_negative(&self) -> bool {
matches!(
self,
Emotion::Hostile
| Emotion::Disgust
| Emotion::Anxious
| Emotion::Bored
| Emotion::Depressed
)
}
#[must_use]
pub const fn is_neutral(&self) -> bool {
matches!(self, Emotion::Neutral)
}
#[must_use]
pub const fn is_high_arousal(&self) -> bool {
matches!(
self,
Emotion::Exuberant
| Emotion::Dependent
| Emotion::Hostile
| Emotion::Disgust
| Emotion::Anxious
)
}
#[must_use]
pub const fn is_high_dominance(&self) -> bool {
matches!(
self,
Emotion::Exuberant
| Emotion::Relaxed
| Emotion::Hostile
| Emotion::Disgust
| Emotion::Bored
)
}
#[must_use]
pub fn membership_from_pad(valence: f32, arousal: f32, dominance: f32) -> HashMap<Emotion, f64> {
const DISTANCE_SHARPNESS: f64 = 3.0;
let pad = (valence as f64, arousal as f64, dominance as f64);
let mut weights = Vec::with_capacity(OCTANT_CENTROIDS.len());
for (emotion, centroid) in OCTANT_CENTROIDS.iter() {
let distance_sq = squared_distance(pad, *centroid);
let weight = (-DISTANCE_SHARPNESS * distance_sq).exp();
weights.push((*emotion, weight));
}
let weight_sum: f64 = weights.iter().map(|(_, w)| *w).sum();
let mut membership = HashMap::new();
for (emotion, weight) in weights {
membership.insert(emotion, weight / weight_sum);
}
for emotion in Emotion::all() {
membership.entry(emotion).or_insert(0.0);
}
membership
}
}
const OCTANT_CENTROIDS: [(Emotion, (f64, f64, f64)); 8] = [
(Emotion::Exuberant, (0.5, 0.5, 0.5)),
(Emotion::Dependent, (0.5, 0.5, -0.5)),
(Emotion::Relaxed, (0.5, -0.5, 0.5)),
(Emotion::Docile, (0.5, -0.5, -0.5)),
(Emotion::Hostile, (-0.5, 0.5, 0.5)),
(Emotion::Anxious, (-0.5, 0.5, -0.5)),
(Emotion::Bored, (-0.5, -0.5, 0.5)),
(Emotion::Depressed, (-0.5, -0.5, -0.5)),
];
fn squared_distance(
pad: (f64, f64, f64),
centroid: (f64, f64, f64),
) -> f64 {
let dx = pad.0 - centroid.0;
let dy = pad.1 - centroid.1;
let dz = pad.2 - centroid.2;
dx * dx + dy * dy + dz * dz
}
impl std::fmt::Display for Emotion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn emotion_all_variants() {
let all = Emotion::all();
assert_eq!(all.len(), 10);
}
#[test]
fn emotion_names() {
assert_eq!(Emotion::Exuberant.name(), "Exuberant");
assert_eq!(Emotion::Dependent.name(), "Dependent");
assert_eq!(Emotion::Relaxed.name(), "Relaxed");
assert_eq!(Emotion::Docile.name(), "Docile");
assert_eq!(Emotion::Hostile.name(), "Hostile");
assert_eq!(Emotion::Disgust.name(), "Disgust");
assert_eq!(Emotion::Anxious.name(), "Anxious");
assert_eq!(Emotion::Bored.name(), "Bored");
assert_eq!(Emotion::Depressed.name(), "Depressed");
assert_eq!(Emotion::Neutral.name(), "Neutral");
}
#[test]
fn positive_emotions() {
assert!(Emotion::Exuberant.is_positive());
assert!(Emotion::Dependent.is_positive());
assert!(Emotion::Relaxed.is_positive());
assert!(Emotion::Docile.is_positive());
assert!(!Emotion::Hostile.is_positive());
assert!(!Emotion::Disgust.is_positive());
assert!(!Emotion::Anxious.is_positive());
assert!(!Emotion::Bored.is_positive());
assert!(!Emotion::Depressed.is_positive());
assert!(!Emotion::Neutral.is_positive());
}
#[test]
fn negative_emotions() {
assert!(Emotion::Hostile.is_negative());
assert!(Emotion::Disgust.is_negative());
assert!(Emotion::Anxious.is_negative());
assert!(Emotion::Bored.is_negative());
assert!(Emotion::Depressed.is_negative());
assert!(!Emotion::Exuberant.is_negative());
assert!(!Emotion::Dependent.is_negative());
assert!(!Emotion::Relaxed.is_negative());
assert!(!Emotion::Docile.is_negative());
assert!(!Emotion::Neutral.is_negative());
}
#[test]
fn neutral_emotion() {
assert!(Emotion::Neutral.is_neutral());
for emotion in Emotion::all() {
if emotion != Emotion::Neutral {
assert!(!emotion.is_neutral());
}
}
}
#[test]
fn high_arousal_emotions() {
assert!(Emotion::Exuberant.is_high_arousal());
assert!(Emotion::Dependent.is_high_arousal());
assert!(Emotion::Hostile.is_high_arousal());
assert!(Emotion::Disgust.is_high_arousal());
assert!(Emotion::Anxious.is_high_arousal());
assert!(!Emotion::Relaxed.is_high_arousal());
assert!(!Emotion::Docile.is_high_arousal());
assert!(!Emotion::Bored.is_high_arousal());
assert!(!Emotion::Depressed.is_high_arousal());
assert!(!Emotion::Neutral.is_high_arousal());
}
#[test]
fn high_dominance_emotions() {
assert!(Emotion::Exuberant.is_high_dominance());
assert!(Emotion::Relaxed.is_high_dominance());
assert!(Emotion::Hostile.is_high_dominance());
assert!(Emotion::Disgust.is_high_dominance());
assert!(Emotion::Bored.is_high_dominance());
assert!(!Emotion::Dependent.is_high_dominance());
assert!(!Emotion::Docile.is_high_dominance());
assert!(!Emotion::Anxious.is_high_dominance());
assert!(!Emotion::Depressed.is_high_dominance());
assert!(!Emotion::Neutral.is_high_dominance());
}
#[test]
fn display_format() {
assert_eq!(format!("{}", Emotion::Anxious), "Anxious");
assert_eq!(format!("{}", Emotion::Neutral), "Neutral");
}
#[test]
fn debug_format() {
let debug = format!("{:?}", Emotion::Hostile);
assert!(debug.contains("Hostile"));
}
#[test]
fn clone_and_copy() {
let e1 = Emotion::Relaxed;
let e2 = e1; let e3 = e1.clone();
assert_eq!(e1, e2);
assert_eq!(e1, e3);
}
#[test]
fn hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(Emotion::Anxious);
set.insert(Emotion::Anxious);
assert_eq!(set.len(), 1);
set.insert(Emotion::Relaxed);
assert_eq!(set.len(), 2);
}
#[test]
fn positive_and_negative_are_mutually_exclusive() {
for emotion in Emotion::all() {
let is_pos = emotion.is_positive();
let is_neg = emotion.is_negative();
let is_neutral = emotion.is_neutral();
let mut count = 0;
for value in [is_pos, is_neg, is_neutral] {
if value {
count += 1;
}
}
assert!(count <= 1);
if emotion != Emotion::Neutral {
assert_ne!(is_pos, is_neg);
}
}
}
#[test]
fn membership_balanced_at_center() {
let membership = Emotion::membership_from_pad(0.0, 0.0, 0.0);
let octant_emotions = [
Emotion::Exuberant,
Emotion::Dependent,
Emotion::Relaxed,
Emotion::Docile,
Emotion::Hostile,
Emotion::Anxious,
Emotion::Bored,
Emotion::Depressed,
];
let mut first: Option<f64> = None;
for emotion in octant_emotions {
let value = membership.get(&emotion).copied().unwrap_or(0.0);
if let Some(expected) = first {
assert!((value - expected).abs() < 1e-6);
} else {
first = Some(value);
}
}
}
#[test]
fn membership_octant_is_dominant() {
let membership = Emotion::membership_from_pad(1.0, 1.0, 1.0);
let exuberant = membership.get(&Emotion::Exuberant).copied().unwrap_or(0.0);
assert!(exuberant > 0.8);
}
#[test]
fn membership_distribution_sums_to_one() {
let membership = Emotion::membership_from_pad(0.2, -0.4, 0.1);
let total: f64 = membership.values().copied().sum();
assert!((total - 1.0).abs() < 1e-6);
}
#[test]
fn membership_includes_neutral_and_disgust() {
let membership = Emotion::membership_from_pad(-0.2, 0.7, -0.3);
assert_eq!(membership.get(&Emotion::Neutral).copied().unwrap_or(1.0), 0.0);
assert_eq!(membership.get(&Emotion::Disgust).copied().unwrap_or(1.0), 0.0);
}
}