Skip to main content

mockforge_intelligence/intelligent_behavior/
context.rs

1//! Stateful AI context management
2//!
3//! This module provides the StatefulAiContext which maintains conversation-like
4//! state across multiple API requests.
5
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use super::config::IntelligentBehaviorConfig;
10use super::memory::VectorMemoryStore;
11use super::types::{InteractionRecord, SessionState};
12use mockforge_foundation::Result;
13
14/// Stateful AI context manager
15///
16/// Tracks state across multiple requests within a session, maintaining
17/// conversation history and enabling intelligent, context-aware responses.
18#[derive(Clone)]
19pub struct StatefulAiContext {
20    /// Session ID
21    session_id: String,
22
23    /// Current session state
24    state: Arc<RwLock<SessionState>>,
25
26    /// Vector memory store for long-term semantic memory
27    memory_store: Option<Arc<VectorMemoryStore>>,
28
29    /// Configuration
30    config: IntelligentBehaviorConfig,
31}
32
33impl StatefulAiContext {
34    /// Create a new stateful AI context
35    pub fn new(session_id: impl Into<String>, config: IntelligentBehaviorConfig) -> Self {
36        let session_id = session_id.into();
37        let state = Arc::new(RwLock::new(SessionState::new(session_id.clone())));
38
39        Self {
40            session_id,
41            state,
42            memory_store: None,
43            config,
44        }
45    }
46
47    /// Create with vector memory store
48    pub fn with_memory_store(mut self, store: Arc<VectorMemoryStore>) -> Self {
49        self.memory_store = Some(store);
50        self
51    }
52
53    /// Get the session ID
54    pub fn session_id(&self) -> &str {
55        &self.session_id
56    }
57
58    /// Record an interaction
59    ///
60    /// Note: Takes `&self` instead of `&mut self` because internal state
61    /// is protected by `RwLock`, allowing concurrent access.
62    pub async fn record_interaction(
63        &self,
64        method: impl Into<String>,
65        path: impl Into<String>,
66        request: Option<serde_json::Value>,
67        response: Option<serde_json::Value>,
68    ) -> Result<()> {
69        let interaction = InteractionRecord::new(
70            method, path, request, 200, // Default status
71            response,
72        );
73
74        // Store in session state
75        let mut state = self.state.write().await;
76        state.record_interaction(interaction.clone());
77
78        // Trim history if needed
79        let max_history = self.config.performance.max_history_length;
80        let history_len = state.history.len();
81        if history_len > max_history {
82            state.history.drain(0..history_len - max_history);
83        }
84
85        drop(state);
86
87        // Store in vector memory if enabled
88        if let Some(ref store) = self.memory_store {
89            if self.config.vector_store.enabled {
90                store.store_interaction(&self.session_id, &interaction).await?;
91            }
92        }
93
94        Ok(())
95    }
96
97    /// Get current session state
98    pub async fn get_state(&self) -> SessionState {
99        let state = self.state.read().await;
100        state.clone()
101    }
102
103    /// Set a state value
104    pub async fn set_value(&self, key: impl Into<String>, value: serde_json::Value) {
105        let mut state = self.state.write().await;
106        state.set(key, value);
107    }
108
109    /// Get a state value
110    pub async fn get_value(&self, key: &str) -> Option<serde_json::Value> {
111        let state = self.state.read().await;
112        state.get(key).cloned()
113    }
114
115    /// Remove a state value
116    pub async fn remove_value(&self, key: &str) -> Option<serde_json::Value> {
117        let mut state = self.state.write().await;
118        state.remove(key)
119    }
120
121    /// Get interaction history
122    pub async fn get_history(&self) -> Vec<InteractionRecord> {
123        let state = self.state.read().await;
124        state.history.clone()
125    }
126
127    /// Get relevant past interactions using semantic search
128    pub async fn get_relevant_context(
129        &self,
130        query: &str,
131        limit: usize,
132    ) -> Result<Vec<InteractionRecord>> {
133        if let Some(ref store) = self.memory_store {
134            if self.config.vector_store.enabled {
135                return store.retrieve_context(&self.session_id, query, limit).await;
136            }
137        }
138
139        // Fallback to recent history
140        let state = self.state.read().await;
141        let history = state.history.clone();
142        Ok(history.into_iter().rev().take(limit).collect())
143    }
144
145    /// Build context summary for LLM prompt
146    pub async fn build_context_summary(&self) -> String {
147        let state = self.state.read().await;
148
149        let mut summary = String::new();
150        summary.push_str("# Session Context\n\n");
151
152        // Current state
153        if !state.state.is_empty() {
154            summary.push_str("## Current State\n");
155            for (key, value) in &state.state {
156                summary.push_str(&format!("- {}: {}\n", key, value));
157            }
158            summary.push('\n');
159        }
160
161        // Recent interactions
162        if !state.history.is_empty() {
163            summary.push_str("## Recent Interactions\n");
164            let recent = state.history.iter().rev().take(5);
165            for interaction in recent {
166                summary.push_str(&format!(
167                    "- {} {} (status {})\n",
168                    interaction.method, interaction.path, interaction.status
169                ));
170            }
171        }
172
173        summary
174    }
175
176    /// Clear all state
177    pub async fn clear(&self) {
178        let mut state = self.state.write().await;
179        *state = SessionState::new(self.session_id.clone());
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[tokio::test]
188    async fn test_context_creation() {
189        let config = IntelligentBehaviorConfig::default();
190        let context = StatefulAiContext::new("test_session", config);
191
192        assert_eq!(context.session_id(), "test_session");
193    }
194
195    #[tokio::test]
196    async fn test_record_interaction() {
197        let config = IntelligentBehaviorConfig::default();
198        let context = StatefulAiContext::new("test_session", config);
199
200        context
201            .record_interaction(
202                "POST",
203                "/api/users",
204                Some(serde_json::json!({"name": "Alice"})),
205                Some(serde_json::json!({"id": "user_1", "name": "Alice"})),
206            )
207            .await
208            .unwrap();
209
210        let history = context.get_history().await;
211        assert_eq!(history.len(), 1);
212        assert_eq!(history[0].method, "POST");
213        assert_eq!(history[0].path, "/api/users");
214    }
215
216    #[tokio::test]
217    async fn test_state_management() {
218        let config = IntelligentBehaviorConfig::default();
219        let context = StatefulAiContext::new("test_session", config);
220
221        // Set values
222        context.set_value("user_id", serde_json::json!("user_123")).await;
223        context.set_value("logged_in", serde_json::json!(true)).await;
224
225        // Get values
226        assert_eq!(context.get_value("user_id").await, Some(serde_json::json!("user_123")));
227        assert_eq!(context.get_value("logged_in").await, Some(serde_json::json!(true)));
228
229        // Remove value
230        let removed = context.remove_value("logged_in").await;
231        assert_eq!(removed, Some(serde_json::json!(true)));
232        assert_eq!(context.get_value("logged_in").await, None);
233    }
234
235    #[tokio::test]
236    async fn test_context_summary() {
237        let config = IntelligentBehaviorConfig::default();
238        let context = StatefulAiContext::new("test_session", config);
239
240        context.set_value("user_id", serde_json::json!("user_1")).await;
241
242        context
243            .record_interaction(
244                "POST",
245                "/api/login",
246                Some(serde_json::json!({"email": "test@example.com"})),
247                Some(serde_json::json!({"token": "abc123"})),
248            )
249            .await
250            .unwrap();
251
252        let summary = context.build_context_summary().await;
253
254        assert!(summary.contains("Session Context"));
255        assert!(summary.contains("user_id"));
256        assert!(summary.contains("POST /api/login"));
257    }
258}