use std::collections::{HashMap, HashSet, VecDeque};
use chrono::{DateTime, TimeDelta, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::llm::{ChatMessage, ToolCall, generate_tool_call_id};
use ironclaw_common::truncate_preview;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: Uuid,
pub user_id: String,
pub active_thread: Option<Uuid>,
pub threads: HashMap<Uuid, Thread>,
pub created_at: DateTime<Utc>,
pub last_active_at: DateTime<Utc>,
pub metadata: serde_json::Value,
#[serde(default)]
pub auto_approved_tools: HashSet<String>,
}
impl Session {
pub fn new(user_id: impl Into<String>) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4(),
user_id: user_id.into(),
active_thread: None,
threads: HashMap::new(),
created_at: now,
last_active_at: now,
metadata: serde_json::Value::Null,
auto_approved_tools: HashSet::new(),
}
}
pub fn is_tool_auto_approved(&self, tool_name: &str) -> bool {
self.auto_approved_tools.contains(tool_name)
}
pub fn auto_approve_tool(&mut self, tool_name: impl Into<String>) {
self.auto_approved_tools.insert(tool_name.into());
}
pub fn create_thread(&mut self) -> &mut Thread {
let thread = Thread::new(self.id);
let thread_id = thread.id;
self.active_thread = Some(thread_id);
self.last_active_at = Utc::now();
self.threads.entry(thread_id).or_insert(thread)
}
pub fn active_thread(&self) -> Option<&Thread> {
self.active_thread.and_then(|id| self.threads.get(&id))
}
pub fn active_thread_mut(&mut self) -> Option<&mut Thread> {
self.active_thread.and_then(|id| self.threads.get_mut(&id))
}
pub fn get_or_create_thread(&mut self) -> &mut Thread {
match self.active_thread {
None => self.create_thread(),
Some(id) => {
if self.threads.contains_key(&id) {
self.threads.get_mut(&id).unwrap() } else {
self.create_thread()
}
}
}
}
pub fn switch_thread(&mut self, thread_id: Uuid) -> bool {
if self.threads.contains_key(&thread_id) {
self.active_thread = Some(thread_id);
self.last_active_at = Utc::now();
true
} else {
false
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ThreadState {
Idle,
Processing,
AwaitingApproval,
Completed,
Interrupted,
}
const AUTH_MODE_TTL_SECS: i64 = 300;
const AUTH_MODE_TTL: TimeDelta = TimeDelta::seconds(AUTH_MODE_TTL_SECS);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PendingAuth {
pub extension_name: String,
#[serde(default = "Utc::now")]
pub created_at: DateTime<Utc>,
}
impl PendingAuth {
pub fn is_expired(&self) -> bool {
Utc::now() - self.created_at > AUTH_MODE_TTL
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PendingApproval {
pub request_id: Uuid,
pub tool_name: String,
pub parameters: serde_json::Value,
#[serde(default)]
pub display_parameters: serde_json::Value,
pub description: String,
pub tool_call_id: String,
pub context_messages: Vec<ChatMessage>,
#[serde(default)]
pub deferred_tool_calls: Vec<ToolCall>,
#[serde(default)]
pub user_timezone: Option<String>,
#[serde(default = "default_true")]
pub allow_always: bool,
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Thread {
pub id: Uuid,
pub session_id: Uuid,
pub state: ThreadState,
pub turns: Vec<Turn>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub metadata: serde_json::Value,
#[serde(default)]
pub pending_approval: Option<PendingApproval>,
#[serde(default)]
pub pending_auth: Option<PendingAuth>,
#[serde(default, skip_serializing_if = "VecDeque::is_empty")]
pub pending_messages: VecDeque<String>,
}
pub const MAX_PENDING_MESSAGES: usize = 10;
impl Thread {
pub fn new(session_id: Uuid) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4(),
session_id,
state: ThreadState::Idle,
turns: Vec::new(),
created_at: now,
updated_at: now,
metadata: serde_json::Value::Null,
pending_approval: None,
pending_auth: None,
pending_messages: VecDeque::new(),
}
}
pub fn with_id(id: Uuid, session_id: Uuid) -> Self {
let now = Utc::now();
Self {
id,
session_id,
state: ThreadState::Idle,
turns: Vec::new(),
created_at: now,
updated_at: now,
metadata: serde_json::Value::Null,
pending_approval: None,
pending_auth: None,
pending_messages: VecDeque::new(),
}
}
pub fn turn_number(&self) -> usize {
self.turns.len() + 1
}
pub fn last_turn(&self) -> Option<&Turn> {
self.turns.last()
}
pub fn last_turn_mut(&mut self) -> Option<&mut Turn> {
self.turns.last_mut()
}
pub fn queue_message(&mut self, content: String) -> bool {
if self.pending_messages.len() >= MAX_PENDING_MESSAGES {
return false;
}
self.pending_messages.push_back(content);
self.updated_at = Utc::now();
true
}
pub fn take_pending_message(&mut self) -> Option<String> {
self.pending_messages.pop_front()
}
pub fn drain_pending_messages(&mut self) -> Option<String> {
if self.pending_messages.is_empty() {
return None;
}
let parts: Vec<String> = self.pending_messages.drain(..).collect();
self.updated_at = Utc::now();
Some(parts.join("\n"))
}
pub fn requeue_drained(&mut self, content: String) {
self.pending_messages.push_front(content);
self.updated_at = Utc::now();
}
pub fn start_turn(&mut self, user_input: impl Into<String>) -> &mut Turn {
let turn_number = self.turns.len();
let turn = Turn::new(turn_number, user_input);
self.turns.push(turn);
self.state = ThreadState::Processing;
self.updated_at = Utc::now();
&mut self.turns[turn_number]
}
pub fn complete_turn(&mut self, response: impl Into<String>) {
if let Some(turn) = self.turns.last_mut() {
turn.complete(response);
}
self.state = ThreadState::Idle;
self.updated_at = Utc::now();
}
pub fn fail_turn(&mut self, error: impl Into<String>) {
if let Some(turn) = self.turns.last_mut() {
turn.fail(error);
}
self.state = ThreadState::Idle;
self.updated_at = Utc::now();
}
pub fn await_approval(&mut self, pending: PendingApproval) {
self.state = ThreadState::AwaitingApproval;
self.pending_approval = Some(pending);
self.updated_at = Utc::now();
}
pub fn take_pending_approval(&mut self) -> Option<PendingApproval> {
self.pending_approval.take()
}
pub fn clear_pending_approval(&mut self) {
self.pending_approval = None;
self.state = ThreadState::Idle;
self.updated_at = Utc::now();
}
pub fn enter_auth_mode(&mut self, extension_name: String) {
self.pending_auth = Some(PendingAuth {
extension_name,
created_at: Utc::now(),
});
self.updated_at = Utc::now();
}
pub fn take_pending_auth(&mut self) -> Option<PendingAuth> {
self.pending_auth.take()
}
pub fn interrupt(&mut self) {
if let Some(turn) = self.turns.last_mut() {
turn.interrupt();
}
self.pending_messages.clear();
self.state = ThreadState::Interrupted;
self.updated_at = Utc::now();
}
pub fn resume(&mut self) {
if self.state == ThreadState::Interrupted {
self.state = ThreadState::Idle;
self.updated_at = Utc::now();
}
}
pub fn messages(&self) -> Vec<ChatMessage> {
let mut messages = Vec::new();
for (turn_idx, turn) in self.turns.iter().enumerate() {
if turn.image_content_parts.is_empty() {
messages.push(ChatMessage::user(&turn.user_input));
} else {
messages.push(ChatMessage::user_with_parts(
&turn.user_input,
turn.image_content_parts.clone(),
));
}
if !turn.tool_calls.is_empty() {
let tool_calls_with_ids: Vec<(String, &_)> = turn
.tool_calls
.iter()
.enumerate()
.map(|(tc_idx, tc)| {
(generate_tool_call_id(turn_idx, tc_idx), tc)
})
.collect();
let tool_calls: Vec<ToolCall> = tool_calls_with_ids
.iter()
.map(|(call_id, tc)| ToolCall {
id: call_id.clone(),
name: tc.name.clone(),
arguments: tc.parameters.clone(),
reasoning: None,
})
.collect();
messages.push(ChatMessage::assistant_with_tool_calls(None, tool_calls));
for (call_id, tc) in tool_calls_with_ids {
let content = if let Some(ref err) = tc.error {
truncate_preview(err, 1000)
} else if let Some(ref res) = tc.result {
let raw = match res {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
truncate_preview(&raw, 1000)
} else {
"OK".to_string()
};
messages.push(ChatMessage::tool_result(call_id, &tc.name, content));
}
}
if let Some(ref response) = turn.response {
messages.push(ChatMessage::assistant(response));
}
}
messages
}
pub fn truncate_turns(&mut self, keep: usize) {
if self.turns.len() > keep {
let drain_count = self.turns.len() - keep;
self.turns.drain(0..drain_count);
for (i, turn) in self.turns.iter_mut().enumerate() {
turn.turn_number = i;
}
}
}
pub fn restore_from_messages(&mut self, messages: Vec<ChatMessage>) {
self.turns.clear();
self.state = ThreadState::Idle;
let mut iter = messages.into_iter().peekable();
let mut turn_number = 0;
while let Some(msg) = iter.next() {
if msg.role == crate::llm::Role::User {
let mut turn = Turn::new(turn_number, &msg.content);
while let Some(next) = iter.peek() {
if next.role == crate::llm::Role::Assistant && next.tool_calls.is_some() {
let call_base_idx = turn.tool_calls.len();
if let Some(assistant_msg) = iter.next()
&& let Some(ref tcs) = assistant_msg.tool_calls
{
for tc in tcs {
turn.record_tool_call_with_reasoning(
&tc.name,
tc.arguments.clone(),
tc.reasoning.clone(),
Some(tc.id.clone()),
);
}
}
let mut pos = 0;
while let Some(tr) = iter.peek() {
if tr.role != crate::llm::Role::Tool {
break;
}
if let Some(tool_msg) = iter.next() {
let idx = call_base_idx + pos;
if idx < turn.tool_calls.len() {
turn.tool_calls[idx].result =
Some(serde_json::Value::String(tool_msg.content.clone()));
}
}
pos += 1;
}
} else {
break;
}
}
let is_final_assistant = iter.peek().is_some_and(|n| {
n.role == crate::llm::Role::Assistant && n.tool_calls.is_none()
});
if is_final_assistant && let Some(response) = iter.next() {
turn.complete(&response.content);
}
self.turns.push(turn);
turn_number += 1;
} else {
continue;
}
}
self.updated_at = Utc::now();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TurnState {
Processing,
Completed,
Failed,
Interrupted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Turn {
pub turn_number: usize,
pub user_input: String,
pub response: Option<String>,
pub tool_calls: Vec<TurnToolCall>,
pub state: TurnState,
pub started_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
pub error: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub narrative: Option<String>,
#[serde(skip)]
pub image_content_parts: Vec<crate::llm::ContentPart>,
}
impl Turn {
pub fn new(turn_number: usize, user_input: impl Into<String>) -> Self {
Self {
turn_number,
user_input: user_input.into(),
response: None,
tool_calls: Vec::new(),
state: TurnState::Processing,
started_at: Utc::now(),
completed_at: None,
error: None,
narrative: None,
image_content_parts: Vec::new(),
}
}
pub fn complete(&mut self, response: impl Into<String>) {
self.response = Some(response.into());
self.state = TurnState::Completed;
self.completed_at = Some(Utc::now());
self.image_content_parts.clear();
}
pub fn fail(&mut self, error: impl Into<String>) {
self.error = Some(error.into());
self.state = TurnState::Failed;
self.completed_at = Some(Utc::now());
self.image_content_parts.clear();
}
pub fn interrupt(&mut self) {
self.state = TurnState::Interrupted;
self.completed_at = Some(Utc::now());
self.image_content_parts.clear();
}
pub fn record_tool_call(&mut self, name: impl Into<String>, params: serde_json::Value) {
self.tool_calls.push(TurnToolCall {
name: name.into(),
parameters: params,
result: None,
error: None,
rationale: None,
tool_call_id: None,
});
}
pub fn record_tool_call_with_reasoning(
&mut self,
name: impl Into<String>,
params: serde_json::Value,
rationale: Option<String>,
tool_call_id: Option<String>,
) {
self.tool_calls.push(TurnToolCall {
name: name.into(),
parameters: params,
result: None,
error: None,
rationale,
tool_call_id,
});
}
pub fn record_tool_result(&mut self, result: serde_json::Value) {
if let Some(call) = self.tool_calls.last_mut() {
call.result = Some(result);
}
}
pub fn record_tool_error(&mut self, error: impl Into<String>) {
if let Some(call) = self.tool_calls.last_mut() {
call.error = Some(error.into());
}
}
pub fn record_tool_result_for(&mut self, tool_call_id: &str, result: serde_json::Value) {
if let Some(call) = self
.tool_calls
.iter_mut()
.find(|c| c.tool_call_id.as_deref() == Some(tool_call_id))
{
call.result = Some(result);
} else if let Some(call) = self
.tool_calls
.iter_mut()
.find(|c| c.result.is_none() && c.error.is_none())
{
tracing::debug!(
tool_call_id = %tool_call_id,
fallback_tool = %call.name,
"tool_call_id not found, falling back to first pending call"
);
call.result = Some(result);
} else {
tracing::warn!(
tool_call_id = %tool_call_id,
"Tool result dropped: no matching or pending tool call"
);
}
}
pub fn record_tool_error_for(&mut self, tool_call_id: &str, error: impl Into<String>) {
if let Some(call) = self
.tool_calls
.iter_mut()
.find(|c| c.tool_call_id.as_deref() == Some(tool_call_id))
{
call.error = Some(error.into());
} else if let Some(call) = self
.tool_calls
.iter_mut()
.find(|c| c.result.is_none() && c.error.is_none())
{
tracing::debug!(
tool_call_id = %tool_call_id,
fallback_tool = %call.name,
"tool_call_id not found, falling back to first pending call"
);
call.error = Some(error.into());
} else {
tracing::warn!(
tool_call_id = %tool_call_id,
"Tool error dropped: no matching or pending tool call"
);
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TurnToolCall {
pub name: String,
pub parameters: serde_json::Value,
pub result: Option<serde_json::Value>,
pub error: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rationale: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_creation() {
let mut session = Session::new("user-123");
assert!(session.active_thread.is_none());
session.create_thread();
assert!(session.active_thread.is_some());
}
#[test]
fn test_thread_turns() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("Hello");
assert_eq!(thread.state, ThreadState::Processing);
assert_eq!(thread.turns.len(), 1);
thread.complete_turn("Hi there!");
assert_eq!(thread.state, ThreadState::Idle);
assert_eq!(thread.turns[0].response, Some("Hi there!".to_string()));
}
#[test]
fn test_thread_messages() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("First message");
thread.complete_turn("First response");
thread.start_turn("Second message");
thread.complete_turn("Second response");
let messages = thread.messages();
assert_eq!(messages.len(), 4);
}
#[test]
fn test_turn_tool_calls() {
let mut turn = Turn::new(0, "Test input");
turn.record_tool_call("echo", serde_json::json!({"message": "test"}));
turn.record_tool_result(serde_json::json!("test"));
assert_eq!(turn.tool_calls.len(), 1);
assert!(turn.tool_calls[0].result.is_some());
}
#[test]
fn test_restore_from_messages() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("Original message");
thread.complete_turn("Original response");
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("How are you?"),
ChatMessage::assistant("I'm good!"),
];
thread.restore_from_messages(messages);
assert_eq!(thread.turns.len(), 2);
assert_eq!(thread.turns[0].user_input, "Hello");
assert_eq!(thread.turns[0].response, Some("Hi there!".to_string()));
assert_eq!(thread.turns[1].user_input, "How are you?");
assert_eq!(thread.turns[1].response, Some("I'm good!".to_string()));
assert_eq!(thread.state, ThreadState::Idle);
}
#[test]
fn test_restore_from_messages_incomplete_turn() {
let mut thread = Thread::new(Uuid::new_v4());
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("How are you?"),
];
thread.restore_from_messages(messages);
assert_eq!(thread.turns.len(), 2);
assert_eq!(thread.turns[1].user_input, "How are you?");
assert!(thread.turns[1].response.is_none());
}
#[test]
fn test_enter_auth_mode() {
let before = Utc::now();
let mut thread = Thread::new(Uuid::new_v4());
assert!(thread.pending_auth.is_none());
thread.enter_auth_mode("telegram".to_string());
assert!(thread.pending_auth.is_some());
let pending = thread.pending_auth.as_ref().unwrap();
assert_eq!(pending.extension_name, "telegram");
assert!(pending.created_at >= before);
assert!(!pending.is_expired());
}
#[test]
fn test_take_pending_auth() {
let mut thread = Thread::new(Uuid::new_v4());
thread.enter_auth_mode("notion".to_string());
let pending = thread.take_pending_auth();
assert!(pending.is_some());
let pending = pending.unwrap();
assert_eq!(pending.extension_name, "notion");
assert!(!pending.is_expired());
assert!(thread.pending_auth.is_none());
assert!(thread.take_pending_auth().is_none());
}
#[test]
fn test_pending_auth_serialization() {
let mut thread = Thread::new(Uuid::new_v4());
thread.enter_auth_mode("openai".to_string());
let json = serde_json::to_string(&thread).expect("should serialize");
assert!(json.contains("pending_auth"));
assert!(json.contains("openai"));
assert!(json.contains("created_at"));
let restored: Thread = serde_json::from_str(&json).expect("should deserialize");
assert!(restored.pending_auth.is_some());
let pending = restored.pending_auth.unwrap();
assert_eq!(pending.extension_name, "openai");
assert!(!pending.is_expired());
}
#[test]
fn test_pending_auth_expiry() {
let mut pending = PendingAuth {
extension_name: "test".to_string(),
created_at: Utc::now(),
};
assert!(!pending.is_expired());
pending.created_at = Utc::now() - AUTH_MODE_TTL - TimeDelta::seconds(1);
assert!(pending.is_expired());
}
#[test]
fn test_pending_auth_default_none() {
let mut thread = Thread::new(Uuid::new_v4());
thread.pending_auth = None;
let json = serde_json::to_string(&thread).expect("serialize");
let json = json.replace(",\"pending_auth\":null", "");
let restored: Thread = serde_json::from_str(&json).expect("should deserialize");
assert!(restored.pending_auth.is_none());
}
#[test]
fn test_thread_with_id() {
let specific_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let thread = Thread::with_id(specific_id, session_id);
assert_eq!(thread.id, specific_id);
assert_eq!(thread.session_id, session_id);
assert_eq!(thread.state, ThreadState::Idle);
assert!(thread.turns.is_empty());
}
#[test]
fn test_thread_with_id_restore_messages() {
let thread_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let mut thread = Thread::with_id(thread_id, session_id);
let messages = vec![
ChatMessage::user("Hello from DB"),
ChatMessage::assistant("Restored response"),
];
thread.restore_from_messages(messages);
assert_eq!(thread.id, thread_id);
assert_eq!(thread.turns.len(), 1);
assert_eq!(thread.turns[0].user_input, "Hello from DB");
assert_eq!(
thread.turns[0].response,
Some("Restored response".to_string())
);
}
#[test]
fn test_restore_from_messages_empty() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("hello");
thread.complete_turn("hi");
assert_eq!(thread.turns.len(), 1);
thread.restore_from_messages(Vec::new());
assert!(thread.turns.is_empty());
assert_eq!(thread.state, ThreadState::Idle);
}
#[test]
fn test_restore_from_messages_only_assistant_messages() {
let mut thread = Thread::new(Uuid::new_v4());
let messages = vec![
ChatMessage::assistant("I'm here"),
ChatMessage::assistant("Still here"),
];
thread.restore_from_messages(messages);
assert!(thread.turns.is_empty());
}
#[test]
fn test_restore_from_messages_multiple_user_messages_in_a_row() {
let mut thread = Thread::new(Uuid::new_v4());
let messages = vec![
ChatMessage::user("first"),
ChatMessage::user("second"),
ChatMessage::assistant("reply to second"),
];
thread.restore_from_messages(messages);
assert_eq!(thread.turns.len(), 2);
assert_eq!(thread.turns[0].user_input, "first");
assert!(thread.turns[0].response.is_none());
assert_eq!(thread.turns[1].user_input, "second");
assert_eq!(
thread.turns[1].response,
Some("reply to second".to_string())
);
}
#[test]
fn test_thread_switch() {
let mut session = Session::new("user-1");
let t1_id = session.create_thread().id;
let t2_id = session.create_thread().id;
assert_eq!(session.active_thread, Some(t2_id));
assert!(session.switch_thread(t1_id));
assert_eq!(session.active_thread, Some(t1_id));
let fake_id = Uuid::new_v4();
assert!(!session.switch_thread(fake_id));
assert_eq!(session.active_thread, Some(t1_id));
}
#[test]
fn test_get_or_create_thread_idempotent() {
let mut session = Session::new("user-1");
let tid1 = session.get_or_create_thread().id;
let tid2 = session.get_or_create_thread().id;
assert_eq!(tid1, tid2);
assert_eq!(session.threads.len(), 1);
}
#[test]
fn test_truncate_turns() {
let mut thread = Thread::new(Uuid::new_v4());
for i in 0..5 {
thread.start_turn(format!("msg-{}", i));
thread.complete_turn(format!("resp-{}", i));
}
assert_eq!(thread.turns.len(), 5);
thread.truncate_turns(3);
assert_eq!(thread.turns.len(), 3);
assert_eq!(thread.turns[0].user_input, "msg-2");
assert_eq!(thread.turns[1].user_input, "msg-3");
assert_eq!(thread.turns[2].user_input, "msg-4");
assert_eq!(thread.turns[0].turn_number, 0);
assert_eq!(thread.turns[1].turn_number, 1);
assert_eq!(thread.turns[2].turn_number, 2);
}
#[test]
fn test_truncate_turns_noop_when_fewer() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("only one");
thread.complete_turn("response");
thread.truncate_turns(10);
assert_eq!(thread.turns.len(), 1);
assert_eq!(thread.turns[0].user_input, "only one");
}
#[test]
fn test_thread_interrupt_and_resume() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("do something");
assert_eq!(thread.state, ThreadState::Processing);
thread.interrupt();
assert_eq!(thread.state, ThreadState::Interrupted);
let last_turn = thread.last_turn().unwrap();
assert_eq!(last_turn.state, TurnState::Interrupted);
assert!(last_turn.completed_at.is_some());
thread.resume();
assert_eq!(thread.state, ThreadState::Idle);
}
#[test]
fn test_resume_only_from_interrupted() {
let mut thread = Thread::new(Uuid::new_v4());
assert_eq!(thread.state, ThreadState::Idle);
thread.resume();
assert_eq!(thread.state, ThreadState::Idle);
thread.start_turn("work");
assert_eq!(thread.state, ThreadState::Processing);
thread.resume();
assert_eq!(thread.state, ThreadState::Processing);
}
#[test]
fn test_turn_fail() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("risky operation");
thread.fail_turn("connection timed out");
assert_eq!(thread.state, ThreadState::Idle);
let turn = thread.last_turn().unwrap();
assert_eq!(turn.state, TurnState::Failed);
assert_eq!(turn.error, Some("connection timed out".to_string()));
assert!(turn.response.is_none());
assert!(turn.completed_at.is_some());
}
#[test]
fn test_messages_with_incomplete_last_turn() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("first");
thread.complete_turn("first reply");
thread.start_turn("second (in progress)");
let messages = thread.messages();
assert_eq!(messages.len(), 3);
assert_eq!(messages[0].content, "first");
assert_eq!(messages[1].content, "first reply");
assert_eq!(messages[2].content, "second (in progress)");
}
#[test]
fn test_thread_serialization_round_trip() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("hello");
thread.complete_turn("world");
let json = serde_json::to_string(&thread).unwrap();
let restored: Thread = serde_json::from_str(&json).unwrap();
assert_eq!(restored.id, thread.id);
assert_eq!(restored.session_id, thread.session_id);
assert_eq!(restored.turns.len(), 1);
assert_eq!(restored.turns[0].user_input, "hello");
assert_eq!(restored.turns[0].response, Some("world".to_string()));
}
#[test]
fn test_session_serialization_round_trip() {
let mut session = Session::new("user-ser");
session.create_thread();
session.auto_approve_tool("echo");
let json = serde_json::to_string(&session).unwrap();
let restored: Session = serde_json::from_str(&json).unwrap();
assert_eq!(restored.user_id, "user-ser");
assert_eq!(restored.threads.len(), 1);
assert!(restored.is_tool_auto_approved("echo"));
assert!(!restored.is_tool_auto_approved("shell"));
}
#[test]
fn test_auto_approved_tools() {
let mut session = Session::new("user-1");
assert!(!session.is_tool_auto_approved("shell"));
session.auto_approve_tool("shell");
assert!(session.is_tool_auto_approved("shell"));
session.auto_approve_tool("shell");
assert_eq!(session.auto_approved_tools.len(), 1);
}
#[test]
fn test_turn_tool_call_error() {
let mut turn = Turn::new(0, "test");
turn.record_tool_call("http", serde_json::json!({"url": "example.com"}));
turn.record_tool_error("timeout");
assert_eq!(turn.tool_calls.len(), 1);
assert_eq!(turn.tool_calls[0].error, Some("timeout".to_string()));
assert!(turn.tool_calls[0].result.is_none());
}
#[test]
fn test_turn_number_increments() {
let mut thread = Thread::new(Uuid::new_v4());
assert_eq!(thread.turn_number(), 1);
thread.start_turn("first");
thread.complete_turn("done");
assert_eq!(thread.turn_number(), 2);
thread.start_turn("second");
assert_eq!(thread.turn_number(), 3);
}
#[test]
fn test_complete_turn_on_empty_thread() {
let mut thread = Thread::new(Uuid::new_v4());
thread.complete_turn("phantom response");
assert_eq!(thread.state, ThreadState::Idle);
assert!(thread.turns.is_empty());
}
#[test]
fn test_fail_turn_on_empty_thread() {
let mut thread = Thread::new(Uuid::new_v4());
thread.fail_turn("phantom error");
assert_eq!(thread.state, ThreadState::Idle);
assert!(thread.turns.is_empty());
}
#[test]
fn test_pending_approval_flow() {
let mut thread = Thread::new(Uuid::new_v4());
let approval = PendingApproval {
request_id: Uuid::new_v4(),
tool_name: "shell".to_string(),
parameters: serde_json::json!({"command": "rm -rf /"}),
display_parameters: serde_json::json!({"command": "rm -rf /"}),
description: "dangerous command".to_string(),
tool_call_id: "call_123".to_string(),
context_messages: vec![ChatMessage::user("do it")],
deferred_tool_calls: vec![],
user_timezone: None,
allow_always: false,
};
thread.await_approval(approval);
assert_eq!(thread.state, ThreadState::AwaitingApproval);
assert!(thread.pending_approval.is_some());
let taken = thread.take_pending_approval();
assert!(taken.is_some());
assert_eq!(taken.unwrap().tool_name, "shell");
assert!(thread.pending_approval.is_none());
}
#[test]
fn test_clear_pending_approval() {
let mut thread = Thread::new(Uuid::new_v4());
let approval = PendingApproval {
request_id: Uuid::new_v4(),
tool_name: "http".to_string(),
parameters: serde_json::json!({}),
display_parameters: serde_json::json!({}),
description: "test".to_string(),
tool_call_id: "call_456".to_string(),
context_messages: vec![],
deferred_tool_calls: vec![],
user_timezone: None,
allow_always: true,
};
thread.await_approval(approval);
thread.clear_pending_approval();
assert_eq!(thread.state, ThreadState::Idle);
assert!(thread.pending_approval.is_none());
}
#[test]
fn test_active_thread_accessors() {
let mut session = Session::new("user-1");
assert!(session.active_thread().is_none());
assert!(session.active_thread_mut().is_none());
let tid = session.create_thread().id;
assert!(session.active_thread().is_some());
assert_eq!(session.active_thread().unwrap().id, tid);
session.active_thread_mut().unwrap().start_turn("test");
assert_eq!(
session.active_thread().unwrap().state,
ThreadState::Processing
);
}
#[test]
fn test_messages_includes_tool_calls() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("Search for X");
{
let turn = thread.turns.last_mut().unwrap();
turn.record_tool_call("memory_search", serde_json::json!({"query": "X"}));
turn.record_tool_result(serde_json::json!("Found X in doc.md"));
}
thread.complete_turn("I found X in doc.md.");
let messages = thread.messages();
assert_eq!(messages.len(), 4);
assert_eq!(messages[0].role, crate::llm::Role::User);
assert_eq!(messages[0].content, "Search for X");
assert_eq!(messages[1].role, crate::llm::Role::Assistant);
assert!(messages[1].tool_calls.is_some());
let tcs = messages[1].tool_calls.as_ref().unwrap();
assert_eq!(tcs.len(), 1);
assert_eq!(tcs[0].name, "memory_search");
assert_eq!(messages[2].role, crate::llm::Role::Tool);
assert!(messages[2].content.contains("Found X"));
assert_eq!(messages[3].role, crate::llm::Role::Assistant);
assert_eq!(messages[3].content, "I found X in doc.md.");
}
#[test]
fn test_messages_multiple_tool_calls_per_turn() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("Do two things");
{
let turn = thread.turns.last_mut().unwrap();
turn.record_tool_call("echo", serde_json::json!({"msg": "a"}));
turn.record_tool_result(serde_json::json!("a"));
turn.record_tool_call("time", serde_json::json!({}));
turn.record_tool_error("timeout");
}
thread.complete_turn("Done.");
let messages = thread.messages();
assert_eq!(messages.len(), 5);
let tcs = messages[1].tool_calls.as_ref().unwrap();
assert_eq!(tcs.len(), 2);
assert_eq!(messages[2].content, "a");
assert!(messages[3].content.contains("timeout"));
}
#[test]
fn test_restore_from_messages_with_tool_calls() {
let mut thread = Thread::new(Uuid::new_v4());
let tc = ToolCall {
id: "call_0".to_string(),
name: "search".to_string(),
arguments: serde_json::json!({"q": "test"}),
reasoning: None,
};
let messages = vec![
ChatMessage::user("Find test"),
ChatMessage::assistant_with_tool_calls(None, vec![tc]),
ChatMessage::tool_result("call_0", "search", "result: found"),
ChatMessage::assistant("Found it."),
];
thread.restore_from_messages(messages);
assert_eq!(thread.turns.len(), 1);
let turn = &thread.turns[0];
assert_eq!(turn.user_input, "Find test");
assert_eq!(turn.tool_calls.len(), 1);
assert_eq!(turn.tool_calls[0].name, "search");
assert_eq!(
turn.tool_calls[0].result,
Some(serde_json::Value::String("result: found".to_string()))
);
assert_eq!(turn.response, Some("Found it.".to_string()));
}
#[test]
fn test_restore_from_messages_with_tool_error() {
let mut thread = Thread::new(Uuid::new_v4());
let tc = ToolCall {
id: "call_0".to_string(),
name: "http".to_string(),
arguments: serde_json::json!({}),
reasoning: None,
};
let messages = vec![
ChatMessage::user("Fetch URL"),
ChatMessage::assistant_with_tool_calls(None, vec![tc]),
ChatMessage::tool_result("call_0", "http", "Error: timeout"),
ChatMessage::assistant("The request timed out."),
];
thread.restore_from_messages(messages);
let turn = &thread.turns[0];
assert_eq!(
turn.tool_calls[0].result,
Some(serde_json::Value::String("Error: timeout".to_string()))
);
}
#[test]
fn test_messages_round_trip_with_tools() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("Do search");
{
let turn = thread.turns.last_mut().unwrap();
turn.record_tool_call("search", serde_json::json!({"q": "test"}));
turn.record_tool_result(serde_json::json!("found"));
}
thread.complete_turn("Here are results.");
let messages_original = thread.messages();
let mut thread2 = Thread::new(Uuid::new_v4());
thread2.restore_from_messages(messages_original.clone());
let messages_restored = thread2.messages();
assert_eq!(messages_original.len(), messages_restored.len());
for (orig, rest) in messages_original.iter().zip(messages_restored.iter()) {
assert_eq!(orig.role, rest.role);
}
assert_eq!(
messages_original.last().unwrap().content,
messages_restored.last().unwrap().content
);
}
#[test]
fn test_restore_multi_stage_tool_calls() {
let mut thread = Thread::new(Uuid::new_v4());
let tc1 = ToolCall {
id: "call_a".to_string(),
name: "search".to_string(),
arguments: serde_json::json!({"q": "data"}),
reasoning: None,
};
let tc2 = ToolCall {
id: "call_b".to_string(),
name: "write".to_string(),
arguments: serde_json::json!({"path": "out.txt"}),
reasoning: None,
};
let messages = vec![
ChatMessage::user("Find and save"),
ChatMessage::assistant_with_tool_calls(None, vec![tc1]),
ChatMessage::tool_result("call_a", "search", "found data"),
ChatMessage::assistant_with_tool_calls(None, vec![tc2]),
ChatMessage::tool_result("call_b", "write", "written"),
ChatMessage::assistant("Done, saved to out.txt"),
];
thread.restore_from_messages(messages);
assert_eq!(thread.turns.len(), 1);
let turn = &thread.turns[0];
assert_eq!(turn.tool_calls.len(), 2);
assert_eq!(turn.tool_calls[0].name, "search");
assert_eq!(turn.tool_calls[1].name, "write");
assert_eq!(
turn.tool_calls[0].result,
Some(serde_json::Value::String("found data".to_string()))
);
assert_eq!(
turn.tool_calls[1].result,
Some(serde_json::Value::String("written".to_string()))
);
assert_eq!(turn.response, Some("Done, saved to out.txt".to_string()));
}
#[test]
fn test_messages_truncates_large_tool_results() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("Read big file");
{
let turn = thread.turns.last_mut().unwrap();
turn.record_tool_call("read_file", serde_json::json!({"path": "big.txt"}));
let big_result = "x".repeat(2000);
turn.record_tool_result(serde_json::json!(big_result));
}
thread.complete_turn("Here's the file content.");
let messages = thread.messages();
let tool_result_content = &messages[2].content;
assert!(
tool_result_content.len() <= 1010,
"Tool result should be truncated, got {} chars",
tool_result_content.len()
);
assert!(tool_result_content.ends_with("..."));
}
#[test]
fn test_thread_message_queue() {
let mut thread = Thread::new(Uuid::new_v4());
assert!(thread.pending_messages.is_empty());
assert!(thread.take_pending_message().is_none());
assert!(thread.queue_message("first".to_string()));
assert!(thread.queue_message("second".to_string()));
assert!(thread.queue_message("third".to_string()));
assert_eq!(thread.pending_messages.len(), 3);
assert_eq!(thread.take_pending_message(), Some("first".to_string()));
assert_eq!(thread.take_pending_message(), Some("second".to_string()));
assert_eq!(thread.take_pending_message(), Some("third".to_string()));
assert!(thread.take_pending_message().is_none());
for i in 0..MAX_PENDING_MESSAGES {
assert!(thread.queue_message(format!("msg-{}", i)));
}
assert_eq!(thread.pending_messages.len(), MAX_PENDING_MESSAGES);
assert!(!thread.queue_message("overflow".to_string()));
assert_eq!(thread.pending_messages.len(), MAX_PENDING_MESSAGES);
for i in 0..MAX_PENDING_MESSAGES {
assert_eq!(thread.take_pending_message(), Some(format!("msg-{}", i)));
}
assert!(thread.take_pending_message().is_none());
}
#[test]
fn test_thread_message_queue_serialization() {
let mut thread = Thread::new(Uuid::new_v4());
let json = serde_json::to_string(&thread).unwrap();
assert!(!json.contains("pending_messages"));
thread.queue_message("queued msg".to_string());
let json = serde_json::to_string(&thread).unwrap();
assert!(json.contains("pending_messages"));
assert!(json.contains("queued msg"));
let restored: Thread = serde_json::from_str(&json).unwrap();
assert_eq!(restored.pending_messages.len(), 1);
assert_eq!(restored.pending_messages[0], "queued msg");
}
#[test]
fn test_thread_message_queue_default_on_old_data() {
let thread = Thread::new(Uuid::new_v4());
let json = serde_json::to_string(&thread).unwrap();
assert!(!json.contains("pending_messages"));
let restored: Thread = serde_json::from_str(&json).unwrap();
assert!(restored.pending_messages.is_empty());
}
#[test]
fn test_interrupt_clears_pending_messages() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("initial input");
thread.queue_message("queued-1".to_string());
thread.queue_message("queued-2".to_string());
thread.queue_message("queued-3".to_string());
assert_eq!(thread.pending_messages.len(), 3);
thread.interrupt();
assert!(thread.pending_messages.is_empty());
assert_eq!(thread.state, ThreadState::Interrupted);
}
#[test]
fn test_thread_state_idle_after_full_drain() {
let mut thread = Thread::new(Uuid::new_v4());
thread.start_turn("turn 1");
assert_eq!(thread.state, ThreadState::Processing);
thread.queue_message("queued-a".to_string());
thread.queue_message("queued-b".to_string());
thread.complete_turn("response 1");
assert_eq!(thread.state, ThreadState::Idle);
let merged = thread.drain_pending_messages().unwrap();
assert_eq!(merged, "queued-a\nqueued-b");
thread.start_turn(&merged);
thread.complete_turn("response for merged");
assert!(thread.drain_pending_messages().is_none());
assert!(thread.pending_messages.is_empty());
assert_eq!(thread.state, ThreadState::Idle);
}
#[test]
fn test_drain_pending_messages_merges_with_newlines() {
let mut thread = Thread::new(Uuid::new_v4());
assert!(thread.drain_pending_messages().is_none());
thread.queue_message("only one".to_string());
assert_eq!(
thread.drain_pending_messages(),
Some("only one".to_string()),
);
assert!(thread.pending_messages.is_empty());
thread.queue_message("hey".to_string());
thread.queue_message("can you check the server".to_string());
thread.queue_message("it started 10 min ago".to_string());
assert_eq!(
thread.drain_pending_messages(),
Some("hey\ncan you check the server\nit started 10 min ago".to_string()),
);
assert!(thread.pending_messages.is_empty());
assert!(thread.drain_pending_messages().is_none());
}
#[test]
fn test_requeue_drained_preserves_content_at_front() {
let mut thread = Thread::new(Uuid::new_v4());
thread.requeue_drained("failed batch".to_string());
assert_eq!(thread.pending_messages.len(), 1);
assert_eq!(thread.pending_messages[0], "failed batch");
thread.queue_message("new msg".to_string());
assert_eq!(thread.pending_messages.len(), 2);
let merged = thread.drain_pending_messages().unwrap();
assert_eq!(merged, "failed batch\nnew msg");
}
#[test]
fn test_record_tool_result_for_by_id() {
let mut turn = Turn::new(0, "test");
turn.record_tool_call_with_reasoning(
"tool_a",
serde_json::json!({}),
None,
Some("id_a".into()),
);
turn.record_tool_call_with_reasoning(
"tool_b",
serde_json::json!({}),
None,
Some("id_b".into()),
);
turn.record_tool_result_for("id_b", serde_json::json!("result_b"));
assert!(turn.tool_calls[0].result.is_none());
assert_eq!(
turn.tool_calls[1].result.as_ref().unwrap(),
&serde_json::json!("result_b")
);
}
#[test]
fn test_record_tool_error_for_by_id() {
let mut turn = Turn::new(0, "test");
turn.record_tool_call_with_reasoning(
"tool_a",
serde_json::json!({}),
None,
Some("id_a".into()),
);
turn.record_tool_call_with_reasoning(
"tool_b",
serde_json::json!({}),
None,
Some("id_b".into()),
);
turn.record_tool_error_for("id_a", "failed");
assert_eq!(turn.tool_calls[0].error.as_deref(), Some("failed"));
assert!(turn.tool_calls[1].error.is_none());
}
#[test]
fn test_record_tool_result_for_fallback_to_pending() {
let mut turn = Turn::new(0, "test");
turn.record_tool_call_with_reasoning(
"tool_a",
serde_json::json!({}),
None,
Some("id_a".into()),
);
turn.record_tool_call_with_reasoning(
"tool_b",
serde_json::json!({}),
None,
Some("id_b".into()),
);
turn.tool_calls[0].result = Some(serde_json::json!("done"));
turn.record_tool_result_for("unknown_id", serde_json::json!("fallback"));
assert_eq!(
turn.tool_calls[0].result.as_ref().unwrap(),
&serde_json::json!("done")
);
assert_eq!(
turn.tool_calls[1].result.as_ref().unwrap(),
&serde_json::json!("fallback")
);
}
#[test]
fn test_record_tool_result_for_no_pending_is_noop() {
let mut turn = Turn::new(0, "test");
turn.record_tool_call_with_reasoning(
"tool_a",
serde_json::json!({}),
None,
Some("id_a".into()),
);
turn.tool_calls[0].result = Some(serde_json::json!("done"));
turn.record_tool_result_for("unknown_id", serde_json::json!("lost"));
assert_eq!(
turn.tool_calls[0].result.as_ref().unwrap(),
&serde_json::json!("done")
);
}
}