use crate::context::RunContext;
use async_trait::async_trait;
use serdes_ai_core::{ModelRequest, ModelRequestPart, ModelResponsePart};
use std::collections::HashSet;
use std::marker::PhantomData;
#[cfg(feature = "tracing-integration")]
use tracing::debug;
#[cfg(not(feature = "tracing-integration"))]
macro_rules! debug {
($($arg:tt)*) => {};
}
fn extract_tool_use_ids(message: &ModelRequest) -> Vec<String> {
let mut ids = Vec::new();
for part in &message.parts {
if let ModelRequestPart::ModelResponse(response) = part {
for response_part in &response.parts {
match response_part {
ModelResponsePart::ToolCall(tc) => {
if let Some(id) = &tc.tool_call_id {
ids.push(id.clone());
}
}
ModelResponsePart::BuiltinToolCall(btc) => {
if let Some(id) = &btc.tool_call_id {
ids.push(id.clone());
}
}
_ => {}
}
}
}
}
ids
}
fn extract_tool_result_ids(message: &ModelRequest) -> Vec<String> {
let mut ids = Vec::new();
for part in &message.parts {
match part {
ModelRequestPart::ToolReturn(tr) => {
if let Some(id) = &tr.tool_call_id {
ids.push(id.clone());
}
}
ModelRequestPart::BuiltinToolReturn(btr) => {
ids.push(btr.tool_call_id.clone());
}
ModelRequestPart::RetryPrompt(rp) => {
if let Some(id) = &rp.tool_call_id {
ids.push(id.clone());
}
}
_ => {}
}
}
ids
}
fn collect_all_tool_use_ids(messages: &[ModelRequest]) -> HashSet<String> {
messages.iter().flat_map(extract_tool_use_ids).collect()
}
fn collect_all_tool_result_ids(messages: &[ModelRequest]) -> HashSet<String> {
messages.iter().flat_map(extract_tool_result_ids).collect()
}
fn remove_orphaned_tool_results(
messages: Vec<ModelRequest>,
valid_tool_ids: &HashSet<String>,
) -> Vec<ModelRequest> {
messages
.into_iter()
.filter_map(|mut msg| {
msg.parts.retain(|part| {
match part {
ModelRequestPart::ToolReturn(tr) => {
let dominated = tr
.tool_call_id
.as_ref()
.map_or(true, |id| valid_tool_ids.contains(id));
if !dominated {
debug!(
tool_name = %tr.tool_name,
tool_call_id = ?tr.tool_call_id,
"Removing orphaned ToolReturn: no matching tool_use found"
);
}
dominated
}
ModelRequestPart::BuiltinToolReturn(btr) => {
let dominated = !btr.tool_call_id.is_empty()
&& valid_tool_ids.contains(&btr.tool_call_id);
if !dominated {
debug!(
tool_name = %btr.tool_name,
tool_call_id = %btr.tool_call_id,
"Removing orphaned BuiltinToolReturn: no matching tool_use found"
);
}
dominated
}
ModelRequestPart::RetryPrompt(rp) => {
let keep = rp
.tool_call_id
.as_ref()
.map_or(true, |id| valid_tool_ids.contains(id));
if !keep {
debug!(
tool_name = ?rp.tool_name,
tool_call_id = ?rp.tool_call_id,
"Removing orphaned RetryPrompt: no matching tool_use found"
);
}
keep
}
_ => true,
}
});
if msg.parts.is_empty() {
None
} else {
Some(msg)
}
})
.collect()
}
fn remove_orphaned_tool_uses(
messages: Vec<ModelRequest>,
valid_result_ids: &HashSet<String>,
) -> Vec<ModelRequest> {
messages
.into_iter()
.filter_map(|mut msg| {
msg.parts = msg.parts
.into_iter()
.filter_map(|part| {
match part {
ModelRequestPart::ModelResponse(mut response) => {
response.parts.retain(|response_part| {
match response_part {
ModelResponsePart::ToolCall(tc) => {
let keep = tc.tool_call_id
.as_ref()
.map_or(true, |id| valid_result_ids.contains(id));
if !keep {
debug!(
tool_name = %tc.tool_name,
tool_call_id = ?tc.tool_call_id,
"Removing orphaned ToolCall: no matching tool_result found"
);
}
keep
}
ModelResponsePart::BuiltinToolCall(btc) => {
let keep = btc.tool_call_id
.as_ref()
.map_or(true, |id| !id.is_empty() && valid_result_ids.contains(id));
if !keep {
debug!(
tool_name = %btc.tool_name,
tool_call_id = ?btc.tool_call_id,
"Removing orphaned BuiltinToolCall: no matching tool_result found"
);
}
keep
}
_ => true,
}
});
if response.parts.is_empty() {
debug!("Removing empty ModelResponse after filtering orphaned tool calls");
None
} else {
Some(ModelRequestPart::ModelResponse(response))
}
}
other => Some(other),
}
})
.collect();
if msg.parts.is_empty() {
None
} else {
Some(msg)
}
})
.collect()
}
#[async_trait]
pub trait HistoryProcessor<Deps>: Send + Sync {
async fn process(
&self,
ctx: &RunContext<Deps>,
messages: Vec<ModelRequest>,
) -> Vec<ModelRequest>;
}
#[derive(Debug, Clone)]
pub struct TruncateHistory {
max_messages: usize,
keep_first: bool,
}
impl TruncateHistory {
pub fn new(max_messages: usize) -> Self {
Self {
max_messages,
keep_first: true,
}
}
pub fn keep_first(mut self, keep: bool) -> Self {
self.keep_first = keep;
self
}
}
#[async_trait]
impl<Deps: Send + Sync> HistoryProcessor<Deps> for TruncateHistory {
async fn process(
&self,
_ctx: &RunContext<Deps>,
mut messages: Vec<ModelRequest>,
) -> Vec<ModelRequest> {
if messages.len() <= self.max_messages {
return messages;
}
let result = if self.keep_first && !messages.is_empty() {
let first = messages.remove(0);
let keep_count = self.max_messages.saturating_sub(1);
let start = messages.len().saturating_sub(keep_count);
let mut result = vec![first];
result.extend(messages.drain(start..));
result
} else {
let start = messages.len().saturating_sub(self.max_messages);
messages.drain(start..).collect()
};
let valid_tool_use_ids = collect_all_tool_use_ids(&result);
let result = remove_orphaned_tool_results(result, &valid_tool_use_ids);
let valid_tool_result_ids = collect_all_tool_result_ids(&result);
remove_orphaned_tool_uses(result, &valid_tool_result_ids)
}
}
#[derive(Debug, Clone)]
pub struct TruncateByTokens {
max_tokens: u64,
chars_per_token: f64,
keep_first_n: usize,
}
impl TruncateByTokens {
pub fn new(max_tokens: u64) -> Self {
Self {
max_tokens,
chars_per_token: 4.0, keep_first_n: 2,
}
}
pub fn chars_per_token(mut self, ratio: f64) -> Self {
self.chars_per_token = ratio;
self
}
pub fn keep_first_n(mut self, n: usize) -> Self {
self.keep_first_n = n;
self
}
pub fn keep_first(mut self, keep: bool) -> Self {
self.keep_first_n = if keep { 1 } else { 0 };
self
}
fn estimate_tokens(&self, message: &ModelRequest) -> u64 {
let chars: usize = message
.parts
.iter()
.map(|p| {
match p {
serdes_ai_core::ModelRequestPart::SystemPrompt(s) => s.content.len(),
serdes_ai_core::ModelRequestPart::UserPrompt(u) => {
match &u.content {
serdes_ai_core::messages::UserContent::Text(t) => t.len(),
serdes_ai_core::messages::UserContent::Parts(parts) => {
parts
.iter()
.map(|p| {
match p {
serdes_ai_core::messages::UserContentPart::Text {
text,
} => text.len(),
_ => 100, }
})
.sum()
}
}
}
serdes_ai_core::ModelRequestPart::ToolReturn(t) => {
t.content.to_string_content().len()
}
serdes_ai_core::ModelRequestPart::RetryPrompt(r) => r.content.message().len(),
serdes_ai_core::ModelRequestPart::BuiltinToolReturn(b) => {
b.content_type().len() + 100
}
serdes_ai_core::ModelRequestPart::ModelResponse(r) => {
r.parts
.iter()
.map(|p| match p {
serdes_ai_core::ModelResponsePart::Text(t) => t.content.len(),
serdes_ai_core::ModelResponsePart::ToolCall(tc) => {
tc.tool_name.len()
+ tc.args.to_json_string().map(|s| s.len()).unwrap_or(50)
}
serdes_ai_core::ModelResponsePart::Thinking(t) => t.content.len(),
serdes_ai_core::ModelResponsePart::File(_) => 100,
serdes_ai_core::ModelResponsePart::BuiltinToolCall(_) => 100,
})
.sum::<usize>()
}
}
})
.sum();
(chars as f64 / self.chars_per_token).ceil() as u64
}
}
#[async_trait]
impl<Deps: Send + Sync> HistoryProcessor<Deps> for TruncateByTokens {
async fn process(
&self,
_ctx: &RunContext<Deps>,
messages: Vec<ModelRequest>,
) -> Vec<ModelRequest> {
if messages.is_empty() {
return messages;
}
let mut result = Vec::new();
let mut total_tokens = 0u64;
let keep_n = self.keep_first_n.min(messages.len());
for msg in messages.iter().take(keep_n) {
let tokens = self.estimate_tokens(msg);
result.push(msg.clone());
total_tokens += tokens;
}
let remaining = &messages[keep_n..];
let mut to_append = Vec::new();
for msg in remaining.iter().rev() {
let tokens = self.estimate_tokens(msg);
if total_tokens + tokens > self.max_tokens {
break;
}
total_tokens += tokens;
to_append.push(msg.clone());
}
to_append.reverse();
result.extend(to_append);
let valid_tool_use_ids = collect_all_tool_use_ids(&result);
let result = remove_orphaned_tool_results(result, &valid_tool_use_ids);
let valid_tool_result_ids = collect_all_tool_result_ids(&result);
remove_orphaned_tool_uses(result, &valid_tool_result_ids)
}
}
#[derive(Debug, Clone)]
pub struct FilterHistory {
remove_system: bool,
remove_tool_returns: bool,
remove_retries: bool,
}
impl FilterHistory {
pub fn new() -> Self {
Self {
remove_system: false,
remove_tool_returns: false,
remove_retries: false,
}
}
pub fn remove_system(mut self, remove: bool) -> Self {
self.remove_system = remove;
self
}
pub fn remove_tool_returns(mut self, remove: bool) -> Self {
self.remove_tool_returns = remove;
self
}
pub fn remove_retries(mut self, remove: bool) -> Self {
self.remove_retries = remove;
self
}
}
impl Default for FilterHistory {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<Deps: Send + Sync> HistoryProcessor<Deps> for FilterHistory {
async fn process(
&self,
_ctx: &RunContext<Deps>,
messages: Vec<ModelRequest>,
) -> Vec<ModelRequest> {
messages
.into_iter()
.map(|mut msg| {
msg.parts.retain(|part| {
use serdes_ai_core::ModelRequestPart::*;
match part {
SystemPrompt(_) => !self.remove_system,
ToolReturn(_) => !self.remove_tool_returns,
RetryPrompt(_) => !self.remove_retries,
_ => true,
}
});
msg
})
.filter(|msg| !msg.parts.is_empty())
.collect()
}
}
pub struct ChainedProcessor<Deps> {
processors: Vec<Box<dyn HistoryProcessor<Deps>>>,
}
impl<Deps: Send + Sync + 'static> ChainedProcessor<Deps> {
pub fn new() -> Self {
Self {
processors: Vec::new(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn add<P: HistoryProcessor<Deps> + 'static>(mut self, processor: P) -> Self {
self.processors.push(Box::new(processor));
self
}
}
impl<Deps: Send + Sync + 'static> Default for ChainedProcessor<Deps> {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<Deps: Send + Sync> HistoryProcessor<Deps> for ChainedProcessor<Deps> {
async fn process(
&self,
ctx: &RunContext<Deps>,
mut messages: Vec<ModelRequest>,
) -> Vec<ModelRequest> {
for processor in &self.processors {
messages = processor.process(ctx, messages).await;
}
messages
}
}
#[derive(Debug, Clone)]
pub struct SummarizeHistory {
keep_recent: usize,
#[allow(dead_code)]
threshold_tokens: u64,
}
impl SummarizeHistory {
pub fn new(keep_recent: usize, threshold_tokens: u64) -> Self {
Self {
keep_recent,
threshold_tokens,
}
}
}
#[async_trait]
impl<Deps: Send + Sync> HistoryProcessor<Deps> for SummarizeHistory {
async fn process(
&self,
_ctx: &RunContext<Deps>,
messages: Vec<ModelRequest>,
) -> Vec<ModelRequest> {
if messages.len() <= self.keep_recent {
return messages;
}
let start = messages.len().saturating_sub(self.keep_recent);
messages[start..].to_vec()
}
}
pub struct FnProcessor<F, Deps>
where
F: Fn(&RunContext<Deps>, Vec<ModelRequest>) -> Vec<ModelRequest> + Send + Sync,
{
func: F,
_phantom: PhantomData<Deps>,
}
impl<F, Deps> FnProcessor<F, Deps>
where
F: Fn(&RunContext<Deps>, Vec<ModelRequest>) -> Vec<ModelRequest> + Send + Sync,
{
pub fn new(func: F) -> Self {
Self {
func,
_phantom: PhantomData,
}
}
}
#[async_trait]
impl<F, Deps> HistoryProcessor<Deps> for FnProcessor<F, Deps>
where
F: Fn(&RunContext<Deps>, Vec<ModelRequest>) -> Vec<ModelRequest> + Send + Sync,
Deps: Send + Sync,
{
async fn process(
&self,
ctx: &RunContext<Deps>,
messages: Vec<ModelRequest>,
) -> Vec<ModelRequest> {
(self.func)(ctx, messages)
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use std::sync::Arc;
fn make_test_context() -> RunContext<()> {
RunContext {
deps: Arc::new(()),
run_id: "test".to_string(),
start_time: Utc::now(),
model_name: "test".to_string(),
model_settings: Default::default(),
tool_name: None,
tool_call_id: None,
retry_count: 0,
metadata: None,
}
}
fn make_messages(count: usize) -> Vec<ModelRequest> {
(0..count)
.map(|i| {
let mut req = ModelRequest::new();
req.add_user_prompt(format!("Message {}", i));
req
})
.collect()
}
#[tokio::test]
async fn test_truncate_history() {
let processor = TruncateHistory::new(3).keep_first(false);
let ctx = make_test_context();
let messages = make_messages(5);
let result = processor.process(&ctx, messages).await;
assert_eq!(result.len(), 3);
}
#[tokio::test]
async fn test_truncate_keep_first() {
let processor = TruncateHistory::new(3).keep_first(true);
let ctx = make_test_context();
let messages = make_messages(5);
let result = processor.process(&ctx, messages).await;
assert_eq!(result.len(), 3);
}
#[tokio::test]
async fn test_truncate_no_change() {
let processor = TruncateHistory::new(10);
let ctx = make_test_context();
let messages = make_messages(5);
let result = processor.process(&ctx, messages).await;
assert_eq!(result.len(), 5);
}
#[tokio::test]
async fn test_chained_processor() {
let processor = ChainedProcessor::<()>::new()
.add(TruncateHistory::new(5))
.add(TruncateHistory::new(3));
let ctx = make_test_context();
let messages = make_messages(10);
let result = processor.process(&ctx, messages).await;
assert_eq!(result.len(), 3);
}
#[tokio::test]
async fn test_fn_processor() {
let processor = FnProcessor::new(|_ctx: &RunContext<()>, mut msgs: Vec<ModelRequest>| {
msgs.pop();
msgs
});
let ctx = make_test_context();
let messages = make_messages(5);
let result = processor.process(&ctx, messages).await;
assert_eq!(result.len(), 4);
}
#[tokio::test]
async fn test_truncate_by_tokens_default_keeps_first_two() {
let processor = TruncateByTokens::new(1); let ctx = make_test_context();
let messages = make_messages(5);
let result = processor.process(&ctx, messages).await;
assert!(result.len() >= 2);
}
#[tokio::test]
async fn test_truncate_by_tokens_keep_first_n() {
let processor = TruncateByTokens::new(1).keep_first_n(3);
let ctx = make_test_context();
let messages = make_messages(5);
let result = processor.process(&ctx, messages).await;
assert!(result.len() >= 3);
}
#[tokio::test]
async fn test_truncate_by_tokens_keep_first_n_zero() {
let processor = TruncateByTokens::new(1).keep_first_n(0);
let ctx = make_test_context();
let messages = make_messages(5);
let result = processor.process(&ctx, messages).await;
assert!(result.len() <= 1);
}
#[tokio::test]
async fn test_truncate_by_tokens_backwards_compat_keep_first_true() {
let processor = TruncateByTokens::new(1).keep_first(true);
let ctx = make_test_context();
let messages = make_messages(5);
let result = processor.process(&ctx, messages).await;
assert!(!result.is_empty());
}
#[tokio::test]
async fn test_truncate_by_tokens_backwards_compat_keep_first_false() {
let processor = TruncateByTokens::new(1).keep_first(false);
let ctx = make_test_context();
let messages = make_messages(5);
let result = processor.process(&ctx, messages).await;
assert!(result.len() <= 1);
}
#[tokio::test]
async fn test_truncate_by_tokens_with_sufficient_tokens() {
let processor = TruncateByTokens::new(10000);
let ctx = make_test_context();
let messages = make_messages(5);
let result = processor.process(&ctx, messages).await;
assert_eq!(result.len(), 5);
}
#[tokio::test]
async fn test_truncate_by_tokens_keeps_most_recent() {
let processor = TruncateByTokens::new(100).keep_first_n(1); let ctx = make_test_context();
let messages = make_messages(10);
let result = processor.process(&ctx, messages).await;
assert!(!result.is_empty());
}
use serdes_ai_core::{
messages::tool_return::ToolReturnContent, ModelResponse, ToolCallPart, ToolReturnPart,
};
fn make_tool_call_message(tool_call_id: &str) -> ModelRequest {
let mut response = ModelResponse::new();
let tool_call = ToolCallPart::new("test_tool", serde_json::json!({"arg": "value"}))
.with_tool_call_id(tool_call_id);
response.add_part(ModelResponsePart::ToolCall(tool_call));
ModelRequest::with_parts(vec![ModelRequestPart::ModelResponse(Box::new(response))])
}
fn make_tool_return_message(tool_call_id: &str) -> ModelRequest {
let tool_return = ToolReturnPart::new("test_tool", ToolReturnContent::text("result"))
.with_tool_call_id(tool_call_id);
ModelRequest::with_parts(vec![ModelRequestPart::ToolReturn(tool_return)])
}
#[test]
fn test_extract_tool_use_ids() {
let msg = make_tool_call_message("call_123");
let ids = extract_tool_use_ids(&msg);
assert_eq!(ids, vec!["call_123"]);
}
#[test]
fn test_extract_tool_use_ids_empty() {
let msg = make_messages(1).pop().unwrap();
let ids = extract_tool_use_ids(&msg);
assert!(ids.is_empty());
}
#[test]
fn test_extract_tool_result_ids() {
let msg = make_tool_return_message("call_456");
let ids = extract_tool_result_ids(&msg);
assert_eq!(ids, vec!["call_456"]);
}
#[test]
fn test_extract_tool_result_ids_empty() {
let msg = make_messages(1).pop().unwrap();
let ids = extract_tool_result_ids(&msg);
assert!(ids.is_empty());
}
#[test]
fn test_remove_orphaned_tool_results() {
let tool_call_msg = make_tool_call_message("call_abc");
let tool_return_msg = make_tool_return_message("call_abc");
let orphan_return_msg = make_tool_return_message("call_orphan");
let messages = vec![tool_call_msg, tool_return_msg, orphan_return_msg];
let valid_ids = collect_all_tool_use_ids(&messages);
assert!(valid_ids.contains("call_abc"));
assert!(!valid_ids.contains("call_orphan"));
let result = remove_orphaned_tool_results(messages, &valid_ids);
assert_eq!(result.len(), 2);
}
#[test]
fn test_remove_orphaned_preserves_mixed_messages() {
let mut mixed_msg = ModelRequest::new();
mixed_msg.add_user_prompt("This is a user message");
let orphan_return =
ToolReturnPart::new("test_tool", ToolReturnContent::text("orphan result"))
.with_tool_call_id("orphan_id");
mixed_msg.add_part(ModelRequestPart::ToolReturn(orphan_return));
let messages = vec![mixed_msg];
let valid_ids: HashSet<String> = HashSet::new();
let result = remove_orphaned_tool_results(messages, &valid_ids);
assert_eq!(result.len(), 1);
assert_eq!(result[0].parts.len(), 1);
assert!(matches!(
result[0].parts[0],
ModelRequestPart::UserPrompt(_)
));
}
#[tokio::test]
async fn test_truncate_history_removes_orphaned_tool_results() {
let mut messages = Vec::new();
let mut user_msg = ModelRequest::new();
user_msg.add_user_prompt("Hello");
messages.push(user_msg);
messages.push(make_tool_call_message("call_1"));
messages.push(make_tool_return_message("call_1"));
messages.push(make_tool_call_message("call_2"));
messages.push(make_tool_return_message("call_2"));
let processor = TruncateHistory::new(3).keep_first(false);
let ctx = make_test_context();
let result = processor.process(&ctx, messages).await;
assert_eq!(result.len(), 2);
let tool_use_ids = collect_all_tool_use_ids(&result);
let tool_result_ids = collect_all_tool_result_ids(&result);
for id in &tool_result_ids {
assert!(
tool_use_ids.contains(id),
"Orphaned tool_result found: {}",
id
);
}
}
#[tokio::test]
async fn test_truncate_by_tokens_removes_orphaned_tool_results() {
let mut messages = Vec::new();
let mut user_msg = ModelRequest::new();
user_msg.add_user_prompt("Hello");
messages.push(user_msg);
messages.push(make_tool_call_message("call_a"));
messages.push(make_tool_return_message("call_a"));
messages.push(make_tool_call_message("call_b"));
messages.push(make_tool_return_message("call_b"));
let processor = TruncateByTokens::new(200).keep_first_n(0);
let ctx = make_test_context();
let result = processor.process(&ctx, messages).await;
let tool_use_ids = collect_all_tool_use_ids(&result);
let tool_result_ids = collect_all_tool_result_ids(&result);
for id in &tool_result_ids {
assert!(
tool_use_ids.contains(id),
"Orphaned tool_result found: {}",
id
);
}
}
#[tokio::test]
async fn test_tool_pair_aware_truncation_keeps_complete_pairs() {
let messages = vec![
make_tool_call_message("call_x"),
make_tool_return_message("call_x"),
];
let processor = TruncateByTokens::new(10000).keep_first_n(0);
let ctx = make_test_context();
let result = processor.process(&ctx, messages).await;
assert_eq!(result.len(), 2);
let tool_use_ids = collect_all_tool_use_ids(&result);
let tool_result_ids = collect_all_tool_result_ids(&result);
assert_eq!(tool_use_ids.len(), 1);
assert_eq!(tool_result_ids.len(), 1);
assert!(tool_use_ids.contains("call_x"));
assert!(tool_result_ids.contains("call_x"));
}
#[test]
fn test_collect_all_tool_use_ids() {
let messages = vec![
make_tool_call_message("id_1"),
make_tool_call_message("id_2"),
make_tool_return_message("id_1"),
];
let ids = collect_all_tool_use_ids(&messages);
assert_eq!(ids.len(), 2);
assert!(ids.contains("id_1"));
assert!(ids.contains("id_2"));
}
#[test]
fn test_collect_all_tool_result_ids() {
let messages = vec![
make_tool_call_message("id_1"),
make_tool_return_message("id_1"),
make_tool_return_message("id_2"),
];
let ids = collect_all_tool_result_ids(&messages);
assert_eq!(ids.len(), 2);
assert!(ids.contains("id_1"));
assert!(ids.contains("id_2"));
}
#[test]
fn test_tool_return_with_none_id_is_kept() {
let tool_return_no_id = ToolReturnPart::new("test_tool", ToolReturnContent::text("result"));
let msg = ModelRequest::with_parts(vec![ModelRequestPart::ToolReturn(tool_return_no_id)]);
let messages = vec![msg];
let valid_ids: HashSet<String> = HashSet::new();
let result = remove_orphaned_tool_results(messages, &valid_ids);
assert_eq!(result.len(), 1);
assert_eq!(result[0].parts.len(), 1);
}
#[test]
fn test_builtin_tool_return_with_empty_string_id_is_removed() {
use serdes_ai_core::messages::parts::{BuiltinToolReturnContent, WebSearchResults};
let empty_results = WebSearchResults::new("query", vec![]);
let content = BuiltinToolReturnContent::web_search(empty_results);
let builtin_return = serdes_ai_core::BuiltinToolReturnPart::new(
"web_search",
content,
"", );
let msg =
ModelRequest::with_parts(vec![ModelRequestPart::BuiltinToolReturn(builtin_return)]);
let messages = vec![msg];
let valid_ids: HashSet<String> = HashSet::new();
let result = remove_orphaned_tool_results(messages, &valid_ids);
assert_eq!(result.len(), 0);
}
#[test]
fn test_builtin_tool_return_with_valid_id_is_kept() {
use serdes_ai_core::messages::parts::{BuiltinToolReturnContent, WebSearchResults};
let empty_results = WebSearchResults::new("query", vec![]);
let content = BuiltinToolReturnContent::web_search(empty_results);
let builtin_return =
serdes_ai_core::BuiltinToolReturnPart::new("web_search", content, "valid_call_id");
let msg =
ModelRequest::with_parts(vec![ModelRequestPart::BuiltinToolReturn(builtin_return)]);
let messages = vec![msg];
let mut valid_ids: HashSet<String> = HashSet::new();
valid_ids.insert("valid_call_id".to_string());
let result = remove_orphaned_tool_results(messages, &valid_ids);
assert_eq!(result.len(), 1);
assert_eq!(result[0].parts.len(), 1);
}
#[test]
fn test_remove_orphaned_tool_uses_basic() {
let orphan_call_msg = make_tool_call_message("orphan_call");
let messages = vec![orphan_call_msg];
let valid_result_ids: HashSet<String> = HashSet::new();
let result = remove_orphaned_tool_uses(messages, &valid_result_ids);
assert_eq!(result.len(), 0);
}
#[test]
fn test_remove_orphaned_tool_uses_keeps_matched() {
let tool_call_msg = make_tool_call_message("matched_call");
let messages = vec![tool_call_msg];
let mut valid_result_ids: HashSet<String> = HashSet::new();
valid_result_ids.insert("matched_call".to_string());
let result = remove_orphaned_tool_uses(messages, &valid_result_ids);
assert_eq!(result.len(), 1);
}
#[test]
fn test_remove_orphaned_tool_uses_preserves_text() {
let mut response = ModelResponse::new();
response.add_part(ModelResponsePart::Text(serdes_ai_core::TextPart::new(
"Some text",
)));
let tool_call = ToolCallPart::new("test_tool", serde_json::json!({"arg": "value"}))
.with_tool_call_id("orphan_id");
response.add_part(ModelResponsePart::ToolCall(tool_call));
let msg =
ModelRequest::with_parts(vec![ModelRequestPart::ModelResponse(Box::new(response))]);
let messages = vec![msg];
let valid_result_ids: HashSet<String> = HashSet::new();
let result = remove_orphaned_tool_uses(messages, &valid_result_ids);
assert_eq!(result.len(), 1);
if let ModelRequestPart::ModelResponse(ref resp) = result[0].parts[0] {
assert_eq!(resp.parts.len(), 1);
assert!(matches!(resp.parts[0], ModelResponsePart::Text(_)));
} else {
panic!("Expected ModelResponse");
}
}
#[test]
fn test_remove_orphaned_tool_uses_with_none_id_is_kept() {
let mut response = ModelResponse::new();
let tool_call = ToolCallPart::new("test_tool", serde_json::json!({"arg": "value"}));
response.add_part(ModelResponsePart::ToolCall(tool_call));
let msg =
ModelRequest::with_parts(vec![ModelRequestPart::ModelResponse(Box::new(response))]);
let messages = vec![msg];
let valid_result_ids: HashSet<String> = HashSet::new();
let result = remove_orphaned_tool_uses(messages, &valid_result_ids);
assert_eq!(result.len(), 1);
}
#[tokio::test]
async fn test_truncate_removes_both_orphaned_directions() {
let messages = vec![
make_tool_call_message("call_1"), make_tool_return_message("call_1"), make_tool_call_message("call_2"), make_tool_return_message("call_2"), ];
let processor = TruncateHistory::new(2).keep_first(true);
let ctx = make_test_context();
let result = processor.process(&ctx, messages).await;
let tool_use_ids = collect_all_tool_use_ids(&result);
let tool_result_ids = collect_all_tool_result_ids(&result);
for id in &tool_result_ids {
assert!(
tool_use_ids.contains(id),
"Orphaned tool_result found: {}",
id
);
}
for id in &tool_use_ids {
assert!(
tool_result_ids.contains(id),
"Orphaned tool_use found: {}",
id
);
}
}
#[tokio::test]
async fn test_truncate_by_tokens_removes_both_orphaned_directions() {
let messages = vec![
make_tool_call_message("call_a"),
make_tool_return_message("call_a"),
make_tool_call_message("call_b"),
make_tool_return_message("call_b"),
];
let processor = TruncateByTokens::new(50).keep_first_n(0); let ctx = make_test_context();
let result = processor.process(&ctx, messages).await;
let tool_use_ids = collect_all_tool_use_ids(&result);
let tool_result_ids = collect_all_tool_result_ids(&result);
for id in &tool_result_ids {
assert!(
tool_use_ids.contains(id),
"Orphaned tool_result found: {}",
id
);
}
for id in &tool_use_ids {
assert!(
tool_result_ids.contains(id),
"Orphaned tool_use found: {}",
id
);
}
}
}