1use chrono::{DateTime, Utc};
4use std::collections::HashMap;
5
6pub const EMA_ALPHA: f32 = 0.2;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
10pub enum PersonaAxis {
11 Instrumentality,
12 Verbosity,
13 Persistence,
14 Systematicity,
15 Curiosity,
16}
17
18impl PersonaAxis {
19 pub const ALL: [PersonaAxis; 5] = [
20 PersonaAxis::Instrumentality,
21 PersonaAxis::Verbosity,
22 PersonaAxis::Persistence,
23 PersonaAxis::Systematicity,
24 PersonaAxis::Curiosity,
25 ];
26
27 pub const fn name(self) -> &'static str {
28 match self {
29 PersonaAxis::Instrumentality => "Instrumentality",
30 PersonaAxis::Verbosity => "Verbosity",
31 PersonaAxis::Persistence => "Persistence",
32 PersonaAxis::Systematicity => "Systematicity",
33 PersonaAxis::Curiosity => "Curiosity",
34 }
35 }
36
37 pub fn parse(s: &str) -> Option<Self> {
38 match s.trim() {
39 "Instrumentality" | "instrumentality" => Some(Self::Instrumentality),
40 "Verbosity" | "verbosity" => Some(Self::Verbosity),
41 "Persistence" | "persistence" => Some(Self::Persistence),
42 "Systematicity" | "systematicity" => Some(Self::Systematicity),
43 "Curiosity" | "curiosity" => Some(Self::Curiosity),
44 _ => None,
45 }
46 }
47}
48
49#[derive(Debug, Clone, PartialEq)]
50pub struct AxisState {
51 pub axis: PersonaAxis,
52 pub score: f32,
54 pub sample_count: u32,
55 pub last_updated: DateTime<Utc>,
56}
57
58impl AxisState {
59 pub fn new(axis: PersonaAxis, initial_score: f32) -> Self {
60 Self {
61 axis,
62 score: initial_score.clamp(0.0, 1.0),
63 sample_count: 0,
64 last_updated: Utc::now(),
65 }
66 }
67
68 pub fn update_score(&mut self, reward: f32) {
70 let r = reward.clamp(0.0, 1.0);
71 self.score = (EMA_ALPHA * r + (1.0 - EMA_ALPHA) * self.score).clamp(0.0, 1.0);
72 self.sample_count = self.sample_count.saturating_add(1);
73 self.last_updated = Utc::now();
74 }
75
76 pub fn update_weighted(&mut self, reward: f32, weight: f32) {
78 let w = weight.clamp(0.0, 1.0);
79 let r = reward.clamp(0.0, 1.0);
80 let target = (r * w).clamp(0.0, 1.0);
81 self.score = (EMA_ALPHA * target + (1.0 - EMA_ALPHA) * self.score).clamp(0.0, 1.0);
82 self.sample_count = self.sample_count.saturating_add(1);
83 self.last_updated = Utc::now();
84 }
85}
86
87pub fn default_axis_map(initial: f32) -> HashMap<PersonaAxis, AxisState> {
88 PersonaAxis::ALL
89 .iter()
90 .copied()
91 .map(|a| (a, AxisState::new(a, initial)))
92 .collect()
93}