1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::{Mutex, RwLock};
10
11use crate::error::Error;
12
13pub struct StateManager {
15 current_context: Arc<RwLock<ExecutionContext>>,
17 snapshots: Arc<Mutex<Vec<ContextSnapshot>>>,
19 persistent_storage: Arc<Mutex<HashMap<String, serde_json::Value>>>,
21 memory_tracker: Arc<Mutex<MemoryTracker>>,
23 config: StateManagerConfig,
25}
26
27impl Default for StateManager {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl StateManager {
34 pub fn new() -> Self {
35 Self {
36 current_context: Arc::new(RwLock::new(ExecutionContext::new())),
37 snapshots: Arc::new(Mutex::new(Vec::new())),
38 persistent_storage: Arc::new(Mutex::new(HashMap::new())),
39 memory_tracker: Arc::new(Mutex::new(MemoryTracker::new())),
40 config: StateManagerConfig::default(),
41 }
42 }
43
44 pub async fn initialize_context(&self, initial_state: &serde_json::Value) -> Result<(), Error> {
46 {
47 let mut context = self.current_context.write().await;
48 context.initialize(initial_state).await?;
49 }
50
51 self.create_snapshot().await?;
52
53 tracing::info!(
54 "Execution context initialized with {} bytes of initial state",
55 serde_json::to_string(initial_state)?.len()
56 );
57
58 Ok(())
59 }
60
61 pub async fn update_context(
63 &self,
64 updates: &HashMap<String, serde_json::Value>,
65 ) -> Result<(), Error> {
66 let mut context = self.current_context.write().await;
67
68 for (key, value) in updates {
69 context.update_variable(key, value).await?;
70 }
71
72 tracing::debug!("Context updated with {} variables", updates.len());
73 Ok(())
74 }
75
76 pub async fn get_current_context(&self) -> Result<serde_json::Value, Error> {
78 let context = self.current_context.read().await;
79 context.serialize().await
80 }
81
82 pub async fn store_persistent(&self, key: &str, value: serde_json::Value) -> Result<(), Error> {
84 let mut storage = self.persistent_storage.lock().await;
85 storage.insert(key.to_string(), value.clone());
86
87 tracing::debug!(
88 "Stored persistent data: {} ({} bytes)",
89 key,
90 serde_json::to_string(&value)?.len()
91 );
92
93 Ok(())
94 }
95
96 pub async fn get_persistent(&self, key: &str) -> Result<Option<serde_json::Value>, Error> {
98 let storage = self.persistent_storage.lock().await;
99 Ok(storage.get(key).cloned())
100 }
101
102 pub async fn create_snapshot(&self) -> Result<ContextSnapshot, Error> {
104 let (execution_context, tool_call_count) = {
105 let context = self.current_context.read().await;
106 (context.serialize().await?, context.get_tool_call_count())
107 };
108
109 let persistent_data = {
110 let storage = self.persistent_storage.lock().await;
111 storage.clone()
112 };
113
114 let context_size_bytes = serde_json::to_string(&execution_context)?.len();
115 let persistent_storage_size_bytes = serde_json::to_string(&persistent_data)?.len();
116 let total_size_bytes = context_size_bytes + persistent_storage_size_bytes;
117
118 let (memory_efficiency, peak_usage_mb) = {
119 let tracker = self.memory_tracker.lock().await;
120 (tracker.calculate_efficiency(), tracker.peak_usage_mb)
121 };
122
123 let current_usage_mb = total_size_bytes as f64 / 1_048_576.0;
124
125 let snapshot = ContextSnapshot {
126 id: format!("snapshot_{}", chrono::Utc::now().timestamp()),
127 timestamp: chrono::Utc::now().timestamp(),
128 execution_context,
129 persistent_data,
130 memory_usage: MemoryUsage {
131 context_size_bytes,
132 persistent_storage_size_bytes,
133 total_size_bytes,
134 memory_efficiency,
135 peak_usage_mb,
136 current_usage_mb,
137 },
138 tool_call_count,
139 checkpoint_metadata: HashMap::new(),
140 compressed: false,
141 compression_ratio: 1.0,
142 };
143
144 let mut snapshots = self.snapshots.lock().await;
146 snapshots.push(snapshot.clone());
147
148 if snapshots.len() > self.config.max_snapshots {
150 let removed = snapshots.remove(0);
151 tracing::debug!("Removed old snapshot: {}", removed.id);
152 }
153
154 {
156 let mut tracker = self.memory_tracker.lock().await;
157 tracker.record_snapshot(&snapshot);
158 }
159
160 tracing::info!(
161 "Created snapshot {} with {} bytes of data",
162 snapshot.id,
163 serde_json::to_string(&snapshot.execution_context)?.len()
164 );
165
166 Ok(snapshot)
167 }
168
169 pub async fn restore_snapshot(&self, snapshot_id: &str) -> Result<(), Error> {
171 let snapshots = self.snapshots.lock().await;
172
173 let snapshot = snapshots
174 .iter()
175 .find(|s| s.id == snapshot_id)
176 .ok_or_else(|| Error::Validation(format!("Snapshot '{}' not found", snapshot_id)))?;
177
178 {
180 let mut context = self.current_context.write().await;
181 context.deserialize(&snapshot.execution_context).await?;
182 }
183
184 {
186 let mut storage = self.persistent_storage.lock().await;
187 *storage = snapshot.persistent_data.clone();
188 }
189
190 tracing::info!("Restored from snapshot: {}", snapshot_id);
191 Ok(())
192 }
193
194 pub async fn get_latest_snapshot(&self) -> Result<Option<ContextSnapshot>, Error> {
196 let snapshots = self.snapshots.lock().await;
197 Ok(snapshots.last().cloned())
198 }
199
200 pub async fn optimize_memory(&self) -> Result<MemoryOptimizationResult, Error> {
202 let mut storage = self.persistent_storage.lock().await;
203 let mut snapshots = self.snapshots.lock().await;
204
205 let mut compression_count = 0;
206 let mut original_size = 0;
207 let mut compressed_size = 0;
208
209 for snapshot in snapshots.iter_mut() {
211 if snapshot.timestamp < chrono::Utc::now().timestamp() - 3600 {
212 let serialized = serde_json::to_string(&snapshot.execution_context)?;
213 original_size += serialized.len();
214
215 let compressed = base64::Engine::encode(
217 &base64::engine::general_purpose::STANDARD,
218 serialized.as_bytes(),
219 );
220 compressed_size += compressed.len();
221
222 snapshot.compressed = true;
223 snapshot.compression_ratio = if original_size > 0 {
224 compressed_size as f64 / original_size as f64
225 } else {
226 1.0
227 };
228
229 compression_count += 1;
230 }
231 }
232
233 let before_count = storage.len();
235 storage.retain(|_key, value| {
236 let expire_time = chrono::Utc::now().timestamp() - self.config.data_ttl_seconds;
237 value
238 .get("timestamp")
239 .and_then(|ts| ts.as_i64())
240 .map(|ts| ts > expire_time)
241 .unwrap_or(true)
242 });
243 let after_count = storage.len();
244 let cleaned_count = before_count - after_count;
245
246 let result = MemoryOptimizationResult {
247 compressed_snapshots: compression_count,
248 cleaned_data_items: cleaned_count as u32,
249 memory_saved_mb: ((original_size - compressed_size) as f64 / 1_048_576.0).max(0.0),
250 optimization_timestamp: chrono::Utc::now().timestamp(),
251 };
252
253 tracing::info!(
254 "Memory optimization completed: {} snapshots compressed, {} data items cleaned",
255 compression_count,
256 cleaned_count
257 );
258
259 Ok(result)
260 }
261
262 pub async fn get_current_memory_usage(&self) -> Result<MemoryUsage, Error> {
264 let context = self.current_context.read().await;
265 let storage = self.persistent_storage.lock().await;
266 let tracker = self.memory_tracker.lock().await;
267
268 let context_size = serde_json::to_string(&context.serialize().await?)?.len();
269 let storage_size = serde_json::to_string(&*storage)?.len();
270 let total_size = context_size + storage_size;
271
272 Ok(MemoryUsage {
273 context_size_bytes: context_size,
274 persistent_storage_size_bytes: storage_size,
275 total_size_bytes: total_size,
276 memory_efficiency: tracker.calculate_efficiency(),
277 peak_usage_mb: tracker.peak_usage_mb,
278 current_usage_mb: total_size as f64 / 1_048_576.0,
279 })
280 }
281
282 pub async fn cleanup_expired_data(&self) -> Result<u32, Error> {
284 let mut storage = self.persistent_storage.lock().await;
285 let expire_time = chrono::Utc::now().timestamp() - self.config.data_ttl_seconds;
286
287 let before_count = storage.len();
288 storage.retain(|_key, value| {
289 value
290 .get("timestamp")
291 .and_then(|ts| ts.as_i64())
292 .map(|ts| ts > expire_time)
293 .unwrap_or(true)
294 });
295
296 let cleaned_count = before_count - storage.len();
297
298 tracing::debug!("Cleaned up {} expired data items", cleaned_count);
299 Ok(cleaned_count as u32)
300 }
301}
302
303#[derive(Debug)]
305struct ExecutionContext {
306 tool_call_count: u32,
308 shared_variables: HashMap<String, serde_json::Value>,
310 metadata: HashMap<String, serde_json::Value>,
312 component_states: HashMap<String, ComponentState>,
314 #[allow(dead_code)]
316 context_cache: ContextCache,
317 created_at: u64,
319}
320
321impl ExecutionContext {
322 fn new() -> Self {
323 Self {
324 tool_call_count: 0,
325 shared_variables: HashMap::new(),
326 metadata: HashMap::new(),
327 component_states: HashMap::new(),
328 context_cache: ContextCache::new(),
329 created_at: chrono::Utc::now().timestamp() as u64,
330 }
331 }
332
333 async fn initialize(&mut self, initial_state: &serde_json::Value) -> Result<(), Error> {
335 if let Some(variables) = initial_state.get("variables") {
336 if let Ok(vars) =
337 serde_json::from_value::<HashMap<String, serde_json::Value>>(variables.clone())
338 {
339 self.shared_variables = vars;
340 }
341 }
342
343 if let Some(metadata) = initial_state.get("metadata") {
344 if let Ok(meta) =
345 serde_json::from_value::<HashMap<String, serde_json::Value>>(metadata.clone())
346 {
347 self.metadata = meta;
348 }
349 }
350
351 Ok(())
352 }
353
354 async fn update_variable(&mut self, key: &str, value: &serde_json::Value) -> Result<(), Error> {
356 self.shared_variables.insert(key.to_string(), value.clone());
357 Ok(())
358 }
359
360 fn get_tool_call_count(&self) -> u32 {
362 self.tool_call_count
363 }
364
365 #[allow(dead_code)]
367 fn increment_tool_call_count(&mut self) {
368 self.tool_call_count += 1;
369 }
370
371 async fn serialize(&self) -> Result<serde_json::Value, Error> {
373 Ok(serde_json::json!({
374 "tool_call_count": self.tool_call_count,
375 "shared_variables": self.shared_variables,
376 "metadata": self.metadata,
377 "component_states": self.component_states,
378 "created_at": self.created_at,
379 }))
380 }
381
382 async fn deserialize(&mut self, data: &serde_json::Value) -> Result<(), Error> {
384 if let Some(tool_call_count) = data.get("tool_call_count").and_then(|v| v.as_u64()) {
385 self.tool_call_count = tool_call_count as u32;
386 }
387
388 if let Some(variables) = data.get("shared_variables") {
389 if let Ok(vars) =
390 serde_json::from_value::<HashMap<String, serde_json::Value>>(variables.clone())
391 {
392 self.shared_variables = vars;
393 }
394 }
395
396 if let Some(metadata) = data.get("metadata") {
397 if let Ok(meta) =
398 serde_json::from_value::<HashMap<String, serde_json::Value>>(metadata.clone())
399 {
400 self.metadata = meta;
401 }
402 }
403
404 if let Some(component_states) = data.get("component_states") {
405 if let Ok(states) =
406 serde_json::from_value::<HashMap<String, ComponentState>>(component_states.clone())
407 {
408 self.component_states = states;
409 }
410 }
411
412 Ok(())
413 }
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize)]
418struct ComponentState {
419 pub component_name: String,
420 pub state_data: serde_json::Value,
421 pub last_updated: u64,
422 pub access_count: u32,
423}
424
425#[derive(Debug)]
427#[allow(dead_code)]
428struct ContextCache {
429 #[allow(dead_code)]
431 lru_cache: HashMap<String, serde_json::Value>,
432 #[allow(dead_code)]
434 capacity: usize,
435 #[allow(dead_code)]
437 current_size: usize,
438}
439
440#[allow(dead_code)]
441impl ContextCache {
442 fn new() -> Self {
443 Self {
444 lru_cache: HashMap::new(),
445 capacity: 100, current_size: 0,
447 }
448 }
449
450 fn add(&mut self, key: &str, value: serde_json::Value) {
452 if self.lru_cache.len() >= self.capacity && !self.lru_cache.contains_key(key) {
453 if let Some(key_to_remove) = self.lru_cache.keys().next().cloned() {
455 if let Some(removed_value) = self.lru_cache.remove(&key_to_remove) {
456 self.current_size -= serde_json::to_string(&removed_value)
457 .unwrap_or_default()
458 .len();
459 }
460 }
461 }
462
463 self.lru_cache.insert(key.to_string(), value);
464 self.current_size += key.len(); }
466
467 fn get(&self, key: &str) -> Option<&serde_json::Value> {
469 self.lru_cache.get(key)
470 }
471}
472
473#[derive(Debug)]
475struct MemoryTracker {
476 peak_usage_mb: f64,
477 usage_history: Vec<MemorySample>,
478 #[allow(dead_code)]
479 optimization_threshold_mb: f64,
480}
481
482impl MemoryTracker {
483 fn new() -> Self {
484 Self {
485 peak_usage_mb: 0.0,
486 usage_history: Vec::new(),
487 optimization_threshold_mb: 100.0, }
489 }
490
491 #[allow(dead_code)]
493 fn record_sample(&mut self, usage: &MemoryUsage) {
494 let sample = MemorySample {
495 timestamp: chrono::Utc::now().timestamp(),
496 usage_mb: usage.current_usage_mb,
497 };
498
499 self.usage_history.push(sample);
500 self.peak_usage_mb = self.peak_usage_mb.max(usage.current_usage_mb);
501
502 if self.usage_history.len() > 1000 {
504 self.usage_history.remove(0);
505 }
506 }
507
508 fn record_snapshot(&mut self, snapshot: &ContextSnapshot) {
510 let snapshot_size_mb = snapshot.memory_usage.current_usage_mb;
512 self.peak_usage_mb = self.peak_usage_mb.max(snapshot_size_mb);
513 }
514
515 fn calculate_efficiency(&self) -> f64 {
517 if self.peak_usage_mb == 0.0 {
518 return 1.0;
519 }
520
521 let current_usage = self.usage_history.last().map(|s| s.usage_mb).unwrap_or(0.0);
523
524 (self.peak_usage_mb / current_usage.max(self.peak_usage_mb)).min(1.0)
525 }
526}
527
528#[derive(Debug, Clone)]
530pub struct StateManagerConfig {
531 pub max_snapshots: usize,
532 pub data_ttl_seconds: i64,
533 pub memory_limit_mb: u64,
534 pub auto_optimize: bool,
535 pub checkpoint_interval: u32,
536}
537
538impl Default for StateManagerConfig {
539 fn default() -> Self {
540 Self {
541 max_snapshots: 50,
542 data_ttl_seconds: 3600, memory_limit_mb: 1024, auto_optimize: true,
545 checkpoint_interval: 10, }
547 }
548}
549
550#[derive(Debug, Clone, Serialize, Deserialize)]
552pub struct ContextSnapshot {
553 pub id: String,
554 pub timestamp: i64,
555 pub execution_context: serde_json::Value,
556 pub persistent_data: HashMap<String, serde_json::Value>,
557 pub memory_usage: MemoryUsage,
558 pub tool_call_count: u32,
559 pub checkpoint_metadata: HashMap<String, serde_json::Value>,
560 pub compressed: bool,
562 pub compression_ratio: f64,
563}
564
565#[derive(Debug, Clone, Serialize, Deserialize)]
566pub struct MemoryUsage {
567 pub context_size_bytes: usize,
568 pub persistent_storage_size_bytes: usize,
569 pub total_size_bytes: usize,
570 pub memory_efficiency: f64,
571 pub peak_usage_mb: f64,
572 pub current_usage_mb: f64,
573}
574
575#[derive(Debug)]
576struct MemorySample {
577 #[allow(dead_code)]
578 timestamp: i64,
579 usage_mb: f64,
580}
581
582#[derive(Debug)]
583pub struct MemoryOptimizationResult {
584 pub compressed_snapshots: u32,
585 pub cleaned_data_items: u32,
586 pub memory_saved_mb: f64,
587 pub optimization_timestamp: i64,
588}
589
590#[async_trait::async_trait]
592pub trait StatePersistence {
593 async fn save_state(&self, state: &serde_json::Value) -> Result<String, Error>;
594 async fn load_state(&self, state_id: &str) -> Result<serde_json::Value, Error>;
595 async fn delete_state(&self, state_id: &str) -> Result<(), Error>;
596 async fn list_states(&self) -> Result<Vec<String>, Error>;
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602
603 #[tokio::test]
604 async fn test_state_manager_creation() {
605 let manager = StateManager::new();
606 assert!(manager.get_current_context().await.is_ok());
607 }
608
609 #[tokio::test]
610 async fn test_context_initialization() {
611 let manager = StateManager::new();
612 let initial_state = serde_json::json!({
613 "variables": {
614 "user_id": "12345",
615 "session_type": "analysis"
616 },
617 "metadata": {
618 "created_by": "test",
619 "version": "1.0"
620 }
621 });
622
623 assert!(manager.initialize_context(&initial_state).await.is_ok());
624 }
625
626 #[tokio::test]
627 async fn test_persistent_storage() {
628 let manager = StateManager::new();
629 let test_data = serde_json::json!({"key": "value", "timestamp": 1234567890});
630
631 assert!(manager
632 .store_persistent("test_key", test_data.clone())
633 .await
634 .is_ok());
635
636 let retrieved = manager.get_persistent("test_key").await.unwrap();
637 assert_eq!(retrieved, Some(test_data));
638 }
639
640 #[tokio::test]
641 async fn test_snapshot_creation() {
642 let manager = StateManager::new();
643 let initial_state = serde_json::json!({"test": "data"});
644
645 manager.initialize_context(&initial_state).await.unwrap();
646 let snapshot = manager.create_snapshot().await.unwrap();
647
648 assert!(!snapshot.id.is_empty());
649 assert_eq!(snapshot.tool_call_count, 0);
650 }
651
652 #[tokio::test]
653 async fn test_memory_usage_tracking() {
654 let manager = StateManager::new();
655 let usage = manager.get_current_memory_usage().await.unwrap();
656
657 assert!(usage.current_usage_mb >= 0.0);
658 assert!(usage.memory_efficiency >= 0.0 && usage.memory_efficiency <= 1.0);
659 }
660}