use std::collections::HashSet;
use crate::channel::Channel;
use zeph_llm::provider::{Message, MessagePart, Role};
use zeph_memory::sqlite::role_str;
use super::Agent;
fn sanitize_tool_pairs(messages: &mut Vec<Message>) -> usize {
let mut removed = 0;
loop {
if let Some(last) = messages.last()
&& last.role == Role::Assistant
&& last
.parts
.iter()
.any(|p| matches!(p, MessagePart::ToolUse { .. }))
{
let ids: Vec<String> = last
.parts
.iter()
.filter_map(|p| {
if let MessagePart::ToolUse { id, .. } = p {
Some(id.clone())
} else {
None
}
})
.collect();
tracing::warn!(
tool_ids = ?ids,
"removing orphaned trailing tool_use message from restored history"
);
messages.pop();
removed += 1;
continue;
}
if let Some(first) = messages.first()
&& first.role == Role::User
&& first
.parts
.iter()
.any(|p| matches!(p, MessagePart::ToolResult { .. }))
{
let ids: Vec<String> = first
.parts
.iter()
.filter_map(|p| {
if let MessagePart::ToolResult { tool_use_id, .. } = p {
Some(tool_use_id.clone())
} else {
None
}
})
.collect();
tracing::warn!(
tool_use_ids = ?ids,
"removing orphaned leading tool_result message from restored history"
);
messages.remove(0);
removed += 1;
continue;
}
break;
}
removed += strip_mid_history_orphans(messages);
removed
}
#[allow(clippy::too_many_lines)]
fn strip_mid_history_orphans(messages: &mut Vec<Message>) -> usize {
let mut removed = 0;
let mut i = 0;
while i < messages.len() {
if messages[i].role == Role::Assistant
&& messages[i]
.parts
.iter()
.any(|p| matches!(p, MessagePart::ToolUse { .. }))
{
let matched_ids: HashSet<String> = messages
.get(i + 1)
.filter(|next| next.role == Role::User)
.map(|next| {
messages[i]
.parts
.iter()
.filter_map(|p| {
if let MessagePart::ToolUse { id, .. } = p {
Some(id.clone())
} else {
None
}
})
.filter(|uid| {
next.parts.iter().any(|np| {
matches!(np, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == uid)
})
})
.collect()
})
.unwrap_or_default();
let orphaned_ids: HashSet<String> = messages[i]
.parts
.iter()
.filter_map(|p| {
if let MessagePart::ToolUse { id, .. } = p
&& !matched_ids.contains(id)
{
return Some(id.clone());
}
None
})
.collect();
if !orphaned_ids.is_empty() {
tracing::warn!(
tool_ids = ?orphaned_ids,
index = i,
"stripping orphaned mid-history tool_use parts from assistant message"
);
messages[i].parts.retain(|p| {
!matches!(
p,
MessagePart::ToolUse { id, .. } if orphaned_ids.contains(id)
)
});
let is_empty =
messages[i].content.trim().is_empty() && messages[i].parts.is_empty();
if is_empty {
messages.remove(i);
removed += 1;
continue;
}
}
}
if messages[i].role == Role::User
&& messages[i]
.parts
.iter()
.any(|p| matches!(p, MessagePart::ToolResult { .. }))
{
let preceding_tool_use_ids: HashSet<&str> =
if i > 0 && messages[i - 1].role == Role::Assistant {
messages[i - 1]
.parts
.iter()
.filter_map(|p| {
if let MessagePart::ToolUse { id, .. } = p {
Some(id.as_str())
} else {
None
}
})
.collect()
} else {
HashSet::new()
};
let orphaned_ids: HashSet<String> = messages[i]
.parts
.iter()
.filter_map(|p| {
if let MessagePart::ToolResult { tool_use_id, .. } = p
&& !preceding_tool_use_ids.contains(tool_use_id.as_str())
{
return Some(tool_use_id.clone());
}
None
})
.collect();
if !orphaned_ids.is_empty() {
tracing::warn!(
tool_use_ids = ?orphaned_ids,
index = i,
"stripping orphaned mid-history tool_result parts from user message"
);
messages[i].parts.retain(|p| {
!matches!(
p,
MessagePart::ToolResult { tool_use_id, .. }
if orphaned_ids.contains(tool_use_id.as_str())
)
});
let is_empty =
messages[i].content.trim().is_empty() && messages[i].parts.is_empty();
if is_empty {
messages.remove(i);
removed += 1;
continue;
}
}
}
i += 1;
}
removed
}
impl<C: Channel> Agent<C> {
pub async fn load_history(&mut self) -> Result<(), super::error::AgentError> {
let (Some(memory), Some(cid)) =
(&self.memory_state.memory, self.memory_state.conversation_id)
else {
return Ok(());
};
let history = memory
.sqlite()
.load_history_filtered(cid, self.memory_state.history_limit, Some(true), None)
.await?;
if !history.is_empty() {
let mut loaded = 0;
let mut skipped = 0;
for msg in history {
if msg.content.trim().is_empty() && msg.parts.is_empty() {
tracing::warn!("skipping empty message from history (role: {:?})", msg.role);
skipped += 1;
continue;
}
self.messages.push(msg);
loaded += 1;
}
let history_start = self.messages.len() - loaded;
let mut restored_slice = self.messages.split_off(history_start);
let orphans = sanitize_tool_pairs(&mut restored_slice);
skipped += orphans;
loaded = loaded.saturating_sub(orphans);
self.messages.append(&mut restored_slice);
tracing::info!("restored {loaded} message(s) from conversation {cid}");
if skipped > 0 {
tracing::warn!("skipped {skipped} empty/orphaned message(s) from history");
}
}
if let Ok(count) = memory.message_count(cid).await {
let count_u64 = u64::try_from(count).unwrap_or(0);
self.update_metrics(|m| {
m.sqlite_message_count = count_u64;
});
}
if let Ok(count) = memory.unsummarized_message_count(cid).await {
self.memory_state.unsummarized_count = usize::try_from(count).unwrap_or(0);
}
self.recompute_prompt_tokens();
Ok(())
}
pub(crate) async fn persist_message(
&mut self,
role: Role,
content: &str,
parts: &[MessagePart],
has_injection_flags: bool,
) {
let (Some(memory), Some(cid)) =
(&self.memory_state.memory, self.memory_state.conversation_id)
else {
return;
};
let parts_json = if parts.is_empty() {
"[]".to_string()
} else {
serde_json::to_string(parts).unwrap_or_else(|e| {
tracing::warn!("failed to serialize message parts, storing empty: {e}");
"[]".to_string()
})
};
let guard_event = self
.exfiltration_guard
.should_guard_memory_write(has_injection_flags);
if let Some(ref event) = guard_event {
tracing::warn!(
?event,
"exfiltration guard: skipping Qdrant embedding for flagged content"
);
self.update_metrics(|m| m.exfiltration_memory_guards += 1);
self.push_security_event(
crate::metrics::SecurityEventCategory::ExfiltrationBlock,
"memory_write",
"Qdrant embedding skipped: flagged content",
);
}
let skip_embedding = guard_event.is_some();
let should_embed = if skip_embedding {
false
} else {
match role {
Role::Assistant => {
self.memory_state.autosave_assistant
&& content.len() >= self.memory_state.autosave_min_length
}
_ => true,
}
};
let embedding_stored = if should_embed {
match memory
.remember_with_parts(cid, role_str(role), content, &parts_json)
.await
{
Ok((_message_id, stored)) => stored,
Err(e) => {
tracing::error!("failed to persist message: {e:#}");
return;
}
}
} else {
match memory
.save_only(cid, role_str(role), content, &parts_json)
.await
{
Ok(_) => false,
Err(e) => {
tracing::error!("failed to persist message: {e:#}");
return;
}
}
};
self.memory_state.unsummarized_count += 1;
self.update_metrics(|m| {
m.sqlite_message_count += 1;
if embedding_stored {
m.embeddings_generated += 1;
}
});
self.check_summarization().await;
self.maybe_spawn_graph_extraction(content, has_injection_flags)
.await;
}
async fn maybe_spawn_graph_extraction(&mut self, content: &str, has_injection_flags: bool) {
use zeph_memory::semantic::GraphExtractionConfig;
if self.memory_state.memory.is_none() || self.memory_state.conversation_id.is_none() {
return;
}
if has_injection_flags {
tracing::warn!("graph extraction skipped: injection patterns detected in content");
return;
}
let extraction_cfg = {
let cfg = &self.memory_state.graph_config;
if !cfg.enabled {
return;
}
GraphExtractionConfig {
max_entities: cfg.max_entities_per_message,
max_edges: cfg.max_edges_per_message,
extraction_timeout_secs: cfg.extraction_timeout_secs,
community_refresh_interval: cfg.community_refresh_interval,
expired_edge_retention_days: cfg.expired_edge_retention_days,
max_entities_cap: cfg.max_entities,
community_summary_max_prompt_bytes: cfg.community_summary_max_prompt_bytes,
community_summary_concurrency: cfg.community_summary_concurrency,
}
};
let context_messages: Vec<String> = self
.messages
.iter()
.rev()
.filter(|m| m.role == Role::User)
.take(4)
.map(|m| m.content.clone())
.collect();
let _ = self.channel.send_status("extracting graph...").await;
if let Some(memory) = &self.memory_state.memory {
memory.spawn_graph_extraction(content.to_owned(), context_messages, extraction_cfg);
}
self.sync_community_detection_failures();
self.sync_graph_extraction_metrics();
self.sync_graph_counts().await;
}
pub(crate) async fn check_summarization(&mut self) {
let (Some(memory), Some(cid)) =
(&self.memory_state.memory, self.memory_state.conversation_id)
else {
return;
};
if self.memory_state.unsummarized_count > self.memory_state.summarization_threshold {
let _ = self.channel.send_status("summarizing...").await;
let batch_size = self.memory_state.summarization_threshold / 2;
match memory.summarize(cid, batch_size).await {
Ok(Some(summary_id)) => {
tracing::info!("created summary {summary_id} for conversation {cid}");
self.memory_state.unsummarized_count = 0;
self.update_metrics(|m| {
m.summaries_count += 1;
});
}
Ok(None) => {
tracing::debug!("no summarization needed");
}
Err(e) => {
tracing::error!("summarization failed: {e:#}");
}
}
let _ = self.channel.send_status("").await;
}
}
}
#[cfg(test)]
mod tests {
use super::super::agent_tests::{
MetricsSnapshot, MockChannel, MockToolExecutor, create_test_registry, mock_provider,
};
use super::*;
use zeph_llm::any::AnyProvider;
use zeph_memory::semantic::SemanticMemory;
async fn test_memory(provider: &AnyProvider) -> SemanticMemory {
SemanticMemory::new(
":memory:",
"http://127.0.0.1:1",
provider.clone(),
"test-model",
)
.await
.unwrap()
}
#[tokio::test]
async fn load_history_without_memory_returns_ok() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
let result = agent.load_history().await;
assert!(result.is_ok());
assert_eq!(agent.messages.len(), 1); }
#[tokio::test]
async fn load_history_with_messages_injects_into_agent() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
memory
.sqlite()
.save_message(cid, "user", "hello from history")
.await
.unwrap();
memory
.sqlite()
.save_message(cid, "assistant", "hi back")
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(agent.messages.len(), messages_before + 2);
}
#[tokio::test]
async fn load_history_skips_empty_messages() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
memory
.sqlite()
.save_message(cid, "user", " ")
.await
.unwrap();
memory
.sqlite()
.save_message(cid, "user", "real message")
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(agent.messages.len(), messages_before + 1);
}
#[tokio::test]
async fn load_history_with_empty_store_returns_ok() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(agent.messages.len(), messages_before);
}
#[tokio::test]
async fn persist_message_without_memory_silently_returns() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent.persist_message(Role::User, "hello", &[], false).await;
}
#[tokio::test]
async fn persist_message_assistant_autosave_false_uses_save_only() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
.with_metrics(tx)
.with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
.with_autosave_config(false, 20);
agent
.persist_message(Role::Assistant, "short assistant reply", &[], false)
.await;
let history = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.load_history(cid, 50)
.await
.unwrap();
assert_eq!(history.len(), 1, "message must be saved");
assert_eq!(history[0].content, "short assistant reply");
assert_eq!(rx.borrow().embeddings_generated, 0);
}
#[tokio::test]
async fn persist_message_assistant_below_min_length_uses_save_only() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
.with_metrics(tx)
.with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
.with_autosave_config(true, 1000);
agent
.persist_message(Role::Assistant, "too short", &[], false)
.await;
let history = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.load_history(cid, 50)
.await
.unwrap();
assert_eq!(history.len(), 1, "message must be saved");
assert_eq!(history[0].content, "too short");
assert_eq!(rx.borrow().embeddings_generated, 0);
}
#[tokio::test]
async fn persist_message_assistant_at_min_length_boundary_uses_embed() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let min_length = 10usize;
let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
.with_metrics(tx)
.with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
.with_autosave_config(true, min_length);
let content_at_boundary = "A".repeat(min_length);
assert_eq!(content_at_boundary.len(), min_length);
agent
.persist_message(Role::Assistant, &content_at_boundary, &[], false)
.await;
assert_eq!(rx.borrow().sqlite_message_count, 1);
}
#[tokio::test]
async fn persist_message_assistant_one_below_min_length_uses_save_only() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let min_length = 10usize;
let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
.with_metrics(tx)
.with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
.with_autosave_config(true, min_length);
let content_below_boundary = "A".repeat(min_length - 1);
assert_eq!(content_below_boundary.len(), min_length - 1);
agent
.persist_message(Role::Assistant, &content_below_boundary, &[], false)
.await;
let history = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.load_history(cid, 50)
.await
.unwrap();
assert_eq!(history.len(), 1, "message must still be saved");
assert_eq!(rx.borrow().embeddings_generated, 0);
}
#[tokio::test]
async fn persist_message_increments_unsummarized_count() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
assert_eq!(agent.memory_state.unsummarized_count, 0);
agent.persist_message(Role::User, "first", &[], false).await;
assert_eq!(agent.memory_state.unsummarized_count, 1);
agent
.persist_message(Role::User, "second", &[], false)
.await;
assert_eq!(agent.memory_state.unsummarized_count, 2);
}
#[tokio::test]
async fn check_summarization_resets_counter_on_success() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
1,
);
agent.persist_message(Role::User, "msg1", &[], false).await;
agent.persist_message(Role::User, "msg2", &[], false).await;
assert!(agent.memory_state.unsummarized_count <= 2);
}
#[tokio::test]
async fn unsummarized_count_not_incremented_without_memory() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent.persist_message(Role::User, "hello", &[], false).await;
assert_eq!(agent.memory_state.unsummarized_count, 0);
}
mod graph_extraction_guards {
use super::*;
use crate::config::GraphConfig;
use zeph_memory::graph::GraphStore;
fn enabled_graph_config() -> GraphConfig {
GraphConfig {
enabled: true,
..GraphConfig::default()
}
}
async fn agent_with_graph(
provider: &AnyProvider,
config: GraphConfig,
) -> Agent<MockChannel> {
let memory =
test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
Agent::new(
provider.clone(),
MockChannel::new(vec![]),
create_test_registry(),
None,
5,
MockToolExecutor::no_tools(),
)
.with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
.with_graph_config(config)
}
#[tokio::test]
async fn injection_flag_guard_skips_extraction() {
let provider = mock_provider(vec![]);
let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
let pool = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.pool()
.clone();
agent.maybe_spawn_graph_extraction("I use Rust", true).await;
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let store = GraphStore::new(pool);
let count = store.get_metadata("extraction_count").await.unwrap();
assert!(
count.is_none(),
"injection flag must prevent extraction_count from being written"
);
}
#[tokio::test]
async fn disabled_config_guard_skips_extraction() {
let provider = mock_provider(vec![]);
let disabled_cfg = GraphConfig {
enabled: false,
..GraphConfig::default()
};
let mut agent = agent_with_graph(&provider, disabled_cfg).await;
let pool = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.pool()
.clone();
agent
.maybe_spawn_graph_extraction("I use Rust", false)
.await;
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let store = GraphStore::new(pool);
let count = store.get_metadata("extraction_count").await.unwrap();
assert!(
count.is_none(),
"disabled graph config must prevent extraction"
);
}
#[tokio::test]
async fn happy_path_fires_extraction() {
let provider = mock_provider(vec![]);
let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
let pool = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.pool()
.clone();
agent
.maybe_spawn_graph_extraction("I use Rust for systems programming", false)
.await;
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let store = GraphStore::new(pool);
let count = store.get_metadata("extraction_count").await.unwrap();
assert!(
count.is_some(),
"happy-path extraction must increment extraction_count"
);
}
}
#[tokio::test]
async fn persist_message_user_always_embeds_regardless_of_autosave_flag() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
.with_metrics(tx)
.with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
.with_autosave_config(false, 20);
let long_user_msg = "A".repeat(100);
agent
.persist_message(Role::User, &long_user_msg, &[], false)
.await;
let history = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.load_history(cid, 50)
.await
.unwrap();
assert_eq!(history.len(), 1, "user message must be saved");
assert_eq!(rx.borrow().sqlite_message_count, 1);
}
#[tokio::test]
async fn persist_message_saves_correct_tool_use_parts() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let parts = vec![MessagePart::ToolUse {
id: "call_abc123".to_string(),
name: "read_file".to_string(),
input: serde_json::json!({"path": "/tmp/test.txt"}),
}];
let content = "[tool_use: read_file(call_abc123)]";
agent
.persist_message(Role::Assistant, content, &parts, false)
.await;
let history = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.load_history(cid, 50)
.await
.unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, Role::Assistant);
assert_eq!(history[0].content, content);
assert_eq!(history[0].parts.len(), 1);
match &history[0].parts[0] {
MessagePart::ToolUse { id, name, .. } => {
assert_eq!(id, "call_abc123");
assert_eq!(name, "read_file");
}
other => panic!("expected ToolUse part, got {other:?}"),
}
assert!(
!history[0]
.parts
.iter()
.any(|p| matches!(p, MessagePart::ToolResult { .. })),
"assistant message must not contain ToolResult parts"
);
}
#[tokio::test]
async fn persist_message_saves_correct_tool_result_parts() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let parts = vec![MessagePart::ToolResult {
tool_use_id: "call_abc123".to_string(),
content: "file contents here".to_string(),
is_error: false,
}];
let content = "[tool_result: call_abc123]\nfile contents here";
agent
.persist_message(Role::User, content, &parts, false)
.await;
let history = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.load_history(cid, 50)
.await
.unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, Role::User);
assert_eq!(history[0].content, content);
assert_eq!(history[0].parts.len(), 1);
match &history[0].parts[0] {
MessagePart::ToolResult {
tool_use_id,
content: result_content,
is_error,
} => {
assert_eq!(tool_use_id, "call_abc123");
assert_eq!(result_content, "file contents here");
assert!(!is_error);
}
other => panic!("expected ToolResult part, got {other:?}"),
}
assert!(
!history[0]
.parts
.iter()
.any(|p| matches!(p, MessagePart::ToolUse { .. })),
"user ToolResult message must not contain ToolUse parts"
);
}
#[tokio::test]
async fn persist_message_roundtrip_preserves_role_part_alignment() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let assistant_parts = vec![MessagePart::ToolUse {
id: "id_1".to_string(),
name: "list_dir".to_string(),
input: serde_json::json!({"path": "/tmp"}),
}];
agent
.persist_message(
Role::Assistant,
"[tool_use: list_dir(id_1)]",
&assistant_parts,
false,
)
.await;
let user_parts = vec![MessagePart::ToolResult {
tool_use_id: "id_1".to_string(),
content: "file1.txt\nfile2.txt".to_string(),
is_error: false,
}];
agent
.persist_message(
Role::User,
"[tool_result: id_1]\nfile1.txt\nfile2.txt",
&user_parts,
false,
)
.await;
let history = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.load_history(cid, 50)
.await
.unwrap();
assert_eq!(history.len(), 2);
assert_eq!(history[0].role, Role::Assistant);
assert_eq!(history[0].content, "[tool_use: list_dir(id_1)]");
assert!(
matches!(&history[0].parts[0], MessagePart::ToolUse { id, .. } if id == "id_1"),
"first message must be assistant ToolUse"
);
assert_eq!(history[1].role, Role::User);
assert_eq!(
history[1].content,
"[tool_result: id_1]\nfile1.txt\nfile2.txt"
);
assert!(
matches!(&history[1].parts[0], MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "id_1"),
"second message must be user ToolResult"
);
assert!(
!history[0]
.parts
.iter()
.any(|p| matches!(p, MessagePart::ToolResult { .. })),
"assistant message must not have ToolResult parts"
);
assert!(
!history[1]
.parts
.iter()
.any(|p| matches!(p, MessagePart::ToolUse { .. })),
"user message must not have ToolUse parts"
);
}
#[tokio::test]
async fn persist_message_saves_correct_tool_output_parts() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let parts = vec![MessagePart::ToolOutput {
tool_name: "shell".to_string(),
body: "hello from shell".to_string(),
compacted_at: None,
}];
let content = "[tool: shell]\nhello from shell";
agent
.persist_message(Role::User, content, &parts, false)
.await;
let history = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.load_history(cid, 50)
.await
.unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, Role::User);
assert_eq!(history[0].content, content);
assert_eq!(history[0].parts.len(), 1);
match &history[0].parts[0] {
MessagePart::ToolOutput {
tool_name,
body,
compacted_at,
} => {
assert_eq!(tool_name, "shell");
assert_eq!(body, "hello from shell");
assert!(compacted_at.is_none());
}
other => panic!("expected ToolOutput part, got {other:?}"),
}
}
#[tokio::test]
async fn load_history_removes_trailing_orphan_tool_use() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
sqlite
.save_message(cid, "user", "do something with a tool")
.await
.unwrap();
let parts = serde_json::to_string(&[MessagePart::ToolUse {
id: "call_orphan".to_string(),
name: "shell".to_string(),
input: serde_json::json!({"command": "ls"}),
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "assistant", "[tool_use: shell(call_orphan)]", &parts)
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(
agent.messages.len(),
messages_before + 1,
"orphaned trailing tool_use must be removed"
);
assert_eq!(agent.messages.last().unwrap().role, Role::User);
}
#[tokio::test]
async fn load_history_removes_leading_orphan_tool_result() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
tool_use_id: "call_missing".to_string(),
content: "result data".to_string(),
is_error: false,
}])
.unwrap();
sqlite
.save_message_with_parts(
cid,
"user",
"[tool_result: call_missing]\nresult data",
&result_parts,
)
.await
.unwrap();
sqlite
.save_message(cid, "assistant", "here is my response")
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(
agent.messages.len(),
messages_before + 1,
"orphaned leading tool_result must be removed"
);
assert_eq!(agent.messages.last().unwrap().role, Role::Assistant);
}
#[tokio::test]
async fn load_history_preserves_complete_tool_pairs() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
id: "call_ok".to_string(),
name: "shell".to_string(),
input: serde_json::json!({"command": "pwd"}),
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "assistant", "[tool_use: shell(call_ok)]", &use_parts)
.await
.unwrap();
let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
tool_use_id: "call_ok".to_string(),
content: "/home/user".to_string(),
is_error: false,
}])
.unwrap();
sqlite
.save_message_with_parts(
cid,
"user",
"[tool_result: call_ok]\n/home/user",
&result_parts,
)
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(
agent.messages.len(),
messages_before + 2,
"complete tool_use/tool_result pair must be preserved"
);
assert_eq!(agent.messages[messages_before].role, Role::Assistant);
assert_eq!(agent.messages[messages_before + 1].role, Role::User);
}
#[tokio::test]
async fn load_history_handles_multiple_trailing_orphans() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
sqlite.save_message(cid, "user", "start").await.unwrap();
let parts1 = serde_json::to_string(&[MessagePart::ToolUse {
id: "call_1".to_string(),
name: "shell".to_string(),
input: serde_json::json!({}),
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "assistant", "[tool_use: shell(call_1)]", &parts1)
.await
.unwrap();
let parts2 = serde_json::to_string(&[MessagePart::ToolUse {
id: "call_2".to_string(),
name: "read_file".to_string(),
input: serde_json::json!({}),
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "assistant", "[tool_use: read_file(call_2)]", &parts2)
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(
agent.messages.len(),
messages_before + 1,
"all trailing orphaned tool_use messages must be removed"
);
assert_eq!(agent.messages.last().unwrap().role, Role::User);
}
#[tokio::test]
async fn load_history_no_tool_messages_unchanged() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
sqlite.save_message(cid, "user", "hello").await.unwrap();
sqlite
.save_message(cid, "assistant", "hi there")
.await
.unwrap();
sqlite
.save_message(cid, "user", "how are you?")
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(
agent.messages.len(),
messages_before + 3,
"plain messages without tool parts must pass through unchanged"
);
}
#[tokio::test]
async fn load_history_removes_both_leading_and_trailing_orphans() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
tool_use_id: "call_leading".to_string(),
content: "orphaned result".to_string(),
is_error: false,
}])
.unwrap();
sqlite
.save_message_with_parts(
cid,
"user",
"[tool_result: call_leading]\norphaned result",
&result_parts,
)
.await
.unwrap();
sqlite
.save_message(cid, "user", "what is 2+2?")
.await
.unwrap();
sqlite.save_message(cid, "assistant", "4").await.unwrap();
let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
id: "call_trailing".to_string(),
name: "shell".to_string(),
input: serde_json::json!({"command": "date"}),
}])
.unwrap();
sqlite
.save_message_with_parts(
cid,
"assistant",
"[tool_use: shell(call_trailing)]",
&use_parts,
)
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(
agent.messages.len(),
messages_before + 2,
"both leading and trailing orphans must be removed"
);
assert_eq!(agent.messages[messages_before].role, Role::User);
assert_eq!(agent.messages[messages_before].content, "what is 2+2?");
assert_eq!(agent.messages[messages_before + 1].role, Role::Assistant);
assert_eq!(agent.messages[messages_before + 1].content, "4");
}
#[tokio::test]
async fn sanitize_tool_pairs_strips_mid_history_orphan_tool_use() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
sqlite
.save_message(cid, "user", "first question")
.await
.unwrap();
sqlite
.save_message(cid, "assistant", "first answer")
.await
.unwrap();
let use_parts = serde_json::to_string(&[
MessagePart::ToolUse {
id: "call_mid_1".to_string(),
name: "shell".to_string(),
input: serde_json::json!({"command": "ls"}),
},
MessagePart::Text {
text: "Let me check the files.".to_string(),
},
])
.unwrap();
sqlite
.save_message_with_parts(cid, "assistant", "Let me check the files.", &use_parts)
.await
.unwrap();
sqlite
.save_message(cid, "user", "second question")
.await
.unwrap();
sqlite
.save_message(cid, "assistant", "second answer")
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(
agent.messages.len(),
messages_before + 5,
"message count must be 5 (orphan message kept — has text content)"
);
let orphan = &agent.messages[messages_before + 2];
assert_eq!(orphan.role, Role::Assistant);
assert!(
!orphan
.parts
.iter()
.any(|p| matches!(p, MessagePart::ToolUse { .. })),
"orphaned ToolUse parts must be stripped from mid-history message"
);
assert!(
orphan.parts.iter().any(
|p| matches!(p, MessagePart::Text { text } if text == "Let me check the files.")
),
"text content of orphaned assistant message must be preserved"
);
}
#[tokio::test]
async fn load_history_keeps_tool_only_user_message() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
id: "call_rc3".to_string(),
name: "memory_save".to_string(),
input: serde_json::json!({"content": "something"}),
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "assistant", "[tool_use: memory_save]", &use_parts)
.await
.unwrap();
let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
tool_use_id: "call_rc3".to_string(),
content: "saved".to_string(),
is_error: false,
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "user", "", &result_parts)
.await
.unwrap();
sqlite.save_message(cid, "assistant", "done").await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(
agent.messages.len(),
messages_before + 3,
"user message with empty content but ToolResult parts must not be dropped"
);
let user_msg = &agent.messages[messages_before + 1];
assert_eq!(user_msg.role, Role::User);
assert!(
user_msg.parts.iter().any(
|p| matches!(p, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_rc3")
),
"ToolResult part must be preserved on user message with empty content"
);
}
#[tokio::test]
async fn strip_orphans_removes_orphaned_tool_result() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
sqlite.save_message(cid, "user", "hello").await.unwrap();
sqlite.save_message(cid, "assistant", "hi").await.unwrap();
sqlite
.save_message(cid, "assistant", "plain answer")
.await
.unwrap();
let orphan_result_parts = serde_json::to_string(&[MessagePart::ToolResult {
tool_use_id: "call_nonexistent".to_string(),
content: "stale result".to_string(),
is_error: false,
}])
.unwrap();
sqlite
.save_message_with_parts(
cid,
"user",
"[tool_result: call_nonexistent]\nstale result",
&orphan_result_parts,
)
.await
.unwrap();
sqlite
.save_message(cid, "assistant", "final")
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
let loaded = &agent.messages[messages_before..];
for msg in loaded {
assert!(
!msg.parts.iter().any(|p| matches!(
p,
MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_nonexistent"
)),
"orphaned ToolResult part must be stripped from history"
);
}
}
#[tokio::test]
async fn strip_orphans_keeps_complete_pair() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
id: "call_valid".to_string(),
name: "shell".to_string(),
input: serde_json::json!({"command": "ls"}),
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
.await
.unwrap();
let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
tool_use_id: "call_valid".to_string(),
content: "file.rs".to_string(),
is_error: false,
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "user", "", &result_parts)
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(
agent.messages.len(),
messages_before + 2,
"complete tool_use/tool_result pair must be preserved"
);
let user_msg = &agent.messages[messages_before + 1];
assert!(
user_msg.parts.iter().any(|p| matches!(
p,
MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_valid"
)),
"ToolResult part for a matched tool_use must not be stripped"
);
}
#[tokio::test]
async fn strip_orphans_mixed_history() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
let use_parts_ok = serde_json::to_string(&[MessagePart::ToolUse {
id: "call_good".to_string(),
name: "shell".to_string(),
input: serde_json::json!({"command": "pwd"}),
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts_ok)
.await
.unwrap();
let result_parts_ok = serde_json::to_string(&[MessagePart::ToolResult {
tool_use_id: "call_good".to_string(),
content: "/home".to_string(),
is_error: false,
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "user", "", &result_parts_ok)
.await
.unwrap();
sqlite
.save_message(cid, "assistant", "text only")
.await
.unwrap();
let orphan_parts = serde_json::to_string(&[MessagePart::ToolResult {
tool_use_id: "call_ghost".to_string(),
content: "ghost result".to_string(),
is_error: false,
}])
.unwrap();
sqlite
.save_message_with_parts(
cid,
"user",
"[tool_result: call_ghost]\nghost result",
&orphan_parts,
)
.await
.unwrap();
sqlite
.save_message(cid, "assistant", "final reply")
.await
.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
let loaded = &agent.messages[messages_before..];
for msg in loaded {
assert!(
!msg.parts.iter().any(|p| matches!(
p,
MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_ghost"
)),
"orphaned ToolResult (call_ghost) must be stripped from history"
);
}
let has_good_result = loaded.iter().any(|msg| {
msg.role == Role::User
&& msg.parts.iter().any(|p| {
matches!(
p,
MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_good"
)
})
});
assert!(
has_good_result,
"matched ToolResult (call_good) must be preserved in history"
);
}
#[tokio::test]
async fn sanitize_tool_pairs_preserves_matched_tool_pair() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let sqlite = memory.sqlite();
sqlite
.save_message(cid, "user", "run a command")
.await
.unwrap();
let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
id: "call_ok".to_string(),
name: "shell".to_string(),
input: serde_json::json!({"command": "echo hi"}),
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
.await
.unwrap();
let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
tool_use_id: "call_ok".to_string(),
content: "hi".to_string(),
is_error: false,
}])
.unwrap();
sqlite
.save_message_with_parts(cid, "user", "[tool_result: call_ok]\nhi", &result_parts)
.await
.unwrap();
sqlite.save_message(cid, "assistant", "done").await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let messages_before = agent.messages.len();
agent.load_history().await.unwrap();
assert_eq!(
agent.messages.len(),
messages_before + 4,
"matched tool pair must not be removed"
);
let tool_msg = &agent.messages[messages_before + 1];
assert!(
tool_msg
.parts
.iter()
.any(|p| matches!(p, MessagePart::ToolUse { id, .. } if id == "call_ok")),
"matched ToolUse parts must be preserved"
);
}
#[tokio::test]
async fn persist_cancelled_tool_results_pairs_tool_use() {
use zeph_llm::provider::MessagePart;
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
let cid = memory.sqlite().create_conversation().await.unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
std::sync::Arc::new(memory),
cid,
50,
5,
100,
);
let tool_calls = vec![
zeph_llm::provider::ToolUseRequest {
id: "cancel_id_1".to_string(),
name: "shell".to_string(),
input: serde_json::json!({}),
},
zeph_llm::provider::ToolUseRequest {
id: "cancel_id_2".to_string(),
name: "read_file".to_string(),
input: serde_json::json!({}),
},
];
agent.persist_cancelled_tool_results(&tool_calls).await;
let history = agent
.memory_state
.memory
.as_ref()
.unwrap()
.sqlite()
.load_history(cid, 50)
.await
.unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, Role::User);
for tc in &tool_calls {
assert!(
history[0].parts.iter().any(|p| matches!(
p,
MessagePart::ToolResult { tool_use_id, is_error, .. }
if tool_use_id == &tc.id && *is_error
)),
"tombstone ToolResult for {} must be present and is_error=true",
tc.id
);
}
}
}