use oxi_ai::{ContentBlock, TextContent, ToolCall, ToolResultMessage};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
pub fn extract_tool_calls(message: &oxi_ai::AssistantMessage) -> Vec<ToolCall> {
let mut tool_calls = Vec::new();
for block in &message.content {
if let ContentBlock::ToolCall(tc) = block {
tool_calls.push(tc.clone());
}
}
tool_calls
}
pub fn create_tool_result_message(finalized: &FinalizedToolCall) -> ToolResultMessage {
let content_blocks = if let Some(ref blocks) = finalized.result.content_blocks {
blocks.clone()
} else {
vec![ContentBlock::Text(TextContent::new(
finalized.result.output.clone(),
))]
};
ToolResultMessage::new(
finalized.tool_call.id.clone(),
&finalized.tool_call.name,
content_blocks,
)
}
pub fn should_terminate_batch(finalized_calls: &[FinalizedToolCall]) -> bool {
if finalized_calls.is_empty() {
return false;
}
finalized_calls.iter().all(|f| f.result.terminate)
}
pub fn should_stop_after_turn(external_stop: &Arc<AtomicBool>) -> bool {
external_stop.load(Ordering::SeqCst)
}
use crate::AgentToolResult;
pub struct FinalizedToolCall {
pub tool_call: oxi_ai::ToolCall,
pub result: AgentToolResult,
pub is_error: bool,
}
pub fn sanitize_orphaned_tool_results(messages: &mut Vec<oxi_ai::Message>) -> usize {
use oxi_ai::{ContentBlock, Message};
use std::collections::HashSet;
if messages.is_empty() {
return 0;
}
struct AssistantBatch {
msg_idx: usize,
issued: HashSet<String>,
matched: HashSet<String>,
}
let mut batches: Vec<AssistantBatch> = Vec::new();
let mut current: Option<AssistantBatch> = None;
let mut valid_result: Vec<bool> = vec![false; messages.len()];
for (i, msg) in messages.iter().enumerate() {
match msg {
Message::Assistant(a) => {
if let Some(b) = current.take() {
batches.push(b);
}
let issued: HashSet<String> = a
.content
.iter()
.filter_map(|b| match b {
ContentBlock::ToolCall(tc) => Some(tc.id.clone()),
_ => None,
})
.collect();
if !issued.is_empty() {
current = Some(AssistantBatch {
msg_idx: i,
issued,
matched: HashSet::new(),
});
}
}
Message::ToolResult(t) => {
if let Some(ref mut b) = current
&& b.issued.contains(&t.tool_call_id)
{
b.matched.insert(t.tool_call_id.clone());
valid_result[i] = true;
}
}
Message::User(_) => {
if let Some(b) = current.take() {
batches.push(b);
}
}
}
}
if let Some(b) = current {
batches.push(b);
}
let mut removed = 0;
let mut kept: Vec<Message> = Vec::with_capacity(messages.len());
let mut strip_from_assistant: HashSet<usize> = HashSet::new();
for b in &batches {
if b.matched.len() < b.issued.len() {
strip_from_assistant.insert(b.msg_idx);
}
}
for (i, msg) in messages.drain(..).enumerate() {
match msg {
Message::ToolResult(_) => {
if valid_result[i] {
kept.push(msg);
} else {
removed += 1;
}
}
Message::Assistant(mut a) => {
if strip_from_assistant.contains(&i) {
let before = a.content.len();
a.content
.retain(|b| !matches!(b, ContentBlock::ToolCall(_)));
removed += before - a.content.len();
if a.content.is_empty() {
removed += 1;
} else {
kept.push(Message::Assistant(a));
}
} else {
kept.push(Message::Assistant(a));
}
}
other => kept.push(other),
}
}
*messages = kept;
removed
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_stop_returns_false_when_no_external_stop() {
let external_stop = Arc::new(AtomicBool::new(false));
assert!(!should_stop_after_turn(&external_stop));
}
#[test]
fn test_should_stop_returns_true_on_external_stop() {
let external_stop = Arc::new(AtomicBool::new(true));
assert!(should_stop_after_turn(&external_stop));
}
#[test]
fn test_sanitize_no_orphans() {
use oxi_ai::{ContentBlock, Message, TextContent, ToolCall, ToolResultMessage};
let mut messages = vec![
Message::User(oxi_ai::UserMessage::new("hello")),
Message::Assistant({
let mut m =
oxi_ai::AssistantMessage::new(oxi_ai::Api::OpenAiCompletions, "agent", "gpt-4");
m.content.push(ContentBlock::ToolCall(ToolCall::new(
"call_1",
"bash",
serde_json::json!({"cmd": "ls"}),
)));
m
}),
Message::ToolResult(ToolResultMessage::new(
"call_1",
"bash",
vec![ContentBlock::Text(TextContent::new("output"))],
)),
];
let removed = sanitize_orphaned_tool_results(&mut messages);
assert_eq!(removed, 0);
assert_eq!(messages.len(), 3);
}
#[test]
fn test_sanitize_removes_orphans() {
use oxi_ai::{ContentBlock, Message, TextContent, ToolResultMessage};
let mut messages = vec![
Message::User(oxi_ai::UserMessage::new("hello")),
Message::ToolResult(ToolResultMessage::new(
"orphan_1",
"bash",
vec![ContentBlock::Text(TextContent::new("orphan output"))],
)),
];
let removed = sanitize_orphaned_tool_results(&mut messages);
assert_eq!(removed, 1);
assert_eq!(messages.len(), 1);
assert!(matches!(messages[0], Message::User(_)));
}
#[test]
fn test_sanitize_tool_result_after_user_is_orphan() {
use oxi_ai::{ContentBlock, Message, TextContent, ToolResultMessage};
let mut messages = vec![
Message::User(oxi_ai::UserMessage::new("hello")),
Message::ToolResult(ToolResultMessage::new(
"call_x",
"bash",
vec![ContentBlock::Text(TextContent::new("result"))],
)),
];
let removed = sanitize_orphaned_tool_results(&mut messages);
assert_eq!(removed, 1);
}
#[test]
fn test_sanitize_multiple_orphans_removes_only_orphans() {
use oxi_ai::{ContentBlock, Message, TextContent, ToolCall, ToolResultMessage};
let mut messages = vec![
Message::ToolResult(ToolResultMessage::new(
"orphan_1",
"bash",
vec![ContentBlock::Text(TextContent::new("o1"))],
)),
Message::ToolResult(ToolResultMessage::new(
"orphan_2",
"bash",
vec![ContentBlock::Text(TextContent::new("o2"))],
)),
Message::Assistant({
let mut m =
oxi_ai::AssistantMessage::new(oxi_ai::Api::OpenAiCompletions, "agent", "gpt-4");
m.content.push(ContentBlock::ToolCall(ToolCall::new(
"call_1",
"read",
serde_json::json!({"path": "foo"}),
)));
m
}),
Message::ToolResult(ToolResultMessage::new(
"call_1",
"read",
vec![ContentBlock::Text(TextContent::new("valid"))],
)),
Message::ToolResult(ToolResultMessage::new(
"orphan_3",
"write",
vec![ContentBlock::Text(TextContent::new("o3"))],
)),
];
let removed = sanitize_orphaned_tool_results(&mut messages);
assert_eq!(removed, 3);
assert_eq!(messages.len(), 2);
assert!(matches!(messages[0], Message::Assistant(_)));
assert!(matches!(messages[1], Message::ToolResult(_)));
}
#[test]
fn test_sanitize_multi_tool_call_assistant_preserves_all_results() {
use oxi_ai::{ContentBlock, Message, TextContent, ToolCall, ToolResultMessage};
let mut messages = vec![
Message::User(oxi_ai::UserMessage::new("do two things")),
Message::Assistant({
let mut m =
oxi_ai::AssistantMessage::new(oxi_ai::Api::OpenAiCompletions, "agent", "gpt-4");
m.content.push(ContentBlock::ToolCall(ToolCall::new(
"call_1",
"read",
serde_json::json!({"path": "a.txt"}),
)));
m.content.push(ContentBlock::ToolCall(ToolCall::new(
"call_2",
"read",
serde_json::json!({"path": "b.txt"}),
)));
m
}),
Message::ToolResult(ToolResultMessage::new(
"call_1",
"read",
vec![ContentBlock::Text(TextContent::new("aaa"))],
)),
Message::ToolResult(ToolResultMessage::new(
"call_2",
"read",
vec![ContentBlock::Text(TextContent::new("bbb"))],
)),
];
let removed = sanitize_orphaned_tool_results(&mut messages);
assert_eq!(removed, 0, "no tool results should be orphaned");
assert_eq!(messages.len(), 4, "all 4 messages should be kept");
}
#[test]
fn test_sanitize_orphan_tool_call_stripped_from_assistant() {
use oxi_ai::{ContentBlock, Message, TextContent, ToolCall, ToolResultMessage};
let mut messages = vec![
Message::Assistant({
let mut m =
oxi_ai::AssistantMessage::new(oxi_ai::Api::OpenAiCompletions, "agent", "gpt-4");
m.content.push(ContentBlock::ToolCall(ToolCall::new(
"call_1",
"read",
serde_json::json!({"path": "a.txt"}),
)));
m
}),
Message::Assistant({
let mut m =
oxi_ai::AssistantMessage::new(oxi_ai::Api::OpenAiCompletions, "agent", "gpt-4");
m.content.push(ContentBlock::ToolCall(ToolCall::new(
"call_2",
"bash",
serde_json::json!({"cmd": "ls"}),
)));
m
}),
Message::ToolResult(ToolResultMessage::new(
"call_2",
"bash",
vec![ContentBlock::Text(TextContent::new("ok"))],
)),
];
let removed = sanitize_orphaned_tool_results(&mut messages);
assert_eq!(
removed, 2,
"1 tool_call block stripped + 1 empty assistant dropped"
);
assert_eq!(messages.len(), 2);
assert!(matches!(messages[0], Message::Assistant(_)));
assert!(matches!(messages[1], Message::ToolResult(_)));
}
#[test]
fn test_sanitize_assistant_with_text_and_orphan_tool_call_keeps_text() {
use oxi_ai::{ContentBlock, Message, TextContent, ToolCall};
let mut messages = vec![
Message::Assistant({
let mut m =
oxi_ai::AssistantMessage::new(oxi_ai::Api::OpenAiCompletions, "agent", "gpt-4");
m.content
.push(ContentBlock::Text(TextContent::new("let me check")));
m.content.push(ContentBlock::ToolCall(ToolCall::new(
"call_1",
"read",
serde_json::json!({"path": "a.txt"}),
)));
m
}),
Message::User(oxi_ai::UserMessage::new("hi")),
];
let removed = sanitize_orphaned_tool_results(&mut messages);
assert_eq!(removed, 1);
assert_eq!(messages.len(), 2);
if let Message::Assistant(a) = &messages[0] {
assert_eq!(a.content.len(), 1, "only the text block should remain");
if let ContentBlock::Text(t) = &a.content[0] {
assert_eq!(t.text, "let me check");
} else {
panic!("expected text block");
}
} else {
panic!("expected assistant message");
}
}
#[test]
fn test_sanitize_orphan_tool_result_with_no_assistant_removed() {
use oxi_ai::{ContentBlock, Message, TextContent, ToolResultMessage};
let mut messages = vec![
Message::User(oxi_ai::UserMessage::new("hello")),
Message::ToolResult(ToolResultMessage::new(
"orphan_1",
"bash",
vec![ContentBlock::Text(TextContent::new("orphan output"))],
)),
];
let removed = sanitize_orphaned_tool_results(&mut messages);
assert_eq!(removed, 1);
assert_eq!(messages.len(), 1);
}
#[test]
fn test_sanitize_wrong_tool_call_id_removed() {
use oxi_ai::{ContentBlock, Message, TextContent, ToolCall, ToolResultMessage};
let mut messages = vec![
Message::Assistant({
let mut m =
oxi_ai::AssistantMessage::new(oxi_ai::Api::OpenAiCompletions, "agent", "gpt-4");
m.content.push(ContentBlock::ToolCall(ToolCall::new(
"call_1",
"bash",
serde_json::json!({"cmd": "ls"}),
)));
m
}),
Message::ToolResult(ToolResultMessage::new(
"wrong_id", "bash",
vec![ContentBlock::Text(TextContent::new("orphan"))],
)),
];
let removed = sanitize_orphaned_tool_results(&mut messages);
assert_eq!(removed, 3);
assert_eq!(messages.len(), 0);
}
}