mockforge_intelligence/intelligent_behavior/
context.rs1use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use super::config::IntelligentBehaviorConfig;
10use super::memory::VectorMemoryStore;
11use super::types::{InteractionRecord, SessionState};
12use mockforge_foundation::Result;
13
14#[derive(Clone)]
19pub struct StatefulAiContext {
20 session_id: String,
22
23 state: Arc<RwLock<SessionState>>,
25
26 memory_store: Option<Arc<VectorMemoryStore>>,
28
29 config: IntelligentBehaviorConfig,
31}
32
33impl StatefulAiContext {
34 pub fn new(session_id: impl Into<String>, config: IntelligentBehaviorConfig) -> Self {
36 let session_id = session_id.into();
37 let state = Arc::new(RwLock::new(SessionState::new(session_id.clone())));
38
39 Self {
40 session_id,
41 state,
42 memory_store: None,
43 config,
44 }
45 }
46
47 pub fn with_memory_store(mut self, store: Arc<VectorMemoryStore>) -> Self {
49 self.memory_store = Some(store);
50 self
51 }
52
53 pub fn session_id(&self) -> &str {
55 &self.session_id
56 }
57
58 pub async fn record_interaction(
63 &self,
64 method: impl Into<String>,
65 path: impl Into<String>,
66 request: Option<serde_json::Value>,
67 response: Option<serde_json::Value>,
68 ) -> Result<()> {
69 let interaction = InteractionRecord::new(
70 method, path, request, 200, response,
72 );
73
74 let mut state = self.state.write().await;
76 state.record_interaction(interaction.clone());
77
78 let max_history = self.config.performance.max_history_length;
80 let history_len = state.history.len();
81 if history_len > max_history {
82 state.history.drain(0..history_len - max_history);
83 }
84
85 drop(state);
86
87 if let Some(ref store) = self.memory_store {
89 if self.config.vector_store.enabled {
90 store.store_interaction(&self.session_id, &interaction).await?;
91 }
92 }
93
94 Ok(())
95 }
96
97 pub async fn get_state(&self) -> SessionState {
99 let state = self.state.read().await;
100 state.clone()
101 }
102
103 pub async fn set_value(&self, key: impl Into<String>, value: serde_json::Value) {
105 let mut state = self.state.write().await;
106 state.set(key, value);
107 }
108
109 pub async fn get_value(&self, key: &str) -> Option<serde_json::Value> {
111 let state = self.state.read().await;
112 state.get(key).cloned()
113 }
114
115 pub async fn remove_value(&self, key: &str) -> Option<serde_json::Value> {
117 let mut state = self.state.write().await;
118 state.remove(key)
119 }
120
121 pub async fn get_history(&self) -> Vec<InteractionRecord> {
123 let state = self.state.read().await;
124 state.history.clone()
125 }
126
127 pub async fn get_relevant_context(
129 &self,
130 query: &str,
131 limit: usize,
132 ) -> Result<Vec<InteractionRecord>> {
133 if let Some(ref store) = self.memory_store {
134 if self.config.vector_store.enabled {
135 return store.retrieve_context(&self.session_id, query, limit).await;
136 }
137 }
138
139 let state = self.state.read().await;
141 let history = state.history.clone();
142 Ok(history.into_iter().rev().take(limit).collect())
143 }
144
145 pub async fn build_context_summary(&self) -> String {
147 let state = self.state.read().await;
148
149 let mut summary = String::new();
150 summary.push_str("# Session Context\n\n");
151
152 if !state.state.is_empty() {
154 summary.push_str("## Current State\n");
155 for (key, value) in &state.state {
156 summary.push_str(&format!("- {}: {}\n", key, value));
157 }
158 summary.push('\n');
159 }
160
161 if !state.history.is_empty() {
163 summary.push_str("## Recent Interactions\n");
164 let recent = state.history.iter().rev().take(5);
165 for interaction in recent {
166 summary.push_str(&format!(
167 "- {} {} (status {})\n",
168 interaction.method, interaction.path, interaction.status
169 ));
170 }
171 }
172
173 summary
174 }
175
176 pub async fn clear(&self) {
178 let mut state = self.state.write().await;
179 *state = SessionState::new(self.session_id.clone());
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[tokio::test]
188 async fn test_context_creation() {
189 let config = IntelligentBehaviorConfig::default();
190 let context = StatefulAiContext::new("test_session", config);
191
192 assert_eq!(context.session_id(), "test_session");
193 }
194
195 #[tokio::test]
196 async fn test_record_interaction() {
197 let config = IntelligentBehaviorConfig::default();
198 let context = StatefulAiContext::new("test_session", config);
199
200 context
201 .record_interaction(
202 "POST",
203 "/api/users",
204 Some(serde_json::json!({"name": "Alice"})),
205 Some(serde_json::json!({"id": "user_1", "name": "Alice"})),
206 )
207 .await
208 .unwrap();
209
210 let history = context.get_history().await;
211 assert_eq!(history.len(), 1);
212 assert_eq!(history[0].method, "POST");
213 assert_eq!(history[0].path, "/api/users");
214 }
215
216 #[tokio::test]
217 async fn test_state_management() {
218 let config = IntelligentBehaviorConfig::default();
219 let context = StatefulAiContext::new("test_session", config);
220
221 context.set_value("user_id", serde_json::json!("user_123")).await;
223 context.set_value("logged_in", serde_json::json!(true)).await;
224
225 assert_eq!(context.get_value("user_id").await, Some(serde_json::json!("user_123")));
227 assert_eq!(context.get_value("logged_in").await, Some(serde_json::json!(true)));
228
229 let removed = context.remove_value("logged_in").await;
231 assert_eq!(removed, Some(serde_json::json!(true)));
232 assert_eq!(context.get_value("logged_in").await, None);
233 }
234
235 #[tokio::test]
236 async fn test_context_summary() {
237 let config = IntelligentBehaviorConfig::default();
238 let context = StatefulAiContext::new("test_session", config);
239
240 context.set_value("user_id", serde_json::json!("user_1")).await;
241
242 context
243 .record_interaction(
244 "POST",
245 "/api/login",
246 Some(serde_json::json!({"email": "test@example.com"})),
247 Some(serde_json::json!({"token": "abc123"})),
248 )
249 .await
250 .unwrap();
251
252 let summary = context.build_context_summary().await;
253
254 assert!(summary.contains("Session Context"));
255 assert!(summary.contains("user_id"));
256 assert!(summary.contains("POST /api/login"));
257 }
258}