Skip to main content

aria_core/
serialiser.rs

1use std::collections::{HashMap, HashSet};
2use crate::state::ProfileState;
3use crate::error::AriaError;
4
5/// Serialisable flat representation of ProfileState.
6/// Callers persist this to any store (JSON, DB, cookie, etc.)
7/// No serde dependency required — pure stdlib.
8#[derive(Debug, Clone)]
9pub struct StateSnapshot {
10    pub skill: f32,
11    pub optimism_bias: f32,
12    pub last_seen: HashMap<String, u64>,
13    pub category_count: HashMap<String, u32>,
14    pub resolved_set: Vec<String>,         // Vec for easier JSON serialisation
15    pub interaction_count: u64,
16    pub extended: HashMap<String, f32>,
17    pub extended_str: HashMap<String, String>,
18}
19
20impl From<&ProfileState> for StateSnapshot {
21    fn from(s: &ProfileState) -> Self {
22        Self {
23            skill: s.skill,
24            optimism_bias: s.optimism_bias,
25            last_seen: s.last_seen.clone(),
26            category_count: s.category_count.clone(),
27            resolved_set: s.resolved_set.iter().cloned().collect(),
28            interaction_count: s.interaction_count,
29            extended: s.extended.clone(),
30            extended_str: s.extended_str.clone(),
31        }
32    }
33}
34
35impl From<StateSnapshot> for ProfileState {
36    fn from(snap: StateSnapshot) -> Self {
37        ProfileState {
38            skill: snap.skill,
39            optimism_bias: snap.optimism_bias,
40            last_seen: snap.last_seen,
41            category_count: snap.category_count,
42            resolved_set: snap.resolved_set.into_iter().collect::<HashSet<String>>(),
43            interaction_count: snap.interaction_count,
44            extended: snap.extended,
45            extended_str: snap.extended_str,
46        }
47    }
48}
49
50/// Serialiser — converts ProfileState ↔ key-value string map.
51/// Portable to any storage layer with no external deps.
52pub struct Serialiser;
53
54impl Serialiser {
55    /// Encode ProfileState to flat string key-value map.
56    /// Callers can JSON-encode this map with their own serialiser.
57    pub fn encode(state: &ProfileState) -> HashMap<String, String> {
58        let mut map = HashMap::new();
59
60        map.insert("skill".into(), state.skill.to_string());
61        map.insert("optimism_bias".into(), state.optimism_bias.to_string());
62        map.insert("interaction_count".into(), state.interaction_count.to_string());
63
64        // last_seen: "last_seen:item_id" → timestamp
65        for (id, ts) in &state.last_seen {
66            map.insert(format!("last_seen:{id}"), ts.to_string());
67        }
68
69        // category_count: "category_count:cat" → count
70        for (cat, count) in &state.category_count {
71            map.insert(format!("category_count:{cat}"), count.to_string());
72        }
73
74        // resolved_set: comma-joined
75        let resolved: Vec<&str> = state.resolved_set.iter().map(|s| s.as_str()).collect();
76        map.insert("resolved_set".into(), resolved.join(","));
77
78        // extended floats
79        for (k, v) in &state.extended {
80            map.insert(format!("ext:{k}"), v.to_string());
81        }
82
83        // extended strings
84        for (k, v) in &state.extended_str {
85            map.insert(format!("ext_str:{k}"), v.clone());
86        }
87
88        map
89    }
90
91    /// Decode flat string map back to ProfileState.
92    pub fn decode(map: &HashMap<String, String>) -> Result<ProfileState, AriaError> {
93        let mut state = ProfileState::new();
94
95        state.skill = map
96            .get("skill")
97            .and_then(|v| v.parse().ok())
98            .unwrap_or(0.0);
99
100        state.optimism_bias = map
101            .get("optimism_bias")
102            .and_then(|v| v.parse().ok())
103            .unwrap_or(0.1);
104
105        state.interaction_count = map
106            .get("interaction_count")
107            .and_then(|v| v.parse().ok())
108            .unwrap_or(0);
109
110        if let Some(resolved_str) = map.get("resolved_set") {
111            if !resolved_str.is_empty() {
112                for id in resolved_str.split(',') {
113                    state.resolved_set.insert(id.to_string());
114                }
115            }
116        }
117
118        for (k, v) in map {
119            if let Some(id) = k.strip_prefix("last_seen:") {
120                if let Ok(ts) = v.parse::<u64>() {
121                    state.last_seen.insert(id.to_string(), ts);
122                }
123            } else if let Some(cat) = k.strip_prefix("category_count:") {
124                if let Ok(count) = v.parse::<u32>() {
125                    state.category_count.insert(cat.to_string(), count);
126                }
127            } else if let Some(key) = k.strip_prefix("ext:") {
128                if let Ok(val) = v.parse::<f32>() {
129                    state.extended.insert(key.to_string(), val);
130                }
131            } else if let Some(key) = k.strip_prefix("ext_str:") {
132                state.extended_str.insert(key.to_string(), v.clone());
133            }
134        }
135
136        Ok(state)
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::state::ProfileState;
144
145    #[test]
146    fn round_trip() {
147        let mut state = ProfileState::new();
148        state.skill = 0.42;
149        state.optimism_bias = 0.15;
150        state.interaction_count = 7;
151        state.last_seen.insert("item1".into(), 123456);
152        state.category_count.insert("math".into(), 3);
153        state.resolved_set.insert("item1".into());
154        state.extended.insert("custom_score".into(), 0.77);
155        state.extended_str.insert("mode".into(), "practice".into());
156
157        let encoded = Serialiser::encode(&state);
158        let decoded = Serialiser::decode(&encoded).unwrap();
159
160        assert!((decoded.skill - state.skill).abs() < 1e-5);
161        assert!((decoded.optimism_bias - state.optimism_bias).abs() < 1e-5);
162        assert_eq!(decoded.interaction_count, state.interaction_count);
163        assert_eq!(decoded.last_seen["item1"], 123456);
164        assert_eq!(decoded.category_count["math"], 3);
165        assert!(decoded.resolved_set.contains("item1"));
166        assert!((decoded.extended["custom_score"] - 0.77).abs() < 1e-5);
167        assert_eq!(decoded.extended_str["mode"], "practice");
168    }
169}