use std::time::Duration;
use reqwest::Client;
use serde_json::{Value, json};
use crate::transport::context::{
ChatMessage, LlmProvider, LlmResponse, ProviderError, TextResponse, ToolCall, ToolDefinition,
};
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
pub struct OpenAiCompatibleProvider {
client: Client,
base_url: String,
api_key: String,
model: String,
temperature: f32,
max_tokens: Option<u32>,
}
impl OpenAiCompatibleProvider {
pub fn new(
api_key: String,
model: String,
base_url: Option<String>,
temperature: f32,
max_tokens: Option<u32>,
timeout_secs: u64,
) -> Result<Self, ProviderError> {
let client = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(ProviderError::Request)?;
Ok(Self {
client,
base_url: base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_string()),
api_key,
model,
temperature,
max_tokens,
})
}
fn serialize_messages(history: &[ChatMessage]) -> Vec<Value> {
let mut messages = Vec::with_capacity(history.len());
for msg in history {
match msg {
ChatMessage::System(text) => {
messages.push(json!({"role": "system", "content": text}));
}
ChatMessage::User(text) => {
messages.push(json!({"role": "user", "content": text}));
}
ChatMessage::AssistantText(text) => {
messages.push(json!({"role": "assistant", "content": text}));
}
ChatMessage::AssistantToolUse { tool_calls } => {
let tc: Vec<Value> = tool_calls
.iter()
.map(|tc| {
let mut obj = json!({
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": serde_json::to_string(&tc.arguments)
.unwrap_or_default(),
}
});
if let Some(meta) = &tc.provider_metadata
&& let Some(map) = obj.as_object_mut()
{
for (k, v) in meta {
map.insert(k.clone(), v.clone());
}
}
obj
})
.collect();
messages.push(json!({"role": "assistant", "tool_calls": tc}));
}
ChatMessage::ToolResult {
tool_call_id,
content,
} => {
let content_str = match content {
Value::String(s) => s.clone(),
other => serde_json::to_string(other).unwrap_or_default(),
};
messages.push(json!({
"role": "tool",
"tool_call_id": tool_call_id,
"content": content_str,
}));
}
}
}
messages
}
fn serialize_tools(tools: &[ToolDefinition]) -> Vec<Value> {
tools
.iter()
.map(|t| {
json!({
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.parameters,
}
})
})
.collect()
}
}
#[async_trait::async_trait]
#[allow(clippy::too_many_lines)]
impl LlmProvider for OpenAiCompatibleProvider {
async fn chat_completion(
&self,
history: &[ChatMessage],
tools: &[ToolDefinition],
) -> Result<LlmResponse, ProviderError> {
let messages = Self::serialize_messages(history);
let mut body = json!({
"model": self.model,
"messages": messages,
"temperature": self.temperature,
});
if let Some(max_tokens) = self.max_tokens {
body["max_completion_tokens"] = json!(max_tokens);
}
if !tools.is_empty() {
body["tools"] = json!(Self::serialize_tools(tools));
}
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
let resp_body = super::retry::send_with_retry(|| {
self.client
.post(&url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
})
.await?;
let choice = resp_body
.get("choices")
.and_then(Value::as_array)
.and_then(|c| c.first())
.ok_or_else(|| ProviderError::Parse("no choices in response".into()))?;
let finish_reason = choice
.get("finish_reason")
.and_then(Value::as_str)
.unwrap_or("stop");
let message = choice
.get("message")
.ok_or_else(|| ProviderError::Parse("no message in choice".into()))?;
if let Some(tool_calls) = message.get("tool_calls").and_then(Value::as_array)
&& !tool_calls.is_empty()
&& finish_reason != "length"
{
let calls: Vec<ToolCall> = tool_calls
.iter()
.filter_map(|tc| {
let id = tc.get("id")?.as_str()?.to_string();
let function = tc.get("function")?;
let name = function.get("name")?.as_str()?.to_string();
let args_str = function
.get("arguments")
.and_then(Value::as_str)
.unwrap_or("{}");
let arguments: Value = serde_json::from_str(args_str).unwrap_or_else(|e| {
let truncated: String = args_str.chars().take(200).collect();
tracing::warn!(
args_preview = %truncated,
args_len = args_str.len(),
error = %e,
"malformed tool call arguments from LLM provider, defaulting to {{}}"
);
json!({})
});
let known_keys: &[&str] = &["id", "type", "function", "index"];
let metadata: serde_json::Map<String, Value> = tc
.as_object()
.map(|obj| {
obj.iter()
.filter(|(k, _)| !known_keys.contains(&k.as_str()))
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
})
.unwrap_or_default();
Some(ToolCall {
id,
name,
arguments,
provider_metadata: if metadata.is_empty() {
None
} else {
Some(metadata)
},
})
})
.collect();
if !calls.is_empty() {
return Ok(LlmResponse::ToolUse(calls));
}
}
let text = message
.get("content")
.and_then(Value::as_str)
.unwrap_or("")
.to_string();
let is_truncated = finish_reason == "length";
if is_truncated {
tracing::warn!(
text_len = text.len(),
"response truncated (finish_reason=length)"
);
}
Ok(LlmResponse::Text(TextResponse { text, is_truncated }))
}
fn provider_name(&self) -> &'static str {
"openai"
}
fn model_name(&self) -> &str {
&self.model
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_messages() {
let history = vec![
ChatMessage::System("You are helpful".into()),
ChatMessage::User("Hello".into()),
ChatMessage::AssistantText("Hi".into()),
];
let messages = OpenAiCompatibleProvider::serialize_messages(&history);
assert_eq!(messages.len(), 3);
assert_eq!(messages[0]["role"], "system");
assert_eq!(messages[1]["role"], "user");
assert_eq!(messages[2]["role"], "assistant");
}
#[test]
fn test_serialize_messages_tool_use() {
let history = vec![
ChatMessage::AssistantToolUse {
tool_calls: vec![ToolCall {
id: "tc1".into(),
name: "search".into(),
arguments: json!({"q": "test"}),
provider_metadata: None,
}],
},
ChatMessage::ToolResult {
tool_call_id: "tc1".into(),
content: json!("result text"),
},
];
let messages = OpenAiCompatibleProvider::serialize_messages(&history);
assert_eq!(messages.len(), 2);
assert!(messages[0]["tool_calls"].is_array());
assert_eq!(messages[1]["role"], "tool");
assert_eq!(messages[1]["tool_call_id"], "tc1");
}
#[test]
fn test_serialize_preserves_provider_metadata() {
let mut meta = serde_json::Map::new();
meta.insert("thought_signature".to_string(), json!("encrypted_sig_xyz"));
let history = vec![ChatMessage::AssistantToolUse {
tool_calls: vec![ToolCall {
id: "tc1".into(),
name: "calc".into(),
arguments: json!({"x": 1}),
provider_metadata: Some(meta),
}],
}];
let messages = OpenAiCompatibleProvider::serialize_messages(&history);
let tc = &messages[0]["tool_calls"][0];
assert_eq!(tc["thought_signature"], "encrypted_sig_xyz");
assert_eq!(tc["id"], "tc1");
assert_eq!(tc["type"], "function");
}
#[test]
fn test_serialize_tools() {
let tools = vec![ToolDefinition {
name: "search".into(),
description: "Search the web".into(),
parameters: json!({"type": "object", "properties": {"q": {"type": "string"}}}),
}];
let serialized = OpenAiCompatibleProvider::serialize_tools(&tools);
assert_eq!(serialized.len(), 1);
assert_eq!(serialized[0]["type"], "function");
assert_eq!(serialized[0]["function"]["name"], "search");
}
}