pub mod message;
pub mod turn;
pub(crate) const KEEP_MESSAGES: usize = 20;
use crate::tool::{ToolCall, ToolCallBuffer, ToolResult};
use message::{Message, MessageContent, Role};
use turn::{TurnStatus, TurnTracker};
#[derive(Debug, Clone, Default)]
pub struct ContextStats {
pub system_tokens: usize,
pub sent_tokens: usize,
pub dropped_tokens: usize,
pub total_messages: usize,
}
#[derive(Debug)]
pub struct Conversation {
pub messages: Vec<Message>,
pub stream_buffer: Option<String>,
pub tool_call_buffer: Option<ToolCallBuffer>,
pub turn_tracker: TurnTracker,
pub cold_summaries: Vec<String>,
}
impl Default for Conversation {
fn default() -> Self {
Self {
messages: Vec::new(),
stream_buffer: None,
tool_call_buffer: None,
turn_tracker: TurnTracker::new(),
cold_summaries: Vec::new(),
}
}
}
impl Conversation {
pub fn new() -> Self {
Self::default()
}
pub fn load(path: &std::path::Path) -> Self {
let data = match std::fs::read_to_string(path) {
Ok(d) => d,
Err(_) => return Self::default(),
};
let messages = match serde_json::from_str::<Vec<Message>>(&data) {
Ok(msgs) => msgs,
Err(_) => {
let backup = path.with_extension("json.bak");
let _ = std::fs::rename(path, &backup);
return Self::default();
}
};
let turn_tracker = TurnTracker::rebuild(&messages);
Self {
messages,
stream_buffer: None,
tool_call_buffer: None,
turn_tracker,
cold_summaries: Vec::new(),
}
}
pub fn save(&self, path: &std::path::Path) {
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if let Ok(data) = serde_json::to_string(&self.messages) {
let temp_path = path.with_extension("json.tmp");
if std::fs::write(&temp_path, &data).is_ok() {
let _ = std::fs::rename(&temp_path, path);
}
}
}
pub fn history_path() -> std::path::PathBuf {
crate::config::Config::config_dir().join("history.json")
}
pub fn add_user_message(&mut self, content: &str) {
if let Some(last) = self.messages.last_mut() {
if matches!(last.role, Role::User) {
if let MessageContent::Text(ref mut text) = last.content {
text.push('\n');
text.push_str(content);
return;
}
}
}
let idx = self.messages.len();
self.messages.push(Message::new(Role::User, content));
self.turn_tracker.on_user_message(idx);
}
pub fn cancel_current_turn(&mut self) {
let start_idx = match self.turn_tracker.active_turn() {
Some(turn) => turn.start_idx,
None => return,
};
self.finalize_stream();
self.tool_call_buffer = None;
self.backfill_cancelled_tool_results();
let msg_count = self.messages.len() - start_idx;
if let Some(current) = self.turn_tracker.turns.last_mut() {
current.msg_count = msg_count;
current.status = TurnStatus::Completed;
}
}
fn backfill_cancelled_tool_results(&mut self) {
let mut seen_result_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
for msg in &self.messages {
if let Some(call_id) = msg.tool_result_call_id() {
seen_result_ids.insert(call_id.to_string());
}
}
let mut missing: Vec<(String, String)> = Vec::new();
for msg in &self.messages {
if let MessageContent::AssistantWithToolCalls { tool_calls, .. } = &msg.content {
for tc in tool_calls {
if !seen_result_ids.contains(&tc.id) {
missing.push((tc.id.clone(), tc.name.clone()));
}
}
}
}
for (call_id, _name) in missing {
let idx = self.messages.len();
self.messages.push(Message {
role: Role::Tool,
content: MessageContent::ToolResult(ToolResult {
call_id,
output: "(cancelled)".into(),
success: false,
}),
});
self.turn_tracker.on_message_added(idx);
}
}
pub fn cancel_current_turn_including_user(&mut self) {
if let Some(turn) = self.turn_tracker.active_turn() {
let start_idx = turn.start_idx;
self.stream_buffer = None;
self.tool_call_buffer = None;
self.messages.truncate(start_idx);
self.turn_tracker.turns.pop();
}
}
pub fn push_delta(&mut self, delta: &str) {
match &mut self.stream_buffer {
Some(buf) => buf.push_str(delta),
None => self.stream_buffer = Some(delta.to_string()),
}
}
pub fn clear_stream_buffer(&mut self) {
self.stream_buffer = None;
}
pub fn finalize_stream(&mut self) {
if let Some(content) = self.stream_buffer.take() {
let Some(content) = clean_assistant_text(&content) else {
return;
};
let idx = self.messages.len();
self.messages.push(Message::new(Role::Assistant, content));
self.turn_tracker.on_message_added(idx);
}
}
pub fn add_assistant_tool_calls(
&mut self,
text: Option<&str>,
tool_calls: Vec<ToolCall>,
reasoning: Option<&str>,
) {
self.add_assistant_tool_calls_with_thinking(text, tool_calls, reasoning, Vec::new());
}
pub fn add_assistant_tool_calls_with_thinking(
&mut self,
text: Option<&str>,
tool_calls: Vec<ToolCall>,
reasoning: Option<&str>,
thinking_blocks: Vec<crate::conversation::message::ThinkingBlock>,
) {
let idx = self.messages.len();
self.messages.push(Message {
role: Role::Assistant,
content: MessageContent::AssistantWithToolCalls {
text: text.map(|s| s.to_string()),
tool_calls,
reasoning_content: reasoning.map(|s| s.to_string()),
thinking_blocks,
},
});
self.turn_tracker.on_message_added(idx);
}
pub fn add_tool_result(&mut self, result: ToolResult) {
let idx = self.messages.len();
self.messages.push(Message {
role: Role::Tool,
content: MessageContent::ToolResult(result),
});
self.turn_tracker.on_message_added(idx);
}
pub fn finalize_stream_with_tool_call(&mut self, tool_call: ToolCall, reasoning: Option<&str>) {
let text = self
.stream_buffer
.take()
.and_then(|s| clean_assistant_text(&s));
self.add_assistant_tool_calls(text.as_deref(), vec![tool_call], reasoning);
}
pub fn finalize_stream_with_tool_calls(
&mut self,
tool_calls: &[ToolCall],
reasoning: Option<&str>,
) {
self.finalize_stream_with_tool_calls_and_thinking(tool_calls, reasoning, Vec::new());
}
pub fn finalize_stream_with_tool_calls_and_thinking(
&mut self,
tool_calls: &[ToolCall],
reasoning: Option<&str>,
thinking_blocks: Vec<crate::conversation::message::ThinkingBlock>,
) {
let text = self
.stream_buffer
.take()
.and_then(|s| clean_assistant_text(&s));
self.add_assistant_tool_calls_with_thinking(
text.as_deref(),
tool_calls.to_vec(),
reasoning,
thinking_blocks,
);
}
pub fn to_provider_messages(&self, system_prompt: &str) -> Vec<Message> {
let mut msgs = Vec::with_capacity(self.messages.len() + 1);
msgs.push(Message::new(Role::System, system_prompt));
msgs.extend(self.messages.iter().cloned());
msgs
}
pub fn to_provider_messages_windowed(
&self,
system_prompt: &str,
window: usize,
) -> Vec<Message> {
let mut start = self.messages.len().saturating_sub(window);
while start < self.messages.len() {
match &self.messages[start].content {
MessageContent::ToolResult(_) | MessageContent::ToolResultRef(_) => {
start += 1;
}
_ => break,
}
}
let original_start = start;
while start < self.messages.len() {
if matches!(self.messages[start].role, Role::User | Role::System) {
break;
}
start += 1;
if start > original_start + 5 {
start = original_start;
break;
}
}
let mut msgs = Vec::with_capacity(self.messages.len() - start + 1);
msgs.push(Message::new(Role::System, system_prompt));
msgs.extend(self.messages[start..].iter().cloned());
msgs
}
pub fn apply_compression(&mut self, remove_count: usize, summary: String) {
if remove_count == 0 || summary.is_empty() {
return;
}
self.cold_summaries.push(summary);
while self.cold_summaries.len() > 3 {
self.cold_summaries.remove(0);
}
let remove_end = remove_count.min(self.messages.len());
self.messages.drain(..remove_end);
let new_msg_len = self.messages.len();
let mut surviving_turns = Vec::new();
for turn in self.turn_tracker.turns.drain(..) {
let turn_end = turn.end_idx();
if turn_end <= remove_end {
continue;
}
let new_start = if turn.start_idx < remove_end {
0
} else {
turn.start_idx - remove_end
};
let new_count = if turn.start_idx < remove_end {
turn_end - remove_end
} else {
turn.msg_count
};
let new_count = new_count.min(new_msg_len.saturating_sub(new_start));
if new_count > 0 && new_start < new_msg_len {
surviving_turns.push(turn::Turn {
start_idx: new_start,
msg_count: new_count,
status: turn.status,
summary: turn.summary,
});
}
}
self.turn_tracker.turns = surviving_turns;
}
}
fn strip_leaked_reasoning(text: &str) -> String {
let trimmed = text.trim();
if trimmed.len() > 1000 || trimmed.contains("```") {
return text.to_string();
}
let paragraphs: Vec<&str> = trimmed
.split("\n\n")
.map(|p| p.trim())
.filter(|p| !p.is_empty())
.collect();
if paragraphs.len() < 2 {
return text.to_string();
}
let first = paragraphs[0];
let reasoning_markers = [
"要求",
"需要",
"这个问题",
"用户",
"根据规则",
"我应该",
"让我",
"分析",
"涉及到",
"敏感",
"回避",
"I need to",
"I should",
"Let me",
"The user",
];
let is_reasoning = reasoning_markers
.iter()
.any(|m| first.starts_with(m) || first.contains(m));
if is_reasoning {
let mut start = paragraphs.len() - 1;
for (i, p) in paragraphs.iter().enumerate().skip(1) {
let still_reasoning = reasoning_markers
.iter()
.any(|m| p.starts_with(m) || p.contains(m));
if !still_reasoning {
start = i;
break;
}
}
return paragraphs[start..].join("\n\n");
}
text.to_string()
}
fn dedup_trailing_repeat(text: &str) -> String {
let text = text.trim_end();
if text.len() < 100 {
return text.to_string();
}
let lines: Vec<&str> = text.lines().collect();
if lines.len() < 6 {
return text.to_string();
}
let half = lines.len() / 2;
for i in 0..half {
let line = lines[i].trim();
if line.len() < 8 {
continue;
}
let is_marker = line.starts_with("**")
|| line.starts_with("##")
|| line.starts_with("1.")
|| line.starts_with("1、");
if !is_marker {
continue;
}
for j in half..lines.len() {
let other = lines[j].trim();
if other == line {
let match_count = lines[i..]
.iter()
.zip(lines[j..].iter())
.filter(|(a, b)| a.trim() == b.trim())
.count();
let remaining = lines.len() - j;
if remaining >= 3 && match_count * 100 / remaining >= 60 {
return lines[..j].join("\n");
}
}
}
}
text.to_string()
}
fn clean_assistant_text(raw: &str) -> Option<String> {
let stripped = raw
.replace("<think>", "")
.replace("</think>", "")
.replace("<|im_start|>", "")
.replace("<|im_end|>", "");
let stripped = strip_orphan_tool_call_xml(&stripped);
let stripped = strip_leaked_reasoning(&stripped);
let stripped = dedup_trailing_repeat(&stripped);
if stripped.trim().is_empty() {
return None;
}
if looks_corrupted(&stripped).is_some() {
return None;
}
Some(stripped)
}
fn strip_orphan_tool_call_xml(text: &str) -> String {
if !text.contains("</tool_call>")
&& !text.contains("</tool_name>")
&& !text.contains("</arg_key>")
&& !text.contains("</arg_value>")
{
return text.to_string();
}
let mut out = text.to_string();
for tag in &["tool_name", "arg_key", "arg_value"] {
let open = format!("<{}>", tag);
let close = format!("</{}>", tag);
loop {
let Some(o) = out.find(&open) else { break };
let after_open = o + open.len();
let Some(c_rel) = out[after_open..].find(&close) else {
out.replace_range(o..after_open, "");
continue;
};
let c_end = after_open + c_rel + close.len();
out.replace_range(o..c_end, "");
}
out = out.replace(&close, "");
}
out = out.replace("<tool_call>", "").replace("</tool_call>", "");
out
}
pub fn looks_corrupted(text: &str) -> Option<&'static str> {
let total_chars = text.chars().count();
if total_chars < 4 {
return None;
}
let replacement = text.chars().filter(|&c| c == '\u{FFFD}').count();
if replacement * 20 > total_chars {
return Some("replacement_char_density");
}
let bad_ctrl = text.chars().filter(|&c| {
let cp = c as u32;
cp < 0x20 && cp != 0x09 && cp != 0x0A && cp != 0x0D
}).count();
if bad_ctrl > 0 {
return Some("c0_control_bytes");
}
let latin_ext_a = text.chars().filter(|&c| {
let cp = c as u32;
(0x0100..=0x017F).contains(&cp)
}).count();
if latin_ext_a * 10 > total_chars * 4 {
return Some("latin_extended_a_mojibake");
}
let mut prev = '\0';
let mut run = 0;
for c in text.chars() {
if c == prev && c as u32 > 0x7F && !is_typographic_repeat_safe(c) {
run += 1;
if run >= 4 {
return Some("stuck_non_ascii_repeat");
}
} else {
run = 0;
prev = c;
}
}
None
}
fn is_typographic_repeat_safe(c: char) -> bool {
let cp = c as u32;
(0x2500..=0x257F).contains(&cp) || (0x2580..=0x259F).contains(&cp) || (0x2010..=0x2015).contains(&cp) || cp == 0x2026 || cp == 0x2022 || cp == 0x25E6 || cp == 0x00B7 }
#[cfg(test)]
mod tests {
use super::*;
use crate::conversation::message::Role;
#[test]
fn test_new_conversation_is_empty() {
let conv = Conversation::new();
assert!(conv.messages.is_empty());
assert!(conv.stream_buffer.is_none());
}
#[test]
fn strip_orphan_xml_no_op_on_plain_prose() {
let text = "答案是可以 ping 通 10.0.0.1,因为服务端用了 TUN 设备。";
assert_eq!(strip_orphan_tool_call_xml(text), text);
}
#[test]
fn strip_orphan_xml_no_op_on_rust_generics() {
let text = "let x: Vec<HashMap<String, Arc<dyn Trait>>> = vec![];\n\
println!(\"<not_a_tag>\");";
assert_eq!(strip_orphan_tool_call_xml(text), text);
}
#[test]
fn strip_orphan_xml_handles_dribbled_close() {
let text = "actual_host, e\n);\npanic!(...);\n}</arg_value>\
<arg_key>limit</arg_key><arg_value>100</arg_value>\
<arg_key>offset</arg_key><arg_value>350</arg_value></tool_call>";
let cleaned = strip_orphan_tool_call_xml(text);
assert!(!cleaned.contains("</tool_call>"), "got: {}", cleaned);
assert!(!cleaned.contains("<arg_key>"), "got: {}", cleaned);
assert!(!cleaned.contains("</arg_value>"), "got: {}", cleaned);
assert!(cleaned.contains("actual_host, e"));
assert!(cleaned.contains("panic!"));
}
#[test]
fn strip_orphan_xml_consumes_paired_inner_payloads() {
let text = "Sure, let me check\n<tool_name>read_file</tool_name>\
<arg_key>path</arg_key><arg_value>/tmp/x.rs</arg_value>";
let cleaned = strip_orphan_tool_call_xml(text);
assert!(!cleaned.contains("read_file"), "got: {}", cleaned);
assert!(!cleaned.contains("/tmp/x.rs"), "got: {}", cleaned);
assert!(cleaned.contains("Sure, let me check"));
}
#[test]
fn strip_orphan_xml_through_clean_assistant_text() {
let only_residue = "<arg_key>limit</arg_key>\
<arg_value>100</arg_value></tool_call>";
assert_eq!(clean_assistant_text(only_residue), None);
}
#[test]
fn strip_orphan_xml_leaves_lone_open_alone_when_no_closes_present() {
let text = "the field is called `<tool_name>` and contains the function name";
assert_eq!(strip_orphan_tool_call_xml(text), text);
}
#[test]
fn test_add_user_message() {
let mut conv = Conversation::new();
conv.add_user_message("hello");
assert_eq!(conv.messages.len(), 1);
assert!(matches!(conv.messages[0].role, Role::User));
assert_eq!(conv.messages[0].text().unwrap(), "hello");
}
#[test]
fn test_push_delta_creates_buffer() {
let mut conv = Conversation::new();
conv.push_delta("Hello");
assert_eq!(conv.stream_buffer, Some("Hello".to_string()));
conv.push_delta(" world");
assert_eq!(conv.stream_buffer, Some("Hello world".to_string()));
}
#[test]
fn test_finalize_stream() {
let mut conv = Conversation::new();
conv.push_delta("Hello world");
conv.finalize_stream();
assert!(conv.stream_buffer.is_none());
assert_eq!(conv.messages.len(), 1);
assert!(matches!(conv.messages[0].role, Role::Assistant));
assert_eq!(conv.messages[0].text().unwrap(), "Hello world");
}
#[test]
fn test_finalize_empty_buffer_is_noop() {
let mut conv = Conversation::new();
conv.finalize_stream();
assert!(conv.messages.is_empty());
}
#[test]
fn test_to_provider_messages_prepends_system() {
let mut conv = Conversation::new();
conv.add_user_message("hi");
let msgs = conv.to_provider_messages("You are helpful.");
assert_eq!(msgs.len(), 2);
assert!(matches!(msgs[0].role, Role::System));
assert_eq!(msgs[0].text().unwrap(), "You are helpful.");
assert!(matches!(msgs[1].role, Role::User));
}
#[test]
fn test_add_assistant_tool_calls() {
use crate::tool::ToolCall;
let mut conv = Conversation::new();
conv.add_user_message("hello");
let call = ToolCall {
id: "call_1".to_string(),
name: "read_file".to_string(),
arguments: r#"{"file_path":"/tmp/test"}"#.to_string(),
};
conv.add_assistant_tool_calls(Some("Let me read that file."), vec![call], None);
assert_eq!(conv.messages.len(), 2);
match &conv.messages[1].content {
MessageContent::AssistantWithToolCalls {
text, tool_calls, ..
} => {
assert_eq!(text.as_deref(), Some("Let me read that file."));
assert_eq!(tool_calls.len(), 1);
}
_ => panic!("Expected AssistantWithToolCalls"),
}
}
#[test]
fn test_add_tool_result() {
use crate::tool::ToolResult;
let mut conv = Conversation::new();
let result = ToolResult {
call_id: "call_1".to_string(),
output: "file contents".to_string(),
success: true,
};
conv.add_tool_result(result);
assert_eq!(conv.messages.len(), 1);
assert!(matches!(conv.messages[0].role, Role::Tool));
}
#[test]
fn test_finalize_stream_with_tool_call() {
use crate::tool::ToolCall;
let mut conv = Conversation::new();
conv.push_delta("Let me check...");
let call = ToolCall {
id: "call_1".to_string(),
name: "read_file".to_string(),
arguments: "{}".to_string(),
};
conv.finalize_stream_with_tool_call(call, None);
assert!(conv.stream_buffer.is_none());
assert_eq!(conv.messages.len(), 1);
match &conv.messages[0].content {
MessageContent::AssistantWithToolCalls {
text, tool_calls, ..
} => {
assert_eq!(text.as_deref(), Some("Let me check..."));
assert_eq!(tool_calls.len(), 1);
}
_ => panic!("Expected AssistantWithToolCalls"),
}
}
#[test]
fn test_cold_zone_fifo_max_3() {
let mut conv = Conversation::new();
conv.cold_summaries.push("summary 1".to_string());
conv.cold_summaries.push("summary 2".to_string());
conv.cold_summaries.push("summary 3".to_string());
for i in 0..4 {
conv.add_user_message(&format!("t{}", i));
conv.messages.push(Message::new(Role::Assistant, "ok"));
conv.turn_tracker.on_message_added(conv.messages.len() - 1);
}
conv.apply_compression(2, "summary 4".to_string());
assert_eq!(conv.cold_summaries.len(), 3);
assert_eq!(conv.cold_summaries[0], "summary 2");
assert_eq!(conv.cold_summaries[2], "summary 4");
}
#[test]
fn test_compression_then_add_user_message_no_underflow() {
let mut conv = Conversation::new();
conv.add_user_message("task 1");
assert_eq!(conv.turn_tracker.turns.len(), 1);
conv.push_delta("response 1");
conv.finalize_stream();
conv.turn_tracker.complete_current();
conv.add_user_message("task 2");
assert_eq!(conv.turn_tracker.turns.len(), 2);
conv.push_delta("response 2");
conv.finalize_stream();
conv.turn_tracker.complete_current();
assert_eq!(conv.messages.len(), 4);
assert_eq!(
conv.turn_tracker.turns[0].status,
turn::TurnStatus::Completed
);
assert_eq!(
conv.turn_tracker.turns[1].status,
turn::TurnStatus::Completed
);
assert_eq!(conv.turn_tracker.turns[0].msg_count, 2);
assert_eq!(conv.turn_tracker.turns[1].msg_count, 2);
conv.apply_compression(2, "Turn 1 summary".to_string());
assert_eq!(conv.messages.len(), 2);
assert_eq!(conv.turn_tracker.turns.len(), 1);
assert_eq!(conv.turn_tracker.turns[0].start_idx, 0);
assert_eq!(conv.turn_tracker.turns[0].msg_count, 2);
conv.add_user_message("task 3");
assert_eq!(conv.messages.len(), 3);
assert_eq!(conv.turn_tracker.turns.len(), 2);
assert_eq!(
conv.turn_tracker.turns[0].status,
turn::TurnStatus::Completed
);
assert_eq!(conv.turn_tracker.turns[0].msg_count, 2);
assert_eq!(conv.turn_tracker.turns[1].status, turn::TurnStatus::Active);
assert_eq!(conv.turn_tracker.turns[1].start_idx, 2);
}
#[test]
fn test_compression_partial_turn_overlap() {
let mut conv = Conversation::new();
conv.add_user_message("task 1");
conv.push_delta("response 1");
conv.finalize_stream();
conv.turn_tracker.complete_current();
conv.add_user_message("task 2");
conv.push_delta("response 2");
conv.finalize_stream();
use crate::tool::ToolResult;
conv.add_tool_result(ToolResult {
call_id: "call_1".to_string(),
output: "result".to_string(),
success: true,
});
conv.turn_tracker.complete_current();
assert_eq!(conv.messages.len(), 5);
assert_eq!(conv.turn_tracker.turns.len(), 2);
assert_eq!(conv.turn_tracker.turns[0].msg_count, 2);
assert_eq!(conv.turn_tracker.turns[1].msg_count, 3);
conv.apply_compression(3, "Old history".to_string());
assert_eq!(conv.messages.len(), 2);
assert_eq!(conv.turn_tracker.turns.len(), 1);
let surviving_turn = &conv.turn_tracker.turns[0];
assert_eq!(surviving_turn.start_idx, 0);
assert_eq!(surviving_turn.msg_count, 2); assert_eq!(surviving_turn.end_idx(), 2);
conv.add_user_message("task 3");
assert_eq!(conv.messages.len(), 3);
assert_eq!(conv.turn_tracker.turns.len(), 2);
assert_eq!(conv.turn_tracker.turns[0].msg_count, 2);
assert_eq!(conv.turn_tracker.turns[1].start_idx, 2);
}
#[test]
fn test_compression_removes_most_messages() {
let mut conv = Conversation::new();
for i in 1..=3 {
conv.add_user_message(&format!("task {}", i));
conv.push_delta(&format!("response {}", i));
conv.finalize_stream();
conv.turn_tracker.complete_current();
}
assert_eq!(conv.messages.len(), 6);
assert_eq!(conv.turn_tracker.turns.len(), 3);
conv.apply_compression(5, "Entire history summarized".to_string());
assert_eq!(conv.messages.len(), 1);
assert_eq!(conv.turn_tracker.turns.len(), 1);
assert_eq!(conv.turn_tracker.turns[0].start_idx, 0);
assert_eq!(conv.turn_tracker.turns[0].msg_count, 1);
conv.add_user_message("new task");
assert_eq!(conv.messages.len(), 2);
assert_eq!(conv.turn_tracker.turns.len(), 2);
assert_eq!(conv.turn_tracker.turns[1].start_idx, 1);
}
#[test]
fn test_compression_exceeds_message_count() {
let mut conv = Conversation::new();
conv.add_user_message("hello");
conv.push_delta("response");
conv.finalize_stream();
assert_eq!(conv.messages.len(), 2);
conv.apply_compression(100, "Summary".to_string());
assert_eq!(conv.messages.is_empty(), true);
assert_eq!(conv.turn_tracker.turns.is_empty(), true);
conv.add_user_message("new message");
assert_eq!(conv.messages.len(), 1);
assert_eq!(conv.turn_tracker.turns.len(), 1);
}
#[test]
fn looks_corrupted_catches_real_datalog_fixture() {
assert_eq!(
looks_corrupted("P<ďĎĎĎĎ"),
Some("latin_extended_a_mojibake")
);
}
#[test]
fn looks_corrupted_catches_replacement_char_density() {
let s: String = (0..10).map(|_| '\u{FFFD}').collect();
assert_eq!(looks_corrupted(&s), Some("replacement_char_density"));
}
#[test]
fn looks_corrupted_catches_c0_control_bytes() {
assert_eq!(
looks_corrupted("hello\x01world"),
Some("c0_control_bytes")
);
}
#[test]
fn looks_corrupted_catches_stuck_repeat() {
let s = format!("hi {}", "中".repeat(5));
assert_eq!(looks_corrupted(&s), Some("stuck_non_ascii_repeat"));
}
#[test]
fn looks_corrupted_passes_normal_chinese() {
assert_eq!(looks_corrupted("你好,让我帮你写代码"), None);
}
#[test]
fn looks_corrupted_passes_normal_english() {
assert_eq!(
looks_corrupted("Let me read the file and figure out what changed."),
None
);
}
#[test]
fn looks_corrupted_passes_short_czech() {
assert_eq!(looks_corrupted("čaj"), None);
assert_eq!(looks_corrupted("čajov"), None);
}
#[test]
fn looks_corrupted_passes_ascii_separators() {
assert_eq!(looks_corrupted("====================="), None);
assert_eq!(looks_corrupted("Done. ......"), None);
}
#[test]
fn looks_corrupted_passes_markdown_table_borders() {
let table = "┌───────────────────────┬──────────────────────────────────┐\n\
│ 文件 │ 动作 │\n\
├───────────────────────┼──────────────────────────────────┤\n\
│ src/main.rs │ CLI 改为子命令 │\n\
└───────────────────────┴──────────────────────────────────┘";
assert_eq!(looks_corrupted(table), None);
}
#[test]
fn looks_corrupted_passes_horizontal_rules_and_typography() {
assert_eq!(looks_corrupted(&"─".repeat(80)), None);
assert_eq!(looks_corrupted(&"═".repeat(40)), None);
assert_eq!(looks_corrupted(&"━".repeat(40)), None);
assert_eq!(looks_corrupted(&"—".repeat(20)), None); assert_eq!(looks_corrupted(&"…".repeat(20)), None); assert_eq!(looks_corrupted(&"•".repeat(10)), None); assert_eq!(looks_corrupted(&"█".repeat(20)), None);
}
#[test]
fn looks_corrupted_still_catches_real_cjk_corruption() {
assert_eq!(
looks_corrupted(&format!("hi {}", "中".repeat(5))),
Some("stuck_non_ascii_repeat")
);
}
#[test]
fn looks_corrupted_too_short_returns_none() {
assert_eq!(looks_corrupted("P"), None);
assert_eq!(looks_corrupted("ok"), None);
}
#[test]
fn finalize_stream_drops_corrupted_output() {
let mut conv = Conversation::new();
conv.push_delta("P<ďĎĎĎĎ");
conv.finalize_stream();
assert!(
conv.messages.is_empty(),
"corrupted assistant output must not be committed to history"
);
assert!(
conv.stream_buffer.is_none(),
"stream buffer must be drained even on drop"
);
}
#[test]
fn cancel_preserves_completed_assistant_text() {
let mut conv = Conversation::new();
conv.add_user_message("创建 index.html");
conv.push_delta("好的,我来帮你创建");
conv.finalize_stream();
assert_eq!(conv.messages.len(), 2);
conv.cancel_current_turn();
assert_eq!(conv.messages.len(), 2);
assert!(matches!(conv.messages[0].role, Role::User));
assert!(matches!(conv.messages[1].role, Role::Assistant));
assert_eq!(conv.messages[0].text().unwrap(), "创建 index.html");
assert_eq!(conv.messages[1].text().unwrap(), "好的,我来帮你创建");
assert_eq!(conv.turn_tracker.turns.len(), 1);
assert_eq!(conv.turn_tracker.turns[0].status, TurnStatus::Completed);
assert_eq!(conv.turn_tracker.turns[0].msg_count, 2);
}
#[test]
fn cancel_backfills_missing_tool_results() {
let mut conv = Conversation::new();
conv.add_user_message("创建 index.html");
conv.add_assistant_tool_calls(
Some("creating file"),
vec![ToolCall {
id: "call_1".into(),
name: "write_file".into(),
arguments: "{}".into(),
}],
None,
);
assert_eq!(conv.messages.len(), 2);
conv.cancel_current_turn();
assert_eq!(conv.messages.len(), 3);
assert!(matches!(conv.messages[0].role, Role::User));
assert!(matches!(conv.messages[1].role, Role::Assistant));
assert!(matches!(conv.messages[2].role, Role::Tool));
if let MessageContent::ToolResult(r) = &conv.messages[2].content {
assert!(!r.success);
assert_eq!(r.output, "(cancelled)");
assert_eq!(r.call_id, "call_1");
} else {
panic!("expected ToolResult");
}
}
#[test]
fn cancel_preserves_completed_tool_pairs_and_backfills_incomplete() {
let mut conv = Conversation::new();
conv.add_user_message("读取 main.rs 然后修改它");
conv.add_assistant_tool_calls(
None,
vec![ToolCall {
id: "call_1".into(),
name: "read_file".into(),
arguments: r#"{"file_path":"main.rs"}"#.into(),
}],
None,
);
conv.add_tool_result(ToolResult {
call_id: "call_1".into(),
output: "fn main() {}".into(),
success: true,
});
conv.add_assistant_tool_calls(
Some("editing file"),
vec![ToolCall {
id: "call_2".into(),
name: "edit_file".into(),
arguments: r#"{"file_path":"main.rs"}"#.into(),
}],
None,
);
assert_eq!(conv.messages.len(), 4);
conv.cancel_current_turn();
assert_eq!(conv.messages.len(), 5);
assert!(matches!(conv.messages[0].role, Role::User));
assert!(matches!(conv.messages[1].role, Role::Assistant)); assert!(matches!(conv.messages[2].role, Role::Tool)); assert!(matches!(conv.messages[3].role, Role::Assistant)); assert!(matches!(conv.messages[4].role, Role::Tool)); if let MessageContent::ToolResult(r) = &conv.messages[4].content {
assert_eq!(r.call_id, "call_2");
assert!(!r.success);
}
}
#[test]
fn cancel_preserves_previous_turns() {
let mut conv = Conversation::new();
conv.add_user_message("你好");
conv.push_delta("你好!有什么可以帮你?");
conv.finalize_stream();
conv.turn_tracker.complete_current();
conv.add_user_message("创建 index.html");
conv.push_delta("好的,我来创建...");
conv.finalize_stream();
assert_eq!(conv.messages.len(), 4);
conv.cancel_current_turn();
assert_eq!(conv.messages.len(), 4);
assert_eq!(conv.turn_tracker.turns.len(), 2);
assert_eq!(conv.turn_tracker.turns[0].status, TurnStatus::Completed);
assert_eq!(conv.turn_tracker.turns[1].status, TurnStatus::Completed);
}
#[test]
fn cancel_then_follow_up_sees_completed_work() {
let mut conv = Conversation::new();
conv.add_user_message("创建 index.html");
conv.add_assistant_tool_calls(
Some("creating file"),
vec![ToolCall {
id: "call_1".into(),
name: "write_file".into(),
arguments: r#"{"file_path":"index.html","content":"hello"}"#.into(),
}],
None,
);
conv.add_tool_result(ToolResult {
call_id: "call_1".into(),
output: "File written successfully".into(),
success: true,
});
conv.cancel_current_turn();
conv.add_user_message("不要删那行,改成 XXX");
let msgs = conv.to_provider_messages("You are helpful.");
let all_text: String = msgs.iter().map(|m| m.text().unwrap_or("")).collect();
assert!(
all_text.contains("write_file") || all_text.contains("index.html"),
"LLM must see what it already did"
);
assert!(all_text.contains("不要删那行"), "LLM must see the corrective prompt");
}
#[test]
fn cancel_finalizes_stream_buffer() {
let mut conv = Conversation::new();
conv.add_user_message("你好");
conv.push_delta("你好!我是");
assert!(conv.stream_buffer.is_some());
conv.cancel_current_turn();
assert!(conv.stream_buffer.is_none());
assert_eq!(conv.messages.len(), 2);
assert!(matches!(conv.messages[1].role, Role::Assistant));
}
#[test]
fn cancel_including_user_removes_everything() {
let mut conv = Conversation::new();
conv.add_user_message("hello");
conv.push_delta("partial response");
conv.finalize_stream();
conv.cancel_current_turn_including_user();
assert!(conv.messages.is_empty());
assert!(conv.turn_tracker.turns.is_empty());
}
#[test]
fn cancel_including_user_preserves_previous_turns() {
let mut conv = Conversation::new();
conv.add_user_message("你好");
conv.push_delta("你好!");
conv.finalize_stream();
conv.turn_tracker.complete_current();
conv.add_user_message("创建文件");
conv.push_delta("好的...");
conv.finalize_stream();
conv.cancel_current_turn_including_user();
assert_eq!(conv.messages.len(), 2);
assert_eq!(conv.turn_tracker.turns.len(), 1);
}
#[test]
fn cancel_on_empty_conversation_is_noop() {
let mut conv = Conversation::new();
conv.cancel_current_turn();
assert!(conv.messages.is_empty());
}
#[test]
fn cancel_backfills_multi_tool_calls_partial_results() {
let mut conv = Conversation::new();
conv.add_user_message("读取 a.rs 和 b.rs");
conv.add_assistant_tool_calls(
None,
vec![
ToolCall {
id: "call_1".into(),
name: "read_file".into(),
arguments: r#"{"file_path":"a.rs"}"#.into(),
},
ToolCall {
id: "call_2".into(),
name: "read_file".into(),
arguments: r#"{"file_path":"b.rs"}"#.into(),
},
],
None,
);
conv.add_tool_result(ToolResult {
call_id: "call_1".into(),
output: "a content".into(),
success: true,
});
conv.cancel_current_turn();
assert_eq!(conv.messages.len(), 4); if let MessageContent::ToolResult(r) = &conv.messages[3].content {
assert_eq!(r.call_id, "call_2");
assert!(!r.success);
assert_eq!(r.output, "(cancelled)");
}
}
#[test]
fn cancel_backfill_recognises_tool_result_ref() {
use crate::tool::result_store::ToolResultRef;
let mut conv = Conversation::new();
conv.add_user_message("读取 big_file.rs");
conv.add_assistant_tool_calls(
Some("reading"),
vec![ToolCall {
id: "call_1".into(),
name: "read_file".into(),
arguments: r#"{"file_path":"big_file.rs"}"#.into(),
}],
None,
);
let idx = conv.messages.len();
conv.messages.push(Message {
role: Role::Tool,
content: MessageContent::ToolResultRef(ToolResultRef {
call_id: "call_1".into(),
hash: "abc123".into(),
summary: "500 lines of Rust code".into(),
byte_size: 20_000,
success: true,
}),
});
conv.turn_tracker.on_message_added(idx);
conv.add_assistant_tool_calls(
None,
vec![ToolCall {
id: "call_2".into(),
name: "edit_file".into(),
arguments: r#"{"file_path":"big_file.rs"}"#.into(),
}],
None,
);
conv.cancel_current_turn();
assert_eq!(conv.messages.len(), 5);
if let MessageContent::ToolResult(r) = &conv.messages[4].content {
assert_eq!(r.call_id, "call_2");
assert!(!r.success);
assert_eq!(r.output, "(cancelled)");
} else {
panic!("expected ToolResult for call_2");
}
}
#[test]
fn cancel_double_cancel_is_noop() {
let mut conv = Conversation::new();
conv.add_user_message("hello");
conv.push_delta("world");
conv.finalize_stream();
conv.cancel_current_turn();
assert_eq!(conv.messages.len(), 2);
conv.cancel_current_turn();
assert_eq!(conv.messages.len(), 2);
}
#[test]
fn cancel_after_including_user_is_noop() {
let mut conv = Conversation::new();
conv.add_user_message("hello");
conv.push_delta("partial");
conv.finalize_stream();
conv.cancel_current_turn_including_user();
assert!(conv.messages.is_empty());
conv.cancel_current_turn();
assert!(conv.messages.is_empty());
assert!(conv.turn_tracker.turns.is_empty());
}
#[test]
fn cancel_including_user_clears_stream_buffer() {
let mut conv = Conversation::new();
conv.add_user_message("hello");
conv.push_delta("partial response still streaming");
assert!(conv.stream_buffer.is_some());
conv.cancel_current_turn_including_user();
assert!(conv.stream_buffer.is_none(), "stream_buffer must be cleared");
assert!(conv.messages.is_empty());
}
#[test]
fn cancel_including_user_clears_tool_call_buffer() {
use crate::tool::ToolCallBuffer;
let mut conv = Conversation::new();
conv.add_user_message("hello");
conv.tool_call_buffer = Some(ToolCallBuffer {
id: "call_partial".into(),
name: "bash".into(),
arguments: r#"{"command":"ls"}"#.into(),
hint_sent: false,
});
assert!(conv.tool_call_buffer.is_some());
conv.cancel_current_turn_including_user();
assert!(conv.tool_call_buffer.is_none(), "tool_call_buffer must be cleared");
}
#[test]
fn cancel_including_user_on_completed_turn_is_noop() {
let mut conv = Conversation::new();
conv.add_user_message("hello");
conv.push_delta("world");
conv.finalize_stream();
conv.turn_tracker.complete_current();
assert_eq!(conv.messages.len(), 2);
assert_eq!(conv.turn_tracker.turns.len(), 1);
conv.cancel_current_turn_including_user();
assert_eq!(conv.messages.len(), 2, "completed turn must not be removed");
assert_eq!(conv.turn_tracker.turns.len(), 1);
}
#[test]
fn cancel_including_user_then_new_turn_produces_valid_messages() {
let mut conv = Conversation::new();
conv.add_user_message("bad prompt");
conv.push_delta("bad response");
conv.finalize_stream();
conv.cancel_current_turn_including_user();
conv.add_user_message("good prompt");
conv.push_delta("good response");
conv.finalize_stream();
conv.turn_tracker.complete_current();
let msgs = conv.to_provider_messages("system");
assert_eq!(msgs.len(), 3);
assert!(matches!(msgs[0].role, Role::System));
assert!(matches!(msgs[1].role, Role::User));
assert!(matches!(msgs[2].role, Role::Assistant));
}
#[test]
fn cancel_backfill_all_tool_result_refs() {
use crate::tool::result_store::ToolResultRef;
let mut conv = Conversation::new();
conv.add_user_message("读取大文件");
conv.add_assistant_tool_calls(
None,
vec![
ToolCall {
id: "call_1".into(),
name: "read_file".into(),
arguments: r#"{"file_path":"a.rs"}"#.into(),
},
ToolCall {
id: "call_2".into(),
name: "read_file".into(),
arguments: r#"{"file_path":"b.rs"}"#.into(),
},
],
None,
);
for (call_id, summary) in [("call_1", "a.rs content"), ("call_2", "b.rs content")] {
let idx = conv.messages.len();
conv.messages.push(Message {
role: Role::Tool,
content: MessageContent::ToolResultRef(ToolResultRef {
call_id: call_id.into(),
hash: format!("hash_{}", call_id),
summary: summary.into(),
byte_size: 10_000,
success: true,
}),
});
conv.turn_tracker.on_message_added(idx);
}
conv.cancel_current_turn();
assert_eq!(conv.messages.len(), 4);
}
#[test]
fn cancel_backfill_mixed_result_types() {
use crate::tool::result_store::ToolResultRef;
let mut conv = Conversation::new();
conv.add_user_message("读取文件并编辑");
conv.add_assistant_tool_calls(
None,
vec![
ToolCall {
id: "call_1".into(),
name: "read_file".into(),
arguments: r#"{"file_path":"x.rs"}"#.into(),
},
ToolCall {
id: "call_2".into(),
name: "bash".into(),
arguments: r#"{"command":"make"}"#.into(),
},
ToolCall {
id: "call_3".into(),
name: "edit_file".into(),
arguments: r#"{"file_path":"x.rs"}"#.into(),
},
],
None,
);
conv.add_tool_result(ToolResult {
call_id: "call_1".into(),
output: "file content".into(),
success: true,
});
let idx = conv.messages.len();
conv.messages.push(Message {
role: Role::Tool,
content: MessageContent::ToolResultRef(ToolResultRef {
call_id: "call_2".into(),
hash: "hash_call_2".into(),
summary: "make output".into(),
byte_size: 50_000,
success: true,
}),
});
conv.turn_tracker.on_message_added(idx);
conv.cancel_current_turn();
assert_eq!(conv.messages.len(), 5);
if let MessageContent::ToolResult(r) = &conv.messages[4].content {
assert_eq!(r.call_id, "call_3");
assert!(!r.success);
assert_eq!(r.output, "(cancelled)");
} else {
panic!("expected ToolResult for call_3");
}
}
#[test]
fn cancel_then_provider_messages_are_api_legal() {
let mut conv = Conversation::new();
conv.add_user_message("读取 main.rs 然后修改它");
conv.add_assistant_tool_calls(
Some("reading file"),
vec![ToolCall {
id: "call_1".into(),
name: "read_file".into(),
arguments: r#"{"file_path":"main.rs"}"#.into(),
}],
None,
);
conv.add_tool_result(ToolResult {
call_id: "call_1".into(),
output: "fn main() {}".into(),
success: true,
});
conv.add_assistant_tool_calls(
Some("editing file"),
vec![ToolCall {
id: "call_2".into(),
name: "edit_file".into(),
arguments: r#"{"file_path":"main.rs"}"#.into(),
}],
None,
);
conv.cancel_current_turn();
let msgs = conv.to_provider_messages("You are helpful.");
assert!(matches!(msgs[0].role, Role::System));
assert!(matches!(msgs[1].role, Role::User));
assert!(matches!(msgs[2].role, Role::Assistant));
assert!(matches!(msgs[3].role, Role::Tool));
assert!(matches!(msgs[4].role, Role::Assistant));
assert!(matches!(msgs[5].role, Role::Tool));
let mut expected_call_ids: Vec<String> = Vec::new();
for msg in &msgs {
if let MessageContent::AssistantWithToolCalls { tool_calls, .. } = &msg.content {
for tc in tool_calls {
expected_call_ids.push(tc.id.clone());
}
}
}
let mut got_call_ids: Vec<String> = Vec::new();
for msg in &msgs {
if let Some(id) = msg.tool_result_call_id() {
got_call_ids.push(id.to_string());
}
}
assert_eq!(
expected_call_ids, got_call_ids,
"every tool call must have a matching result"
);
}
#[test]
fn cancel_then_follow_up_full_sequence_api_legal() {
let mut conv = Conversation::new();
conv.add_user_message("你好");
conv.push_delta("你好!");
conv.finalize_stream();
conv.turn_tracker.complete_current();
conv.add_user_message("读取 main.rs");
conv.add_assistant_tool_calls(
None,
vec![ToolCall {
id: "call_1".into(),
name: "read_file".into(),
arguments: "{}".into(),
}],
None,
);
conv.cancel_current_turn();
conv.add_user_message("不要修改那行");
conv.push_delta("好的,我只添加新代码");
conv.finalize_stream();
conv.turn_tracker.complete_current();
let msgs = conv.to_provider_messages("system");
for i in 1..msgs.len() {
if matches!(msgs[i].role, Role::User) {
assert!(
!matches!(msgs[i - 1].role, Role::User),
"consecutive User messages at index {}-{} are illegal",
i - 1,
i
);
}
}
let mut expected: Vec<String> = Vec::new();
for msg in &msgs {
if let MessageContent::AssistantWithToolCalls { tool_calls, .. } = &msg.content {
for tc in tool_calls {
expected.push(tc.id.clone());
}
}
}
let mut got: Vec<String> = Vec::new();
for msg in &msgs {
if let Some(id) = msg.tool_result_call_id() {
got.push(id.to_string());
}
}
assert_eq!(expected, got, "all tool calls must have matching results");
}
#[test]
fn cancel_updates_turn_tracker_correctly() {
let mut conv = Conversation::new();
conv.add_user_message("hello");
conv.add_assistant_tool_calls(
None,
vec![ToolCall {
id: "call_1".into(),
name: "bash".into(),
arguments: "{}".into(),
}],
None,
);
assert_eq!(conv.turn_tracker.turns.len(), 1);
assert_eq!(conv.turn_tracker.turns[0].status, TurnStatus::Active);
assert_eq!(conv.turn_tracker.turns[0].msg_count, 2);
conv.cancel_current_turn();
assert_eq!(conv.messages.len(), 3);
assert_eq!(conv.turn_tracker.turns[0].status, TurnStatus::Completed);
assert_eq!(
conv.turn_tracker.turns[0].msg_count, 3,
"msg_count must include the backfilled result"
);
}
#[test]
fn cancel_including_user_removes_turn_not_just_marks() {
let mut conv = Conversation::new();
conv.add_user_message("hello");
conv.push_delta("hi");
conv.finalize_stream();
conv.turn_tracker.complete_current();
conv.add_user_message("bad");
conv.push_delta("oops");
conv.finalize_stream();
assert_eq!(conv.turn_tracker.turns.len(), 2);
conv.cancel_current_turn_including_user();
assert_eq!(conv.turn_tracker.turns.len(), 1);
assert_eq!(conv.turn_tracker.turns[0].status, TurnStatus::Completed);
}
}