Skip to main content

ainl_persona/
axes.rs

1//! Named persona axes — soft spectra, not discrete classes.
2
3use chrono::{DateTime, Utc};
4use std::collections::HashMap;
5
6/// α for exponential moving average updates.
7pub const EMA_ALPHA: f32 = 0.2;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
10pub enum PersonaAxis {
11    Instrumentality,
12    Verbosity,
13    Persistence,
14    Systematicity,
15    Curiosity,
16}
17
18impl PersonaAxis {
19    pub const ALL: [PersonaAxis; 5] = [
20        PersonaAxis::Instrumentality,
21        PersonaAxis::Verbosity,
22        PersonaAxis::Persistence,
23        PersonaAxis::Systematicity,
24        PersonaAxis::Curiosity,
25    ];
26
27    pub const fn name(self) -> &'static str {
28        match self {
29            PersonaAxis::Instrumentality => "Instrumentality",
30            PersonaAxis::Verbosity => "Verbosity",
31            PersonaAxis::Persistence => "Persistence",
32            PersonaAxis::Systematicity => "Systematicity",
33            PersonaAxis::Curiosity => "Curiosity",
34        }
35    }
36
37    pub fn parse(s: &str) -> Option<Self> {
38        match s.trim() {
39            "Instrumentality" | "instrumentality" => Some(Self::Instrumentality),
40            "Verbosity" | "verbosity" => Some(Self::Verbosity),
41            "Persistence" | "persistence" => Some(Self::Persistence),
42            "Systematicity" | "systematicity" => Some(Self::Systematicity),
43            "Curiosity" | "curiosity" => Some(Self::Curiosity),
44            _ => None,
45        }
46    }
47}
48
49#[derive(Debug, Clone, PartialEq)]
50pub struct AxisState {
51    pub axis: PersonaAxis,
52    /// EMA score in \[0, 1\].
53    pub score: f32,
54    pub sample_count: u32,
55    pub last_updated: DateTime<Utc>,
56}
57
58impl AxisState {
59    pub fn new(axis: PersonaAxis, initial_score: f32) -> Self {
60        Self {
61            axis,
62            score: initial_score.clamp(0.0, 1.0),
63            sample_count: 0,
64            last_updated: Utc::now(),
65        }
66    }
67
68    /// Plain EMA toward `reward` (no per-signal weighting).
69    pub fn update_score(&mut self, reward: f32) {
70        let r = reward.clamp(0.0, 1.0);
71        self.score = (EMA_ALPHA * r + (1.0 - EMA_ALPHA) * self.score).clamp(0.0, 1.0);
72        self.sample_count = self.sample_count.saturating_add(1);
73        self.last_updated = Utc::now();
74    }
75
76    /// Weighted EMA: effective target is `reward * weight` (clamped to \[0,1\]).
77    pub fn update_weighted(&mut self, reward: f32, weight: f32) {
78        let w = weight.clamp(0.0, 1.0);
79        let r = reward.clamp(0.0, 1.0);
80        let target = (r * w).clamp(0.0, 1.0);
81        self.score = (EMA_ALPHA * target + (1.0 - EMA_ALPHA) * self.score).clamp(0.0, 1.0);
82        self.sample_count = self.sample_count.saturating_add(1);
83        self.last_updated = Utc::now();
84    }
85}
86
87pub fn default_axis_map(initial: f32) -> HashMap<PersonaAxis, AxisState> {
88    PersonaAxis::ALL
89        .iter()
90        .copied()
91        .map(|a| (a, AxisState::new(a, initial)))
92        .collect()
93}