use crate::tools::ToolSpec;
use async_trait::async_trait;
use futures_util::{stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".into(),
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".into(),
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".into(),
content: content.into(),
}
}
pub fn tool(content: impl Into<String>) -> Self {
Self {
role: "tool".into(),
content: content.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Default)]
pub struct TokenUsage {
pub input_tokens: Option<u64>,
pub output_tokens: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct ChatResponse {
pub text: Option<String>,
pub tool_calls: Vec<ToolCall>,
pub usage: Option<TokenUsage>,
pub reasoning_content: Option<String>,
}
impl ChatResponse {
pub fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
pub fn text_or_empty(&self) -> &str {
self.text.as_deref().unwrap_or("")
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChatRequest<'a> {
pub messages: &'a [ChatMessage],
pub tools: Option<&'a [ToolSpec]>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultMessage {
pub tool_call_id: String,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum ConversationMessage {
Chat(ChatMessage),
AssistantToolCalls {
text: Option<String>,
tool_calls: Vec<ToolCall>,
reasoning_content: Option<String>,
},
ToolResults(Vec<ToolResultMessage>),
}
#[derive(Debug, Clone)]
pub struct StreamChunk {
pub delta: String,
pub is_final: bool,
pub token_count: usize,
}
impl StreamChunk {
pub fn delta(text: impl Into<String>) -> Self {
Self {
delta: text.into(),
is_final: false,
token_count: 0,
}
}
pub fn final_chunk() -> Self {
Self {
delta: String::new(),
is_final: true,
token_count: 0,
}
}
pub fn error(message: impl Into<String>) -> Self {
Self {
delta: message.into(),
is_final: true,
token_count: 0,
}
}
pub fn with_token_estimate(mut self) -> Self {
self.token_count = self.delta.len().div_ceil(4);
self
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct StreamOptions {
pub enabled: bool,
pub count_tokens: bool,
}
impl StreamOptions {
pub fn new(enabled: bool) -> Self {
Self {
enabled,
count_tokens: false,
}
}
pub fn with_token_count(mut self) -> Self {
self.count_tokens = true;
self
}
}
pub type StreamResult<T> = std::result::Result<T, StreamError>;
#[derive(Debug, thiserror::Error)]
pub enum StreamError {
#[error("HTTP error: {0}")]
Http(reqwest::Error),
#[error("JSON parse error: {0}")]
Json(serde_json::Error),
#[error("Invalid SSE format: {0}")]
InvalidSse(String),
#[error("Provider error: {0}")]
Provider(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
#[derive(Debug, Clone, thiserror::Error)]
#[error("provider_capability_error provider={provider} capability={capability} message={message}")]
pub struct ProviderCapabilityError {
pub provider: String,
pub capability: String,
pub message: String,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ProviderCapabilities {
pub native_tool_calling: bool,
pub vision: bool,
}
#[derive(Debug, Clone)]
pub enum ToolsPayload {
Gemini {
function_declarations: Vec<serde_json::Value>,
},
Anthropic { tools: Vec<serde_json::Value> },
OpenAI { tools: Vec<serde_json::Value> },
PromptGuided { instructions: String },
}
#[async_trait]
pub trait Provider: Send + Sync {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::default()
}
fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
ToolsPayload::PromptGuided {
instructions: build_tool_instructions_text(tools),
}
}
async fn simple_chat(
&self,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
self.chat_with_system(None, message, model, temperature)
.await
}
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String>;
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let system = messages
.iter()
.find(|m| m.role == "system")
.map(|m| m.content.as_str());
let last_user = messages
.iter()
.rfind(|m| m.role == "user")
.map(|m| m.content.as_str())
.unwrap_or("");
self.chat_with_system(system, last_user, model, temperature)
.await
}
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
if let Some(tools) = request.tools {
if !tools.is_empty() && !self.supports_native_tools() {
let tool_instructions = match self.convert_tools(tools) {
ToolsPayload::PromptGuided { instructions } => instructions,
payload => {
anyhow::bail!(
"Provider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false"
)
}
};
let mut modified_messages = request.messages.to_vec();
if let Some(system_message) =
modified_messages.iter_mut().find(|m| m.role == "system")
{
if !system_message.content.is_empty() {
system_message.content.push_str("\n\n");
}
system_message.content.push_str(&tool_instructions);
} else {
modified_messages.insert(0, ChatMessage::system(tool_instructions));
}
let text = self
.chat_with_history(&modified_messages, model, temperature)
.await?;
return Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
usage: None,
reasoning_content: None,
});
}
}
let text = self
.chat_with_history(request.messages, model, temperature)
.await?;
Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
usage: None,
reasoning_content: None,
})
}
fn supports_native_tools(&self) -> bool {
self.capabilities().native_tool_calling
}
fn supports_vision(&self) -> bool {
self.capabilities().vision
}
async fn warmup(&self) -> anyhow::Result<()> {
Ok(())
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
_tools: &[serde_json::Value],
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let text = self.chat_with_history(messages, model, temperature).await?;
Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
usage: None,
reasoning_content: None,
})
}
fn supports_streaming(&self) -> bool {
false
}
fn stream_chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
_options: StreamOptions,
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
stream::empty().boxed()
}
fn stream_chat_with_history(
&self,
_messages: &[ChatMessage],
_model: &str,
_temperature: f64,
_options: StreamOptions,
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
let provider_name = "unknown".to_string();
let chunk = StreamChunk::error(format!("{} does not support streaming", provider_name));
stream::once(async move { Ok(chunk) }).boxed()
}
}
pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String {
let mut instructions = String::new();
instructions.push_str("## Tool Use Protocol\n\n");
instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
instructions.push_str("<tool_call>\n");
instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#);
instructions.push_str("\n</tool_call>\n\n");
instructions.push_str("You may use multiple tool calls in a single response. ");
instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
instructions
.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
instructions.push_str("### Available Tools\n\n");
for tool in tools {
writeln!(&mut instructions, "**{}**: {}", tool.name, tool.description)
.expect("writing to String cannot fail");
let parameters =
serde_json::to_string(&tool.parameters).unwrap_or_else(|_| "{}".to_string());
writeln!(&mut instructions, "Parameters: `{parameters}`")
.expect("writing to String cannot fail");
instructions.push('\n');
}
instructions
}
#[cfg(test)]
mod tests {
use super::*;
struct CapabilityMockProvider;
#[async_trait]
impl Provider for CapabilityMockProvider {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
native_tool_calling: true,
vision: true,
}
}
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("ok".into())
}
}
#[test]
fn chat_message_constructors() {
let sys = ChatMessage::system("Be helpful");
assert_eq!(sys.role, "system");
assert_eq!(sys.content, "Be helpful");
let user = ChatMessage::user("Hello");
assert_eq!(user.role, "user");
let asst = ChatMessage::assistant("Hi there");
assert_eq!(asst.role, "assistant");
let tool = ChatMessage::tool("{}");
assert_eq!(tool.role, "tool");
}
#[test]
fn chat_response_helpers() {
let empty = ChatResponse {
text: None,
tool_calls: vec![],
usage: None,
reasoning_content: None,
};
assert!(!empty.has_tool_calls());
assert_eq!(empty.text_or_empty(), "");
let with_tools = ChatResponse {
text: Some("Let me check".into()),
tool_calls: vec![ToolCall {
id: "1".into(),
name: "shell".into(),
arguments: "{}".into(),
}],
usage: None,
reasoning_content: None,
};
assert!(with_tools.has_tool_calls());
assert_eq!(with_tools.text_or_empty(), "Let me check");
}
#[test]
fn token_usage_default_is_none() {
let usage = TokenUsage::default();
assert!(usage.input_tokens.is_none());
assert!(usage.output_tokens.is_none());
}
#[test]
fn chat_response_with_usage() {
let resp = ChatResponse {
text: Some("Hello".into()),
tool_calls: vec![],
usage: Some(TokenUsage {
input_tokens: Some(100),
output_tokens: Some(50),
}),
reasoning_content: None,
};
assert_eq!(resp.usage.as_ref().unwrap().input_tokens, Some(100));
assert_eq!(resp.usage.as_ref().unwrap().output_tokens, Some(50));
}
#[test]
fn tool_call_serialization() {
let tc = ToolCall {
id: "call_123".into(),
name: "file_read".into(),
arguments: r#"{"path":"test.txt"}"#.into(),
};
let json = serde_json::to_string(&tc).unwrap();
assert!(json.contains("call_123"));
assert!(json.contains("file_read"));
}
#[test]
fn conversation_message_variants() {
let chat = ConversationMessage::Chat(ChatMessage::user("hi"));
let json = serde_json::to_string(&chat).unwrap();
assert!(json.contains("\"type\":\"Chat\""));
let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage {
tool_call_id: "1".into(),
content: "done".into(),
}]);
let json = serde_json::to_string(&tool_result).unwrap();
assert!(json.contains("\"type\":\"ToolResults\""));
}
#[test]
fn provider_capabilities_default() {
let caps = ProviderCapabilities::default();
assert!(!caps.native_tool_calling);
assert!(!caps.vision);
}
#[test]
fn provider_capabilities_equality() {
let caps1 = ProviderCapabilities {
native_tool_calling: true,
vision: false,
};
let caps2 = ProviderCapabilities {
native_tool_calling: true,
vision: false,
};
let caps3 = ProviderCapabilities {
native_tool_calling: false,
vision: false,
};
assert_eq!(caps1, caps2);
assert_ne!(caps1, caps3);
}
#[test]
fn supports_native_tools_reflects_capabilities_default_mapping() {
let provider = CapabilityMockProvider;
assert!(provider.supports_native_tools());
}
#[test]
fn supports_vision_reflects_capabilities_default_mapping() {
let provider = CapabilityMockProvider;
assert!(provider.supports_vision());
}
#[test]
fn tools_payload_variants() {
let gemini = ToolsPayload::Gemini {
function_declarations: vec![serde_json::json!({"name": "test"})],
};
assert!(matches!(gemini, ToolsPayload::Gemini { .. }));
let anthropic = ToolsPayload::Anthropic {
tools: vec![serde_json::json!({"name": "test"})],
};
assert!(matches!(anthropic, ToolsPayload::Anthropic { .. }));
let openai = ToolsPayload::OpenAI {
tools: vec![serde_json::json!({"type": "function"})],
};
assert!(matches!(openai, ToolsPayload::OpenAI { .. }));
let prompt_guided = ToolsPayload::PromptGuided {
instructions: "Use tools...".to_string(),
};
assert!(matches!(prompt_guided, ToolsPayload::PromptGuided { .. }));
}
#[test]
fn build_tool_instructions_text_format() {
let tools = vec![
ToolSpec {
name: "shell".to_string(),
description: "Execute commands".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"command": {"type": "string"}
}
}),
},
ToolSpec {
name: "file_read".to_string(),
description: "Read files".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"path": {"type": "string"}
}
}),
},
];
let instructions = build_tool_instructions_text(&tools);
assert!(instructions.contains("Tool Use Protocol"));
assert!(instructions.contains("<tool_call>"));
assert!(instructions.contains("</tool_call>"));
assert!(instructions.contains("**shell**"));
assert!(instructions.contains("Execute commands"));
assert!(instructions.contains("**file_read**"));
assert!(instructions.contains("Read files"));
assert!(instructions.contains("Parameters:"));
assert!(instructions.contains(r#""type":"object""#));
}
#[test]
fn build_tool_instructions_text_empty() {
let instructions = build_tool_instructions_text(&[]);
assert!(instructions.contains("Tool Use Protocol"));
assert!(instructions.contains("Available Tools"));
}
struct MockProvider {
supports_native: bool,
}
#[async_trait]
impl Provider for MockProvider {
fn supports_native_tools(&self) -> bool {
self.supports_native
}
async fn chat_with_system(
&self,
_system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("response".to_string())
}
}
#[test]
fn provider_convert_tools_default() {
let provider = MockProvider {
supports_native: false,
};
let tools = vec![ToolSpec {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let payload = provider.convert_tools(&tools);
assert!(matches!(payload, ToolsPayload::PromptGuided { .. }));
if let ToolsPayload::PromptGuided { instructions } = payload {
assert!(instructions.contains("test_tool"));
assert!(instructions.contains("A test tool"));
}
}
#[tokio::test]
async fn provider_chat_prompt_guided_fallback() {
let provider = MockProvider {
supports_native: false,
};
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[ChatMessage::user("Hello")],
tools: Some(&tools),
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
assert!(response.text.is_some());
}
#[tokio::test]
async fn provider_chat_without_tools() {
let provider = MockProvider {
supports_native: true,
};
let request = ChatRequest {
messages: &[ChatMessage::user("Hello")],
tools: None,
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
assert!(response.text.is_some());
}
struct EchoSystemProvider {
supports_native: bool,
}
#[async_trait]
impl Provider for EchoSystemProvider {
fn supports_native_tools(&self) -> bool {
self.supports_native
}
async fn chat_with_system(
&self,
system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok(system.unwrap_or_default().to_string())
}
}
struct CustomConvertProvider;
#[async_trait]
impl Provider for CustomConvertProvider {
fn supports_native_tools(&self) -> bool {
false
}
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
ToolsPayload::PromptGuided {
instructions: "CUSTOM_TOOL_INSTRUCTIONS".to_string(),
}
}
async fn chat_with_system(
&self,
system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok(system.unwrap_or_default().to_string())
}
}
struct InvalidConvertProvider;
#[async_trait]
impl Provider for InvalidConvertProvider {
fn supports_native_tools(&self) -> bool {
false
}
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
ToolsPayload::OpenAI {
tools: vec![serde_json::json!({"type": "function"})],
}
}
async fn chat_with_system(
&self,
_system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("should_not_reach".to_string())
}
}
#[tokio::test]
async fn provider_chat_prompt_guided_preserves_existing_system_not_first() {
let provider = EchoSystemProvider {
supports_native: false,
};
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[
ChatMessage::user("Hello"),
ChatMessage::system("BASE_SYSTEM_PROMPT"),
],
tools: Some(&tools),
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
let text = response.text.unwrap_or_default();
assert!(text.contains("BASE_SYSTEM_PROMPT"));
assert!(text.contains("Tool Use Protocol"));
}
#[tokio::test]
async fn provider_chat_prompt_guided_uses_convert_tools_override() {
let provider = CustomConvertProvider;
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[ChatMessage::system("BASE"), ChatMessage::user("Hello")],
tools: Some(&tools),
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
let text = response.text.unwrap_or_default();
assert!(text.contains("BASE"));
assert!(text.contains("CUSTOM_TOOL_INSTRUCTIONS"));
}
#[tokio::test]
async fn provider_chat_prompt_guided_rejects_non_prompt_payload() {
let provider = InvalidConvertProvider;
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[ChatMessage::user("Hello")],
tools: Some(&tools),
};
let err = provider.chat(request, "model", 0.7).await.unwrap_err();
let message = err.to_string();
assert!(message.contains("non-prompt-guided"));
}
}