use crate::compiler;
use crate::db::Database;
use crate::index::VectorIndex;
use crate::storage::StorageBackend;
use crate::types::{
CompilerConfig, Message, MessageRole, Result, Session, WorkingSet,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
pub struct SessionManager {
db: Database,
}
impl SessionManager {
pub fn new(db: Database) -> Self {
Self { db }
}
pub fn start_session(&self, user_id: Option<&str>) -> Result<Session> {
self.db.create_session(user_id, None)
}
pub async fn add_user_message(
&self,
session_id: &str,
query: &str,
config: CompilerConfig,
index: &VectorIndex,
api_key: Option<&str>,
) -> Result<(Message, WorkingSet)> {
let message = self
.db
.add_message(session_id, MessageRole::User, query, None)?;
let working_set = compiler::compile(query, config.clone(), &self.db, index, api_key).await?;
self.db.associate_working_set(
session_id,
Some(&message.id),
&working_set,
query,
&config,
)?;
Ok((message, working_set))
}
pub fn add_assistant_message(
&self,
session_id: &str,
content: &str,
metadata: Option<&serde_json::Value>,
) -> Result<Message> {
self.db
.add_message(session_id, MessageRole::Assistant, content, metadata)
}
pub fn get_conversation_history(
&self,
session_id: &str,
max_tokens: Option<usize>,
) -> Result<String> {
let messages = self.db.get_messages(session_id, None)?;
if messages.is_empty() {
return Ok(String::new());
}
let formatted_messages: Vec<String> = messages
.iter()
.map(|msg| {
let role = match msg.role {
MessageRole::User => "User",
MessageRole::Assistant => "Assistant",
MessageRole::System => "System",
MessageRole::Tool => "Tool",
};
format!("{}: {}", role, msg.content)
})
.collect();
if max_tokens.is_none() {
return Ok(formatted_messages.join("\n\n"));
}
let max_tokens = max_tokens.unwrap();
let mut selected_messages = Vec::new();
let mut total_tokens = 0;
for msg in formatted_messages.iter().rev() {
let msg_tokens = estimate_tokens(msg);
if total_tokens + msg_tokens <= max_tokens {
selected_messages.push(msg.clone());
total_tokens += msg_tokens;
} else {
break;
}
}
selected_messages.reverse();
Ok(selected_messages.join("\n\n"))
}
pub fn replay_session(&self, session_id: &str) -> Result<SessionReplay> {
let session_data = self.db.get_session_full(session_id)?;
if session_data.is_none() {
return Err(crate::types::Error::NotFound(format!(
"Session not found: {}",
session_id
)));
}
let session_data = session_data.unwrap();
let session = session_data.session;
let messages = session_data.messages;
let working_sets = session_data.working_sets;
let mut working_set_map = std::collections::HashMap::new();
for ws in working_sets {
if let Some(msg_id) = &ws.message_id {
working_set_map.insert(msg_id.clone(), ws.working_set);
}
}
let mut turns = Vec::new();
let mut i = 0;
while i < messages.len() {
let msg = &messages[i];
if matches!(msg.role, MessageRole::User) {
let user_message = msg.clone();
let working_set = working_set_map.get(&user_message.id).cloned();
let assistant_message = if i + 1 < messages.len()
&& matches!(messages[i + 1].role, MessageRole::Assistant)
{
i += 1; Some(messages[i].clone())
} else {
None
};
turns.push(SessionTurn {
user_message,
working_set,
assistant_message,
});
}
i += 1;
}
Ok(SessionReplay { session, turns })
}
}
pub struct SessionManagerGeneric<B: StorageBackend> {
backend: Arc<B>,
}
impl<B: StorageBackend> SessionManagerGeneric<B> {
pub fn new(backend: B) -> Self {
Self {
backend: Arc::new(backend),
}
}
pub fn from_arc(backend: Arc<B>) -> Self {
Self { backend }
}
pub fn backend(&self) -> &B {
&self.backend
}
pub async fn start_session(&self, user_id: Option<&str>) -> Result<Session> {
self.backend.create_session(user_id, None).await
}
pub async fn add_user_message(
&self,
session_id: &str,
query: &str,
config: CompilerConfig,
api_key: Option<&str>,
) -> Result<(Message, WorkingSet)> {
let message = self
.backend
.add_message(session_id, MessageRole::User, query, None)
.await?;
let working_set = compiler::compile_with_backend(
query,
config.clone(),
self.backend.as_ref(),
api_key,
)
.await?;
self.backend
.associate_working_set(session_id, Some(&message.id), &working_set, query, &config)
.await?;
Ok((message, working_set))
}
pub async fn add_user_message_with_explain(
&self,
session_id: &str,
query: &str,
config: CompilerConfig,
api_key: Option<&str>,
explain: bool,
) -> Result<(Message, WorkingSet)> {
let message = self
.backend
.add_message(session_id, MessageRole::User, query, None)
.await?;
let working_set = compiler::compile_with_backend_options(
query,
config.clone(),
self.backend.as_ref(),
api_key,
explain,
)
.await?;
self.backend
.associate_working_set(session_id, Some(&message.id), &working_set, query, &config)
.await?;
Ok((message, working_set))
}
pub async fn add_assistant_message(
&self,
session_id: &str,
content: &str,
metadata: Option<&serde_json::Value>,
) -> Result<Message> {
self.backend
.add_message(session_id, MessageRole::Assistant, content, metadata)
.await
}
pub async fn get_conversation_history(
&self,
session_id: &str,
max_tokens: Option<usize>,
) -> Result<String> {
let messages = self.backend.get_messages(session_id, None).await?;
if messages.is_empty() {
return Ok(String::new());
}
let formatted_messages: Vec<String> = messages
.iter()
.map(|msg| {
let role = match msg.role {
MessageRole::User => "User",
MessageRole::Assistant => "Assistant",
MessageRole::System => "System",
MessageRole::Tool => "Tool",
};
format!("{}: {}", role, msg.content)
})
.collect();
if max_tokens.is_none() {
return Ok(formatted_messages.join("\n\n"));
}
let max_tokens = max_tokens.unwrap();
let mut selected_messages = Vec::new();
let mut total_tokens = 0;
for msg in formatted_messages.iter().rev() {
let msg_tokens = estimate_tokens(msg);
if total_tokens + msg_tokens <= max_tokens {
selected_messages.push(msg.clone());
total_tokens += msg_tokens;
} else {
break;
}
}
selected_messages.reverse();
Ok(selected_messages.join("\n\n"))
}
pub async fn replay_session(&self, session_id: &str) -> Result<SessionReplay> {
let session_data = self.backend.get_session_full(session_id).await?;
if session_data.is_none() {
return Err(crate::types::Error::NotFound(format!(
"Session not found: {}",
session_id
)));
}
let session_data = session_data.unwrap();
let session = session_data.session;
let messages = session_data.messages;
let working_sets = session_data.working_sets;
let mut working_set_map = std::collections::HashMap::new();
for ws in working_sets {
if let Some(msg_id) = &ws.message_id {
working_set_map.insert(msg_id.clone(), ws.working_set);
}
}
let mut turns = Vec::new();
let mut i = 0;
while i < messages.len() {
let msg = &messages[i];
if matches!(msg.role, MessageRole::User) {
let user_message = msg.clone();
let working_set = working_set_map.get(&user_message.id).cloned();
let assistant_message = if i + 1 < messages.len()
&& matches!(messages[i + 1].role, MessageRole::Assistant)
{
i += 1;
Some(messages[i].clone())
} else {
None
};
turns.push(SessionTurn {
user_message,
working_set,
assistant_message,
});
}
i += 1;
}
Ok(SessionReplay { session, turns })
}
pub async fn get_session(&self, session_id: &str) -> Result<Option<Session>> {
self.backend.get_session(session_id).await
}
pub async fn list_sessions(
&self,
user_id: Option<&str>,
limit: Option<usize>,
) -> Result<Vec<Session>> {
self.backend.list_sessions(user_id, limit).await
}
pub async fn delete_session(&self, session_id: &str) -> Result<()> {
self.backend.delete_session(session_id).await
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionReplay {
pub session: Session,
pub turns: Vec<SessionTurn>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionTurn {
pub user_message: Message,
pub working_set: Option<WorkingSet>,
pub assistant_message: Option<Message>,
}
fn estimate_tokens(text: &str) -> usize {
(text.len() + 3) / 4
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Artifact;
use crate::types::Span;
use uuid::Uuid;
#[test]
fn test_session_manager_new() {
let db = Database::new(":memory:").unwrap();
let _manager = SessionManager::new(db);
}
#[test]
fn test_start_session() {
let db = Database::new(":memory:").unwrap();
let manager = SessionManager::new(db);
let session = manager.start_session(Some("test_user")).unwrap();
assert!(!session.id.is_empty());
assert_eq!(session.user_id, Some("test_user".to_string()));
}
#[tokio::test]
async fn test_add_user_message() {
let db = Database::new(":memory:").unwrap();
let manager = SessionManager::new(db.clone());
let session = manager.start_session(Some("user1")).unwrap();
let artifact = Artifact {
id: Uuid::new_v4().to_string(),
path: "test.txt".to_string(),
content: "This is a test document about Rust programming.".to_string(),
content_hash: "hash123".to_string(),
metadata: None,
created_at: chrono::Utc::now(),
};
db.insert_artifact(&artifact).unwrap();
let span = Span {
id: Uuid::new_v4().to_string(),
artifact_id: artifact.id.clone(),
start_line: 1,
end_line: 1,
text: "This is a test document about Rust programming.".to_string(),
embedding: Some(vec![0.1; 384]), embedding_model: Some("test".to_string()),
token_count: 10,
metadata: None,
};
db.insert_spans(&[span]).unwrap();
let index = db.get_vector_index().unwrap();
let config = CompilerConfig::default();
let (message, working_set) = manager
.add_user_message(&session.id, "What is Rust?", config, &index, None)
.await
.unwrap();
assert_eq!(message.content, "What is Rust?");
assert_eq!(message.role.as_str(), "user");
assert!(!working_set.text.is_empty());
}
#[test]
fn test_add_assistant_message() {
let db = Database::new(":memory:").unwrap();
let manager = SessionManager::new(db);
let session = manager.start_session(Some("user1")).unwrap();
let message = manager
.add_assistant_message(&session.id, "Rust is a systems programming language.", None)
.unwrap();
assert_eq!(message.content, "Rust is a systems programming language.");
assert_eq!(message.role.as_str(), "assistant");
}
#[test]
fn test_get_conversation_history() {
let db = Database::new(":memory:").unwrap();
let manager = SessionManager::new(db.clone());
let session = manager.start_session(Some("user1")).unwrap();
db.add_message(&session.id, MessageRole::User, "Hello", None)
.unwrap();
db.add_message(&session.id, MessageRole::Assistant, "Hi there!", None)
.unwrap();
db.add_message(&session.id, MessageRole::User, "How are you?", None)
.unwrap();
let history = manager
.get_conversation_history(&session.id, None)
.unwrap();
assert!(history.contains("User: Hello"));
assert!(history.contains("Assistant: Hi there!"));
assert!(history.contains("User: How are you?"));
let lines: Vec<&str> = history.split("\n\n").collect();
assert_eq!(lines.len(), 3);
assert_eq!(lines[0], "User: Hello");
assert_eq!(lines[1], "Assistant: Hi there!");
assert_eq!(lines[2], "User: How are you?");
}
#[test]
fn test_get_conversation_history_with_token_limit() {
let db = Database::new(":memory:").unwrap();
let manager = SessionManager::new(db.clone());
let session = manager.start_session(Some("user1")).unwrap();
db.add_message(&session.id, MessageRole::User, "Message 1", None)
.unwrap();
db.add_message(&session.id, MessageRole::Assistant, "Response 1", None)
.unwrap();
db.add_message(&session.id, MessageRole::User, "Message 2", None)
.unwrap();
db.add_message(&session.id, MessageRole::Assistant, "Response 2", None)
.unwrap();
let history = manager
.get_conversation_history(&session.id, Some(20))
.unwrap();
assert!(history.contains("Message 2"));
assert!(history.contains("Response 2"));
let message_count = history.split("\n\n").count();
assert!(message_count <= 4); }
#[test]
fn test_get_conversation_history_empty() {
let db = Database::new(":memory:").unwrap();
let manager = SessionManager::new(db);
let session = manager.start_session(Some("user1")).unwrap();
let history = manager
.get_conversation_history(&session.id, None)
.unwrap();
assert_eq!(history, "");
}
#[tokio::test]
async fn test_replay_session() {
let db = Database::new(":memory:").unwrap();
let manager = SessionManager::new(db.clone());
let session = manager.start_session(Some("user1")).unwrap();
let artifact = Artifact {
id: Uuid::new_v4().to_string(),
path: "test.txt".to_string(),
content: "Test content for replay.".to_string(),
content_hash: "hash123".to_string(),
metadata: None,
created_at: chrono::Utc::now(),
};
db.insert_artifact(&artifact).unwrap();
let span = Span {
id: Uuid::new_v4().to_string(),
artifact_id: artifact.id.clone(),
start_line: 1,
end_line: 1,
text: "Test content for replay.".to_string(),
embedding: Some(vec![0.1; 384]),
embedding_model: Some("test".to_string()),
token_count: 5,
metadata: None,
};
db.insert_spans(&[span]).unwrap();
let index = db.get_vector_index().unwrap();
let config = CompilerConfig::default();
manager
.add_user_message(&session.id, "First query", config.clone(), &index, None)
.await
.unwrap();
manager
.add_assistant_message(&session.id, "First response", None)
.unwrap();
manager
.add_user_message(&session.id, "Second query", config, &index, None)
.await
.unwrap();
manager
.add_assistant_message(&session.id, "Second response", None)
.unwrap();
let replay = manager.replay_session(&session.id).unwrap();
assert_eq!(replay.session.id, session.id);
assert_eq!(replay.turns.len(), 2);
let turn1 = &replay.turns[0];
assert_eq!(turn1.user_message.content, "First query");
assert!(turn1.working_set.is_some());
assert!(turn1.assistant_message.is_some());
assert_eq!(
turn1.assistant_message.as_ref().unwrap().content,
"First response"
);
let turn2 = &replay.turns[1];
assert_eq!(turn2.user_message.content, "Second query");
assert!(turn2.working_set.is_some());
assert!(turn2.assistant_message.is_some());
assert_eq!(
turn2.assistant_message.as_ref().unwrap().content,
"Second response"
);
}
#[test]
fn test_replay_session_not_found() {
let db = Database::new(":memory:").unwrap();
let manager = SessionManager::new(db);
let result = manager.replay_session("nonexistent-id");
assert!(result.is_err());
}
#[tokio::test]
async fn test_replay_session_incomplete_turns() {
let db = Database::new(":memory:").unwrap();
let manager = SessionManager::new(db.clone());
let session = manager.start_session(Some("user1")).unwrap();
let artifact = Artifact {
id: Uuid::new_v4().to_string(),
path: "test.txt".to_string(),
content: "Test content.".to_string(),
content_hash: "hash123".to_string(),
metadata: None,
created_at: chrono::Utc::now(),
};
db.insert_artifact(&artifact).unwrap();
let span = Span {
id: Uuid::new_v4().to_string(),
artifact_id: artifact.id.clone(),
start_line: 1,
end_line: 1,
text: "Test content.".to_string(),
embedding: Some(vec![0.1; 384]),
embedding_model: Some("test".to_string()),
token_count: 3,
metadata: None,
};
db.insert_spans(&[span]).unwrap();
let index = db.get_vector_index().unwrap();
let config = CompilerConfig::default();
manager
.add_user_message(&session.id, "Query without response", config, &index, None)
.await
.unwrap();
let replay = manager.replay_session(&session.id).unwrap();
assert_eq!(replay.turns.len(), 1);
let turn = &replay.turns[0];
assert_eq!(turn.user_message.content, "Query without response");
assert!(turn.working_set.is_some());
assert!(turn.assistant_message.is_none());
}
#[test]
fn test_estimate_tokens() {
let text = "Hello world";
let tokens = estimate_tokens(text);
assert_eq!(tokens, 3);
let longer_text = "This is a longer piece of text for testing token estimation.";
let tokens = estimate_tokens(longer_text);
assert!(tokens > 10);
assert!(tokens < 20);
}
#[tokio::test]
async fn test_full_session_workflow() {
let db = Database::new(":memory:").unwrap();
let manager = SessionManager::new(db.clone());
let docs = vec![
("rust_basics.md", "Rust is a systems programming language that runs blazingly fast, prevents segfaults, and guarantees thread safety."),
("rust_ownership.md", "Ownership is Rust's most unique feature. It enables Rust to make memory safety guarantees without needing a garbage collector."),
("rust_concurrency.md", "Rust's type system and ownership model guarantee thread safety. You can't have data races in safe Rust code."),
];
for (path, content) in &docs {
let artifact = Artifact {
id: Uuid::new_v4().to_string(),
path: path.to_string(),
content: content.to_string(),
content_hash: format!("hash_{}", path),
metadata: None,
created_at: chrono::Utc::now(),
};
db.insert_artifact(&artifact).unwrap();
let span = Span {
id: Uuid::new_v4().to_string(),
artifact_id: artifact.id.clone(),
start_line: 1,
end_line: 1,
text: content.to_string(),
embedding: Some(vec![0.1; 384]), embedding_model: Some("test".to_string()),
token_count: content.split_whitespace().count(),
metadata: None,
};
db.insert_spans(&[span]).unwrap();
}
let index = db.get_vector_index().unwrap();
let session = manager.start_session(Some("alice")).unwrap();
assert_eq!(session.user_id, Some("alice".to_string()));
let config = CompilerConfig::default();
let (msg1, ws1) = manager
.add_user_message(&session.id, "What is Rust?", config.clone(), &index, None)
.await
.unwrap();
assert_eq!(msg1.content, "What is Rust?");
assert!(!ws1.text.is_empty());
assert!(!ws1.citations.is_empty());
let resp1 = manager
.add_assistant_message(
&session.id,
"Rust is a systems programming language known for memory safety.",
None,
)
.unwrap();
assert!(resp1.content.contains("memory safety"));
let (msg2, ws2) = manager
.add_user_message(
&session.id,
"Tell me about ownership",
config.clone(),
&index,
None,
)
.await
.unwrap();
assert_eq!(msg2.content, "Tell me about ownership");
assert!(!ws2.text.is_empty());
let resp2 = manager
.add_assistant_message(
&session.id,
"Ownership is Rust's unique feature for memory management.",
None,
)
.unwrap();
assert!(resp2.content.contains("Ownership"));
let history = manager
.get_conversation_history(&session.id, None)
.unwrap();
assert!(history.contains("What is Rust?"));
assert!(history.contains("memory safety"));
assert!(history.contains("Tell me about ownership"));
assert!(history.contains("Ownership is Rust's unique feature"));
let limited_history = manager
.get_conversation_history(&session.id, Some(100))
.unwrap();
assert!(!limited_history.is_empty());
assert!(limited_history.contains("Ownership"));
let replay = manager.replay_session(&session.id).unwrap();
assert_eq!(replay.session.id, session.id);
assert_eq!(replay.turns.len(), 2);
let turn1 = &replay.turns[0];
assert_eq!(turn1.user_message.content, "What is Rust?");
assert!(turn1.working_set.is_some());
assert!(turn1.assistant_message.is_some());
let turn2 = &replay.turns[1];
assert_eq!(turn2.user_message.content, "Tell me about ownership");
assert!(turn2.working_set.is_some());
assert!(turn2.assistant_message.is_some());
assert!(turn1.user_message.sequence_number < turn2.user_message.sequence_number);
}
}