use std::sync::Arc;
use tokio::sync::RwLock;
use super::config::IntelligentBehaviorConfig;
use super::memory::VectorMemoryStore;
use super::types::{InteractionRecord, SessionState};
use crate::Result;
#[derive(Clone)]
pub struct StatefulAiContext {
session_id: String,
state: Arc<RwLock<SessionState>>,
memory_store: Option<Arc<VectorMemoryStore>>,
config: IntelligentBehaviorConfig,
}
impl StatefulAiContext {
pub fn new(session_id: impl Into<String>, config: IntelligentBehaviorConfig) -> Self {
let session_id = session_id.into();
let state = Arc::new(RwLock::new(SessionState::new(session_id.clone())));
Self {
session_id,
state,
memory_store: None,
config,
}
}
pub fn with_memory_store(mut self, store: Arc<VectorMemoryStore>) -> Self {
self.memory_store = Some(store);
self
}
pub fn session_id(&self) -> &str {
&self.session_id
}
pub async fn record_interaction(
&self,
method: impl Into<String>,
path: impl Into<String>,
request: Option<serde_json::Value>,
response: Option<serde_json::Value>,
) -> Result<()> {
let interaction = InteractionRecord::new(
method, path, request, 200, response,
);
let mut state = self.state.write().await;
state.record_interaction(interaction.clone());
let max_history = self.config.performance.max_history_length;
let history_len = state.history.len();
if history_len > max_history {
state.history.drain(0..history_len - max_history);
}
drop(state);
if let Some(ref store) = self.memory_store {
if self.config.vector_store.enabled {
store.store_interaction(&self.session_id, &interaction).await?;
}
}
Ok(())
}
pub async fn get_state(&self) -> SessionState {
let state = self.state.read().await;
state.clone()
}
pub async fn set_value(&self, key: impl Into<String>, value: serde_json::Value) {
let mut state = self.state.write().await;
state.set(key, value);
}
pub async fn get_value(&self, key: &str) -> Option<serde_json::Value> {
let state = self.state.read().await;
state.get(key).cloned()
}
pub async fn remove_value(&self, key: &str) -> Option<serde_json::Value> {
let mut state = self.state.write().await;
state.remove(key)
}
pub async fn get_history(&self) -> Vec<InteractionRecord> {
let state = self.state.read().await;
state.history.clone()
}
pub async fn get_relevant_context(
&self,
query: &str,
limit: usize,
) -> Result<Vec<InteractionRecord>> {
if let Some(ref store) = self.memory_store {
if self.config.vector_store.enabled {
return store.retrieve_context(&self.session_id, query, limit).await;
}
}
let state = self.state.read().await;
let history = state.history.clone();
Ok(history.into_iter().rev().take(limit).collect())
}
pub async fn build_context_summary(&self) -> String {
let state = self.state.read().await;
let mut summary = String::new();
summary.push_str("# Session Context\n\n");
if !state.state.is_empty() {
summary.push_str("## Current State\n");
for (key, value) in &state.state {
summary.push_str(&format!("- {}: {}\n", key, value));
}
summary.push('\n');
}
if !state.history.is_empty() {
summary.push_str("## Recent Interactions\n");
let recent = state.history.iter().rev().take(5);
for interaction in recent {
summary.push_str(&format!(
"- {} {} (status {})\n",
interaction.method, interaction.path, interaction.status
));
}
}
summary
}
pub async fn clear(&self) {
let mut state = self.state.write().await;
*state = SessionState::new(self.session_id.clone());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_context_creation() {
let config = IntelligentBehaviorConfig::default();
let context = StatefulAiContext::new("test_session", config);
assert_eq!(context.session_id(), "test_session");
}
#[tokio::test]
async fn test_record_interaction() {
let config = IntelligentBehaviorConfig::default();
let context = StatefulAiContext::new("test_session", config);
context
.record_interaction(
"POST",
"/api/users",
Some(serde_json::json!({"name": "Alice"})),
Some(serde_json::json!({"id": "user_1", "name": "Alice"})),
)
.await
.unwrap();
let history = context.get_history().await;
assert_eq!(history.len(), 1);
assert_eq!(history[0].method, "POST");
assert_eq!(history[0].path, "/api/users");
}
#[tokio::test]
async fn test_state_management() {
let config = IntelligentBehaviorConfig::default();
let context = StatefulAiContext::new("test_session", config);
context.set_value("user_id", serde_json::json!("user_123")).await;
context.set_value("logged_in", serde_json::json!(true)).await;
assert_eq!(context.get_value("user_id").await, Some(serde_json::json!("user_123")));
assert_eq!(context.get_value("logged_in").await, Some(serde_json::json!(true)));
let removed = context.remove_value("logged_in").await;
assert_eq!(removed, Some(serde_json::json!(true)));
assert_eq!(context.get_value("logged_in").await, None);
}
#[tokio::test]
async fn test_context_summary() {
let config = IntelligentBehaviorConfig::default();
let context = StatefulAiContext::new("test_session", config);
context.set_value("user_id", serde_json::json!("user_1")).await;
context
.record_interaction(
"POST",
"/api/login",
Some(serde_json::json!({"email": "test@example.com"})),
Some(serde_json::json!({"token": "abc123"})),
)
.await
.unwrap();
let summary = context.build_context_summary().await;
assert!(summary.contains("Session Context"));
assert!(summary.contains("user_id"));
assert!(summary.contains("POST /api/login"));
}
}