use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use crate::error::Result;
use futures::stream::BoxStream;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
impl ToolDefinition {
pub fn function(
name: impl Into<String>,
description: impl Into<String>,
parameters: JsonValue,
) -> Self {
Self {
tool_type: "function".to_string(),
function: FunctionDefinition {
name: name.into(),
description: description.into(),
parameters,
strict: Some(true),
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thought_signature: Option<String>,
}
impl ToolCall {
pub fn parse_arguments<T: serde::de::DeserializeOwned>(&self) -> Result<T> {
serde_json::from_str(&self.function.arguments).map_err(|e| {
crate::error::LlmError::InvalidRequest(format!("Failed to parse tool arguments: {}", e))
})
}
pub fn name(&self) -> &str {
&self.function.name
}
pub fn arguments(&self) -> &str {
&self.function.arguments
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Auto(String),
Required(String),
Function {
#[serde(rename = "type")]
choice_type: String,
function: ToolChoiceFunction,
},
}
impl ToolChoice {
pub fn auto() -> Self {
ToolChoice::Auto("auto".to_string())
}
pub fn required() -> Self {
ToolChoice::Required("required".to_string())
}
pub fn function(name: impl Into<String>) -> Self {
ToolChoice::Function {
choice_type: "function".to_string(),
function: ToolChoiceFunction { name: name.into() },
}
}
pub fn none() -> Self {
ToolChoice::Auto("none".to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceFunction {
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub tool_call_id: String,
pub role: String,
pub content: String,
}
impl ToolResult {
pub fn new(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
tool_call_id: tool_call_id.into(),
role: "tool".to_string(),
content: content.into(),
}
}
pub fn error(tool_call_id: impl Into<String>, error: impl std::fmt::Display) -> Self {
Self {
tool_call_id: tool_call_id.into(),
role: "tool".to_string(),
content: format!("Error: {}", error),
}
}
}
#[derive(Debug, Clone)]
pub enum StreamChunk {
Content(String),
ThinkingContent {
text: String,
tokens_used: Option<usize>,
budget_total: Option<usize>,
},
ToolCallDelta {
index: usize,
id: Option<String>,
function_name: Option<String>,
function_arguments: Option<String>,
thought_signature: Option<String>,
},
Finished {
reason: String,
#[allow(dead_code)]
ttft_ms: Option<f64>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMResponse {
pub content: String,
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
pub model: String,
pub finish_reason: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCall>,
pub metadata: HashMap<String, serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_hit_tokens: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thinking_tokens: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thinking_content: Option<String>,
}
impl LLMResponse {
pub fn new(content: impl Into<String>, model: impl Into<String>) -> Self {
Self {
content: content.into(),
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
model: model.into(),
finish_reason: None,
tool_calls: Vec::new(),
metadata: HashMap::new(),
cache_hit_tokens: None,
thinking_tokens: None,
thinking_content: None,
}
}
pub fn with_usage(mut self, prompt: usize, completion: usize) -> Self {
self.prompt_tokens = prompt;
self.completion_tokens = completion;
self.total_tokens = prompt + completion;
self
}
pub fn with_finish_reason(mut self, reason: impl Into<String>) -> Self {
self.finish_reason = Some(reason.into());
self
}
pub fn with_tool_calls(mut self, calls: Vec<ToolCall>) -> Self {
self.tool_calls = calls;
self
}
pub fn with_cache_hit_tokens(mut self, tokens: usize) -> Self {
self.cache_hit_tokens = Some(tokens);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn with_thinking_tokens(mut self, tokens: usize) -> Self {
self.thinking_tokens = Some(tokens);
self
}
pub fn with_thinking_content(mut self, content: impl Into<String>) -> Self {
self.thinking_content = Some(content.into());
self
}
pub fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
pub fn has_thinking(&self) -> bool {
self.thinking_tokens.is_some() || self.thinking_content.is_some()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CompletionOptions {
pub max_tokens: Option<usize>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub response_format: Option<String>,
pub system_prompt: Option<String>,
}
impl CompletionOptions {
pub fn with_temperature(temperature: f32) -> Self {
Self {
temperature: Some(temperature),
..Default::default()
}
}
pub fn json_mode() -> Self {
Self {
response_format: Some("json_object".to_string()),
..Default::default()
}
}
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
fn name(&self) -> &str;
fn model(&self) -> &str;
fn max_context_length(&self) -> usize;
async fn complete(&self, prompt: &str) -> Result<LLMResponse>;
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse>;
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse>;
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let _ = (tools, tool_choice);
self.chat(messages, options).await
}
async fn stream(&self, _prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
Err(crate::error::LlmError::NotSupported(
"Streaming not supported".to_string(),
))
}
async fn chat_with_tools_stream(
&self,
_messages: &[ChatMessage],
_tools: &[ToolDefinition],
_tool_choice: Option<ToolChoice>,
_options: Option<&CompletionOptions>,
) -> Result<BoxStream<'static, Result<StreamChunk>>> {
Err(crate::error::LlmError::NotSupported(
"Streaming tool calls not supported by this provider".to_string(),
))
}
fn supports_streaming(&self) -> bool {
false
}
fn supports_tool_streaming(&self) -> bool {
false
}
fn supports_json_mode(&self) -> bool {
false
}
fn supports_function_calling(&self) -> bool {
false
}
fn model_name(&self) -> Option<String> {
let m = self.model();
if m.is_empty() {
None
} else {
Some(m.to_string())
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ImageData {
pub data: String,
pub mime_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
impl ImageData {
pub fn new(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
Self {
data: data.into(),
mime_type: mime_type.into(),
detail: None,
}
}
pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
self.detail = Some(detail.into());
self
}
pub fn to_data_uri(&self) -> String {
format!("data:{};base64,{}", self.mime_type, self.data)
}
pub fn from_url(url: impl Into<String>) -> Self {
Self {
data: url.into(),
mime_type: "url".to_string(),
detail: None,
}
}
pub fn is_url(&self) -> bool {
self.mime_type == "url"
}
pub fn to_api_url(&self) -> String {
if self.is_url() {
self.data.clone()
} else {
self.to_data_uri()
}
}
pub fn is_supported_mime(&self) -> bool {
matches!(
self.mime_type.as_str(),
"image/png" | "image/jpeg" | "image/gif" | "image/webp" | "url"
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct CacheControl {
#[serde(rename = "type")]
pub cache_type: String,
}
impl CacheControl {
pub fn ephemeral() -> Self {
Self {
cache_type: "ephemeral".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
#[serde(skip_serializing_if = "Option::is_none")]
pub images: Option<Vec<ImageData>>,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: ChatRole::System,
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: None,
cache_control: None,
images: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: ChatRole::User,
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: None,
cache_control: None,
images: None,
}
}
pub fn user_with_images(content: impl Into<String>, images: Vec<ImageData>) -> Self {
Self {
role: ChatRole::User,
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: None,
cache_control: None,
images: if images.is_empty() {
None
} else {
Some(images)
},
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: ChatRole::Assistant,
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: None,
cache_control: None,
images: None,
}
}
pub fn assistant_with_tools(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
Self {
role: ChatRole::Assistant,
content: content.into(),
name: None,
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: None,
cache_control: None,
images: None,
}
}
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: ChatRole::Tool,
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
cache_control: None,
images: None,
}
}
pub fn has_images(&self) -> bool {
self.images.as_ref().map(|v| !v.is_empty()).unwrap_or(false)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChatRole {
System,
User,
Assistant,
Tool,
Function,
}
impl ChatRole {
pub fn as_str(&self) -> &'static str {
match self {
ChatRole::System => "system",
ChatRole::User => "user",
ChatRole::Assistant => "assistant",
ChatRole::Tool => "tool",
ChatRole::Function => "function",
}
}
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
fn name(&self) -> &str;
fn model(&self) -> &str;
fn dimension(&self) -> usize;
fn max_tokens(&self) -> usize;
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
async fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
let results = self.embed(&[text.to_string()]).await?;
results
.into_iter()
.next()
.ok_or_else(|| crate::error::LlmError::Unknown("Empty embedding result".to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_response_builder() {
let response = LLMResponse::new("Hello, world!", "gpt-4")
.with_usage(10, 5)
.with_finish_reason("stop");
assert_eq!(response.content, "Hello, world!");
assert_eq!(response.model, "gpt-4");
assert_eq!(response.prompt_tokens, 10);
assert_eq!(response.completion_tokens, 5);
assert_eq!(response.total_tokens, 15);
assert_eq!(response.finish_reason, Some("stop".to_string()));
}
#[test]
fn test_llm_response_with_cache_hit_tokens() {
let response = LLMResponse::new("cached response", "gemini-pro")
.with_usage(1000, 50)
.with_cache_hit_tokens(800);
assert_eq!(response.cache_hit_tokens, Some(800));
assert_eq!(response.prompt_tokens, 1000);
let cache_rate = response.cache_hit_tokens.unwrap() as f64 / response.prompt_tokens as f64;
assert!((cache_rate - 0.8).abs() < 0.001);
}
#[test]
fn test_llm_response_no_cache_hit_tokens() {
let response = LLMResponse::new("no cache", "gpt-4").with_usage(100, 20);
assert_eq!(response.cache_hit_tokens, None);
}
#[test]
fn test_chat_message_constructors() {
let system = ChatMessage::system("You are helpful");
assert_eq!(system.role, ChatRole::System);
let user = ChatMessage::user("Hello");
assert_eq!(user.role, ChatRole::User);
let assistant = ChatMessage::assistant("Hi there!");
assert_eq!(assistant.role, ChatRole::Assistant);
}
#[test]
fn test_cache_control_ephemeral() {
let cache = CacheControl::ephemeral();
assert_eq!(cache.cache_type, "ephemeral");
}
#[test]
fn test_cache_control_serialization() {
let cache = CacheControl::ephemeral();
let json = serde_json::to_value(&cache).unwrap();
assert_eq!(json["type"], "ephemeral");
assert!(!json.as_object().unwrap().contains_key("cache_type"));
}
#[test]
fn test_message_with_cache_control() {
let mut msg = ChatMessage::system("System prompt");
msg.cache_control = Some(CacheControl::ephemeral());
let json = serde_json::to_value(&msg).unwrap();
assert!(json.as_object().unwrap().contains_key("cache_control"));
assert_eq!(json["cache_control"]["type"], "ephemeral");
}
#[test]
fn test_message_without_cache_control() {
let msg = ChatMessage::user("Hello");
let json = serde_json::to_value(&msg).unwrap();
assert!(!json.as_object().unwrap().contains_key("cache_control"));
}
#[test]
fn test_cache_control_roundtrip() {
let original = CacheControl {
cache_type: "ephemeral".to_string(),
};
let json_str = serde_json::to_string(&original).unwrap();
let deserialized: CacheControl = serde_json::from_str(&json_str).unwrap();
assert_eq!(original.cache_type, deserialized.cache_type);
}
#[test]
fn test_image_data_new() {
let image = ImageData::new("iVBORw0KGgo...", "image/png");
assert_eq!(image.mime_type, "image/png");
assert_eq!(image.data, "iVBORw0KGgo...");
assert_eq!(image.detail, None);
}
#[test]
fn test_image_data_with_detail() {
let image = ImageData::new("data123", "image/jpeg").with_detail("high");
assert_eq!(image.detail, Some("high".to_string()));
}
#[test]
fn test_image_data_to_data_uri() {
let image = ImageData::new("base64data", "image/png");
assert_eq!(image.to_data_uri(), "data:image/png;base64,base64data");
}
#[test]
fn test_image_data_supported_mime() {
assert!(ImageData::new("", "image/png").is_supported_mime());
assert!(ImageData::new("", "image/jpeg").is_supported_mime());
assert!(ImageData::new("", "image/gif").is_supported_mime());
assert!(ImageData::new("", "image/webp").is_supported_mime());
assert!(!ImageData::new("", "image/bmp").is_supported_mime());
assert!(!ImageData::new("", "text/plain").is_supported_mime());
}
#[test]
fn test_chat_message_user_with_images() {
let images = vec![ImageData::new("data1", "image/png")];
let msg = ChatMessage::user_with_images("What's this?", images);
assert_eq!(msg.role, ChatRole::User);
assert_eq!(msg.content, "What's this?");
assert!(msg.has_images());
assert_eq!(msg.images.as_ref().unwrap().len(), 1);
}
#[test]
fn test_chat_message_user_with_empty_images() {
let msg = ChatMessage::user_with_images("Hello", vec![]);
assert!(!msg.has_images());
assert!(msg.images.is_none());
}
#[test]
fn test_image_data_serialization() {
let image = ImageData::new("base64", "image/png").with_detail("low");
let json = serde_json::to_value(&image).unwrap();
assert_eq!(json["data"], "base64");
assert_eq!(json["mime_type"], "image/png");
assert_eq!(json["detail"], "low");
}
#[test]
fn test_tool_definition_function_constructor() {
let tool = ToolDefinition::function(
"my_func",
"Does something",
serde_json::json!({"type": "object"}),
);
assert_eq!(tool.tool_type, "function");
assert_eq!(tool.function.name, "my_func");
assert_eq!(tool.function.description, "Does something");
assert_eq!(tool.function.strict, Some(true));
}
#[test]
fn test_tool_definition_serialization() {
let tool = ToolDefinition::function(
"search",
"Search the web",
serde_json::json!({"type": "object", "properties": {}}),
);
let json = serde_json::to_value(&tool).unwrap();
assert_eq!(json["type"], "function");
assert_eq!(json["function"]["name"], "search");
}
#[test]
fn test_tool_call_name_and_arguments() {
let tc = ToolCall {
id: "call_1".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"city": "Paris"}"#.to_string(),
},
thought_signature: None,
};
assert_eq!(tc.name(), "get_weather");
assert_eq!(tc.arguments(), r#"{"city": "Paris"}"#);
}
#[test]
fn test_tool_call_parse_arguments() {
let tc = ToolCall {
id: "call_2".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "add".to_string(),
arguments: r#"{"a": 1, "b": 2}"#.to_string(),
},
thought_signature: None,
};
let parsed: serde_json::Value = tc.parse_arguments().unwrap();
assert_eq!(parsed["a"], 1);
assert_eq!(parsed["b"], 2);
}
#[test]
fn test_tool_call_parse_arguments_invalid() {
let tc = ToolCall {
id: "call_3".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "bad".to_string(),
arguments: "not json".to_string(),
},
thought_signature: None,
};
let result: std::result::Result<serde_json::Value, _> = tc.parse_arguments();
assert!(result.is_err());
}
#[test]
fn test_tool_choice_auto() {
let tc = ToolChoice::auto();
let json = serde_json::to_value(&tc).unwrap();
assert_eq!(json, "auto");
}
#[test]
fn test_tool_choice_required() {
let tc = ToolChoice::required();
let json = serde_json::to_value(&tc).unwrap();
assert_eq!(json, "required");
}
#[test]
fn test_tool_choice_none() {
let tc = ToolChoice::none();
let json = serde_json::to_value(&tc).unwrap();
assert_eq!(json, "none");
}
#[test]
fn test_tool_choice_function() {
let tc = ToolChoice::function("get_weather");
if let ToolChoice::Function {
choice_type,
function,
} = tc
{
assert_eq!(choice_type, "function");
assert_eq!(function.name, "get_weather");
} else {
panic!("Expected ToolChoice::Function");
}
}
#[test]
fn test_tool_result_new() {
let tr = ToolResult::new("call_1", "sunny, 20C");
assert_eq!(tr.tool_call_id, "call_1");
assert_eq!(tr.role, "tool");
assert_eq!(tr.content, "sunny, 20C");
}
#[test]
fn test_tool_result_error() {
let tr = ToolResult::error("call_2", "City not found");
assert_eq!(tr.tool_call_id, "call_2");
assert_eq!(tr.content, "Error: City not found");
}
#[test]
fn test_llm_response_with_tool_calls() {
let tc = vec![ToolCall {
id: "c1".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "search".to_string(),
arguments: "{}".to_string(),
},
thought_signature: None,
}];
let resp = LLMResponse::new("", "gpt-4").with_tool_calls(tc);
assert!(resp.has_tool_calls());
assert_eq!(resp.tool_calls.len(), 1);
}
#[test]
fn test_llm_response_no_tool_calls() {
let resp = LLMResponse::new("hello", "gpt-4");
assert!(!resp.has_tool_calls());
}
#[test]
fn test_llm_response_with_metadata() {
let resp =
LLMResponse::new("hi", "gpt-4").with_metadata("id", serde_json::json!("resp_123"));
assert_eq!(
resp.metadata.get("id"),
Some(&serde_json::json!("resp_123"))
);
}
#[test]
fn test_llm_response_with_thinking() {
let resp = LLMResponse::new("answer", "claude-3")
.with_thinking_tokens(500)
.with_thinking_content("Let me think...");
assert!(resp.has_thinking());
assert_eq!(resp.thinking_tokens, Some(500));
assert_eq!(resp.thinking_content, Some("Let me think...".to_string()));
}
#[test]
fn test_llm_response_has_thinking_tokens_only() {
let resp = LLMResponse::new("x", "o1").with_thinking_tokens(100);
assert!(resp.has_thinking());
}
#[test]
fn test_llm_response_has_thinking_content_only() {
let resp = LLMResponse::new("x", "claude").with_thinking_content("hmm");
assert!(resp.has_thinking());
}
#[test]
fn test_llm_response_no_thinking() {
let resp = LLMResponse::new("x", "gpt-4");
assert!(!resp.has_thinking());
}
#[test]
fn test_completion_options_default() {
let opts = CompletionOptions::default();
assert!(opts.max_tokens.is_none());
assert!(opts.temperature.is_none());
assert!(opts.response_format.is_none());
}
#[test]
fn test_completion_options_with_temperature() {
let opts = CompletionOptions::with_temperature(0.7);
assert_eq!(opts.temperature, Some(0.7));
assert!(opts.max_tokens.is_none());
}
#[test]
fn test_completion_options_json_mode() {
let opts = CompletionOptions::json_mode();
assert_eq!(opts.response_format, Some("json_object".to_string()));
}
#[test]
fn test_chat_role_as_str() {
assert_eq!(ChatRole::System.as_str(), "system");
assert_eq!(ChatRole::User.as_str(), "user");
assert_eq!(ChatRole::Assistant.as_str(), "assistant");
assert_eq!(ChatRole::Tool.as_str(), "tool");
assert_eq!(ChatRole::Function.as_str(), "function");
}
#[test]
fn test_chat_role_serialization() {
let json = serde_json::to_value(ChatRole::User).unwrap();
assert_eq!(json, "user");
let json = serde_json::to_value(ChatRole::Tool).unwrap();
assert_eq!(json, "tool");
}
#[test]
fn test_chat_message_assistant_with_tools() {
let tc = vec![ToolCall {
id: "c1".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "search".to_string(),
arguments: "{}".to_string(),
},
thought_signature: None,
}];
let msg = ChatMessage::assistant_with_tools("I'll search", tc);
assert_eq!(msg.role, ChatRole::Assistant);
assert!(msg.tool_calls.is_some());
assert_eq!(msg.tool_calls.as_ref().unwrap().len(), 1);
}
#[test]
fn test_chat_message_assistant_with_empty_tools() {
let msg = ChatMessage::assistant_with_tools("just text", vec![]);
assert!(msg.tool_calls.is_none());
}
#[test]
fn test_chat_message_tool_result() {
let msg = ChatMessage::tool_result("call_1", "result data");
assert_eq!(msg.role, ChatRole::Tool);
assert_eq!(msg.tool_call_id, Some("call_1".to_string()));
assert_eq!(msg.content, "result data");
}
#[test]
fn test_chat_message_has_images_false() {
let msg = ChatMessage::user("hello");
assert!(!msg.has_images());
}
#[test]
fn test_image_data_equality() {
let a = ImageData::new("data", "image/png");
let b = ImageData::new("data", "image/png");
assert_eq!(a, b);
let c = ImageData::new("data", "image/jpeg");
assert_ne!(a, c);
}
#[test]
fn test_stream_chunk_content() {
let chunk = StreamChunk::Content("hello".to_string());
if let StreamChunk::Content(text) = chunk {
assert_eq!(text, "hello");
} else {
panic!("Expected Content");
}
}
#[test]
fn test_stream_chunk_thinking() {
let chunk = StreamChunk::ThinkingContent {
text: "reasoning...".to_string(),
tokens_used: Some(50),
budget_total: Some(10000),
};
if let StreamChunk::ThinkingContent {
text,
tokens_used,
budget_total,
} = chunk
{
assert_eq!(text, "reasoning...");
assert_eq!(tokens_used, Some(50));
assert_eq!(budget_total, Some(10000));
}
}
#[test]
fn test_stream_chunk_finished() {
let chunk = StreamChunk::Finished {
reason: "stop".to_string(),
ttft_ms: Some(120.5),
};
if let StreamChunk::Finished { reason, ttft_ms } = chunk {
assert_eq!(reason, "stop");
assert_eq!(ttft_ms, Some(120.5));
}
}
#[test]
fn test_stream_chunk_tool_call_delta() {
let chunk = StreamChunk::ToolCallDelta {
index: 0,
id: Some("call_1".to_string()),
function_name: Some("search".to_string()),
function_arguments: Some(r#"{"q":"#.to_string()),
thought_signature: None,
};
if let StreamChunk::ToolCallDelta {
index,
id,
function_name,
function_arguments,
..
} = chunk
{
assert_eq!(index, 0);
assert_eq!(id, Some("call_1".to_string()));
assert_eq!(function_name, Some("search".to_string()));
assert!(function_arguments.is_some());
}
}
}