Skip to main content

proof_engine/save/
migrations.rs

1//! Schema migration system for save data.
2//!
3//! `MigrationRegistry` holds an ordered chain of `MigrationFn` values keyed by
4//! source version.  `migrate(data, from, to)` runs the chain to bring save data
5//! from any older version up to the current version without data loss.
6
7use std::collections::HashMap;
8
9// ─────────────────────────────────────────────────────────────────────────────
10//  SchemaVersion
11// ─────────────────────────────────────────────────────────────────────────────
12
13/// A monotonically increasing version number for the save schema.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
15pub struct SchemaVersion(pub u32);
16
17impl SchemaVersion {
18    pub const CURRENT: SchemaVersion = SchemaVersion(10);
19
20    pub fn value(self) -> u32 {
21        self.0
22    }
23}
24
25impl std::fmt::Display for SchemaVersion {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        write!(f, "v{}", self.0)
28    }
29}
30
31// ─────────────────────────────────────────────────────────────────────────────
32//  SaveValue
33// ─────────────────────────────────────────────────────────────────────────────
34
35/// A flexible value type for use inside `SaveData`.
36#[derive(Debug, Clone, PartialEq)]
37pub enum SaveValue {
38    Bool(bool),
39    Int(i64),
40    Float(f64),
41    Str(String),
42    List(Vec<SaveValue>),
43    Map(HashMap<String, SaveValue>),
44    Bytes(Vec<u8>),
45}
46
47impl SaveValue {
48    pub fn as_bool(&self) -> Option<bool> {
49        if let SaveValue::Bool(b) = self { Some(*b) } else { None }
50    }
51    pub fn as_int(&self) -> Option<i64> {
52        if let SaveValue::Int(i) = self { Some(*i) } else { None }
53    }
54    pub fn as_float(&self) -> Option<f64> {
55        match self {
56            SaveValue::Float(f) => Some(*f),
57            SaveValue::Int(i) => Some(*i as f64),
58            _ => None,
59        }
60    }
61    pub fn as_str(&self) -> Option<&str> {
62        if let SaveValue::Str(s) = self { Some(s.as_str()) } else { None }
63    }
64    pub fn as_list(&self) -> Option<&[SaveValue]> {
65        if let SaveValue::List(v) = self { Some(v.as_slice()) } else { None }
66    }
67    pub fn as_map(&self) -> Option<&HashMap<String, SaveValue>> {
68        if let SaveValue::Map(m) = self { Some(m) } else { None }
69    }
70    pub fn as_map_mut(&mut self) -> Option<&mut HashMap<String, SaveValue>> {
71        if let SaveValue::Map(m) = self { Some(m) } else { None }
72    }
73    pub fn as_bytes(&self) -> Option<&[u8]> {
74        if let SaveValue::Bytes(b) = self { Some(b.as_slice()) } else { None }
75    }
76
77    pub fn type_name(&self) -> &'static str {
78        match self {
79            SaveValue::Bool(_)  => "Bool",
80            SaveValue::Int(_)   => "Int",
81            SaveValue::Float(_) => "Float",
82            SaveValue::Str(_)   => "Str",
83            SaveValue::List(_)  => "List",
84            SaveValue::Map(_)   => "Map",
85            SaveValue::Bytes(_) => "Bytes",
86        }
87    }
88}
89
90// ─────────────────────────────────────────────────────────────────────────────
91//  SaveData
92// ─────────────────────────────────────────────────────────────────────────────
93
94/// Flexible key→value store that migrations read and modify.
95#[derive(Debug, Clone, Default)]
96pub struct SaveData {
97    pub fields: HashMap<String, SaveValue>,
98    pub version: SchemaVersion,
99}
100
101impl SaveData {
102    pub fn new(version: SchemaVersion) -> Self {
103        Self { fields: HashMap::new(), version }
104    }
105
106    pub fn get(&self, key: &str) -> Option<&SaveValue> {
107        self.fields.get(key)
108    }
109
110    pub fn get_mut(&mut self, key: &str) -> Option<&mut SaveValue> {
111        self.fields.get_mut(key)
112    }
113
114    pub fn set(&mut self, key: impl Into<String>, value: SaveValue) {
115        self.fields.insert(key.into(), value);
116    }
117
118    pub fn remove(&mut self, key: &str) -> Option<SaveValue> {
119        self.fields.remove(key)
120    }
121
122    pub fn contains(&self, key: &str) -> bool {
123        self.fields.contains_key(key)
124    }
125
126    /// Sum all integer and float values for checksum computation.
127    pub fn sum_numeric(&self) -> f64 {
128        fn recurse(v: &SaveValue) -> f64 {
129            match v {
130                SaveValue::Int(i)   => *i as f64,
131                SaveValue::Float(f) => *f,
132                SaveValue::List(l)  => l.iter().map(recurse).sum(),
133                SaveValue::Map(m)   => m.values().map(recurse).sum(),
134                _ => 0.0,
135            }
136        }
137        self.fields.values().map(recurse).sum()
138    }
139}
140
141// ─────────────────────────────────────────────────────────────────────────────
142//  MigrationFn
143// ─────────────────────────────────────────────────────────────────────────────
144
145/// A function that transforms `SaveData` from version N to N+1.
146pub type MigrationFn = fn(data: &mut SaveData) -> Result<(), String>;
147
148// ─────────────────────────────────────────────────────────────────────────────
149//  MigrationRegistry
150// ─────────────────────────────────────────────────────────────────────────────
151
152/// Holds an ordered chain of migrations keyed by source version.
153pub struct MigrationRegistry {
154    /// from_version → migration function
155    migrations: Vec<(SchemaVersion, MigrationFn)>,
156}
157
158impl MigrationRegistry {
159    pub fn new() -> Self {
160        Self { migrations: Vec::new() }
161    }
162
163    /// Register a migration from `from_version` to `from_version + 1`.
164    pub fn register(&mut self, from_version: u32, f: MigrationFn) {
165        self.migrations.push((SchemaVersion(from_version), f));
166        self.migrations.sort_by_key(|(v, _)| *v);
167    }
168
169    /// Run the chain of migrations to bring `data` from `from` to `to`.
170    pub fn migrate(
171        &self,
172        data: &mut SaveData,
173        from: SchemaVersion,
174        to: SchemaVersion,
175    ) -> Result<(), String> {
176        if from >= to {
177            return Ok(());
178        }
179        let mut current = from;
180        for (version, f) in &self.migrations {
181            if *version < from || *version >= to {
182                continue;
183            }
184            if *version != current {
185                return Err(format!(
186                    "missing migration from {current} (next available is {version})"
187                ));
188            }
189            f(data).map_err(|e| format!("migration {version}: {e}"))?;
190            current = SchemaVersion(current.0 + 1);
191            data.version = current;
192        }
193        if current != to {
194            return Err(format!("migration chain incomplete: reached {current}, needed {to}"));
195        }
196        Ok(())
197    }
198
199    /// Build a registry pre-populated with all 10 built-in migrations (v0→v10).
200    pub fn with_builtin_migrations() -> Self {
201        let mut reg = Self::new();
202        reg.register(0, migrate_v0_to_v1);
203        reg.register(1, migrate_v1_to_v2);
204        reg.register(2, migrate_v2_to_v3);
205        reg.register(3, migrate_v3_to_v4);
206        reg.register(4, migrate_v4_to_v5);
207        reg.register(5, migrate_v5_to_v6);
208        reg.register(6, migrate_v6_to_v7);
209        reg.register(7, migrate_v7_to_v8);
210        reg.register(8, migrate_v8_to_v9);
211        reg.register(9, migrate_v9_to_v10);
212        reg
213    }
214}
215
216impl Default for MigrationRegistry {
217    fn default() -> Self {
218        Self::new()
219    }
220}
221
222// ─────────────────────────────────────────────────────────────────────────────
223//  Concrete migrations
224// ─────────────────────────────────────────────────────────────────────────────
225
226/// v0 → v1: Add `created_at` timestamp field (defaults to 0).
227fn migrate_v0_to_v1(data: &mut SaveData) -> Result<(), String> {
228    if !data.contains("created_at") {
229        data.set("created_at", SaveValue::Int(0));
230    }
231    Ok(())
232}
233
234/// v1 → v2: Rename `hp` to `health_points` throughout the top-level fields.
235fn migrate_v1_to_v2(data: &mut SaveData) -> Result<(), String> {
236    if let Some(val) = data.remove("hp") {
237        data.set("health_points", val);
238    }
239    // Also rename inside any nested maps
240    let keys: Vec<String> = data.fields.keys().cloned().collect();
241    for key in keys {
242        if let Some(SaveValue::Map(ref mut m)) = data.fields.get_mut(&key) {
243            if let Some(val) = m.remove("hp") {
244                m.insert("health_points".into(), val);
245            }
246        }
247    }
248    Ok(())
249}
250
251/// v2 → v3: Flatten nested `stats` map — `stats.strength` becomes `stat_strength`, etc.
252fn migrate_v2_to_v3(data: &mut SaveData) -> Result<(), String> {
253    if let Some(SaveValue::Map(stats_map)) = data.remove("stats") {
254        for (k, v) in stats_map {
255            data.set(format!("stat_{k}"), v);
256        }
257    }
258    Ok(())
259}
260
261/// v3 → v4: Convert `inventory` from a list-of-strings to a list-of-objects
262/// `{name: String, quantity: 1, durability: 100}`.
263fn migrate_v3_to_v4(data: &mut SaveData) -> Result<(), String> {
264    if let Some(SaveValue::List(inv)) = data.remove("inventory") {
265        let new_inv: Vec<SaveValue> = inv
266            .into_iter()
267            .map(|item| {
268                let name = match &item {
269                    SaveValue::Str(s) => s.clone(),
270                    _ => "unknown".into(),
271                };
272                let mut m = HashMap::new();
273                m.insert("name".into(),       SaveValue::Str(name));
274                m.insert("quantity".into(),   SaveValue::Int(1));
275                m.insert("durability".into(), SaveValue::Int(100));
276                SaveValue::Map(m)
277            })
278            .collect();
279        data.set("inventory", SaveValue::List(new_inv));
280    }
281    Ok(())
282}
283
284/// v4 → v5: Add `player_level` defaulting to 1.
285fn migrate_v4_to_v5(data: &mut SaveData) -> Result<(), String> {
286    if !data.contains("player_level") {
287        data.set("player_level", SaveValue::Int(1));
288    }
289    Ok(())
290}
291
292/// v5 → v6: Convert `position` from a `[x, y]` list to `{x, y, z: 0}` map.
293fn migrate_v5_to_v6(data: &mut SaveData) -> Result<(), String> {
294    if let Some(SaveValue::List(pos)) = data.remove("position") {
295        let x = pos.get(0).and_then(|v| v.as_float()).unwrap_or(0.0);
296        let y = pos.get(1).and_then(|v| v.as_float()).unwrap_or(0.0);
297        let mut m = HashMap::new();
298        m.insert("x".into(), SaveValue::Float(x));
299        m.insert("y".into(), SaveValue::Float(y));
300        m.insert("z".into(), SaveValue::Float(0.0));
301        data.set("position", SaveValue::Map(m));
302    }
303    Ok(())
304}
305
306/// v6 → v7: Add `difficulty` defaulting to "normal".
307fn migrate_v6_to_v7(data: &mut SaveData) -> Result<(), String> {
308    if !data.contains("difficulty") {
309        data.set("difficulty", SaveValue::Str("normal".into()));
310    }
311    Ok(())
312}
313
314/// v7 → v8: Encode `player_name` as UTF-8 bytes (simulates an encoding migration).
315fn migrate_v7_to_v8(data: &mut SaveData) -> Result<(), String> {
316    if let Some(SaveValue::Str(name)) = data.remove("player_name") {
317        data.set("player_name", SaveValue::Bytes(name.into_bytes()));
318    }
319    Ok(())
320}
321
322/// v8 → v9: Split `audio_volume` (0.0–1.0) into `music_volume` and `sfx_volume`.
323fn migrate_v8_to_v9(data: &mut SaveData) -> Result<(), String> {
324    let vol = if let Some(v) = data.remove("audio_volume") {
325        v.as_float().unwrap_or(1.0)
326    } else {
327        1.0
328    };
329    if !data.contains("music_volume") {
330        data.set("music_volume", SaveValue::Float(vol));
331    }
332    if !data.contains("sfx_volume") {
333        data.set("sfx_volume", SaveValue::Float(vol));
334    }
335    Ok(())
336}
337
338/// v9 → v10: Compute a `checksum` field as the integer sum of all numeric values.
339fn migrate_v9_to_v10(data: &mut SaveData) -> Result<(), String> {
340    let sum = data.sum_numeric();
341    data.set("checksum", SaveValue::Int(sum as i64));
342    Ok(())
343}
344
345// ─────────────────────────────────────────────────────────────────────────────
346//  Tests
347// ─────────────────────────────────────────────────────────────────────────────
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    fn make_registry() -> MigrationRegistry {
354        MigrationRegistry::with_builtin_migrations()
355    }
356
357    fn v0_data() -> SaveData {
358        let mut d = SaveData::new(SchemaVersion(0));
359        d.set("hp", SaveValue::Int(100));
360        d.set("audio_volume", SaveValue::Float(0.8));
361        d.set("player_name", SaveValue::Str("Hero".into()));
362        let inv = SaveValue::List(vec![
363            SaveValue::Str("sword".into()),
364            SaveValue::Str("shield".into()),
365        ]);
366        d.set("inventory", inv);
367        let mut stats = HashMap::new();
368        stats.insert("strength".into(), SaveValue::Int(10));
369        stats.insert("agility".into(),  SaveValue::Int(8));
370        d.set("stats", SaveValue::Map(stats));
371        d.set("position", SaveValue::List(vec![SaveValue::Float(1.5), SaveValue::Float(2.5)]));
372        d
373    }
374
375    #[test]
376    fn test_v0_to_v1_adds_created_at() {
377        let reg = make_registry();
378        let mut data = v0_data();
379        reg.migrate(&mut data, SchemaVersion(0), SchemaVersion(1)).unwrap();
380        assert!(data.contains("created_at"));
381    }
382
383    #[test]
384    fn test_v1_to_v2_renames_hp() {
385        let reg = make_registry();
386        let mut data = v0_data();
387        reg.migrate(&mut data, SchemaVersion(0), SchemaVersion(2)).unwrap();
388        assert!(!data.contains("hp"));
389        assert!(data.contains("health_points"));
390    }
391
392    #[test]
393    fn test_v2_to_v3_flattens_stats() {
394        let reg = make_registry();
395        let mut data = v0_data();
396        reg.migrate(&mut data, SchemaVersion(0), SchemaVersion(3)).unwrap();
397        assert!(!data.contains("stats"));
398        assert!(data.contains("stat_strength"));
399        assert!(data.contains("stat_agility"));
400    }
401
402    #[test]
403    fn test_v3_to_v4_converts_inventory() {
404        let reg = make_registry();
405        let mut data = v0_data();
406        reg.migrate(&mut data, SchemaVersion(0), SchemaVersion(4)).unwrap();
407        let inv = data.get("inventory").unwrap().as_list().unwrap();
408        assert_eq!(inv.len(), 2);
409        let item = inv[0].as_map().unwrap();
410        assert!(item.contains_key("name"));
411        assert!(item.contains_key("quantity"));
412        assert!(item.contains_key("durability"));
413    }
414
415    #[test]
416    fn test_v4_to_v5_adds_player_level() {
417        let reg = make_registry();
418        let mut data = v0_data();
419        reg.migrate(&mut data, SchemaVersion(0), SchemaVersion(5)).unwrap();
420        assert_eq!(data.get("player_level").unwrap().as_int(), Some(1));
421    }
422
423    #[test]
424    fn test_v5_to_v6_converts_position() {
425        let reg = make_registry();
426        let mut data = v0_data();
427        reg.migrate(&mut data, SchemaVersion(0), SchemaVersion(6)).unwrap();
428        let pos = data.get("position").unwrap().as_map().unwrap();
429        assert!(pos.contains_key("x"));
430        assert!(pos.contains_key("y"));
431        assert!(pos.contains_key("z"));
432        assert_eq!(pos["z"].as_float(), Some(0.0));
433    }
434
435    #[test]
436    fn test_v6_to_v7_adds_difficulty() {
437        let reg = make_registry();
438        let mut data = v0_data();
439        reg.migrate(&mut data, SchemaVersion(0), SchemaVersion(7)).unwrap();
440        assert_eq!(data.get("difficulty").unwrap().as_str(), Some("normal"));
441    }
442
443    #[test]
444    fn test_v7_to_v8_encodes_name_as_bytes() {
445        let reg = make_registry();
446        let mut data = v0_data();
447        reg.migrate(&mut data, SchemaVersion(0), SchemaVersion(8)).unwrap();
448        let name_val = data.get("player_name").unwrap();
449        assert!(matches!(name_val, SaveValue::Bytes(_)));
450        assert_eq!(name_val.as_bytes(), Some(b"Hero" as &[u8]));
451    }
452
453    #[test]
454    fn test_v8_to_v9_splits_audio_volume() {
455        let reg = make_registry();
456        let mut data = v0_data();
457        reg.migrate(&mut data, SchemaVersion(0), SchemaVersion(9)).unwrap();
458        assert!(!data.contains("audio_volume"));
459        assert!(data.contains("music_volume"));
460        assert!(data.contains("sfx_volume"));
461    }
462
463    #[test]
464    fn test_v9_to_v10_adds_checksum() {
465        let reg = make_registry();
466        let mut data = v0_data();
467        reg.migrate(&mut data, SchemaVersion(0), SchemaVersion(10)).unwrap();
468        assert!(data.contains("checksum"));
469    }
470
471    #[test]
472    fn test_full_migration_chain() {
473        let reg = make_registry();
474        let mut data = v0_data();
475        reg.migrate(&mut data, SchemaVersion(0), SchemaVersion(10)).unwrap();
476        assert_eq!(data.version, SchemaVersion(10));
477    }
478
479    #[test]
480    fn test_migration_already_at_version() {
481        let reg = make_registry();
482        let mut data = SaveData::new(SchemaVersion(5));
483        let result = reg.migrate(&mut data, SchemaVersion(5), SchemaVersion(5));
484        assert!(result.is_ok());
485    }
486
487    #[test]
488    fn test_schema_version_ordering() {
489        assert!(SchemaVersion(0) < SchemaVersion(1));
490        assert!(SchemaVersion(10) == SchemaVersion::CURRENT);
491    }
492
493    #[test]
494    fn test_save_value_type_accessors() {
495        let v = SaveValue::Int(42);
496        assert_eq!(v.as_int(), Some(42));
497        assert_eq!(v.as_bool(), None);
498        assert_eq!(v.as_float(), Some(42.0));
499
500        let s = SaveValue::Str("hello".into());
501        assert_eq!(s.as_str(), Some("hello"));
502
503        let b = SaveValue::Bytes(vec![1, 2, 3]);
504        assert_eq!(b.as_bytes(), Some(&[1u8, 2, 3] as &[u8]));
505    }
506}