Skip to main content

mofa_plugins/hot_reload/
state.rs

1//! Plugin state management
2//!
3//! Handles state preservation and restoration during hot-reload
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::{debug, info, warn};
10
11/// Plugin hot-reload state
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
13pub enum PluginState {
14    /// Not loaded
15    #[default]
16    Unloaded,
17    /// Loading in progress
18    Loading,
19    /// Loaded and ready
20    Loaded,
21    /// Running
22    Running,
23    /// Reloading in progress
24    Reloading,
25    /// Failed to load/reload
26    Failed(String),
27    /// Unloading in progress
28    Unloading,
29}
30
31impl std::fmt::Display for PluginState {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            PluginState::Unloaded => write!(f, "Unloaded"),
35            PluginState::Loading => write!(f, "Loading"),
36            PluginState::Loaded => write!(f, "Loaded"),
37            PluginState::Running => write!(f, "Running"),
38            PluginState::Reloading => write!(f, "Reloading"),
39            PluginState::Failed(err) => write!(f, "Failed: {}", err),
40            PluginState::Unloading => write!(f, "Unloading"),
41        }
42    }
43}
44
45/// A snapshot of plugin state that can be preserved across reloads
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct StateSnapshot {
48    /// Plugin ID
49    pub plugin_id: String,
50    /// Snapshot timestamp
51    pub timestamp: u64,
52    /// Serialized state data
53    pub data: HashMap<String, serde_json::Value>,
54    /// Plugin version at snapshot time
55    pub plugin_version: String,
56    /// Custom metadata
57    pub metadata: HashMap<String, String>,
58}
59
60impl StateSnapshot {
61    /// Create a new state snapshot
62    pub fn new(plugin_id: &str, plugin_version: &str) -> Self {
63        Self {
64            plugin_id: plugin_id.to_string(),
65            timestamp: std::time::SystemTime::now()
66                .duration_since(std::time::UNIX_EPOCH)
67                .unwrap_or_default()
68                .as_secs(),
69            data: HashMap::new(),
70            plugin_version: plugin_version.to_string(),
71            metadata: HashMap::new(),
72        }
73    }
74
75    /// Add state data
76    pub fn with_data<T: Serialize>(mut self, key: &str, value: &T) -> Self {
77        if let Ok(json_value) = serde_json::to_value(value) {
78            self.data.insert(key.to_string(), json_value);
79        }
80        self
81    }
82
83    /// Add metadata
84    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
85        self.metadata.insert(key.to_string(), value.to_string());
86        self
87    }
88
89    /// Get state data
90    pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
91        self.data
92            .get(key)
93            .and_then(|v| serde_json::from_value(v.clone()).ok())
94    }
95
96    /// Get metadata
97    pub fn get_metadata(&self, key: &str) -> Option<&str> {
98        self.metadata.get(key).map(|s| s.as_str())
99    }
100
101    /// Check if snapshot is compatible with a plugin version
102    pub fn is_compatible(&self, plugin_version: &str) -> bool {
103        // Simple semantic version major check
104        let snapshot_major = self.plugin_version.split('.').next();
105        let plugin_major = plugin_version.split('.').next();
106        snapshot_major == plugin_major
107    }
108
109    /// Serialize to bytes
110    pub fn to_bytes(&self) -> Result<Vec<u8>, serde_json::Error> {
111        serde_json::to_vec(self)
112    }
113
114    /// Deserialize from bytes
115    pub fn from_bytes(bytes: &[u8]) -> Result<Self, serde_json::Error> {
116        serde_json::from_slice(bytes)
117    }
118}
119
120/// State manager for plugin hot-reload
121pub struct StateManager {
122    /// Active snapshots by plugin ID
123    snapshots: Arc<RwLock<HashMap<String, StateSnapshot>>>,
124    /// Historical snapshots (for rollback)
125    history: Arc<RwLock<HashMap<String, Vec<StateSnapshot>>>>,
126    /// Maximum history entries per plugin
127    max_history: usize,
128    /// Enable state persistence
129    persist_enabled: bool,
130    /// Persistence directory
131    persist_dir: Option<std::path::PathBuf>,
132}
133
134impl StateManager {
135    /// Create a new state manager
136    pub fn new() -> Self {
137        Self {
138            snapshots: Arc::new(RwLock::new(HashMap::new())),
139            history: Arc::new(RwLock::new(HashMap::new())),
140            max_history: 10,
141            persist_enabled: false,
142            persist_dir: None,
143        }
144    }
145
146    /// Set maximum history entries
147    pub fn with_max_history(mut self, max: usize) -> Self {
148        self.max_history = max;
149        self
150    }
151
152    /// Enable state persistence
153    pub fn with_persistence<P: AsRef<std::path::Path>>(mut self, dir: P) -> Self {
154        self.persist_enabled = true;
155        self.persist_dir = Some(dir.as_ref().to_path_buf());
156        self
157    }
158
159    /// Save a state snapshot
160    pub async fn save_snapshot(&self, snapshot: StateSnapshot) -> Result<(), String> {
161        let plugin_id = snapshot.plugin_id.clone();
162
163        info!("Saving state snapshot for plugin: {}", plugin_id);
164
165        // Move current snapshot to history
166        let mut snapshots = self.snapshots.write().await;
167        if let Some(current) = snapshots.remove(&plugin_id) {
168            let mut history = self.history.write().await;
169            let entry = history.entry(plugin_id.clone()).or_insert_with(Vec::new);
170            entry.push(current);
171
172            // Trim history
173            if entry.len() > self.max_history {
174                let to_remove = entry.len() - self.max_history;
175                entry.drain(0..to_remove);
176            }
177        }
178
179        // Save new snapshot
180        if self.persist_enabled
181            && let Err(e) = self.persist_snapshot(&snapshot).await
182        {
183            warn!("Failed to persist snapshot: {}", e);
184        }
185
186        snapshots.insert(plugin_id, snapshot);
187        Ok(())
188    }
189
190    /// Load a state snapshot
191    pub async fn load_snapshot(&self, plugin_id: &str) -> Option<StateSnapshot> {
192        let snapshots = self.snapshots.read().await;
193        snapshots.get(plugin_id).cloned()
194    }
195
196    /// Get the latest snapshot for a plugin
197    pub async fn get_latest(&self, plugin_id: &str) -> Option<StateSnapshot> {
198        // First try current snapshots
199        if let Some(snapshot) = self.load_snapshot(plugin_id).await {
200            return Some(snapshot);
201        }
202
203        // Try to load from persistence
204        if self.persist_enabled
205            && let Ok(snapshot) = self.load_persisted_snapshot(plugin_id).await
206        {
207            return Some(snapshot);
208        }
209
210        None
211    }
212
213    /// Rollback to a previous snapshot
214    pub async fn rollback(&self, plugin_id: &str) -> Option<StateSnapshot> {
215        let mut history = self.history.write().await;
216
217        if let Some(entry) = history.get_mut(plugin_id)
218            && let Some(snapshot) = entry.pop()
219        {
220            info!(
221                "Rolling back plugin {} to snapshot from {}",
222                plugin_id, snapshot.timestamp
223            );
224
225            // Update current snapshot
226            let mut snapshots = self.snapshots.write().await;
227            snapshots.insert(plugin_id.to_string(), snapshot.clone());
228
229            return Some(snapshot);
230        }
231
232        None
233    }
234
235    /// Clear all snapshots for a plugin
236    pub async fn clear(&self, plugin_id: &str) {
237        debug!("Clearing snapshots for plugin: {}", plugin_id);
238
239        let mut snapshots = self.snapshots.write().await;
240        snapshots.remove(plugin_id);
241
242        let mut history = self.history.write().await;
243        history.remove(plugin_id);
244    }
245
246    /// Clear all snapshots
247    pub async fn clear_all(&self) {
248        debug!("Clearing all snapshots");
249
250        let mut snapshots = self.snapshots.write().await;
251        snapshots.clear();
252
253        let mut history = self.history.write().await;
254        history.clear();
255    }
256
257    /// Get history for a plugin
258    pub async fn get_history(&self, plugin_id: &str) -> Vec<StateSnapshot> {
259        let history = self.history.read().await;
260        history.get(plugin_id).cloned().unwrap_or_default()
261    }
262
263    /// Get all managed plugin IDs
264    pub async fn plugin_ids(&self) -> Vec<String> {
265        let snapshots = self.snapshots.read().await;
266        snapshots.keys().cloned().collect()
267    }
268
269    /// Persist snapshot to disk
270    async fn persist_snapshot(&self, snapshot: &StateSnapshot) -> Result<(), String> {
271        let dir = self
272            .persist_dir
273            .as_ref()
274            .ok_or_else(|| "Persistence directory not set".to_string())?;
275
276        // Ensure directory exists
277        std::fs::create_dir_all(dir)
278            .map_err(|e| format!("Failed to create persistence directory: {}", e))?;
279
280        let file_path = dir.join(format!("{}.json", snapshot.plugin_id));
281
282        let json = serde_json::to_string_pretty(snapshot)
283            .map_err(|e| format!("Failed to serialize snapshot: {}", e))?;
284
285        std::fs::write(&file_path, json)
286            .map_err(|e| format!("Failed to write snapshot file: {}", e))?;
287
288        debug!("Persisted snapshot to {:?}", file_path);
289        Ok(())
290    }
291
292    /// Load persisted snapshot from disk
293    async fn load_persisted_snapshot(&self, plugin_id: &str) -> Result<StateSnapshot, String> {
294        let dir = self
295            .persist_dir
296            .as_ref()
297            .ok_or_else(|| "Persistence directory not set".to_string())?;
298
299        let file_path = dir.join(format!("{}.json", plugin_id));
300
301        let json = std::fs::read_to_string(&file_path)
302            .map_err(|e| format!("Failed to read snapshot file: {}", e))?;
303
304        serde_json::from_str(&json).map_err(|e| format!("Failed to deserialize snapshot: {}", e))
305    }
306}
307
308impl Default for StateManager {
309    fn default() -> Self {
310        Self::new()
311    }
312}
313
314/// Trait for plugins that support state preservation
315pub trait StatefulPlugin {
316    /// Create a state snapshot
317    fn create_snapshot(&self) -> StateSnapshot;
318
319    /// Restore from a state snapshot
320    fn restore_snapshot(&mut self, snapshot: &StateSnapshot) -> Result<(), String>;
321
322    /// Check if a snapshot is compatible
323    fn is_snapshot_compatible(&self, snapshot: &StateSnapshot) -> bool {
324        snapshot.is_compatible(&self.plugin_version())
325    }
326
327    /// Get plugin version
328    fn plugin_version(&self) -> String;
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn test_plugin_state_display() {
337        assert_eq!(PluginState::Unloaded.to_string(), "Unloaded");
338        assert_eq!(PluginState::Running.to_string(), "Running");
339        assert_eq!(
340            PluginState::Failed("test".to_string()).to_string(),
341            "Failed: test"
342        );
343    }
344
345    #[test]
346    fn test_state_snapshot() {
347        let snapshot = StateSnapshot::new("test-plugin", "1.0.0")
348            .with_data("counter", &42)
349            .with_data("name", &"test")
350            .with_metadata("author", "developer");
351
352        assert_eq!(snapshot.plugin_id, "test-plugin");
353        assert_eq!(snapshot.plugin_version, "1.0.0");
354        assert_eq!(snapshot.get::<i32>("counter"), Some(42));
355        assert_eq!(snapshot.get::<String>("name"), Some("test".to_string()));
356        assert_eq!(snapshot.get_metadata("author"), Some("developer"));
357    }
358
359    #[test]
360    fn test_snapshot_compatibility() {
361        let snapshot = StateSnapshot::new("test", "1.0.0");
362
363        assert!(snapshot.is_compatible("1.0.0"));
364        assert!(snapshot.is_compatible("1.1.0"));
365        assert!(snapshot.is_compatible("1.2.3"));
366        assert!(!snapshot.is_compatible("2.0.0"));
367    }
368
369    #[test]
370    fn test_snapshot_serialization() {
371        let snapshot = StateSnapshot::new("test", "1.0.0").with_data("value", &123);
372
373        let bytes = snapshot.to_bytes().unwrap();
374        let restored = StateSnapshot::from_bytes(&bytes).unwrap();
375
376        assert_eq!(restored.plugin_id, snapshot.plugin_id);
377        assert_eq!(restored.get::<i32>("value"), Some(123));
378    }
379
380    #[tokio::test]
381    async fn test_state_manager() {
382        let manager = StateManager::new();
383
384        let snapshot = StateSnapshot::new("plugin-1", "1.0.0").with_data("state", &"active");
385
386        manager.save_snapshot(snapshot.clone()).await.unwrap();
387
388        let loaded = manager.load_snapshot("plugin-1").await.unwrap();
389        assert_eq!(loaded.plugin_id, "plugin-1");
390        assert_eq!(loaded.get::<String>("state"), Some("active".to_string()));
391    }
392
393    #[tokio::test]
394    async fn test_state_manager_rollback() {
395        let manager = StateManager::new();
396
397        // Save first snapshot
398        let snapshot1 = StateSnapshot::new("plugin-1", "1.0.0").with_data("version", &1);
399        manager.save_snapshot(snapshot1).await.unwrap();
400
401        // Save second snapshot (moves first to history)
402        let snapshot2 = StateSnapshot::new("plugin-1", "1.0.0").with_data("version", &2);
403        manager.save_snapshot(snapshot2).await.unwrap();
404
405        // Current should be version 2
406        let current = manager.load_snapshot("plugin-1").await.unwrap();
407        assert_eq!(current.get::<i32>("version"), Some(2));
408
409        // Rollback to version 1
410        let rolled_back = manager.rollback("plugin-1").await.unwrap();
411        assert_eq!(rolled_back.get::<i32>("version"), Some(1));
412
413        // Current should now be version 1
414        let current = manager.load_snapshot("plugin-1").await.unwrap();
415        assert_eq!(current.get::<i32>("version"), Some(1));
416    }
417}