1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::{debug, info, warn};
10
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
13pub enum PluginState {
14 #[default]
16 Unloaded,
17 Loading,
19 Loaded,
21 Running,
23 Reloading,
25 Failed(String),
27 Unloading,
29}
30
31impl std::fmt::Display for PluginState {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 PluginState::Unloaded => write!(f, "Unloaded"),
35 PluginState::Loading => write!(f, "Loading"),
36 PluginState::Loaded => write!(f, "Loaded"),
37 PluginState::Running => write!(f, "Running"),
38 PluginState::Reloading => write!(f, "Reloading"),
39 PluginState::Failed(err) => write!(f, "Failed: {}", err),
40 PluginState::Unloading => write!(f, "Unloading"),
41 }
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct StateSnapshot {
48 pub plugin_id: String,
50 pub timestamp: u64,
52 pub data: HashMap<String, serde_json::Value>,
54 pub plugin_version: String,
56 pub metadata: HashMap<String, String>,
58}
59
60impl StateSnapshot {
61 pub fn new(plugin_id: &str, plugin_version: &str) -> Self {
63 Self {
64 plugin_id: plugin_id.to_string(),
65 timestamp: std::time::SystemTime::now()
66 .duration_since(std::time::UNIX_EPOCH)
67 .unwrap_or_default()
68 .as_secs(),
69 data: HashMap::new(),
70 plugin_version: plugin_version.to_string(),
71 metadata: HashMap::new(),
72 }
73 }
74
75 pub fn with_data<T: Serialize>(mut self, key: &str, value: &T) -> Self {
77 if let Ok(json_value) = serde_json::to_value(value) {
78 self.data.insert(key.to_string(), json_value);
79 }
80 self
81 }
82
83 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
85 self.metadata.insert(key.to_string(), value.to_string());
86 self
87 }
88
89 pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
91 self.data
92 .get(key)
93 .and_then(|v| serde_json::from_value(v.clone()).ok())
94 }
95
96 pub fn get_metadata(&self, key: &str) -> Option<&str> {
98 self.metadata.get(key).map(|s| s.as_str())
99 }
100
101 pub fn is_compatible(&self, plugin_version: &str) -> bool {
103 let snapshot_major = self.plugin_version.split('.').next();
105 let plugin_major = plugin_version.split('.').next();
106 snapshot_major == plugin_major
107 }
108
109 pub fn to_bytes(&self) -> Result<Vec<u8>, serde_json::Error> {
111 serde_json::to_vec(self)
112 }
113
114 pub fn from_bytes(bytes: &[u8]) -> Result<Self, serde_json::Error> {
116 serde_json::from_slice(bytes)
117 }
118}
119
120pub struct StateManager {
122 snapshots: Arc<RwLock<HashMap<String, StateSnapshot>>>,
124 history: Arc<RwLock<HashMap<String, Vec<StateSnapshot>>>>,
126 max_history: usize,
128 persist_enabled: bool,
130 persist_dir: Option<std::path::PathBuf>,
132}
133
134impl StateManager {
135 pub fn new() -> Self {
137 Self {
138 snapshots: Arc::new(RwLock::new(HashMap::new())),
139 history: Arc::new(RwLock::new(HashMap::new())),
140 max_history: 10,
141 persist_enabled: false,
142 persist_dir: None,
143 }
144 }
145
146 pub fn with_max_history(mut self, max: usize) -> Self {
148 self.max_history = max;
149 self
150 }
151
152 pub fn with_persistence<P: AsRef<std::path::Path>>(mut self, dir: P) -> Self {
154 self.persist_enabled = true;
155 self.persist_dir = Some(dir.as_ref().to_path_buf());
156 self
157 }
158
159 pub async fn save_snapshot(&self, snapshot: StateSnapshot) -> Result<(), String> {
161 let plugin_id = snapshot.plugin_id.clone();
162
163 info!("Saving state snapshot for plugin: {}", plugin_id);
164
165 let mut snapshots = self.snapshots.write().await;
167 if let Some(current) = snapshots.remove(&plugin_id) {
168 let mut history = self.history.write().await;
169 let entry = history.entry(plugin_id.clone()).or_insert_with(Vec::new);
170 entry.push(current);
171
172 if entry.len() > self.max_history {
174 let to_remove = entry.len() - self.max_history;
175 entry.drain(0..to_remove);
176 }
177 }
178
179 if self.persist_enabled
181 && let Err(e) = self.persist_snapshot(&snapshot).await
182 {
183 warn!("Failed to persist snapshot: {}", e);
184 }
185
186 snapshots.insert(plugin_id, snapshot);
187 Ok(())
188 }
189
190 pub async fn load_snapshot(&self, plugin_id: &str) -> Option<StateSnapshot> {
192 let snapshots = self.snapshots.read().await;
193 snapshots.get(plugin_id).cloned()
194 }
195
196 pub async fn get_latest(&self, plugin_id: &str) -> Option<StateSnapshot> {
198 if let Some(snapshot) = self.load_snapshot(plugin_id).await {
200 return Some(snapshot);
201 }
202
203 if self.persist_enabled
205 && let Ok(snapshot) = self.load_persisted_snapshot(plugin_id).await
206 {
207 return Some(snapshot);
208 }
209
210 None
211 }
212
213 pub async fn rollback(&self, plugin_id: &str) -> Option<StateSnapshot> {
215 let mut history = self.history.write().await;
216
217 if let Some(entry) = history.get_mut(plugin_id)
218 && let Some(snapshot) = entry.pop()
219 {
220 info!(
221 "Rolling back plugin {} to snapshot from {}",
222 plugin_id, snapshot.timestamp
223 );
224
225 let mut snapshots = self.snapshots.write().await;
227 snapshots.insert(plugin_id.to_string(), snapshot.clone());
228
229 return Some(snapshot);
230 }
231
232 None
233 }
234
235 pub async fn clear(&self, plugin_id: &str) {
237 debug!("Clearing snapshots for plugin: {}", plugin_id);
238
239 let mut snapshots = self.snapshots.write().await;
240 snapshots.remove(plugin_id);
241
242 let mut history = self.history.write().await;
243 history.remove(plugin_id);
244 }
245
246 pub async fn clear_all(&self) {
248 debug!("Clearing all snapshots");
249
250 let mut snapshots = self.snapshots.write().await;
251 snapshots.clear();
252
253 let mut history = self.history.write().await;
254 history.clear();
255 }
256
257 pub async fn get_history(&self, plugin_id: &str) -> Vec<StateSnapshot> {
259 let history = self.history.read().await;
260 history.get(plugin_id).cloned().unwrap_or_default()
261 }
262
263 pub async fn plugin_ids(&self) -> Vec<String> {
265 let snapshots = self.snapshots.read().await;
266 snapshots.keys().cloned().collect()
267 }
268
269 async fn persist_snapshot(&self, snapshot: &StateSnapshot) -> Result<(), String> {
271 let dir = self
272 .persist_dir
273 .as_ref()
274 .ok_or_else(|| "Persistence directory not set".to_string())?;
275
276 std::fs::create_dir_all(dir)
278 .map_err(|e| format!("Failed to create persistence directory: {}", e))?;
279
280 let file_path = dir.join(format!("{}.json", snapshot.plugin_id));
281
282 let json = serde_json::to_string_pretty(snapshot)
283 .map_err(|e| format!("Failed to serialize snapshot: {}", e))?;
284
285 std::fs::write(&file_path, json)
286 .map_err(|e| format!("Failed to write snapshot file: {}", e))?;
287
288 debug!("Persisted snapshot to {:?}", file_path);
289 Ok(())
290 }
291
292 async fn load_persisted_snapshot(&self, plugin_id: &str) -> Result<StateSnapshot, String> {
294 let dir = self
295 .persist_dir
296 .as_ref()
297 .ok_or_else(|| "Persistence directory not set".to_string())?;
298
299 let file_path = dir.join(format!("{}.json", plugin_id));
300
301 let json = std::fs::read_to_string(&file_path)
302 .map_err(|e| format!("Failed to read snapshot file: {}", e))?;
303
304 serde_json::from_str(&json).map_err(|e| format!("Failed to deserialize snapshot: {}", e))
305 }
306}
307
308impl Default for StateManager {
309 fn default() -> Self {
310 Self::new()
311 }
312}
313
314pub trait StatefulPlugin {
316 fn create_snapshot(&self) -> StateSnapshot;
318
319 fn restore_snapshot(&mut self, snapshot: &StateSnapshot) -> Result<(), String>;
321
322 fn is_snapshot_compatible(&self, snapshot: &StateSnapshot) -> bool {
324 snapshot.is_compatible(&self.plugin_version())
325 }
326
327 fn plugin_version(&self) -> String;
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn test_plugin_state_display() {
337 assert_eq!(PluginState::Unloaded.to_string(), "Unloaded");
338 assert_eq!(PluginState::Running.to_string(), "Running");
339 assert_eq!(
340 PluginState::Failed("test".to_string()).to_string(),
341 "Failed: test"
342 );
343 }
344
345 #[test]
346 fn test_state_snapshot() {
347 let snapshot = StateSnapshot::new("test-plugin", "1.0.0")
348 .with_data("counter", &42)
349 .with_data("name", &"test")
350 .with_metadata("author", "developer");
351
352 assert_eq!(snapshot.plugin_id, "test-plugin");
353 assert_eq!(snapshot.plugin_version, "1.0.0");
354 assert_eq!(snapshot.get::<i32>("counter"), Some(42));
355 assert_eq!(snapshot.get::<String>("name"), Some("test".to_string()));
356 assert_eq!(snapshot.get_metadata("author"), Some("developer"));
357 }
358
359 #[test]
360 fn test_snapshot_compatibility() {
361 let snapshot = StateSnapshot::new("test", "1.0.0");
362
363 assert!(snapshot.is_compatible("1.0.0"));
364 assert!(snapshot.is_compatible("1.1.0"));
365 assert!(snapshot.is_compatible("1.2.3"));
366 assert!(!snapshot.is_compatible("2.0.0"));
367 }
368
369 #[test]
370 fn test_snapshot_serialization() {
371 let snapshot = StateSnapshot::new("test", "1.0.0").with_data("value", &123);
372
373 let bytes = snapshot.to_bytes().unwrap();
374 let restored = StateSnapshot::from_bytes(&bytes).unwrap();
375
376 assert_eq!(restored.plugin_id, snapshot.plugin_id);
377 assert_eq!(restored.get::<i32>("value"), Some(123));
378 }
379
380 #[tokio::test]
381 async fn test_state_manager() {
382 let manager = StateManager::new();
383
384 let snapshot = StateSnapshot::new("plugin-1", "1.0.0").with_data("state", &"active");
385
386 manager.save_snapshot(snapshot.clone()).await.unwrap();
387
388 let loaded = manager.load_snapshot("plugin-1").await.unwrap();
389 assert_eq!(loaded.plugin_id, "plugin-1");
390 assert_eq!(loaded.get::<String>("state"), Some("active".to_string()));
391 }
392
393 #[tokio::test]
394 async fn test_state_manager_rollback() {
395 let manager = StateManager::new();
396
397 let snapshot1 = StateSnapshot::new("plugin-1", "1.0.0").with_data("version", &1);
399 manager.save_snapshot(snapshot1).await.unwrap();
400
401 let snapshot2 = StateSnapshot::new("plugin-1", "1.0.0").with_data("version", &2);
403 manager.save_snapshot(snapshot2).await.unwrap();
404
405 let current = manager.load_snapshot("plugin-1").await.unwrap();
407 assert_eq!(current.get::<i32>("version"), Some(2));
408
409 let rolled_back = manager.rollback("plugin-1").await.unwrap();
411 assert_eq!(rolled_back.get::<i32>("version"), Some(1));
412
413 let current = manager.load_snapshot("plugin-1").await.unwrap();
415 assert_eq!(current.get::<i32>("version"), Some(1));
416 }
417}