reasonkit/orchestration/
state_manager.rs

1//! # State Management for Long-Horizon Execution
2//!
3//! This module provides advanced state management capabilities for maintaining context,
4//! memory, and execution state across extended tool calling sequences (100+ calls).
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::{Mutex, RwLock};
10
11use crate::error::Error;
12
13/// Global execution state manager
14pub struct StateManager {
15    /// Current execution context
16    current_context: Arc<RwLock<ExecutionContext>>,
17    /// Historical state snapshots for recovery
18    snapshots: Arc<Mutex<Vec<ContextSnapshot>>>,
19    /// Persistent storage for long-term state
20    persistent_storage: Arc<Mutex<HashMap<String, serde_json::Value>>>,
21    /// Memory tracking for context optimization
22    memory_tracker: Arc<Mutex<MemoryTracker>>,
23    /// Configuration
24    config: StateManagerConfig,
25}
26
27impl Default for StateManager {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl StateManager {
34    pub fn new() -> Self {
35        Self {
36            current_context: Arc::new(RwLock::new(ExecutionContext::new())),
37            snapshots: Arc::new(Mutex::new(Vec::new())),
38            persistent_storage: Arc::new(Mutex::new(HashMap::new())),
39            memory_tracker: Arc::new(Mutex::new(MemoryTracker::new())),
40            config: StateManagerConfig::default(),
41        }
42    }
43
44    /// Initialize the execution context with initial state
45    pub async fn initialize_context(&self, initial_state: &serde_json::Value) -> Result<(), Error> {
46        {
47            let mut context = self.current_context.write().await;
48            context.initialize(initial_state).await?;
49        }
50
51        self.create_snapshot().await?;
52
53        tracing::info!(
54            "Execution context initialized with {} bytes of initial state",
55            serde_json::to_string(initial_state)?.len()
56        );
57
58        Ok(())
59    }
60
61    /// Update the current execution context
62    pub async fn update_context(
63        &self,
64        updates: &HashMap<String, serde_json::Value>,
65    ) -> Result<(), Error> {
66        let mut context = self.current_context.write().await;
67
68        for (key, value) in updates {
69            context.update_variable(key, value).await?;
70        }
71
72        tracing::debug!("Context updated with {} variables", updates.len());
73        Ok(())
74    }
75
76    /// Get current context snapshot
77    pub async fn get_current_context(&self) -> Result<serde_json::Value, Error> {
78        let context = self.current_context.read().await;
79        context.serialize().await
80    }
81
82    /// Store persistent data that survives across tool calls
83    pub async fn store_persistent(&self, key: &str, value: serde_json::Value) -> Result<(), Error> {
84        let mut storage = self.persistent_storage.lock().await;
85        storage.insert(key.to_string(), value.clone());
86
87        tracing::debug!(
88            "Stored persistent data: {} ({} bytes)",
89            key,
90            serde_json::to_string(&value)?.len()
91        );
92
93        Ok(())
94    }
95
96    /// Retrieve persistent data
97    pub async fn get_persistent(&self, key: &str) -> Result<Option<serde_json::Value>, Error> {
98        let storage = self.persistent_storage.lock().await;
99        Ok(storage.get(key).cloned())
100    }
101
102    /// Create a checkpoint snapshot of current state
103    pub async fn create_snapshot(&self) -> Result<ContextSnapshot, Error> {
104        let (execution_context, tool_call_count) = {
105            let context = self.current_context.read().await;
106            (context.serialize().await?, context.get_tool_call_count())
107        };
108
109        let persistent_data = {
110            let storage = self.persistent_storage.lock().await;
111            storage.clone()
112        };
113
114        let context_size_bytes = serde_json::to_string(&execution_context)?.len();
115        let persistent_storage_size_bytes = serde_json::to_string(&persistent_data)?.len();
116        let total_size_bytes = context_size_bytes + persistent_storage_size_bytes;
117
118        let (memory_efficiency, peak_usage_mb) = {
119            let tracker = self.memory_tracker.lock().await;
120            (tracker.calculate_efficiency(), tracker.peak_usage_mb)
121        };
122
123        let current_usage_mb = total_size_bytes as f64 / 1_048_576.0;
124
125        let snapshot = ContextSnapshot {
126            id: format!("snapshot_{}", chrono::Utc::now().timestamp()),
127            timestamp: chrono::Utc::now().timestamp(),
128            execution_context,
129            persistent_data,
130            memory_usage: MemoryUsage {
131                context_size_bytes,
132                persistent_storage_size_bytes,
133                total_size_bytes,
134                memory_efficiency,
135                peak_usage_mb,
136                current_usage_mb,
137            },
138            tool_call_count,
139            checkpoint_metadata: HashMap::new(),
140            compressed: false,
141            compression_ratio: 1.0,
142        };
143
144        // Add to snapshots history
145        let mut snapshots = self.snapshots.lock().await;
146        snapshots.push(snapshot.clone());
147
148        // Maintain snapshot limit
149        if snapshots.len() > self.config.max_snapshots {
150            let removed = snapshots.remove(0);
151            tracing::debug!("Removed old snapshot: {}", removed.id);
152        }
153
154        // Update memory tracking
155        {
156            let mut tracker = self.memory_tracker.lock().await;
157            tracker.record_snapshot(&snapshot);
158        }
159
160        tracing::info!(
161            "Created snapshot {} with {} bytes of data",
162            snapshot.id,
163            serde_json::to_string(&snapshot.execution_context)?.len()
164        );
165
166        Ok(snapshot)
167    }
168
169    /// Restore from a specific snapshot
170    pub async fn restore_snapshot(&self, snapshot_id: &str) -> Result<(), Error> {
171        let snapshots = self.snapshots.lock().await;
172
173        let snapshot = snapshots
174            .iter()
175            .find(|s| s.id == snapshot_id)
176            .ok_or_else(|| Error::Validation(format!("Snapshot '{}' not found", snapshot_id)))?;
177
178        // Restore execution context
179        {
180            let mut context = self.current_context.write().await;
181            context.deserialize(&snapshot.execution_context).await?;
182        }
183
184        // Restore persistent storage
185        {
186            let mut storage = self.persistent_storage.lock().await;
187            *storage = snapshot.persistent_data.clone();
188        }
189
190        tracing::info!("Restored from snapshot: {}", snapshot_id);
191        Ok(())
192    }
193
194    /// Get the most recent snapshot
195    pub async fn get_latest_snapshot(&self) -> Result<Option<ContextSnapshot>, Error> {
196        let snapshots = self.snapshots.lock().await;
197        Ok(snapshots.last().cloned())
198    }
199
200    /// Optimize memory usage by compressing old snapshots
201    pub async fn optimize_memory(&self) -> Result<MemoryOptimizationResult, Error> {
202        let mut storage = self.persistent_storage.lock().await;
203        let mut snapshots = self.snapshots.lock().await;
204
205        let mut compression_count = 0;
206        let mut original_size = 0;
207        let mut compressed_size = 0;
208
209        // Compress old snapshots if needed
210        for snapshot in snapshots.iter_mut() {
211            if snapshot.timestamp < chrono::Utc::now().timestamp() - 3600 {
212                let serialized = serde_json::to_string(&snapshot.execution_context)?;
213                original_size += serialized.len();
214
215                // Simple compression (in real implementation, use proper compression)
216                let compressed = base64::Engine::encode(
217                    &base64::engine::general_purpose::STANDARD,
218                    serialized.as_bytes(),
219                );
220                compressed_size += compressed.len();
221
222                snapshot.compressed = true;
223                snapshot.compression_ratio = if original_size > 0 {
224                    compressed_size as f64 / original_size as f64
225                } else {
226                    1.0
227                };
228
229                compression_count += 1;
230            }
231        }
232
233        // Clean up expired persistent data
234        let before_count = storage.len();
235        storage.retain(|_key, value| {
236            let expire_time = chrono::Utc::now().timestamp() - self.config.data_ttl_seconds;
237            value
238                .get("timestamp")
239                .and_then(|ts| ts.as_i64())
240                .map(|ts| ts > expire_time)
241                .unwrap_or(true)
242        });
243        let after_count = storage.len();
244        let cleaned_count = before_count - after_count;
245
246        let result = MemoryOptimizationResult {
247            compressed_snapshots: compression_count,
248            cleaned_data_items: cleaned_count as u32,
249            memory_saved_mb: ((original_size - compressed_size) as f64 / 1_048_576.0).max(0.0),
250            optimization_timestamp: chrono::Utc::now().timestamp(),
251        };
252
253        tracing::info!(
254            "Memory optimization completed: {} snapshots compressed, {} data items cleaned",
255            compression_count,
256            cleaned_count
257        );
258
259        Ok(result)
260    }
261
262    /// Get current memory usage statistics
263    pub async fn get_current_memory_usage(&self) -> Result<MemoryUsage, Error> {
264        let context = self.current_context.read().await;
265        let storage = self.persistent_storage.lock().await;
266        let tracker = self.memory_tracker.lock().await;
267
268        let context_size = serde_json::to_string(&context.serialize().await?)?.len();
269        let storage_size = serde_json::to_string(&*storage)?.len();
270        let total_size = context_size + storage_size;
271
272        Ok(MemoryUsage {
273            context_size_bytes: context_size,
274            persistent_storage_size_bytes: storage_size,
275            total_size_bytes: total_size,
276            memory_efficiency: tracker.calculate_efficiency(),
277            peak_usage_mb: tracker.peak_usage_mb,
278            current_usage_mb: total_size as f64 / 1_048_576.0,
279        })
280    }
281
282    /// Clean up expired data
283    pub async fn cleanup_expired_data(&self) -> Result<u32, Error> {
284        let mut storage = self.persistent_storage.lock().await;
285        let expire_time = chrono::Utc::now().timestamp() - self.config.data_ttl_seconds;
286
287        let before_count = storage.len();
288        storage.retain(|_key, value| {
289            value
290                .get("timestamp")
291                .and_then(|ts| ts.as_i64())
292                .map(|ts| ts > expire_time)
293                .unwrap_or(true)
294        });
295
296        let cleaned_count = before_count - storage.len();
297
298        tracing::debug!("Cleaned up {} expired data items", cleaned_count);
299        Ok(cleaned_count as u32)
300    }
301}
302
303/// Execution context that maintains state across tool calls
304#[derive(Debug)]
305struct ExecutionContext {
306    /// Current tool call sequence number
307    tool_call_count: u32,
308    /// Shared variables accessible across tool calls
309    shared_variables: HashMap<String, serde_json::Value>,
310    /// Execution metadata
311    metadata: HashMap<String, serde_json::Value>,
312    /// Component-specific state
313    component_states: HashMap<String, ComponentState>,
314    /// Memory-efficient context cache
315    #[allow(dead_code)]
316    context_cache: ContextCache,
317    /// Created timestamp
318    created_at: u64,
319}
320
321impl ExecutionContext {
322    fn new() -> Self {
323        Self {
324            tool_call_count: 0,
325            shared_variables: HashMap::new(),
326            metadata: HashMap::new(),
327            component_states: HashMap::new(),
328            context_cache: ContextCache::new(),
329            created_at: chrono::Utc::now().timestamp() as u64,
330        }
331    }
332
333    /// Initialize with provided state
334    async fn initialize(&mut self, initial_state: &serde_json::Value) -> Result<(), Error> {
335        if let Some(variables) = initial_state.get("variables") {
336            if let Ok(vars) =
337                serde_json::from_value::<HashMap<String, serde_json::Value>>(variables.clone())
338            {
339                self.shared_variables = vars;
340            }
341        }
342
343        if let Some(metadata) = initial_state.get("metadata") {
344            if let Ok(meta) =
345                serde_json::from_value::<HashMap<String, serde_json::Value>>(metadata.clone())
346            {
347                self.metadata = meta;
348            }
349        }
350
351        Ok(())
352    }
353
354    /// Update a shared variable
355    async fn update_variable(&mut self, key: &str, value: &serde_json::Value) -> Result<(), Error> {
356        self.shared_variables.insert(key.to_string(), value.clone());
357        Ok(())
358    }
359
360    /// Get tool call count
361    fn get_tool_call_count(&self) -> u32 {
362        self.tool_call_count
363    }
364
365    /// Increment tool call count
366    #[allow(dead_code)]
367    fn increment_tool_call_count(&mut self) {
368        self.tool_call_count += 1;
369    }
370
371    /// Serialize context for persistence
372    async fn serialize(&self) -> Result<serde_json::Value, Error> {
373        Ok(serde_json::json!({
374            "tool_call_count": self.tool_call_count,
375            "shared_variables": self.shared_variables,
376            "metadata": self.metadata,
377            "component_states": self.component_states,
378            "created_at": self.created_at,
379        }))
380    }
381
382    /// Deserialize context from snapshot
383    async fn deserialize(&mut self, data: &serde_json::Value) -> Result<(), Error> {
384        if let Some(tool_call_count) = data.get("tool_call_count").and_then(|v| v.as_u64()) {
385            self.tool_call_count = tool_call_count as u32;
386        }
387
388        if let Some(variables) = data.get("shared_variables") {
389            if let Ok(vars) =
390                serde_json::from_value::<HashMap<String, serde_json::Value>>(variables.clone())
391            {
392                self.shared_variables = vars;
393            }
394        }
395
396        if let Some(metadata) = data.get("metadata") {
397            if let Ok(meta) =
398                serde_json::from_value::<HashMap<String, serde_json::Value>>(metadata.clone())
399            {
400                self.metadata = meta;
401            }
402        }
403
404        if let Some(component_states) = data.get("component_states") {
405            if let Ok(states) =
406                serde_json::from_value::<HashMap<String, ComponentState>>(component_states.clone())
407            {
408                self.component_states = states;
409            }
410        }
411
412        Ok(())
413    }
414}
415
416/// Component-specific state
417#[derive(Debug, Clone, Serialize, Deserialize)]
418struct ComponentState {
419    pub component_name: String,
420    pub state_data: serde_json::Value,
421    pub last_updated: u64,
422    pub access_count: u32,
423}
424
425/// Context cache for memory optimization
426#[derive(Debug)]
427#[allow(dead_code)]
428struct ContextCache {
429    /// Frequently accessed data with LRU eviction
430    #[allow(dead_code)]
431    lru_cache: HashMap<String, serde_json::Value>,
432    /// Cache capacity
433    #[allow(dead_code)]
434    capacity: usize,
435    /// Current cache size
436    #[allow(dead_code)]
437    current_size: usize,
438}
439
440#[allow(dead_code)]
441impl ContextCache {
442    fn new() -> Self {
443        Self {
444            lru_cache: HashMap::new(),
445            capacity: 100, // Cache up to 100 items
446            current_size: 0,
447        }
448    }
449
450    /// Add item to cache
451    fn add(&mut self, key: &str, value: serde_json::Value) {
452        if self.lru_cache.len() >= self.capacity && !self.lru_cache.contains_key(key) {
453            // Remove least recently used item
454            if let Some(key_to_remove) = self.lru_cache.keys().next().cloned() {
455                if let Some(removed_value) = self.lru_cache.remove(&key_to_remove) {
456                    self.current_size -= serde_json::to_string(&removed_value)
457                        .unwrap_or_default()
458                        .len();
459                }
460            }
461        }
462
463        self.lru_cache.insert(key.to_string(), value);
464        self.current_size += key.len(); // Simplified size calculation
465    }
466
467    /// Get item from cache
468    fn get(&self, key: &str) -> Option<&serde_json::Value> {
469        self.lru_cache.get(key)
470    }
471}
472
473/// Memory tracker for optimization
474#[derive(Debug)]
475struct MemoryTracker {
476    peak_usage_mb: f64,
477    usage_history: Vec<MemorySample>,
478    #[allow(dead_code)]
479    optimization_threshold_mb: f64,
480}
481
482impl MemoryTracker {
483    fn new() -> Self {
484        Self {
485            peak_usage_mb: 0.0,
486            usage_history: Vec::new(),
487            optimization_threshold_mb: 100.0, // Trigger optimization at 100MB
488        }
489    }
490
491    /// Record a memory usage sample
492    #[allow(dead_code)]
493    fn record_sample(&mut self, usage: &MemoryUsage) {
494        let sample = MemorySample {
495            timestamp: chrono::Utc::now().timestamp(),
496            usage_mb: usage.current_usage_mb,
497        };
498
499        self.usage_history.push(sample);
500        self.peak_usage_mb = self.peak_usage_mb.max(usage.current_usage_mb);
501
502        // Keep only recent samples
503        if self.usage_history.len() > 1000 {
504            self.usage_history.remove(0);
505        }
506    }
507
508    /// Record snapshot creation
509    fn record_snapshot(&mut self, snapshot: &ContextSnapshot) {
510        // Update peak usage
511        let snapshot_size_mb = snapshot.memory_usage.current_usage_mb;
512        self.peak_usage_mb = self.peak_usage_mb.max(snapshot_size_mb);
513    }
514
515    /// Calculate memory efficiency
516    fn calculate_efficiency(&self) -> f64 {
517        if self.peak_usage_mb == 0.0 {
518            return 1.0;
519        }
520
521        // Efficiency based on how close current usage is to peak
522        let current_usage = self.usage_history.last().map(|s| s.usage_mb).unwrap_or(0.0);
523
524        (self.peak_usage_mb / current_usage.max(self.peak_usage_mb)).min(1.0)
525    }
526}
527
528/// Configuration for state manager
529#[derive(Debug, Clone)]
530pub struct StateManagerConfig {
531    pub max_snapshots: usize,
532    pub data_ttl_seconds: i64,
533    pub memory_limit_mb: u64,
534    pub auto_optimize: bool,
535    pub checkpoint_interval: u32,
536}
537
538impl Default for StateManagerConfig {
539    fn default() -> Self {
540        Self {
541            max_snapshots: 50,
542            data_ttl_seconds: 3600, // 1 hour
543            memory_limit_mb: 1024,  // 1GB
544            auto_optimize: true,
545            checkpoint_interval: 10, // Every 10 tool calls
546        }
547    }
548}
549
550/// Context snapshot for state persistence
551#[derive(Debug, Clone, Serialize, Deserialize)]
552pub struct ContextSnapshot {
553    pub id: String,
554    pub timestamp: i64,
555    pub execution_context: serde_json::Value,
556    pub persistent_data: HashMap<String, serde_json::Value>,
557    pub memory_usage: MemoryUsage,
558    pub tool_call_count: u32,
559    pub checkpoint_metadata: HashMap<String, serde_json::Value>,
560    /// Compression status
561    pub compressed: bool,
562    pub compression_ratio: f64,
563}
564
565#[derive(Debug, Clone, Serialize, Deserialize)]
566pub struct MemoryUsage {
567    pub context_size_bytes: usize,
568    pub persistent_storage_size_bytes: usize,
569    pub total_size_bytes: usize,
570    pub memory_efficiency: f64,
571    pub peak_usage_mb: f64,
572    pub current_usage_mb: f64,
573}
574
575#[derive(Debug)]
576struct MemorySample {
577    #[allow(dead_code)]
578    timestamp: i64,
579    usage_mb: f64,
580}
581
582#[derive(Debug)]
583pub struct MemoryOptimizationResult {
584    pub compressed_snapshots: u32,
585    pub cleaned_data_items: u32,
586    pub memory_saved_mb: f64,
587    pub optimization_timestamp: i64,
588}
589
590/// State persistence interface
591#[async_trait::async_trait]
592pub trait StatePersistence {
593    async fn save_state(&self, state: &serde_json::Value) -> Result<String, Error>;
594    async fn load_state(&self, state_id: &str) -> Result<serde_json::Value, Error>;
595    async fn delete_state(&self, state_id: &str) -> Result<(), Error>;
596    async fn list_states(&self) -> Result<Vec<String>, Error>;
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602
603    #[tokio::test]
604    async fn test_state_manager_creation() {
605        let manager = StateManager::new();
606        assert!(manager.get_current_context().await.is_ok());
607    }
608
609    #[tokio::test]
610    async fn test_context_initialization() {
611        let manager = StateManager::new();
612        let initial_state = serde_json::json!({
613            "variables": {
614                "user_id": "12345",
615                "session_type": "analysis"
616            },
617            "metadata": {
618                "created_by": "test",
619                "version": "1.0"
620            }
621        });
622
623        assert!(manager.initialize_context(&initial_state).await.is_ok());
624    }
625
626    #[tokio::test]
627    async fn test_persistent_storage() {
628        let manager = StateManager::new();
629        let test_data = serde_json::json!({"key": "value", "timestamp": 1234567890});
630
631        assert!(manager
632            .store_persistent("test_key", test_data.clone())
633            .await
634            .is_ok());
635
636        let retrieved = manager.get_persistent("test_key").await.unwrap();
637        assert_eq!(retrieved, Some(test_data));
638    }
639
640    #[tokio::test]
641    async fn test_snapshot_creation() {
642        let manager = StateManager::new();
643        let initial_state = serde_json::json!({"test": "data"});
644
645        manager.initialize_context(&initial_state).await.unwrap();
646        let snapshot = manager.create_snapshot().await.unwrap();
647
648        assert!(!snapshot.id.is_empty());
649        assert_eq!(snapshot.tool_call_count, 0);
650    }
651
652    #[tokio::test]
653    async fn test_memory_usage_tracking() {
654        let manager = StateManager::new();
655        let usage = manager.get_current_memory_usage().await.unwrap();
656
657        assert!(usage.current_usage_mb >= 0.0);
658        assert!(usage.memory_efficiency >= 0.0 && usage.memory_efficiency <= 1.0);
659    }
660}