use std::collections::{HashMap, HashSet};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::llm::ChatMessage;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: Uuid,
pub user_id: String,
pub active_thread: Option<Uuid>,
pub threads: HashMap<Uuid, Thread>,
pub created_at: DateTime<Utc>,
pub last_active_at: DateTime<Utc>,
pub metadata: serde_json::Value,
#[serde(default)]
pub auto_approved_tools: HashSet<String>,
}
impl Session {
pub fn new(user_id: impl Into<String>) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4(),
user_id: user_id.into(),
active_thread: None,
threads: HashMap::new(),
created_at: now,
last_active_at: now,
metadata: serde_json::Value::Null,
auto_approved_tools: HashSet::new(),
}
}
pub fn is_tool_auto_approved(&self, tool_name: &str) -> bool {
self.auto_approved_tools.contains(tool_name)
}
pub fn auto_approve_tool(&mut self, tool_name: impl Into<String>) {
self.auto_approved_tools.insert(tool_name.into());
}
pub fn create_thread(&mut self) -> &mut Thread {
let thread = Thread::new(self.id);
let thread_id = thread.id;
self.threads.insert(thread_id, thread);
self.active_thread = Some(thread_id);
self.last_active_at = Utc::now();
self.threads.get_mut(&thread_id).expect("just inserted")
}
pub fn active_thread(&self) -> Option<&Thread> {
self.active_thread.and_then(|id| self.threads.get(&id))
}
pub fn active_thread_mut(&mut self) -> Option<&mut Thread> {
self.active_thread.and_then(|id| self.threads.get_mut(&id))
}
pub fn get_or_create_thread(&mut self) -> &mut Thread {
if self.active_thread.is_none() {
self.create_thread();
}
self.active_thread_mut().expect("just created")
}
pub fn switch_thread(&mut self, thread_id: Uuid) -> bool {
if self.threads.contains_key(&thread_id) {
self.active_thread = Some(thread_id);
self.last_active_at = Utc::now();
true
} else {
false
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ThreadState {
Idle,
Processing,
AwaitingApproval,
Completed,
Interrupted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PendingAuth {
pub extension_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PendingApproval {
pub request_id: Uuid,
pub tool_name: String,
pub parameters: serde_json::Value,
pub description: String,
pub tool_call_id: String,
pub context_messages: Vec<ChatMessage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Thread {
pub id: Uuid,
pub session_id: Uuid,
pub state: ThreadState,
pub turns: Vec<Turn>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub metadata: serde_json::Value,
#[serde(default)]
pub pending_approval: Option<PendingApproval>,
#[serde(default)]
pub pending_auth: Option<PendingAuth>,
#[serde(default)]
pub last_response_id: Option<String>,
}
impl Thread {
pub fn new(session_id: Uuid) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4(),
session_id,
state: ThreadState::Idle,
turns: Vec::new(),
created_at: now,
updated_at: now,
metadata: serde_json::Value::Null,
pending_approval: None,
pending_auth: None,
last_response_id: None,
}
}
pub fn with_id(id: Uuid, session_id: Uuid) -> Self {
let now = Utc::now();
Self {
id,
session_id,
state: ThreadState::Idle,
turns: Vec::new(),
created_at: now,
updated_at: now,
metadata: serde_json::Value::Null,
pending_approval: None,
pending_auth: None,
last_response_id: None,
}
}
pub fn turn_number(&self) -> usize {
self.turns.len() + 1
}
pub fn last_turn(&self) -> Option<&Turn> {
self.turns.last()
}
pub fn last_turn_mut(&mut self) -> Option<&mut Turn> {
self.turns.last_mut()
}
pub fn start_turn(&mut self, user_input: impl Into<String>) -> &mut Turn {
let turn_number = self.turns.len();
let turn = Turn::new(turn_number, user_input);
self.turns.push(turn);
self.state = ThreadState::Processing;
self.updated_at = Utc::now();
self.turns.last_mut().expect("just pushed")
}
pub fn complete_turn(&mut self, response: impl Into<String>) {
if let Some(turn) = self.turns.last_mut() {
turn.complete(response);
}
self.state = ThreadState::Idle;
self.updated_at = Utc::now();
}
pub fn fail_turn(&mut self, error: impl Into<String>) {
if let Some(turn) = self.turns.last_mut() {
turn.fail(error);
}
self.state = ThreadState::Idle;
self.updated_at = Utc::now();
}
pub fn await_approval(&mut self, pending: PendingApproval) {
self.state = ThreadState::AwaitingApproval;
self.pending_approval = Some(pending);
self.updated_at = Utc::now();
}
pub fn take_pending_approval(&mut self) -> Option<PendingApproval> {
self.pending_approval.take()
}
pub fn clear_pending_approval(&mut self) {
self.pending_approval = None;
self.state = ThreadState::Idle;
self.updated_at = Utc::now();
}
pub fn enter_auth_mode(&mut self, extension_name: String) {
self.pending_auth = Some(PendingAuth { extension_name });
self.updated_at = Utc::now();
}
pub fn take_pending_auth(&mut self) -> Option<PendingAuth> {
self.pending_auth.take()
}
pub fn interrupt(&mut self) {
if let Some(turn) = self.turns.last_mut() {
turn.interrupt();
}
self.state = ThreadState::Interrupted;
self.updated_at = Utc::now();
}
pub fn resume(&mut self) {
if self.state == ThreadState::Interrupted {
self.state = ThreadState::Idle;
self.updated_at = Utc::now();
}
}
pub fn messages(&self) -> Vec<ChatMessage> {
let mut messages = Vec::new();
for turn in &self.turns {
messages.push(ChatMessage::user(&turn.user_input));
if let Some(ref response) = turn.response {
messages.push(ChatMessage::assistant(response));
}
}
messages
}
pub fn truncate_turns(&mut self, keep: usize) {
if self.turns.len() > keep {
let drain_count = self.turns.len() - keep;
self.turns.drain(0..drain_count);
for (i, turn) in self.turns.iter_mut().enumerate() {
turn.turn_number = i;
}
}
}
pub fn restore_from_messages(&mut self, messages: Vec<ChatMessage>) {
self.turns.clear();
self.state = ThreadState::Idle;
let mut iter = messages.into_iter().peekable();
let mut turn_number = 0;
while let Some(msg) = iter.next() {
if msg.role == crate::llm::Role::User {
let mut turn = Turn::new(turn_number, &msg.content);
if let Some(next) = iter.peek()
&& next.role == crate::llm::Role::Assistant
{
let response = iter.next().expect("peeked");
turn.complete(&response.content);
}
self.turns.push(turn);
turn_number += 1;
}
}
self.updated_at = Utc::now();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TurnState {
Processing,
Completed,
Failed,
Interrupted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Turn {
pub turn_number: usize,
pub user_input: String,
pub response: Option<String>,
pub tool_calls: Vec<TurnToolCall>,
pub state: TurnState,
pub started_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
pub error: Option<String>,
}
impl Turn {
pub fn new(turn_number: usize, user_input: impl Into<String>) -> Self {
Self {
turn_number,
user_input: user_input.into(),
response: None,
tool_calls: Vec::new(),
state: TurnState::Processing,
started_at: Utc::now(),
completed_at: None,
error: None,
}
}
pub fn complete(&mut self, response: impl Into<String>) {
self.response = Some(response.into());
self.state = TurnState::Completed;
self.completed_at = Some(Utc::now());
}
pub fn fail(&mut self, error: impl Into<String>) {
self.error = Some(error.into());
self.state = TurnState::Failed;
self.completed_at = Some(Utc::now());
}
pub fn interrupt(&mut self) {
self.state = TurnState::Interrupted;
self.completed_at = Some(Utc::now());
}
pub fn record_tool_call(&mut self, name: impl Into<String>, params: serde_json::Value) {
self.tool_calls.push(TurnToolCall {
name: name.into(),
parameters: params,
result: None,
error: None,
});
}
pub fn record_tool_result(&mut self, result: serde_json::Value) {
if let Some(call) = self.tool_calls.last_mut() {
call.result = Some(result);
}
}
pub fn record_tool_error(&mut self, error: impl Into<String>) {
if let Some(call) = self.tool_calls.last_mut() {
call.error = Some(error.into());
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TurnToolCall {
pub name: String,
pub parameters: serde_json::Value,
pub result: Option<serde_json::Value>,
pub error: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_creation() {
let mut session = Session::new("user-123");
assert!(session.active_thread.is_none());
session.create_thread();
assert!(session.active_thread.is_some());
}
#[test]
fn test_thread_turns() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("Hello");
assert_eq!(thread.state, ThreadState::Processing);
assert_eq!(thread.turns.len(), 1);
thread.complete_turn("Hi there!");
assert_eq!(thread.state, ThreadState::Idle);
assert_eq!(thread.turns[0].response, Some("Hi there!".to_string()));
}
#[test]
fn test_thread_messages() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("First message");
thread.complete_turn("First response");
thread.start_turn("Second message");
thread.complete_turn("Second response");
let messages = thread.messages();
assert_eq!(messages.len(), 4);
}
#[test]
fn test_turn_tool_calls() {
let mut turn = Turn::new(0, "Test input");
turn.record_tool_call("echo", serde_json::json!({"message": "test"}));
turn.record_tool_result(serde_json::json!("test"));
assert_eq!(turn.tool_calls.len(), 1);
assert!(turn.tool_calls[0].result.is_some());
}
#[test]
fn test_restore_from_messages() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("Original message");
thread.complete_turn("Original response");
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("How are you?"),
ChatMessage::assistant("I'm good!"),
];
thread.restore_from_messages(messages);
assert_eq!(thread.turns.len(), 2);
assert_eq!(thread.turns[0].user_input, "Hello");
assert_eq!(thread.turns[0].response, Some("Hi there!".to_string()));
assert_eq!(thread.turns[1].user_input, "How are you?");
assert_eq!(thread.turns[1].response, Some("I'm good!".to_string()));
assert_eq!(thread.state, ThreadState::Idle);
}
#[test]
fn test_restore_from_messages_incomplete_turn() {
let mut thread = Thread::new(Uuid::new_v4());
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("How are you?"),
];
thread.restore_from_messages(messages);
assert_eq!(thread.turns.len(), 2);
assert_eq!(thread.turns[1].user_input, "How are you?");
assert!(thread.turns[1].response.is_none());
}
#[test]
fn test_enter_auth_mode() {
let mut thread = Thread::new(Uuid::new_v4());
assert!(thread.pending_auth.is_none());
thread.enter_auth_mode("telegram".to_string());
assert!(thread.pending_auth.is_some());
assert_eq!(
thread.pending_auth.as_ref().unwrap().extension_name,
"telegram"
);
}
#[test]
fn test_take_pending_auth() {
let mut thread = Thread::new(Uuid::new_v4());
thread.enter_auth_mode("notion".to_string());
let pending = thread.take_pending_auth();
assert!(pending.is_some());
assert_eq!(pending.unwrap().extension_name, "notion");
assert!(thread.pending_auth.is_none());
assert!(thread.take_pending_auth().is_none());
}
#[test]
fn test_pending_auth_serialization() {
let mut thread = Thread::new(Uuid::new_v4());
thread.enter_auth_mode("openai".to_string());
let json = serde_json::to_string(&thread).expect("should serialize");
assert!(json.contains("pending_auth"));
assert!(json.contains("openai"));
let restored: Thread = serde_json::from_str(&json).expect("should deserialize");
assert!(restored.pending_auth.is_some());
assert_eq!(restored.pending_auth.unwrap().extension_name, "openai");
}
#[test]
fn test_pending_auth_default_none() {
let mut thread = Thread::new(Uuid::new_v4());
thread.pending_auth = None;
let json = serde_json::to_string(&thread).expect("serialize");
let json = json.replace(",\"pending_auth\":null", "");
let restored: Thread = serde_json::from_str(&json).expect("should deserialize");
assert!(restored.pending_auth.is_none());
}
#[test]
fn test_thread_with_id() {
let specific_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let thread = Thread::with_id(specific_id, session_id);
assert_eq!(thread.id, specific_id);
assert_eq!(thread.session_id, session_id);
assert_eq!(thread.state, ThreadState::Idle);
assert!(thread.turns.is_empty());
}
#[test]
fn test_thread_with_id_restore_messages() {
let thread_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let mut thread = Thread::with_id(thread_id, session_id);
let messages = vec![
ChatMessage::user("Hello from DB"),
ChatMessage::assistant("Restored response"),
];
thread.restore_from_messages(messages);
assert_eq!(thread.id, thread_id);
assert_eq!(thread.turns.len(), 1);
assert_eq!(thread.turns[0].user_input, "Hello from DB");
assert_eq!(
thread.turns[0].response,
Some("Restored response".to_string())
);
}
#[test]
fn test_restore_from_messages_empty() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("hello");
thread.complete_turn("hi");
assert_eq!(thread.turns.len(), 1);
thread.restore_from_messages(Vec::new());
assert!(thread.turns.is_empty());
assert_eq!(thread.state, ThreadState::Idle);
}
#[test]
fn test_restore_from_messages_only_assistant_messages() {
let mut thread = Thread::new(Uuid::new_v4());
let messages = vec![
ChatMessage::assistant("I'm here"),
ChatMessage::assistant("Still here"),
];
thread.restore_from_messages(messages);
assert!(thread.turns.is_empty());
}
#[test]
fn test_restore_from_messages_multiple_user_messages_in_a_row() {
let mut thread = Thread::new(Uuid::new_v4());
let messages = vec![
ChatMessage::user("first"),
ChatMessage::user("second"),
ChatMessage::assistant("reply to second"),
];
thread.restore_from_messages(messages);
assert_eq!(thread.turns.len(), 2);
assert_eq!(thread.turns[0].user_input, "first");
assert!(thread.turns[0].response.is_none());
assert_eq!(thread.turns[1].user_input, "second");
assert_eq!(
thread.turns[1].response,
Some("reply to second".to_string())
);
}
#[test]
fn test_thread_switch() {
let mut session = Session::new("user-1");
let t1_id = session.create_thread().id;
let t2_id = session.create_thread().id;
assert_eq!(session.active_thread, Some(t2_id));
assert!(session.switch_thread(t1_id));
assert_eq!(session.active_thread, Some(t1_id));
let fake_id = Uuid::new_v4();
assert!(!session.switch_thread(fake_id));
assert_eq!(session.active_thread, Some(t1_id));
}
#[test]
fn test_get_or_create_thread_idempotent() {
let mut session = Session::new("user-1");
let tid1 = session.get_or_create_thread().id;
let tid2 = session.get_or_create_thread().id;
assert_eq!(tid1, tid2);
assert_eq!(session.threads.len(), 1);
}
#[test]
fn test_truncate_turns() {
let mut thread = Thread::new(Uuid::new_v4());
for i in 0..5 {
thread.start_turn(format!("msg-{}", i));
thread.complete_turn(format!("resp-{}", i));
}
assert_eq!(thread.turns.len(), 5);
thread.truncate_turns(3);
assert_eq!(thread.turns.len(), 3);
assert_eq!(thread.turns[0].user_input, "msg-2");
assert_eq!(thread.turns[1].user_input, "msg-3");
assert_eq!(thread.turns[2].user_input, "msg-4");
assert_eq!(thread.turns[0].turn_number, 0);
assert_eq!(thread.turns[1].turn_number, 1);
assert_eq!(thread.turns[2].turn_number, 2);
}
#[test]
fn test_truncate_turns_noop_when_fewer() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("only one");
thread.complete_turn("response");
thread.truncate_turns(10);
assert_eq!(thread.turns.len(), 1);
assert_eq!(thread.turns[0].user_input, "only one");
}
#[test]
fn test_thread_interrupt_and_resume() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("do something");
assert_eq!(thread.state, ThreadState::Processing);
thread.interrupt();
assert_eq!(thread.state, ThreadState::Interrupted);
let last_turn = thread.last_turn().unwrap();
assert_eq!(last_turn.state, TurnState::Interrupted);
assert!(last_turn.completed_at.is_some());
thread.resume();
assert_eq!(thread.state, ThreadState::Idle);
}
#[test]
fn test_resume_only_from_interrupted() {
let mut thread = Thread::new(Uuid::new_v4());
assert_eq!(thread.state, ThreadState::Idle);
thread.resume();
assert_eq!(thread.state, ThreadState::Idle);
thread.start_turn("work");
assert_eq!(thread.state, ThreadState::Processing);
thread.resume();
assert_eq!(thread.state, ThreadState::Processing);
}
#[test]
fn test_turn_fail() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("risky operation");
thread.fail_turn("connection timed out");
assert_eq!(thread.state, ThreadState::Idle);
let turn = thread.last_turn().unwrap();
assert_eq!(turn.state, TurnState::Failed);
assert_eq!(turn.error, Some("connection timed out".to_string()));
assert!(turn.response.is_none());
assert!(turn.completed_at.is_some());
}
#[test]
fn test_messages_with_incomplete_last_turn() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("first");
thread.complete_turn("first reply");
thread.start_turn("second (in progress)");
let messages = thread.messages();
assert_eq!(messages.len(), 3);
assert_eq!(messages[0].content, "first");
assert_eq!(messages[1].content, "first reply");
assert_eq!(messages[2].content, "second (in progress)");
}
#[test]
fn test_thread_serialization_round_trip() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("hello");
thread.complete_turn("world");
thread.last_response_id = Some("resp_abc123".to_string());
let json = serde_json::to_string(&thread).unwrap();
let restored: Thread = serde_json::from_str(&json).unwrap();
assert_eq!(restored.id, thread.id);
assert_eq!(restored.session_id, thread.session_id);
assert_eq!(restored.turns.len(), 1);
assert_eq!(restored.turns[0].user_input, "hello");
assert_eq!(restored.turns[0].response, Some("world".to_string()));
assert_eq!(restored.last_response_id, Some("resp_abc123".to_string()));
}
#[test]
fn test_session_serialization_round_trip() {
let mut session = Session::new("user-ser");
session.create_thread();
session.auto_approve_tool("echo");
let json = serde_json::to_string(&session).unwrap();
let restored: Session = serde_json::from_str(&json).unwrap();
assert_eq!(restored.user_id, "user-ser");
assert_eq!(restored.threads.len(), 1);
assert!(restored.is_tool_auto_approved("echo"));
assert!(!restored.is_tool_auto_approved("shell"));
}
#[test]
fn test_auto_approved_tools() {
let mut session = Session::new("user-1");
assert!(!session.is_tool_auto_approved("shell"));
session.auto_approve_tool("shell");
assert!(session.is_tool_auto_approved("shell"));
session.auto_approve_tool("shell");
assert_eq!(session.auto_approved_tools.len(), 1);
}
#[test]
fn test_turn_tool_call_error() {
let mut turn = Turn::new(0, "test");
turn.record_tool_call("http", serde_json::json!({"url": "example.com"}));
turn.record_tool_error("timeout");
assert_eq!(turn.tool_calls.len(), 1);
assert_eq!(turn.tool_calls[0].error, Some("timeout".to_string()));
assert!(turn.tool_calls[0].result.is_none());
}
#[test]
fn test_turn_number_increments() {
let mut thread = Thread::new(Uuid::new_v4());
assert_eq!(thread.turn_number(), 1);
thread.start_turn("first");
thread.complete_turn("done");
assert_eq!(thread.turn_number(), 2);
thread.start_turn("second");
assert_eq!(thread.turn_number(), 3);
}
#[test]
fn test_complete_turn_on_empty_thread() {
let mut thread = Thread::new(Uuid::new_v4());
thread.complete_turn("phantom response");
assert_eq!(thread.state, ThreadState::Idle);
assert!(thread.turns.is_empty());
}
#[test]
fn test_fail_turn_on_empty_thread() {
let mut thread = Thread::new(Uuid::new_v4());
thread.fail_turn("phantom error");
assert_eq!(thread.state, ThreadState::Idle);
assert!(thread.turns.is_empty());
}
#[test]
fn test_pending_approval_flow() {
let mut thread = Thread::new(Uuid::new_v4());
let approval = PendingApproval {
request_id: Uuid::new_v4(),
tool_name: "shell".to_string(),
parameters: serde_json::json!({"command": "rm -rf /"}),
description: "dangerous command".to_string(),
tool_call_id: "call_123".to_string(),
context_messages: vec![ChatMessage::user("do it")],
};
thread.await_approval(approval);
assert_eq!(thread.state, ThreadState::AwaitingApproval);
assert!(thread.pending_approval.is_some());
let taken = thread.take_pending_approval();
assert!(taken.is_some());
assert_eq!(taken.unwrap().tool_name, "shell");
assert!(thread.pending_approval.is_none());
}
#[test]
fn test_clear_pending_approval() {
let mut thread = Thread::new(Uuid::new_v4());
let approval = PendingApproval {
request_id: Uuid::new_v4(),
tool_name: "http".to_string(),
parameters: serde_json::json!({}),
description: "test".to_string(),
tool_call_id: "call_456".to_string(),
context_messages: vec![],
};
thread.await_approval(approval);
thread.clear_pending_approval();
assert_eq!(thread.state, ThreadState::Idle);
assert!(thread.pending_approval.is_none());
}
#[test]
fn test_active_thread_accessors() {
let mut session = Session::new("user-1");
assert!(session.active_thread().is_none());
assert!(session.active_thread_mut().is_none());
let tid = session.create_thread().id;
assert!(session.active_thread().is_some());
assert_eq!(session.active_thread().unwrap().id, tid);
session.active_thread_mut().unwrap().start_turn("test");
assert_eq!(
session.active_thread().unwrap().state,
ThreadState::Processing
);
}
}