pub mod anthropic;
pub mod gemini;
pub mod openai;
pub mod registry;
pub mod retry;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: String,
is_error: bool,
},
}
impl ContentPart {
pub fn text(s: impl Into<String>) -> Self {
ContentPart::Text { text: s.into() }
}
pub fn as_text(&self) -> Option<&str> {
match self {
ContentPart::Text { text } => Some(text),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
fn serialize_content_parts<S>(parts: &Vec<ContentPart>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
if parts.len() == 1 {
if let ContentPart::Text { text } = &parts[0] {
return s.serialize_str(text);
}
}
parts.serialize(s)
}
fn deserialize_content_parts<'de, D>(d: D) -> Result<Vec<ContentPart>, D::Error>
where
D: Deserializer<'de>,
{
let value = serde_json::Value::deserialize(d)?;
match value {
serde_json::Value::Null => Ok(vec![]),
serde_json::Value::String(s) => Ok(vec![ContentPart::Text { text: s }]),
serde_json::Value::Array(_) => {
let parts: Vec<ContentPart> =
serde_json::from_value(value).map_err(serde::de::Error::custom)?;
Ok(parts)
}
other => Err(serde::de::Error::custom(format!(
"Expected string or array for message content, got: {}",
other
))),
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Message {
pub role: Role,
#[serde(
default,
deserialize_with = "deserialize_content_parts",
skip_serializing_if = "Vec::is_empty"
)]
pub content: Vec<ContentPart>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl Serialize for Message {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeMap;
let mut count = 1; if !self.content.is_empty() {
count += 1; }
if self.tool_calls.is_some() {
count += 1;
}
if self.tool_call_id.is_some() {
count += 1;
}
if self.name.is_some() {
count += 1;
}
let mut map = s.serialize_map(Some(count))?;
map.serialize_entry("role", &self.role)?;
if !self.content.is_empty() {
map.serialize_entry("content", &ContentPartsSerializer(&self.content))?;
}
if let Some(tc) = &self.tool_calls {
map.serialize_entry("tool_calls", tc)?;
}
if let Some(id) = &self.tool_call_id {
map.serialize_entry("tool_call_id", id)?;
}
if let Some(n) = &self.name {
map.serialize_entry("name", n)?;
}
map.end()
}
}
struct ContentPartsSerializer<'a>(&'a Vec<ContentPart>);
impl<'a> Serialize for ContentPartsSerializer<'a> {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
serialize_content_parts(self.0, s)
}
}
impl Message {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: vec![ContentPart::text(content)],
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: vec![ContentPart::text(content)],
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn assistant(content: Option<String>, tool_calls: Option<Vec<ToolCall>>) -> Self {
Self {
role: Role::Assistant,
content: content
.map(|s| vec![ContentPart::text(s)])
.unwrap_or_default(),
tool_calls,
tool_call_id: None,
name: None,
}
}
pub fn tool(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: Role::Tool,
content: vec![ContentPart::text(content)],
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
name: None,
}
}
pub fn text_content(&self) -> Option<String> {
let texts: Vec<&str> = self.content.iter().filter_map(|p| p.as_text()).collect();
if texts.is_empty() {
None
} else {
Some(texts.join("\n"))
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
#[serde(rename = "type")]
pub def_type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone)]
pub struct LlmResponse {
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub usage: Option<Usage>,
}
#[async_trait]
pub trait LlmProvider: Send + Sync {
async fn chat_completion(
&self,
messages: &[Message],
tools: &[ToolDefinition],
) -> Result<LlmResponse>;
fn is_copilot(&self) -> bool {
false
}
fn set_stream_print(&self, _enabled: bool) {}
async fn set_copilot_oauth_token(&self, _token: String) {}
}
#[allow(dead_code)]
pub struct NullLlmProvider;
#[async_trait]
impl LlmProvider for NullLlmProvider {
async fn chat_completion(
&self,
_messages: &[Message],
_tools: &[ToolDefinition],
) -> Result<LlmResponse> {
Ok(LlmResponse {
content: Some(String::new()),
tool_calls: None,
usage: None,
})
}
}
#[allow(dead_code)]
pub fn image_to_content_part(path: &std::path::Path) -> anyhow::Result<ContentPart> {
use anyhow::Context as _;
use std::io::Read as _;
let mime = match path.extension().and_then(|e| e.to_str()) {
Some("jpg") | Some("jpeg") => "image/jpeg",
Some("png") => "image/png",
Some("gif") => "image/gif",
Some("webp") => "image/webp",
Some(ext) => anyhow::bail!("Unsupported image format: .{}", ext),
None => anyhow::bail!("Cannot determine image format: no file extension"),
};
let mut file = std::fs::File::open(path)
.with_context(|| format!("Cannot open image file: {}", path.display()))?;
let mut bytes = Vec::new();
file.read_to_end(&mut bytes)
.with_context(|| format!("Cannot read image file: {}", path.display()))?;
use base64::Engine as _;
let b64 = base64::engine::general_purpose::STANDARD.encode(&bytes);
let data_uri = format!("data:{};base64,{}", mime, b64);
Ok(ContentPart::ImageUrl {
image_url: ImageUrl {
url: data_uri,
detail: None,
},
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_usage_default_is_zero() {
let u = Usage::default();
assert_eq!(u.prompt_tokens, 0, "prompt_tokens should start at 0");
assert_eq!(
u.completion_tokens, 0,
"completion_tokens should start at 0"
);
assert_eq!(u.total_tokens, 0, "total_tokens should start at 0");
}
#[test]
fn test_usage_equality() {
let a = Usage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
};
let b = Usage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
};
assert_eq!(a, b);
}
#[test]
fn test_llm_response_usage_is_none_by_default() {
let resp = LlmResponse {
content: Some("hello".to_string()),
tool_calls: None,
usage: None,
};
assert!(
resp.usage.is_none(),
"usage should be None when not provided by API"
);
}
#[test]
fn test_llm_response_usage_can_be_some() {
let resp = LlmResponse {
content: Some("answer".to_string()),
tool_calls: None,
usage: Some(Usage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
}),
};
let u = resp.usage.unwrap();
assert_eq!(u.prompt_tokens, 100);
assert_eq!(u.completion_tokens, 50);
assert_eq!(u.total_tokens, 150);
}
#[test]
fn test_content_part_text_serde() {
let part = ContentPart::text("hello world");
let json = serde_json::to_string(&part).unwrap();
assert!(json.contains("\"type\":\"text\""));
assert!(json.contains("\"text\":\"hello world\""));
let back: ContentPart = serde_json::from_str(&json).unwrap();
assert_eq!(part, back);
}
#[test]
fn test_content_part_image_url_serde() {
let part = ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/img.png".to_string(),
detail: Some("auto".to_string()),
},
};
let json = serde_json::to_string(&part).unwrap();
assert!(json.contains("\"type\":\"image_url\""));
let back: ContentPart = serde_json::from_str(&json).unwrap();
assert_eq!(part, back);
}
#[test]
fn test_message_single_text_serializes_as_string() {
let msg = Message::user("Hello!");
let v: serde_json::Value = serde_json::to_value(&msg).unwrap();
assert_eq!(
v["content"],
serde_json::Value::String("Hello!".to_string())
);
}
#[test]
fn test_message_multipart_serializes_as_array() {
let msg = Message {
role: Role::User,
content: vec![
ContentPart::text("describe this image"),
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "data:image/png;base64,abc".to_string(),
detail: None,
},
},
],
tool_calls: None,
tool_call_id: None,
name: None,
};
let v: serde_json::Value = serde_json::to_value(&msg).unwrap();
assert!(
v["content"].is_array(),
"Expected array for multi-part content"
);
assert_eq!(v["content"].as_array().unwrap().len(), 2);
}
#[test]
fn test_backwards_compat_string_content() {
let old_json = r#"{"role":"user","content":"Write a hello world program"}"#;
let msg: Message = serde_json::from_str(old_json).unwrap();
assert_eq!(msg.role, Role::User);
assert_eq!(msg.content.len(), 1);
assert_eq!(
msg.text_content(),
Some("Write a hello world program".to_string())
);
}
#[test]
fn test_null_content_deserializes_to_empty_vec() {
let json = r#"{"role":"assistant","content":null,"tool_calls":[{"id":"c1","type":"function","function":{"name":"file_write","arguments":"{}"}}]}"#;
let msg: Message = serde_json::from_str(json).unwrap();
assert!(msg.content.is_empty());
assert!(msg.tool_calls.is_some());
}
#[test]
fn test_assistant_message_roundtrip() {
let msg = Message::assistant(Some("Here is the code.".to_string()), None);
let json = serde_json::to_string(&msg).unwrap();
let back: Message = serde_json::from_str(&json).unwrap();
assert_eq!(msg, back);
assert_eq!(back.text_content(), Some("Here is the code.".to_string()));
}
#[test]
fn test_tool_message_roundtrip() {
let msg = Message::tool("call_1", "File written successfully");
let json = serde_json::to_string(&msg).unwrap();
let back: Message = serde_json::from_str(&json).unwrap();
assert_eq!(msg, back);
}
#[test]
fn test_text_content_no_text_parts() {
let msg = Message {
role: Role::User,
content: vec![ContentPart::ImageUrl {
image_url: ImageUrl {
url: "data:image/png;base64,abc".to_string(),
detail: None,
},
}],
tool_calls: None,
tool_call_id: None,
name: None,
};
assert_eq!(msg.text_content(), None);
}
#[test]
fn test_text_content_multiple_parts() {
let msg = Message {
role: Role::User,
content: vec![ContentPart::text("part one"), ContentPart::text("part two")],
tool_calls: None,
tool_call_id: None,
name: None,
};
assert_eq!(msg.text_content(), Some("part one\npart two".to_string()));
}
#[test]
fn test_image_to_content_part_unsupported_ext() {
let result = image_to_content_part(std::path::Path::new("file.bmp"));
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Unsupported image format"));
}
#[test]
fn test_image_to_content_part_no_ext() {
let result = image_to_content_part(std::path::Path::new("noextension"));
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Cannot determine image format"));
}
#[test]
fn test_image_to_content_part_missing_file() {
let result = image_to_content_part(std::path::Path::new("/nonexistent/path/img.png"));
assert!(result.is_err());
}
#[test]
fn test_image_to_content_part_from_disk() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.png");
let png_bytes: &[u8] = &[0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a];
std::fs::write(&path, png_bytes).unwrap();
let part = image_to_content_part(&path).unwrap();
match part {
ContentPart::ImageUrl { image_url } => {
assert!(
image_url.url.starts_with("data:image/png;base64,"),
"Expected data URI prefix, got: {}",
&image_url.url[..50.min(image_url.url.len())]
);
assert!(image_url.detail.is_none());
}
other => panic!("Expected ImageUrl variant, got {:?}", other),
}
}
}