use crate::Error;
use crate::hooks::Hooks;
use crate::tools::Tool;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ModelName(String);
impl ModelName {
pub fn new(name: impl Into<String>) -> crate::Result<Self> {
let name = name.into();
let trimmed = name.trim();
if trimmed.is_empty() {
return Err(Error::invalid_input(
"Model name cannot be empty or whitespace",
));
}
Ok(ModelName(name))
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn into_inner(self) -> String {
self.0
}
}
impl std::fmt::Display for ModelName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BaseUrl(String);
impl BaseUrl {
pub fn new(url: impl Into<String>) -> crate::Result<Self> {
let url = url.into();
let trimmed = url.trim();
if trimmed.is_empty() {
return Err(Error::invalid_input("base_url cannot be empty"));
}
if !trimmed.starts_with("http://") && !trimmed.starts_with("https://") {
return Err(Error::invalid_input(
"base_url must start with http:// or https://",
));
}
Ok(BaseUrl(url))
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn into_inner(self) -> String {
self.0
}
}
impl std::fmt::Display for BaseUrl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Temperature(f32);
impl Temperature {
pub fn new(temp: f32) -> crate::Result<Self> {
if !(0.0..=2.0).contains(&temp) {
return Err(Error::invalid_input(
"temperature must be between 0.0 and 2.0",
));
}
Ok(Temperature(temp))
}
pub fn value(&self) -> f32 {
self.0
}
}
impl std::fmt::Display for Temperature {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone)]
pub struct AgentOptions {
system_prompt: String,
model: String,
base_url: String,
api_key: String,
max_turns: u32,
max_tokens: Option<u32>,
temperature: f32,
timeout: u64,
tools: Vec<Arc<Tool>>,
auto_execute_tools: bool,
max_tool_iterations: u32,
hooks: Hooks,
}
impl std::fmt::Debug for AgentOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AgentOptions")
.field("system_prompt", &self.system_prompt)
.field("model", &self.model)
.field("base_url", &self.base_url)
.field("api_key", &"***")
.field("max_turns", &self.max_turns)
.field("max_tokens", &self.max_tokens)
.field("temperature", &self.temperature)
.field("timeout", &self.timeout)
.field("tools", &format!("{} tools", self.tools.len()))
.field("auto_execute_tools", &self.auto_execute_tools)
.field("max_tool_iterations", &self.max_tool_iterations)
.field("hooks", &self.hooks)
.finish()
}
}
impl Default for AgentOptions {
fn default() -> Self {
Self {
system_prompt: String::new(),
model: String::new(),
base_url: String::new(),
api_key: "not-needed".to_string(),
max_turns: 1,
max_tokens: Some(4096),
temperature: 0.7,
timeout: 60,
tools: Vec::new(),
auto_execute_tools: false,
max_tool_iterations: 5,
hooks: Hooks::new(),
}
}
}
impl AgentOptions {
pub fn builder() -> AgentOptionsBuilder {
AgentOptionsBuilder::default()
}
pub fn system_prompt(&self) -> &str {
&self.system_prompt
}
pub fn model(&self) -> &str {
&self.model
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn max_turns(&self) -> u32 {
self.max_turns
}
pub fn max_tokens(&self) -> Option<u32> {
self.max_tokens
}
pub fn temperature(&self) -> f32 {
self.temperature
}
pub fn timeout(&self) -> u64 {
self.timeout
}
pub fn tools(&self) -> &[Arc<Tool>] {
&self.tools
}
pub fn auto_execute_tools(&self) -> bool {
self.auto_execute_tools
}
pub fn max_tool_iterations(&self) -> u32 {
self.max_tool_iterations
}
pub fn hooks(&self) -> &Hooks {
&self.hooks
}
}
#[derive(Default)]
pub struct AgentOptionsBuilder {
system_prompt: Option<String>,
model: Option<String>,
base_url: Option<String>,
api_key: Option<String>,
max_turns: Option<u32>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout: Option<u64>,
tools: Vec<Arc<Tool>>,
auto_execute_tools: Option<bool>,
max_tool_iterations: Option<u32>,
hooks: Hooks,
}
impl std::fmt::Debug for AgentOptionsBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AgentOptionsBuilder")
.field("system_prompt", &self.system_prompt)
.field("model", &self.model)
.field("base_url", &self.base_url)
.field("tools", &format!("{} tools", self.tools.len()))
.finish()
}
}
impl AgentOptionsBuilder {
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
pub fn max_turns(mut self, turns: u32) -> Self {
self.max_turns = Some(turns);
self
}
pub fn max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = Some(tokens);
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn timeout(mut self, timeout: u64) -> Self {
self.timeout = Some(timeout);
self
}
pub fn auto_execute_tools(mut self, auto: bool) -> Self {
self.auto_execute_tools = Some(auto);
self
}
pub fn max_tool_iterations(mut self, iterations: u32) -> Self {
self.max_tool_iterations = Some(iterations);
self
}
pub fn tool(mut self, tool: Tool) -> Self {
self.tools.push(Arc::new(tool));
self
}
pub fn tools(mut self, tools: Vec<Tool>) -> Self {
self.tools.extend(tools.into_iter().map(Arc::new));
self
}
pub fn hooks(mut self, hooks: Hooks) -> Self {
self.hooks = hooks;
self
}
pub fn build(self) -> crate::Result<AgentOptions> {
let model = self
.model
.ok_or_else(|| crate::Error::config("model is required"))?;
let base_url = self
.base_url
.ok_or_else(|| crate::Error::config("base_url is required"))?;
if model.trim().is_empty() {
return Err(crate::Error::invalid_input(
"model cannot be empty or whitespace",
));
}
if base_url.trim().is_empty() {
return Err(crate::Error::invalid_input("base_url cannot be empty"));
}
if !base_url.starts_with("http://") && !base_url.starts_with("https://") {
return Err(crate::Error::invalid_input(
"base_url must start with http:// or https://",
));
}
let temperature = self.temperature.unwrap_or(0.7);
if !(0.0..=2.0).contains(&temperature) {
return Err(crate::Error::invalid_input(
"temperature must be between 0.0 and 2.0",
));
}
let max_tokens = self.max_tokens.or(Some(4096));
if let Some(tokens) = max_tokens {
if tokens == 0 {
return Err(crate::Error::invalid_input(
"max_tokens must be greater than 0",
));
}
}
Ok(AgentOptions {
system_prompt: self.system_prompt.unwrap_or_default(),
model,
base_url,
api_key: self.api_key.unwrap_or_else(|| "not-needed".to_string()),
max_turns: self.max_turns.unwrap_or(1),
max_tokens,
temperature,
timeout: self.timeout.unwrap_or(60),
tools: self.tools,
auto_execute_tools: self.auto_execute_tools.unwrap_or(false),
max_tool_iterations: self.max_tool_iterations.unwrap_or(5),
hooks: self.hooks,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text(TextBlock),
Image(ImageBlock),
ToolUse(ToolUseBlock),
ToolResult(ToolResultBlock),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextBlock {
pub text: String,
}
impl TextBlock {
pub fn new(text: impl Into<String>) -> Self {
Self { text: text.into() }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolUseBlock {
id: String,
name: String,
input: serde_json::Value,
}
impl ToolUseBlock {
pub fn new(id: impl Into<String>, name: impl Into<String>, input: serde_json::Value) -> Self {
Self {
id: id.into(),
name: name.into(),
input,
}
}
pub fn id(&self) -> &str {
&self.id
}
pub fn name(&self) -> &str {
&self.name
}
pub fn input(&self) -> &serde_json::Value {
&self.input
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultBlock {
tool_use_id: String,
content: serde_json::Value,
}
impl ToolResultBlock {
pub fn new(tool_use_id: impl Into<String>, content: serde_json::Value) -> Self {
Self {
tool_use_id: tool_use_id.into(),
content,
}
}
pub fn tool_use_id(&self) -> &str {
&self.tool_use_id
}
pub fn content(&self) -> &serde_json::Value {
&self.content
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum ImageDetail {
Low,
High,
#[default]
Auto,
}
impl std::fmt::Display for ImageDetail {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ImageDetail::Low => write!(f, "low"),
ImageDetail::High => write!(f, "high"),
ImageDetail::Auto => write!(f, "auto"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageBlock {
url: String,
#[serde(default)]
detail: ImageDetail,
}
impl ImageBlock {
pub fn from_url(url: impl Into<String>) -> crate::Result<Self> {
let url = url.into();
if url.is_empty() {
return Err(crate::Error::invalid_input("Image URL cannot be empty"));
}
if url.contains(char::is_control) {
return Err(crate::Error::invalid_input(
"Image URL contains invalid control characters",
));
}
if url.len() > 2000 {
eprintln!(
"WARNING: Very long image URL ({} chars). \
Some APIs may have URL length limits.",
url.len()
);
}
if url.starts_with("http://") || url.starts_with("https://") {
Ok(Self {
url,
detail: ImageDetail::default(),
})
} else if let Some(mime_part) = url.strip_prefix("data:") {
if !url.contains(";base64,") {
return Err(crate::Error::invalid_input(
"Data URI must be in format: data:image/TYPE;base64,DATA",
));
}
let mime_type = if let Some(semicolon_pos) = mime_part.find(';') {
&mime_part[..semicolon_pos]
} else {
return Err(crate::Error::invalid_input(
"Malformed data URI: missing MIME type",
));
};
if mime_type.is_empty() || !mime_type.starts_with("image/") {
return Err(crate::Error::invalid_input(
"Data URI MIME type must start with 'image/'",
));
}
if let Some(base64_start_pos) = url.find(";base64,") {
let base64_data = &url[base64_start_pos + 8..];
if base64_data.is_empty() {
return Err(crate::Error::invalid_input(
"Data URI base64 data cannot be empty",
));
}
if !base64_data
.chars()
.all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '=')
{
return Err(crate::Error::invalid_input(
"Data URI base64 data contains invalid characters. Valid characters: A-Z, a-z, 0-9, +, /, =",
));
}
if base64_data.len() % 4 != 0 {
return Err(crate::Error::invalid_input(
"Data URI base64 data has invalid length (must be multiple of 4)",
));
}
let equals_count = base64_data.chars().filter(|c| *c == '=').count();
if equals_count > 2 {
return Err(crate::Error::invalid_input(
"Data URI base64 data has invalid padding (max 2 '=' characters allowed)",
));
}
if equals_count > 0 {
let trimmed = base64_data.trim_end_matches('=');
if trimmed.len() + equals_count != base64_data.len() {
return Err(crate::Error::invalid_input(
"Data URI base64 padding characters must be at the end",
));
}
}
}
Ok(Self {
url,
detail: ImageDetail::default(),
})
} else {
Err(crate::Error::invalid_input(
"Image URL must start with http://, https://, or data:",
))
}
}
pub fn from_base64(
base64_data: impl AsRef<str>,
mime_type: impl AsRef<str>,
) -> crate::Result<Self> {
let data = base64_data.as_ref();
let mime = mime_type.as_ref();
if data.is_empty() {
return Err(crate::Error::invalid_input(
"Base64 image data cannot be empty",
));
}
if !data
.chars()
.all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '=')
{
return Err(crate::Error::invalid_input(
"Base64 data contains invalid characters. Valid characters: A-Z, a-z, 0-9, +, /, =",
));
}
if data.len() % 4 != 0 {
return Err(crate::Error::invalid_input(
"Base64 data has invalid length (must be multiple of 4)",
));
}
let equals_count = data.chars().filter(|c| *c == '=').count();
if equals_count > 2 {
return Err(crate::Error::invalid_input(
"Base64 data has invalid padding (max 2 '=' characters allowed)",
));
}
if equals_count > 0 {
let trimmed = data.trim_end_matches('=');
if trimmed.len() + equals_count != data.len() {
return Err(crate::Error::invalid_input(
"Base64 padding characters must be at the end",
));
}
}
if mime.is_empty() {
return Err(crate::Error::invalid_input("MIME type cannot be empty"));
}
if !mime.starts_with("image/") {
return Err(crate::Error::invalid_input(
"MIME type must start with 'image/' (e.g., 'image/png', 'image/jpeg')",
));
}
if mime.contains([';', ',', '\n', '\r']) {
return Err(crate::Error::invalid_input(
"MIME type contains invalid characters (;, \\n, \\r not allowed)",
));
}
if data.len() > 10_000_000 {
eprintln!(
"WARNING: Very large base64 image data ({} chars, ~{:.1}MB). \
This may exceed API limits or cause performance issues.",
data.len(),
(data.len() as f64 * 0.75) / 1_000_000.0
);
}
let url = format!("data:{};base64,{}", mime, data);
Ok(Self {
url,
detail: ImageDetail::default(),
})
}
pub fn from_file_path(path: impl AsRef<std::path::Path>) -> crate::Result<Self> {
use base64::{Engine as _, engine::general_purpose};
let path = path.as_ref();
let bytes = std::fs::read(path).map_err(|e| {
crate::Error::invalid_input(format!(
"Failed to read image file '{}': {}",
path.display(),
e
))
})?;
let mime_type = match path.extension().and_then(|e| e.to_str()) {
Some("jpg") | Some("jpeg") => "image/jpeg",
Some("png") => "image/png",
Some("gif") => "image/gif",
Some("webp") => "image/webp",
Some("bmp") => "image/bmp",
Some("svg") => "image/svg+xml",
Some(ext) => {
return Err(crate::Error::invalid_input(format!(
"Unsupported image file extension: .{}. Supported: jpg, jpeg, png, gif, webp, bmp, svg",
ext
)));
}
None => {
return Err(crate::Error::invalid_input(
"Image file path must have a file extension (e.g., .jpg, .png)",
));
}
};
let base64_data = general_purpose::STANDARD.encode(&bytes);
Self::from_base64(&base64_data, mime_type)
}
pub fn with_detail(mut self, detail: ImageDetail) -> Self {
self.detail = detail;
self
}
pub fn url(&self) -> &str {
&self.url
}
pub fn detail(&self) -> ImageDetail {
self.detail
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: Vec<ContentBlock>,
}
impl Message {
pub fn new(role: MessageRole, content: Vec<ContentBlock>) -> Self {
Self { role, content }
}
pub fn user(text: impl Into<String>) -> Self {
Self {
role: MessageRole::User,
content: vec![ContentBlock::Text(TextBlock::new(text))],
}
}
pub fn assistant(content: Vec<ContentBlock>) -> Self {
Self {
role: MessageRole::Assistant,
content,
}
}
pub fn system(text: impl Into<String>) -> Self {
Self {
role: MessageRole::System,
content: vec![ContentBlock::Text(TextBlock::new(text))],
}
}
pub fn user_with_blocks(content: Vec<ContentBlock>) -> Self {
Self {
role: MessageRole::User,
content,
}
}
pub fn user_with_image(
text: impl Into<String>,
image_url: impl Into<String>,
) -> crate::Result<Self> {
Ok(Self {
role: MessageRole::User,
content: vec![
ContentBlock::Text(TextBlock::new(text)),
ContentBlock::Image(ImageBlock::from_url(image_url)?),
],
})
}
pub fn user_with_image_detail(
text: impl Into<String>,
image_url: impl Into<String>,
detail: ImageDetail,
) -> crate::Result<Self> {
Ok(Self {
role: MessageRole::User,
content: vec![
ContentBlock::Text(TextBlock::new(text)),
ContentBlock::Image(ImageBlock::from_url(image_url)?.with_detail(detail)),
],
})
}
pub fn user_with_base64_image(
text: impl Into<String>,
base64_data: impl AsRef<str>,
mime_type: impl AsRef<str>,
) -> crate::Result<Self> {
Ok(Self {
role: MessageRole::User,
content: vec![
ContentBlock::Text(TextBlock::new(text)),
ContentBlock::Image(ImageBlock::from_base64(base64_data, mime_type)?),
],
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum OpenAIContent {
Text(String),
Parts(Vec<OpenAIContentPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum OpenAIContentPart {
Text {
text: String,
},
#[serde(rename = "image_url")]
ImageUrl {
image_url: OpenAIImageUrl,
},
}
impl OpenAIContentPart {
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
pub fn from_image(image: &ImageBlock) -> Self {
Self::ImageUrl {
image_url: OpenAIImageUrl {
url: image.url().to_string(),
detail: Some(image.detail().to_string()),
},
}
}
#[deprecated(
since = "0.6.0",
note = "Use `from_image()` instead to ensure proper validation"
)]
pub fn image_url(url: impl Into<String>, detail: ImageDetail) -> Self {
Self::ImageUrl {
image_url: OpenAIImageUrl {
url: url.into(),
detail: Some(detail.to_string()),
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIMessage {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<OpenAIContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<OpenAIToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: OpenAIFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIFunction {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<OpenAIMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<serde_json::Value>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OpenAIChunk {
#[allow(dead_code)]
pub id: String,
#[allow(dead_code)]
pub object: String,
#[allow(dead_code)]
pub created: i64,
#[allow(dead_code)]
pub model: String,
pub choices: Vec<OpenAIChoice>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OpenAIChoice {
#[allow(dead_code)]
pub index: u32,
pub delta: OpenAIDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OpenAIDelta {
#[allow(dead_code)]
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<OpenAIToolCallDelta>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OpenAIToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[allow(dead_code)]
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
pub call_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<OpenAIFunctionDelta>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OpenAIFunctionDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_options_builder() {
let options = AgentOptions::builder()
.system_prompt("Test prompt")
.model("test-model")
.base_url("http://localhost:1234/v1")
.api_key("test-key")
.max_turns(5)
.max_tokens(1000)
.temperature(0.5)
.timeout(30)
.auto_execute_tools(true)
.max_tool_iterations(10)
.build()
.unwrap();
assert_eq!(options.system_prompt, "Test prompt");
assert_eq!(options.model, "test-model");
assert_eq!(options.base_url, "http://localhost:1234/v1");
assert_eq!(options.api_key, "test-key");
assert_eq!(options.max_turns, 5);
assert_eq!(options.max_tokens, Some(1000));
assert_eq!(options.temperature, 0.5);
assert_eq!(options.timeout, 30);
assert!(options.auto_execute_tools);
assert_eq!(options.max_tool_iterations, 10);
}
#[test]
fn test_agent_options_builder_defaults() {
let options = AgentOptions::builder()
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
assert_eq!(options.system_prompt, "");
assert_eq!(options.api_key, "not-needed");
assert_eq!(options.max_turns, 1);
assert_eq!(options.max_tokens, Some(4096));
assert_eq!(options.temperature, 0.7);
assert_eq!(options.timeout, 60);
assert!(!options.auto_execute_tools);
assert_eq!(options.max_tool_iterations, 5);
}
#[test]
fn test_agent_options_builder_missing_required() {
let result = AgentOptions::builder()
.base_url("http://localhost:1234/v1")
.build();
assert!(result.is_err());
let result = AgentOptions::builder().model("test-model").build();
assert!(result.is_err());
}
#[test]
fn test_message_user() {
let msg = Message::user("Hello");
assert!(matches!(msg.role, MessageRole::User));
assert_eq!(msg.content.len(), 1);
match &msg.content[0] {
ContentBlock::Text(text) => assert_eq!(text.text, "Hello"),
_ => panic!("Expected TextBlock"),
}
}
#[test]
fn test_message_system() {
let msg = Message::system("System prompt");
assert!(matches!(msg.role, MessageRole::System));
assert_eq!(msg.content.len(), 1);
match &msg.content[0] {
ContentBlock::Text(text) => assert_eq!(text.text, "System prompt"),
_ => panic!("Expected TextBlock"),
}
}
#[test]
fn test_message_assistant() {
let content = vec![ContentBlock::Text(TextBlock::new("Response"))];
let msg = Message::assistant(content);
assert!(matches!(msg.role, MessageRole::Assistant));
assert_eq!(msg.content.len(), 1);
}
#[test]
fn test_message_user_with_image() {
let msg =
Message::user_with_image("What's in this image?", "https://example.com/image.jpg")
.unwrap();
assert!(matches!(msg.role, MessageRole::User));
assert_eq!(msg.content.len(), 2);
match &msg.content[0] {
ContentBlock::Text(text) => assert_eq!(text.text, "What's in this image?"),
_ => panic!("Expected TextBlock at position 0"),
}
match &msg.content[1] {
ContentBlock::Image(image) => {
assert_eq!(image.url(), "https://example.com/image.jpg");
assert_eq!(image.detail(), ImageDetail::Auto);
}
_ => panic!("Expected ImageBlock at position 1"),
}
}
#[test]
fn test_message_user_with_image_and_detail() {
let msg = Message::user_with_image_detail(
"Analyze this in detail",
"https://example.com/diagram.png",
ImageDetail::High,
)
.unwrap();
assert!(matches!(msg.role, MessageRole::User));
assert_eq!(msg.content.len(), 2);
match &msg.content[1] {
ContentBlock::Image(image) => {
assert_eq!(image.detail(), ImageDetail::High);
}
_ => panic!("Expected ImageBlock"),
}
}
#[test]
fn test_message_user_with_base64_image() {
let base64_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ";
let msg =
Message::user_with_base64_image("What's this?", base64_data, "image/png").unwrap();
assert!(matches!(msg.role, MessageRole::User));
assert_eq!(msg.content.len(), 2);
match &msg.content[1] {
ContentBlock::Image(image) => {
assert!(image.url().starts_with("data:image/png;base64,"));
assert!(image.url().contains(base64_data));
}
_ => panic!("Expected ImageBlock"),
}
}
#[test]
fn test_text_block() {
let block = TextBlock::new("Hello");
assert_eq!(block.text, "Hello");
}
#[test]
fn test_tool_use_block() {
let input = serde_json::json!({"arg": "value"});
let block = ToolUseBlock::new("call_123", "tool_name", input.clone());
assert_eq!(block.id(), "call_123");
assert_eq!(block.name(), "tool_name");
assert_eq!(block.input(), &input);
}
#[test]
fn test_tool_result_block() {
let content = serde_json::json!({"result": "success"});
let block = ToolResultBlock::new("call_123", content.clone());
assert_eq!(block.tool_use_id(), "call_123");
assert_eq!(block.content(), &content);
}
#[test]
fn test_tool_use_block_getters() {
let input = serde_json::json!({"x": 5});
let block = ToolUseBlock::new("call_123", "calculator", input.clone());
assert_eq!(block.id(), "call_123");
assert_eq!(block.name(), "calculator");
assert_eq!(block.input(), &input);
}
#[test]
fn test_tool_result_block_getters() {
let content = serde_json::json!({"answer": 42});
let result = ToolResultBlock::new("call_123", content.clone());
assert_eq!(result.tool_use_id(), "call_123");
assert_eq!(result.content(), &content);
}
#[test]
fn test_message_role_serialization() {
assert_eq!(
serde_json::to_string(&MessageRole::User).unwrap(),
"\"user\""
);
assert_eq!(
serde_json::to_string(&MessageRole::System).unwrap(),
"\"system\""
);
assert_eq!(
serde_json::to_string(&MessageRole::Assistant).unwrap(),
"\"assistant\""
);
assert_eq!(
serde_json::to_string(&MessageRole::Tool).unwrap(),
"\"tool\""
);
}
#[test]
fn test_openai_request_serialization() {
let request = OpenAIRequest {
model: "gpt-3.5".to_string(),
messages: vec![OpenAIMessage {
role: "user".to_string(),
content: Some(OpenAIContent::Text("Hello".to_string())),
tool_calls: None,
tool_call_id: None,
}],
stream: true,
max_tokens: Some(100),
temperature: Some(0.7),
tools: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("gpt-3.5"));
assert!(json.contains("Hello"));
assert!(json.contains("\"stream\":true"));
}
#[test]
fn test_openai_chunk_deserialization() {
let json = r#"{
"id": "chunk_1",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "gpt-3.5",
"choices": [{
"index": 0,
"delta": {
"content": "Hello"
},
"finish_reason": null
}]
}"#;
let chunk: OpenAIChunk = serde_json::from_str(json).unwrap();
assert_eq!(chunk.id, "chunk_1");
assert_eq!(chunk.choices.len(), 1);
assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
}
#[test]
fn test_content_block_serialization() {
let text_block = ContentBlock::Text(TextBlock::new("Hello"));
let json = serde_json::to_string(&text_block).unwrap();
assert!(json.contains("\"type\":\"text\""));
assert!(json.contains("Hello"));
}
#[test]
fn test_agent_options_clone() {
let options1 = AgentOptions::builder()
.model("test-model")
.base_url("http://localhost:1234/v1")
.build()
.unwrap();
let options2 = options1.clone();
assert_eq!(options1.model, options2.model);
assert_eq!(options1.base_url, options2.base_url);
}
#[test]
fn test_temperature_validation() {
let result = AgentOptions::builder()
.model("test-model")
.base_url("http://localhost:1234/v1")
.temperature(-0.1)
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("temperature"));
let result = AgentOptions::builder()
.model("test-model")
.base_url("http://localhost:1234/v1")
.temperature(2.1)
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("temperature"));
let result = AgentOptions::builder()
.model("test-model")
.base_url("http://localhost:1234/v1")
.temperature(0.0)
.build();
assert!(result.is_ok());
let result = AgentOptions::builder()
.model("test-model")
.base_url("http://localhost:1234/v1")
.temperature(2.0)
.build();
assert!(result.is_ok());
}
#[test]
fn test_url_validation() {
let result = AgentOptions::builder()
.model("test-model")
.base_url("")
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("base_url"));
let result = AgentOptions::builder()
.model("test-model")
.base_url("not-a-url")
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("base_url"));
let result = AgentOptions::builder()
.model("test-model")
.base_url("http://localhost:1234/v1")
.build();
assert!(result.is_ok());
let result = AgentOptions::builder()
.model("test-model")
.base_url("https://api.openai.com/v1")
.build();
assert!(result.is_ok());
}
#[test]
fn test_model_validation() {
let result = AgentOptions::builder()
.model("")
.base_url("http://localhost:1234/v1")
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("model"));
let result = AgentOptions::builder()
.model(" ")
.base_url("http://localhost:1234/v1")
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("model"));
}
#[test]
fn test_max_tokens_validation() {
let result = AgentOptions::builder()
.model("test-model")
.base_url("http://localhost:1234/v1")
.max_tokens(0)
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("max_tokens"));
let result = AgentOptions::builder()
.model("test-model")
.base_url("http://localhost:1234/v1")
.max_tokens(1)
.build();
assert!(result.is_ok());
}
#[test]
fn test_agent_options_getters() {
let options = AgentOptions::builder()
.model("test-model")
.base_url("http://localhost:1234/v1")
.system_prompt("Test prompt")
.api_key("test-key")
.max_turns(5)
.max_tokens(1000)
.temperature(0.5)
.timeout(30)
.auto_execute_tools(true)
.max_tool_iterations(10)
.build()
.unwrap();
assert_eq!(options.system_prompt(), "Test prompt");
assert_eq!(options.model(), "test-model");
assert_eq!(options.base_url(), "http://localhost:1234/v1");
assert_eq!(options.api_key(), "test-key");
assert_eq!(options.max_turns(), 5);
assert_eq!(options.max_tokens(), Some(1000));
assert_eq!(options.temperature(), 0.5);
assert_eq!(options.timeout(), 30);
assert!(options.auto_execute_tools());
assert_eq!(options.max_tool_iterations(), 10);
assert_eq!(options.tools().len(), 0);
}
#[test]
fn test_image_block_from_url() {
let block = ImageBlock::from_url("https://example.com/image.jpg").unwrap();
assert_eq!(block.url(), "https://example.com/image.jpg");
assert!(matches!(block.detail(), ImageDetail::Auto));
}
#[test]
fn test_image_block_from_base64() {
let block = ImageBlock::from_base64("iVBORw0KGgoAAAA=", "image/jpeg").unwrap();
assert!(block.url().starts_with("data:image/jpeg;base64,"));
assert!(matches!(block.detail(), ImageDetail::Auto));
}
#[test]
fn test_image_block_from_file_path() {
use base64::{Engine as _, engine::general_purpose};
use std::io::Write;
let temp_dir = std::env::temp_dir();
let test_file = temp_dir.join("test_image.png");
let png_bytes = general_purpose::STANDARD
.decode("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==")
.unwrap();
std::fs::File::create(&test_file)
.unwrap()
.write_all(&png_bytes)
.unwrap();
let block = ImageBlock::from_file_path(&test_file).unwrap();
assert!(block.url().starts_with("data:image/png;base64,"));
assert!(matches!(block.detail(), ImageDetail::Auto));
let no_ext_file = temp_dir.join("test_image_no_ext");
std::fs::File::create(&no_ext_file)
.unwrap()
.write_all(&png_bytes)
.unwrap();
let result = ImageBlock::from_file_path(&no_ext_file);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("extension"));
let bad_ext_file = temp_dir.join("test_image.txt");
std::fs::File::create(&bad_ext_file)
.unwrap()
.write_all(&png_bytes)
.unwrap();
let result = ImageBlock::from_file_path(&bad_ext_file);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Unsupported"));
let _ = std::fs::remove_file(&test_file);
let _ = std::fs::remove_file(&no_ext_file);
let _ = std::fs::remove_file(&bad_ext_file);
}
#[test]
fn test_image_block_with_detail() {
let block = ImageBlock::from_url("https://example.com/image.jpg")
.unwrap()
.with_detail(ImageDetail::High);
assert!(matches!(block.detail(), ImageDetail::High));
}
#[test]
fn test_image_detail_serialization() {
let json = serde_json::to_string(&ImageDetail::Low).unwrap();
assert_eq!(json, "\"low\"");
let json = serde_json::to_string(&ImageDetail::High).unwrap();
assert_eq!(json, "\"high\"");
let json = serde_json::to_string(&ImageDetail::Auto).unwrap();
assert_eq!(json, "\"auto\"");
}
#[test]
fn test_content_block_image_variant() {
let image = ImageBlock::from_url("https://example.com/image.jpg").unwrap();
let block = ContentBlock::Image(image);
match block {
ContentBlock::Image(img) => {
assert_eq!(img.url(), "https://example.com/image.jpg");
}
_ => panic!("Expected Image variant"),
}
}
#[test]
fn test_openai_content_text_format() {
let content = OpenAIContent::Text("Hello".to_string());
let json = serde_json::to_value(&content).unwrap();
assert_eq!(json, serde_json::json!("Hello"));
}
#[test]
#[allow(deprecated)]
fn test_openai_content_parts_format() {
let parts = vec![
OpenAIContentPart::text("What's in this image?"),
OpenAIContentPart::image_url("https://example.com/img.jpg", ImageDetail::High),
];
let content = OpenAIContent::Parts(parts);
let json = serde_json::to_value(&content).unwrap();
assert!(json.is_array());
assert_eq!(json[0]["type"], "text");
assert_eq!(json[0]["text"], "What's in this image?");
assert_eq!(json[1]["type"], "image_url");
assert_eq!(json[1]["image_url"]["url"], "https://example.com/img.jpg");
assert_eq!(json[1]["image_url"]["detail"], "high");
}
#[test]
fn test_openai_content_part_text_serialization() {
let part = OpenAIContentPart::text("Hello world");
let json = serde_json::to_value(&part).unwrap();
assert_eq!(json["type"], "text");
assert_eq!(json["text"], "Hello world");
assert!(json.get("image_url").is_none());
}
#[test]
#[allow(deprecated)]
fn test_openai_content_part_image_serialization() {
let part = OpenAIContentPart::image_url("https://example.com/img.jpg", ImageDetail::Low);
let json = serde_json::to_value(&part).unwrap();
assert_eq!(json["type"], "image_url");
assert_eq!(json["image_url"]["url"], "https://example.com/img.jpg");
assert_eq!(json["image_url"]["detail"], "low");
assert!(json.get("text").is_none());
}
#[test]
#[allow(deprecated)]
fn test_openai_content_part_enum_exhaustiveness() {
let text_part = OpenAIContentPart::text("test");
let image_part = OpenAIContentPart::image_url("url", ImageDetail::Auto);
match text_part {
OpenAIContentPart::Text { .. } => {
}
OpenAIContentPart::ImageUrl { .. } => {
panic!("Text part should not match ImageUrl variant");
}
}
match image_part {
OpenAIContentPart::Text { .. } => {
panic!("Image part should not match Text variant");
}
OpenAIContentPart::ImageUrl { .. } => {
}
}
}
#[test]
fn test_image_detail_display() {
assert_eq!(ImageDetail::Low.to_string(), "low");
assert_eq!(ImageDetail::High.to_string(), "high");
assert_eq!(ImageDetail::Auto.to_string(), "auto");
}
#[test]
fn test_image_block_from_url_rejects_empty() {
let result = ImageBlock::from_url("");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("empty"));
}
#[test]
fn test_image_block_from_url_rejects_invalid_scheme() {
let result = ImageBlock::from_url("ftp://example.com/image.jpg");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("scheme") || err.to_string().contains("http"));
}
#[test]
fn test_image_block_from_url_rejects_relative_path() {
let result = ImageBlock::from_url("/images/photo.jpg");
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), crate::Error::InvalidInput(_)));
}
#[test]
fn test_image_block_from_url_accepts_http() {
let result = ImageBlock::from_url("http://example.com/image.jpg");
assert!(result.is_ok());
assert_eq!(result.unwrap().url(), "http://example.com/image.jpg");
}
#[test]
fn test_image_block_from_url_accepts_https() {
let result = ImageBlock::from_url("https://example.com/image.jpg");
assert!(result.is_ok());
assert_eq!(result.unwrap().url(), "https://example.com/image.jpg");
}
#[test]
fn test_image_block_from_url_accepts_data_uri() {
let data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
let result = ImageBlock::from_url(data_uri);
assert!(result.is_ok());
assert_eq!(result.unwrap().url(), data_uri);
}
#[test]
fn test_image_block_from_url_rejects_malformed_data_uri() {
let result = ImageBlock::from_url("data:notanimage");
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), crate::Error::InvalidInput(_)));
}
#[test]
fn test_from_url_rejects_control_characters() {
let invalid_urls = [
"https://example.com\n/image.jpg", "https://example.com\t/image.jpg", "https://example.com\0/image.jpg", "https://example.com\r/image.jpg", ];
for url in &invalid_urls {
let result = ImageBlock::from_url(*url);
assert!(
result.is_err(),
"Should reject URL with control characters: {:?}",
url
);
let err = result.unwrap_err();
assert!(
err.to_string().contains("control") || err.to_string().contains("character"),
"Error should mention control characters, got: {}",
err
);
}
}
#[test]
fn test_from_url_warns_very_long_url() {
let long_url = format!("https://example.com/{}", "a".repeat(2980));
let result = ImageBlock::from_url(&long_url);
assert!(result.is_ok(), "Should accept long URL (with warning)");
let block = result.unwrap();
assert_eq!(block.url().len(), 3000);
}
#[test]
fn test_from_url_validates_data_uri_base64() {
let invalid_data_uris = [
"data:image/png;base64,", "data:image/png;base64,hello world", "data:image/png;base64,@@@", "data:image/png;base64,ABC", "data:image/png;base64,==abc", "data:image/png;base64,ab==cd", ];
for uri in &invalid_data_uris {
let result = ImageBlock::from_url(*uri);
assert!(
result.is_err(),
"Should reject data URI with invalid base64: {}",
uri
);
}
}
#[test]
fn test_from_url_rejects_javascript_scheme() {
let result = ImageBlock::from_url("javascript:alert(1)");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("http") || err.to_string().contains("scheme"),
"Error should mention scheme requirements, got: {}",
err
);
}
#[test]
fn test_from_url_rejects_file_scheme() {
let result = ImageBlock::from_url("file:///etc/passwd");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("http") || err.to_string().contains("scheme"),
"Error should mention scheme requirements, got: {}",
err
);
}
#[test]
fn test_image_block_from_base64_rejects_empty() {
let result = ImageBlock::from_base64("", "image/png");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("empty"));
}
#[test]
fn test_image_block_from_base64_rejects_invalid_mime() {
let result = ImageBlock::from_base64("somedata", "text/plain");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("MIME") || err.to_string().contains("image"));
}
#[test]
fn test_image_block_from_base64_accepts_valid_input() {
let base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
let result = ImageBlock::from_base64(base64, "image/png");
assert!(result.is_ok());
let block = result.unwrap();
assert!(block.url().starts_with("data:image/png;base64,"));
}
#[test]
fn test_image_block_from_base64_accepts_all_image_types() {
let base64 = "iVBORw0KGgo="; let mime_types = ["image/jpeg", "image/png", "image/gif", "image/webp"];
for mime in &mime_types {
let result = ImageBlock::from_base64(base64, *mime);
assert!(result.is_ok(), "Should accept {}", mime);
let block = result.unwrap();
assert!(block.url().starts_with(&format!("data:{};base64,", mime)));
}
}
#[test]
fn test_from_base64_rejects_invalid_characters() {
let invalid_inputs = [
"hello world", "test@data", "test#data", "test$data", "test%data", "abc\ndef", ];
for invalid in &invalid_inputs {
let result = ImageBlock::from_base64(invalid, "image/png");
assert!(
result.is_err(),
"Should reject base64 with invalid characters: {}",
invalid
);
let err = result.unwrap_err();
assert!(
err.to_string().contains("base64") || err.to_string().contains("character"),
"Error should mention base64 or character issue, got: {}",
err
);
}
}
#[test]
fn test_from_base64_rejects_malformed_padding() {
let invalid_padding = [
"A", "AB", "ABC", "ABCD===", ];
for invalid in &invalid_padding {
let result = ImageBlock::from_base64(invalid, "image/png");
assert!(
result.is_err(),
"Should reject malformed padding: {}",
invalid
);
}
}
#[test]
fn test_from_base64_rejects_mime_with_semicolon() {
let result = ImageBlock::from_base64("AAAA", "image/png;charset=utf-8");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("MIME") || err.to_string().contains("character"),
"Error should mention MIME or character issue, got: {}",
err
);
}
#[test]
fn test_from_base64_rejects_mime_with_newline() {
let invalid_mimes = [
"image/png\n",
"image/png\r",
"image/png\r\n",
"image/png,extra",
];
for mime in &invalid_mimes {
let result = ImageBlock::from_base64("AAAA", mime);
assert!(
result.is_err(),
"Should reject MIME with control/injection chars: {:?}",
mime
);
}
}
#[test]
fn test_from_base64_warns_large_data() {
let large_base64 = "A".repeat(15_000_000);
let result = ImageBlock::from_base64(&large_base64, "image/png");
assert!(result.is_ok(), "Should accept large base64 (with warning)");
let block = result.unwrap();
assert!(block.url().len() > 15_000_000);
}
#[test]
fn test_from_base64_accepts_all_image_mime_types() {
let valid_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
let mime_types = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"image/avif",
"image/bmp",
"image/tiff",
];
for mime in &mime_types {
let result = ImageBlock::from_base64(valid_data, *mime);
assert!(result.is_ok(), "Should accept valid MIME type: {}", mime);
}
}
#[test]
fn test_image_block_from_base64_rejects_empty_mime() {
let result = ImageBlock::from_base64("somedata", "");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("MIME") || err.to_string().contains("empty"));
}
}