use crate::types::{
AgentOptions, ContentBlock, Message, MessageRole, OpenAIContent, OpenAIContentPart,
OpenAIFunction, OpenAIMessage, OpenAIRequest, OpenAIToolCall, TextBlock,
};
use crate::utils::{ToolCallAggregator, parse_sse_stream};
use crate::{Error, Result};
use futures::stream::{Stream, StreamExt};
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
pub type ContentStream = Pin<Box<dyn Stream<Item = Result<ContentBlock>> + Send>>;
pub async fn query(prompt: &str, options: &AgentOptions) -> Result<ContentStream> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(options.timeout()))
.build()
.map_err(Error::Http)?;
let mut messages = Vec::new();
if !options.system_prompt().is_empty() {
messages.push(OpenAIMessage {
role: "system".to_string(),
content: Some(OpenAIContent::Text(options.system_prompt().to_string())),
tool_calls: None,
tool_call_id: None,
});
}
messages.push(OpenAIMessage {
role: "user".to_string(),
content: Some(OpenAIContent::Text(prompt.to_string())),
tool_calls: None,
tool_call_id: None,
});
let tools = if !options.tools().is_empty() {
Some(
options
.tools()
.iter()
.map(|t| t.to_openai_format())
.collect(),
)
} else {
None
};
let request = OpenAIRequest {
model: options.model().to_string(),
messages,
stream: true, max_tokens: options.max_tokens(),
temperature: Some(options.temperature()),
tools,
};
let url = format!("{}/chat/completions", options.base_url());
let response = client
.post(&url)
.header("Authorization", format!("Bearer {}", options.api_key()))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(Error::Http)?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_else(|e| {
eprintln!("WARNING: Failed to read error response body: {}", e);
"Unknown error (failed to read response body)".to_string()
});
return Err(Error::api(format!("API error {}: {}", status, body)));
}
let sse_stream = parse_sse_stream(response);
let stream = sse_stream.scan(ToolCallAggregator::new(), |aggregator, chunk_result| {
let result = match chunk_result {
Ok(chunk) => match aggregator.process_chunk(chunk) {
Ok(blocks) => {
if blocks.is_empty() {
Some(None) } else {
Some(Some(Ok(blocks))) }
}
Err(e) => Some(Some(Err(e))), },
Err(e) => Some(Some(Err(e))), };
futures::future::ready(result)
});
let flattened = stream
.filter_map(|item| async move { item })
.flat_map(|result| {
futures::stream::iter(match result {
Ok(blocks) => blocks.into_iter().map(Ok).collect(),
Err(e) => vec![Err(e)],
})
});
Ok(Box::pin(flattened))
}
pub struct Client {
options: AgentOptions,
history: Vec<Message>,
current_stream: Option<ContentStream>,
http_client: reqwest::Client,
interrupted: Arc<AtomicBool>,
auto_exec_buffer: Vec<ContentBlock>,
auto_exec_index: usize,
manual_receive_buffer: Vec<ContentBlock>,
}
impl Client {
pub fn new(options: AgentOptions) -> Result<Self> {
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(options.timeout()))
.build()
.map_err(|e| Error::config(format!("Failed to build HTTP client: {}", e)))?;
Ok(Self {
options,
history: Vec::new(), current_stream: None, http_client,
interrupted: Arc::new(AtomicBool::new(false)), auto_exec_buffer: Vec::new(), auto_exec_index: 0, manual_receive_buffer: Vec::new(), })
}
pub async fn send(&mut self, prompt: &str) -> Result<()> {
use crate::hooks::UserPromptSubmitEvent;
self.interrupted.store(false, Ordering::SeqCst);
self.manual_receive_buffer.clear();
self.current_stream = None;
let mut final_prompt = prompt.to_string();
let history_snapshot: Vec<serde_json::Value> = self
.history
.iter()
.map(|_| serde_json::json!({})) .collect();
let event = UserPromptSubmitEvent::new(final_prompt.clone(), history_snapshot);
if let Some(decision) = self.options.hooks().execute_user_prompt_submit(event).await {
if !decision.continue_execution() {
return Err(Error::other(format!(
"Prompt blocked by hook: {}",
decision.reason().unwrap_or("")
)));
}
if let Some(modified) = decision.modified_prompt() {
final_prompt = modified.to_string();
}
}
self.history.push(Message::user(final_prompt));
let mut messages = Vec::new();
if !self.options.system_prompt().is_empty() {
messages.push(OpenAIMessage {
role: "system".to_string(),
content: Some(OpenAIContent::Text(
self.options.system_prompt().to_string(),
)),
tool_calls: None,
tool_call_id: None,
});
}
for msg in &self.history {
let mut text_blocks = Vec::new();
let mut image_blocks = Vec::new();
let mut tool_use_blocks = Vec::new();
let mut tool_result_blocks = Vec::new();
for block in &msg.content {
match block {
ContentBlock::Text(text) => text_blocks.push(text),
ContentBlock::Image(image) => image_blocks.push(image),
ContentBlock::ToolUse(tool_use) => tool_use_blocks.push(tool_use),
ContentBlock::ToolResult(tool_result) => tool_result_blocks.push(tool_result),
}
}
if !tool_result_blocks.is_empty() {
for tool_result in tool_result_blocks {
let content =
serde_json::to_string(tool_result.content()).unwrap_or_else(|e| {
format!("{{\"error\": \"Failed to serialize: {}\"}}", e)
});
messages.push(OpenAIMessage {
role: "tool".to_string(),
content: Some(OpenAIContent::Text(content)),
tool_calls: None,
tool_call_id: Some(tool_result.tool_use_id().to_string()),
});
}
}
else if !tool_use_blocks.is_empty() {
let tool_calls: Vec<OpenAIToolCall> = tool_use_blocks
.iter()
.map(|tool_use| {
let arguments = serde_json::to_string(tool_use.input())
.unwrap_or_else(|_| "{}".to_string());
OpenAIToolCall {
id: tool_use.id().to_string(),
call_type: "function".to_string(),
function: OpenAIFunction {
name: tool_use.name().to_string(),
arguments,
},
}
})
.collect();
let content = if !text_blocks.is_empty() {
let text = text_blocks
.iter()
.map(|t| t.text.as_str())
.collect::<Vec<_>>()
.join("\n");
Some(OpenAIContent::Text(text))
} else {
Some(OpenAIContent::Text(String::new()))
};
messages.push(OpenAIMessage {
role: "assistant".to_string(),
content,
tool_calls: Some(tool_calls),
tool_call_id: None,
});
}
else if !image_blocks.is_empty() {
log::debug!(
"Serializing message with {} image(s) for {:?} role",
image_blocks.len(),
msg.role
);
let mut content_parts = Vec::new();
for block in &msg.content {
match block {
ContentBlock::Text(text) => {
content_parts.push(OpenAIContentPart::text(&text.text));
}
ContentBlock::Image(image) => {
let url_display = if image.url().len() > 100 {
format!("{}... ({} chars)", &image.url()[..100], image.url().len())
} else {
image.url().to_string()
};
let detail_str = match image.detail() {
crate::types::ImageDetail::Low => "low",
crate::types::ImageDetail::High => "high",
crate::types::ImageDetail::Auto => "auto",
};
log::debug!(" - Image: {} (detail: {})", url_display, detail_str);
content_parts.push(OpenAIContentPart::from_image(image));
}
ContentBlock::ToolUse(_) | ContentBlock::ToolResult(_) => {}
}
}
if content_parts.is_empty() {
return Err(Error::other(
"Internal error: Message with images produced empty content array",
));
}
let role_str = match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::Tool => "tool",
};
messages.push(OpenAIMessage {
role: role_str.to_string(),
content: Some(OpenAIContent::Parts(content_parts)),
tool_calls: None,
tool_call_id: None,
});
}
else {
let content = text_blocks
.iter()
.map(|t| t.text.as_str())
.collect::<Vec<_>>()
.join("\n");
let role_str = match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::Tool => "tool",
};
messages.push(OpenAIMessage {
role: role_str.to_string(),
content: Some(OpenAIContent::Text(content)),
tool_calls: None,
tool_call_id: None,
});
}
}
let tools = if !self.options.tools().is_empty() {
Some(
self.options
.tools()
.iter()
.map(|t| t.to_openai_format())
.collect(),
)
} else {
None
};
let request = OpenAIRequest {
model: self.options.model().to_string(),
messages,
stream: true, max_tokens: self.options.max_tokens(),
temperature: Some(self.options.temperature()),
tools,
};
let url = format!("{}/chat/completions", self.options.base_url());
let response = self
.http_client
.post(&url)
.header(
"Authorization",
format!("Bearer {}", self.options.api_key()),
)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(Error::Http)?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_else(|e| {
eprintln!("WARNING: Failed to read error response body: {}", e);
"Unknown error (failed to read response body)".to_string()
});
return Err(Error::api(format!("API error {}: {}", status, body)));
}
let sse_stream = parse_sse_stream(response);
let stream = sse_stream.scan(ToolCallAggregator::new(), |aggregator, chunk_result| {
let result = match chunk_result {
Ok(chunk) => match aggregator.process_chunk(chunk) {
Ok(blocks) => {
if blocks.is_empty() {
Some(None) } else {
Some(Some(Ok(blocks))) }
}
Err(e) => Some(Some(Err(e))), },
Err(e) => Some(Some(Err(e))), };
futures::future::ready(result)
});
let flattened = stream
.filter_map(|item| async move { item })
.flat_map(|result| {
futures::stream::iter(match result {
Ok(blocks) => blocks.into_iter().map(Ok).collect(),
Err(e) => vec![Err(e)],
})
});
self.current_stream = Some(Box::pin(flattened));
Ok(())
}
async fn receive_one(&mut self) -> Result<Option<ContentBlock>> {
if self.interrupted.load(Ordering::SeqCst) {
return Ok(None);
}
if let Some(stream) = &mut self.current_stream {
match stream.next().await {
Some(Ok(block)) => Ok(Some(block)),
Some(Err(e)) => Err(e),
None => {
self.current_stream = None;
Ok(None)
}
}
} else {
Ok(None)
}
}
async fn collect_all_blocks(&mut self) -> Result<Vec<ContentBlock>> {
let mut blocks = Vec::new();
while let Some(block) = self.receive_one().await? {
if self.interrupted.load(Ordering::SeqCst) {
self.current_stream = None;
return Err(Error::other(
"Operation interrupted during block collection",
));
}
blocks.push(block);
}
Ok(blocks)
}
async fn execute_tool_internal(
&self,
tool_name: &str,
input: serde_json::Value,
) -> Result<serde_json::Value> {
let tool = self
.options
.tools()
.iter()
.find(|t| t.name() == tool_name)
.ok_or_else(|| Error::tool(format!("Tool '{}' not found", tool_name)))?;
tool.execute(input).await
}
async fn auto_execute_loop(&mut self) -> Result<Vec<ContentBlock>> {
use crate::types::ToolResultBlock;
let mut iteration = 0;
let max_iterations = self.options.max_tool_iterations();
loop {
let blocks = self.collect_all_blocks().await?;
if blocks.is_empty() {
return Ok(Vec::new());
}
let mut text_blocks = Vec::new();
let mut tool_blocks = Vec::new();
for block in blocks {
match block {
ContentBlock::Text(_) => text_blocks.push(block),
ContentBlock::ToolUse(_) => tool_blocks.push(block),
ContentBlock::ToolResult(_) | ContentBlock::Image(_) => {} }
}
if tool_blocks.is_empty() {
if !text_blocks.is_empty() {
let assistant_msg = Message::assistant(text_blocks.clone());
self.history.push(assistant_msg);
}
return Ok(text_blocks);
}
iteration += 1;
if iteration > max_iterations {
if !text_blocks.is_empty() {
let assistant_msg = Message::assistant(text_blocks.clone());
self.history.push(assistant_msg);
}
return Ok(text_blocks);
}
let mut all_blocks = text_blocks.clone();
all_blocks.extend(tool_blocks.clone());
let assistant_msg = Message::assistant(all_blocks);
self.history.push(assistant_msg);
for block in tool_blocks {
if let ContentBlock::ToolUse(tool_use) = block {
let history_snapshot: Vec<serde_json::Value> =
self.history.iter().map(|_| serde_json::json!({})).collect();
use crate::hooks::PreToolUseEvent;
let pre_event = PreToolUseEvent::new(
tool_use.name().to_string(),
tool_use.input().clone(),
tool_use.id().to_string(),
history_snapshot.clone(),
);
let mut tool_input = tool_use.input().clone();
let mut should_execute = true;
let mut block_reason = None;
if let Some(decision) =
self.options.hooks().execute_pre_tool_use(pre_event).await
{
if !decision.continue_execution() {
should_execute = false;
block_reason = decision.reason().map(|s| s.to_string());
} else if let Some(modified) = decision.modified_input() {
tool_input = modified.clone();
}
}
let result = if should_execute {
match self
.execute_tool_internal(tool_use.name(), tool_input.clone())
.await
{
Ok(res) => res, Err(e) => {
serde_json::json!({
"error": e.to_string(),
"tool": tool_use.name(),
"id": tool_use.id()
})
}
}
} else {
serde_json::json!({
"error": "Tool execution blocked by hook",
"reason": block_reason.unwrap_or_else(|| "No reason provided".to_string()),
"tool": tool_use.name(),
"id": tool_use.id()
})
};
use crate::hooks::PostToolUseEvent;
let post_event = PostToolUseEvent::new(
tool_use.name().to_string(),
tool_input,
tool_use.id().to_string(),
result.clone(),
history_snapshot,
);
let mut final_result = result;
if let Some(decision) =
self.options.hooks().execute_post_tool_use(post_event).await
{
if let Some(modified) = decision.modified_input() {
final_result = modified.clone();
}
}
let tool_result = ToolResultBlock::new(tool_use.id(), final_result);
let tool_result_msg =
Message::user_with_blocks(vec![ContentBlock::ToolResult(tool_result)]);
self.history.push(tool_result_msg);
}
}
self.send("").await?;
}
}
pub async fn send_message(&mut self, message: Message) -> Result<()> {
self.interrupted.store(false, Ordering::SeqCst);
self.manual_receive_buffer.clear();
self.current_stream = None;
self.history.push(message);
let mut messages = Vec::new();
if !self.options.system_prompt().is_empty() {
messages.push(OpenAIMessage {
role: "system".to_string(),
content: Some(OpenAIContent::Text(
self.options.system_prompt().to_string(),
)),
tool_calls: None,
tool_call_id: None,
});
}
for msg in &self.history {
let mut text_blocks = Vec::new();
let mut image_blocks = Vec::new();
let mut tool_use_blocks = Vec::new();
let mut tool_result_blocks = Vec::new();
for block in &msg.content {
match block {
ContentBlock::Text(text) => text_blocks.push(text),
ContentBlock::Image(image) => image_blocks.push(image),
ContentBlock::ToolUse(tool_use) => tool_use_blocks.push(tool_use),
ContentBlock::ToolResult(tool_result) => tool_result_blocks.push(tool_result),
}
}
if !tool_result_blocks.is_empty() {
for tool_result in tool_result_blocks {
let content =
serde_json::to_string(tool_result.content()).unwrap_or_else(|e| {
format!("{{\"error\": \"Failed to serialize: {}\"}}", e)
});
messages.push(OpenAIMessage {
role: "tool".to_string(),
content: Some(OpenAIContent::Text(content)),
tool_calls: None,
tool_call_id: Some(tool_result.tool_use_id().to_string()),
});
}
}
else if !tool_use_blocks.is_empty() {
let tool_calls: Vec<OpenAIToolCall> = tool_use_blocks
.iter()
.map(|tool_use| {
let arguments = serde_json::to_string(tool_use.input())
.unwrap_or_else(|_| "{}".to_string());
OpenAIToolCall {
id: tool_use.id().to_string(),
call_type: "function".to_string(),
function: OpenAIFunction {
name: tool_use.name().to_string(),
arguments,
},
}
})
.collect();
let content = if !text_blocks.is_empty() {
let text = text_blocks
.iter()
.map(|t| t.text.as_str())
.collect::<Vec<_>>()
.join("\n");
Some(OpenAIContent::Text(text))
} else {
Some(OpenAIContent::Text(String::new()))
};
messages.push(OpenAIMessage {
role: "assistant".to_string(),
content,
tool_calls: Some(tool_calls),
tool_call_id: None,
});
}
else if !image_blocks.is_empty() {
log::debug!(
"Serializing message with {} image(s) for {:?} role",
image_blocks.len(),
msg.role
);
let mut content_parts = Vec::new();
for block in &msg.content {
match block {
ContentBlock::Text(text) => {
content_parts.push(OpenAIContentPart::text(&text.text));
}
ContentBlock::Image(image) => {
let url_display = if image.url().len() > 100 {
format!("{}... ({} chars)", &image.url()[..100], image.url().len())
} else {
image.url().to_string()
};
let detail_str = match image.detail() {
crate::types::ImageDetail::Low => "low",
crate::types::ImageDetail::High => "high",
crate::types::ImageDetail::Auto => "auto",
};
log::debug!(" - Image: {} (detail: {})", url_display, detail_str);
content_parts.push(OpenAIContentPart::from_image(image));
}
ContentBlock::ToolUse(_) | ContentBlock::ToolResult(_) => {}
}
}
if content_parts.is_empty() {
return Err(Error::other(
"Internal error: Message with images produced empty content array",
));
}
let role_str = match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::Tool => "tool",
};
messages.push(OpenAIMessage {
role: role_str.to_string(),
content: Some(OpenAIContent::Parts(content_parts)),
tool_calls: None,
tool_call_id: None,
});
}
else {
let content = text_blocks
.iter()
.map(|t| t.text.as_str())
.collect::<Vec<_>>()
.join("\n");
let role_str = match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::Tool => "tool",
};
messages.push(OpenAIMessage {
role: role_str.to_string(),
content: Some(OpenAIContent::Text(content)),
tool_calls: None,
tool_call_id: None,
});
}
}
let tools = if !self.options.tools().is_empty() {
Some(
self.options
.tools()
.iter()
.map(|t| t.to_openai_format())
.collect(),
)
} else {
None
};
let request = OpenAIRequest {
model: self.options.model().to_string(),
messages,
stream: true,
max_tokens: self.options.max_tokens(),
temperature: Some(self.options.temperature()),
tools,
};
let url = format!("{}/chat/completions", self.options.base_url());
let response = self
.http_client
.post(&url)
.header(
"Authorization",
format!("Bearer {}", self.options.api_key()),
)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(Error::Http)?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_else(|e| {
eprintln!("WARNING: Failed to read error response body: {}", e);
"Unknown error (failed to read response body)".to_string()
});
return Err(Error::api(format!("API error {}: {}", status, body)));
}
let sse_stream = parse_sse_stream(response);
let stream = sse_stream.scan(ToolCallAggregator::new(), |aggregator, chunk_result| {
let result = match chunk_result {
Ok(chunk) => match aggregator.process_chunk(chunk) {
Ok(blocks) => {
if blocks.is_empty() {
Some(None) } else {
Some(Some(Ok(blocks))) }
}
Err(e) => Some(Some(Err(e))),
},
Err(e) => Some(Some(Err(e))),
};
futures::future::ready(result)
});
let stream = stream.filter_map(futures::future::ready);
let stream = stream.flat_map(|result| {
futures::stream::iter(match result {
Ok(blocks) => blocks.into_iter().map(Ok).collect(),
Err(e) => vec![Err(e)],
})
});
self.current_stream = Some(Box::pin(stream));
Ok(())
}
pub async fn receive(&mut self) -> Result<Option<ContentBlock>> {
if self.options.auto_execute_tools() {
if self.auto_exec_index < self.auto_exec_buffer.len() {
let block = self.auto_exec_buffer[self.auto_exec_index].clone();
self.auto_exec_index += 1;
return Ok(Some(block));
}
if self.auto_exec_buffer.is_empty() {
match self.auto_execute_loop().await {
Ok(blocks) => {
self.auto_exec_buffer = blocks;
self.auto_exec_index = 0;
if self.auto_exec_buffer.is_empty() {
return Ok(None);
}
let block = self.auto_exec_buffer[0].clone();
self.auto_exec_index = 1;
return Ok(Some(block));
}
Err(e) => return Err(e),
}
}
Ok(None)
} else {
match self.receive_one().await {
Err(e) => {
self.manual_receive_buffer.clear();
Err(e)
}
Ok(Some(block)) => {
self.manual_receive_buffer.push(block.clone());
Ok(Some(block))
}
Ok(None) => {
if self.interrupted.load(Ordering::SeqCst) && self.current_stream.is_some() {
self.current_stream = None;
self.manual_receive_buffer.clear();
} else if !self.manual_receive_buffer.is_empty() {
let blocks = std::mem::take(&mut self.manual_receive_buffer);
self.history.push(Message::assistant(blocks));
}
Ok(None)
}
}
}
}
pub fn interrupt(&self) {
self.interrupted.store(true, Ordering::SeqCst);
}
pub fn interrupt_handle(&self) -> Arc<AtomicBool> {
self.interrupted.clone()
}
pub fn history(&self) -> &[Message] {
&self.history
}
pub fn history_mut(&mut self) -> &mut Vec<Message> {
&mut self.history
}
pub fn options(&self) -> &AgentOptions {
&self.options
}
pub fn clear_history(&mut self) {
self.history.clear();
self.manual_receive_buffer.clear();
}
pub fn add_tool_result(&mut self, tool_use_id: &str, content: serde_json::Value) -> Result<()> {
use crate::types::ToolResultBlock;
if !self.manual_receive_buffer.is_empty() {
let blocks = std::mem::take(&mut self.manual_receive_buffer);
self.history.push(Message::assistant(blocks));
}
let result_block = ToolResultBlock::new(tool_use_id, content);
let serialized = serde_json::to_string(result_block.content())
.map_err(|e| Error::config(format!("Failed to serialize tool result: {}", e)))?;
self.history.push(Message::new(
MessageRole::Tool,
vec![ContentBlock::Text(TextBlock::new(serialized))],
));
Ok(())
}
pub fn get_tool(&self, name: &str) -> Option<&crate::tools::Tool> {
self.options
.tools()
.iter()
.find(|t| t.name() == name)
.map(|t| t.as_ref())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let client = Client::new(options).expect("Should create client successfully");
assert_eq!(client.history().len(), 0);
}
#[test]
fn test_client_new_returns_result() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let result = Client::new(options);
assert!(result.is_ok(), "Client::new() should return Ok");
let client = result.unwrap();
assert_eq!(client.history().len(), 0);
}
#[test]
fn test_interrupt_flag_initial_state() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let client = Client::new(options).expect("Should create client successfully");
assert!(!client.interrupted.load(Ordering::SeqCst));
}
#[test]
fn test_interrupt_sets_flag() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let client = Client::new(options).expect("Should create client successfully");
client.interrupt();
assert!(client.interrupted.load(Ordering::SeqCst));
}
#[test]
fn test_interrupt_idempotent() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let client = Client::new(options).expect("Should create client successfully");
client.interrupt();
assert!(client.interrupted.load(Ordering::SeqCst));
client.interrupt();
assert!(client.interrupted.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_receive_returns_none_when_interrupted() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let mut client = Client::new(options).expect("Should create client successfully");
client.interrupt();
let result = client.receive().await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn test_receive_returns_ok_none_when_no_stream() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let mut client = Client::new(options).expect("Should create client successfully");
let result = client.receive().await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn test_receive_error_propagation() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let client = Client::new(options).expect("Should create client successfully");
let _: Result<Option<ContentBlock>> = std::future::ready(Ok(None)).await;
drop(client);
}
#[tokio::test]
async fn test_manual_mode_adds_assistant_to_history() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let mut client = Client::new(options).expect("Should create client successfully");
client
.history
.push(Message::user("What's the capital of France?"));
let blocks = vec![
Ok(ContentBlock::Text(TextBlock::new("Paris is"))),
Ok(ContentBlock::Text(TextBlock::new(
" the capital of France.",
))),
];
let stream = futures::stream::iter(blocks);
client.current_stream = Some(Box::pin(stream));
let mut received = Vec::new();
while let Some(block) = client.receive().await.unwrap() {
received.push(block);
}
assert_eq!(received.len(), 2);
assert_eq!(client.history().len(), 2);
assert_eq!(client.history()[0].role, MessageRole::User);
assert_eq!(client.history()[1].role, MessageRole::Assistant);
assert_eq!(client.history()[1].content.len(), 2);
}
#[tokio::test]
async fn test_manual_mode_empty_stream_no_assistant_message() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let mut client = Client::new(options).expect("Should create client successfully");
let result = client.receive().await.unwrap();
assert!(result.is_none());
assert_eq!(client.history().len(), 0);
}
#[tokio::test]
async fn test_manual_mode_tool_call_flushed_on_send() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let mut client = Client::new(options).expect("Should create client successfully");
client.history.push(Message::user("Calculate 2+2"));
let tool_use =
crate::types::ToolUseBlock::new("call_1", "calculator", serde_json::json!({"a": 2}));
let blocks = vec![Ok(ContentBlock::ToolUse(tool_use))];
client.current_stream = Some(Box::pin(futures::stream::iter(blocks)));
let block = client.receive().await.unwrap().unwrap();
assert!(matches!(block, ContentBlock::ToolUse(_)));
assert_eq!(client.manual_receive_buffer.len(), 1);
assert_eq!(client.history().len(), 1);
client
.add_tool_result("call_1", serde_json::json!({"result": 4}))
.unwrap();
assert_eq!(client.history().len(), 3);
assert_eq!(client.history()[0].role, MessageRole::User);
assert_eq!(client.history()[1].role, MessageRole::Assistant);
assert!(matches!(
client.history()[1].content[0],
ContentBlock::ToolUse(_)
));
assert_eq!(client.history()[2].role, MessageRole::Tool);
assert!(client.manual_receive_buffer.is_empty());
}
#[tokio::test]
async fn test_manual_mode_interrupt_discards_buffer() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let mut client = Client::new(options).expect("Should create client successfully");
client.history.push(Message::user("Tell me a story"));
let blocks = vec![
Ok(ContentBlock::Text(TextBlock::new("Once upon"))),
Ok(ContentBlock::Text(TextBlock::new(" a time..."))),
];
client.current_stream = Some(Box::pin(futures::stream::iter(blocks)));
let block = client.receive().await.unwrap().unwrap();
assert!(matches!(block, ContentBlock::Text(_)));
assert_eq!(client.manual_receive_buffer.len(), 1);
client.interrupt();
let result = client.receive().await.unwrap();
assert!(result.is_none());
assert_eq!(client.history().len(), 1);
assert_eq!(client.history()[0].role, MessageRole::User);
assert!(client.manual_receive_buffer.is_empty());
}
#[tokio::test]
async fn test_manual_mode_interrupt_after_eof_commits() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let mut client = Client::new(options).expect("Should create client successfully");
client.history.push(Message::user("Hello"));
let blocks = vec![Ok(ContentBlock::Text(TextBlock::new("Hi there!")))];
client.current_stream = Some(Box::pin(futures::stream::iter(blocks)));
let block = client.receive().await.unwrap().unwrap();
assert!(matches!(block, ContentBlock::Text(_)));
let eof = client.receive().await.unwrap();
assert!(eof.is_none());
assert!(client.current_stream.is_none());
assert_eq!(client.history().len(), 2);
assert_eq!(client.history()[1].role, MessageRole::Assistant);
}
#[tokio::test]
async fn test_manual_mode_send_discards_unfinished_stream() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let mut client = Client::new(options).expect("Should create client successfully");
client.history.push(Message::user("Tell me everything"));
let blocks = vec![
Ok(ContentBlock::Text(TextBlock::new("First"))),
Ok(ContentBlock::Text(TextBlock::new("Second"))),
Ok(ContentBlock::Text(TextBlock::new("Third"))),
];
client.current_stream = Some(Box::pin(futures::stream::iter(blocks)));
let block = client.receive().await.unwrap().unwrap();
assert!(matches!(block, ContentBlock::Text(_)));
assert_eq!(client.manual_receive_buffer.len(), 1);
client.manual_receive_buffer.clear();
client.current_stream = None;
assert_eq!(client.history().len(), 1);
}
#[tokio::test]
async fn test_manual_mode_error_discards_buffer() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let mut client = Client::new(options).expect("Should create client successfully");
client.history.push(Message::user("Hello"));
let blocks: Vec<Result<ContentBlock>> = vec![
Ok(ContentBlock::Text(TextBlock::new("Partial"))),
Err(Error::stream("connection reset")),
];
client.current_stream = Some(Box::pin(futures::stream::iter(blocks)));
let block = client.receive().await.unwrap().unwrap();
assert!(matches!(block, ContentBlock::Text(_)));
assert_eq!(client.manual_receive_buffer.len(), 1);
let err = client.receive().await.unwrap_err();
assert!(err.to_string().contains("connection reset"));
assert!(client.manual_receive_buffer.is_empty());
assert_eq!(client.history().len(), 1);
}
#[tokio::test]
async fn test_clear_history_also_clears_manual_buffer() {
let options = AgentOptions::builder()
.system_prompt("Test")
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let mut client = Client::new(options).expect("Should create client successfully");
client.history.push(Message::user("Hello"));
let blocks = vec![Ok(ContentBlock::Text(TextBlock::new("Hi there")))];
client.current_stream = Some(Box::pin(futures::stream::iter(blocks)));
client.receive().await.unwrap();
assert_eq!(client.manual_receive_buffer.len(), 1);
client.clear_history();
assert!(client.history().is_empty());
assert!(client.manual_receive_buffer.is_empty());
}
#[test]
fn test_empty_content_parts_protection() {
use crate::types::{ContentBlock, ImageBlock, Message, MessageRole};
let img = ImageBlock::from_url("https://example.com/test.jpg").expect("Valid URL");
let msg = Message::new(MessageRole::User, vec![ContentBlock::Image(img)]);
let mut content_parts = Vec::new();
for block in &msg.content {
match block {
ContentBlock::Text(text) => {
content_parts.push(crate::types::OpenAIContentPart::text(&text.text));
}
ContentBlock::Image(image) => {
content_parts.push(crate::types::OpenAIContentPart::from_image(image));
}
ContentBlock::ToolUse(_) | ContentBlock::ToolResult(_) => {}
}
}
assert!(
!content_parts.is_empty(),
"Messages with images should produce non-empty content_parts"
);
}
}