use std::collections::{HashMap, HashSet};
use crate::state::ProfileState;
use crate::error::AriaError;
#[derive(Debug, Clone)]
pub struct StateSnapshot {
pub skill: f32,
pub optimism_bias: f32,
pub last_seen: HashMap<String, u64>,
pub category_count: HashMap<String, u32>,
pub resolved_set: Vec<String>, pub interaction_count: u64,
pub extended: HashMap<String, f32>,
pub extended_str: HashMap<String, String>,
}
impl From<&ProfileState> for StateSnapshot {
fn from(s: &ProfileState) -> Self {
Self {
skill: s.skill,
optimism_bias: s.optimism_bias,
last_seen: s.last_seen.clone(),
category_count: s.category_count.clone(),
resolved_set: s.resolved_set.iter().cloned().collect(),
interaction_count: s.interaction_count,
extended: s.extended.clone(),
extended_str: s.extended_str.clone(),
}
}
}
impl From<StateSnapshot> for ProfileState {
fn from(snap: StateSnapshot) -> Self {
ProfileState {
skill: snap.skill,
optimism_bias: snap.optimism_bias,
last_seen: snap.last_seen,
category_count: snap.category_count,
resolved_set: snap.resolved_set.into_iter().collect::<HashSet<String>>(),
interaction_count: snap.interaction_count,
extended: snap.extended,
extended_str: snap.extended_str,
}
}
}
pub struct Serialiser;
impl Serialiser {
pub fn encode(state: &ProfileState) -> HashMap<String, String> {
let mut map = HashMap::new();
map.insert("skill".into(), state.skill.to_string());
map.insert("optimism_bias".into(), state.optimism_bias.to_string());
map.insert("interaction_count".into(), state.interaction_count.to_string());
for (id, ts) in &state.last_seen {
map.insert(format!("last_seen:{id}"), ts.to_string());
}
for (cat, count) in &state.category_count {
map.insert(format!("category_count:{cat}"), count.to_string());
}
let resolved: Vec<&str> = state.resolved_set.iter().map(|s| s.as_str()).collect();
map.insert("resolved_set".into(), resolved.join(","));
for (k, v) in &state.extended {
map.insert(format!("ext:{k}"), v.to_string());
}
for (k, v) in &state.extended_str {
map.insert(format!("ext_str:{k}"), v.clone());
}
map
}
pub fn decode(map: &HashMap<String, String>) -> Result<ProfileState, AriaError> {
let mut state = ProfileState::new();
state.skill = map
.get("skill")
.and_then(|v| v.parse().ok())
.unwrap_or(0.0);
state.optimism_bias = map
.get("optimism_bias")
.and_then(|v| v.parse().ok())
.unwrap_or(0.1);
state.interaction_count = map
.get("interaction_count")
.and_then(|v| v.parse().ok())
.unwrap_or(0);
if let Some(resolved_str) = map.get("resolved_set") {
if !resolved_str.is_empty() {
for id in resolved_str.split(',') {
state.resolved_set.insert(id.to_string());
}
}
}
for (k, v) in map {
if let Some(id) = k.strip_prefix("last_seen:") {
if let Ok(ts) = v.parse::<u64>() {
state.last_seen.insert(id.to_string(), ts);
}
} else if let Some(cat) = k.strip_prefix("category_count:") {
if let Ok(count) = v.parse::<u32>() {
state.category_count.insert(cat.to_string(), count);
}
} else if let Some(key) = k.strip_prefix("ext:") {
if let Ok(val) = v.parse::<f32>() {
state.extended.insert(key.to_string(), val);
}
} else if let Some(key) = k.strip_prefix("ext_str:") {
state.extended_str.insert(key.to_string(), v.clone());
}
}
Ok(state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::ProfileState;
#[test]
fn round_trip() {
let mut state = ProfileState::new();
state.skill = 0.42;
state.optimism_bias = 0.15;
state.interaction_count = 7;
state.last_seen.insert("item1".into(), 123456);
state.category_count.insert("math".into(), 3);
state.resolved_set.insert("item1".into());
state.extended.insert("custom_score".into(), 0.77);
state.extended_str.insert("mode".into(), "practice".into());
let encoded = Serialiser::encode(&state);
let decoded = Serialiser::decode(&encoded).unwrap();
assert!((decoded.skill - state.skill).abs() < 1e-5);
assert!((decoded.optimism_bias - state.optimism_bias).abs() < 1e-5);
assert_eq!(decoded.interaction_count, state.interaction_count);
assert_eq!(decoded.last_seen["item1"], 123456);
assert_eq!(decoded.category_count["math"], 3);
assert!(decoded.resolved_set.contains("item1"));
assert!((decoded.extended["custom_score"] - 0.77).abs() < 1e-5);
assert_eq!(decoded.extended_str["mode"], "practice");
}
}