1use std::collections::{HashMap, HashSet};
2use crate::state::ProfileState;
3use crate::error::AriaError;
4
5#[derive(Debug, Clone)]
9pub struct StateSnapshot {
10 pub skill: f32,
11 pub optimism_bias: f32,
12 pub last_seen: HashMap<String, u64>,
13 pub category_count: HashMap<String, u32>,
14 pub resolved_set: Vec<String>, pub interaction_count: u64,
16 pub extended: HashMap<String, f32>,
17 pub extended_str: HashMap<String, String>,
18}
19
20impl From<&ProfileState> for StateSnapshot {
21 fn from(s: &ProfileState) -> Self {
22 Self {
23 skill: s.skill,
24 optimism_bias: s.optimism_bias,
25 last_seen: s.last_seen.clone(),
26 category_count: s.category_count.clone(),
27 resolved_set: s.resolved_set.iter().cloned().collect(),
28 interaction_count: s.interaction_count,
29 extended: s.extended.clone(),
30 extended_str: s.extended_str.clone(),
31 }
32 }
33}
34
35impl From<StateSnapshot> for ProfileState {
36 fn from(snap: StateSnapshot) -> Self {
37 ProfileState {
38 skill: snap.skill,
39 optimism_bias: snap.optimism_bias,
40 last_seen: snap.last_seen,
41 category_count: snap.category_count,
42 resolved_set: snap.resolved_set.into_iter().collect::<HashSet<String>>(),
43 interaction_count: snap.interaction_count,
44 extended: snap.extended,
45 extended_str: snap.extended_str,
46 }
47 }
48}
49
50pub struct Serialiser;
53
54impl Serialiser {
55 pub fn encode(state: &ProfileState) -> HashMap<String, String> {
58 let mut map = HashMap::new();
59
60 map.insert("skill".into(), state.skill.to_string());
61 map.insert("optimism_bias".into(), state.optimism_bias.to_string());
62 map.insert("interaction_count".into(), state.interaction_count.to_string());
63
64 for (id, ts) in &state.last_seen {
66 map.insert(format!("last_seen:{id}"), ts.to_string());
67 }
68
69 for (cat, count) in &state.category_count {
71 map.insert(format!("category_count:{cat}"), count.to_string());
72 }
73
74 let resolved: Vec<&str> = state.resolved_set.iter().map(|s| s.as_str()).collect();
76 map.insert("resolved_set".into(), resolved.join(","));
77
78 for (k, v) in &state.extended {
80 map.insert(format!("ext:{k}"), v.to_string());
81 }
82
83 for (k, v) in &state.extended_str {
85 map.insert(format!("ext_str:{k}"), v.clone());
86 }
87
88 map
89 }
90
91 pub fn decode(map: &HashMap<String, String>) -> Result<ProfileState, AriaError> {
93 let mut state = ProfileState::new();
94
95 state.skill = map
96 .get("skill")
97 .and_then(|v| v.parse().ok())
98 .unwrap_or(0.0);
99
100 state.optimism_bias = map
101 .get("optimism_bias")
102 .and_then(|v| v.parse().ok())
103 .unwrap_or(0.1);
104
105 state.interaction_count = map
106 .get("interaction_count")
107 .and_then(|v| v.parse().ok())
108 .unwrap_or(0);
109
110 if let Some(resolved_str) = map.get("resolved_set") {
111 if !resolved_str.is_empty() {
112 for id in resolved_str.split(',') {
113 state.resolved_set.insert(id.to_string());
114 }
115 }
116 }
117
118 for (k, v) in map {
119 if let Some(id) = k.strip_prefix("last_seen:") {
120 if let Ok(ts) = v.parse::<u64>() {
121 state.last_seen.insert(id.to_string(), ts);
122 }
123 } else if let Some(cat) = k.strip_prefix("category_count:") {
124 if let Ok(count) = v.parse::<u32>() {
125 state.category_count.insert(cat.to_string(), count);
126 }
127 } else if let Some(key) = k.strip_prefix("ext:") {
128 if let Ok(val) = v.parse::<f32>() {
129 state.extended.insert(key.to_string(), val);
130 }
131 } else if let Some(key) = k.strip_prefix("ext_str:") {
132 state.extended_str.insert(key.to_string(), v.clone());
133 }
134 }
135
136 Ok(state)
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::state::ProfileState;
144
145 #[test]
146 fn round_trip() {
147 let mut state = ProfileState::new();
148 state.skill = 0.42;
149 state.optimism_bias = 0.15;
150 state.interaction_count = 7;
151 state.last_seen.insert("item1".into(), 123456);
152 state.category_count.insert("math".into(), 3);
153 state.resolved_set.insert("item1".into());
154 state.extended.insert("custom_score".into(), 0.77);
155 state.extended_str.insert("mode".into(), "practice".into());
156
157 let encoded = Serialiser::encode(&state);
158 let decoded = Serialiser::decode(&encoded).unwrap();
159
160 assert!((decoded.skill - state.skill).abs() < 1e-5);
161 assert!((decoded.optimism_bias - state.optimism_bias).abs() < 1e-5);
162 assert_eq!(decoded.interaction_count, state.interaction_count);
163 assert_eq!(decoded.last_seen["item1"], 123456);
164 assert_eq!(decoded.category_count["math"], 3);
165 assert!(decoded.resolved_set.contains("item1"));
166 assert!((decoded.extended["custom_score"] - 0.77).abs() < 1e-5);
167 assert_eq!(decoded.extended_str["mode"], "practice");
168 }
169}