mod types;
pub use types::*;
use async_trait::async_trait;
use serde_json::{Value, json};
use crate::error::{Result, RustAgentsError};
use crate::harness::message::{AssistantMessage, ContentBlock, Message};
use crate::harness::model::{ChatModel, ModelRequest, ModelResponse, ResponseFormat, ToolChoice};
use crate::harness::tool::ToolCall;
use crate::harness::usage::Usage;
const DEFAULT_MODEL: &str = "gpt-4.1-mini";
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
pub struct OpenAiModel {
client: reqwest::Client,
api_key: String,
model: String,
base_url: String,
}
impl OpenAiModel {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
model: DEFAULT_MODEL.to_string(),
base_url: DEFAULT_BASE_URL.to_string(),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into().trim_end_matches('/').to_string();
self
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY")
.ok()
.filter(|k| !k.trim().is_empty())
.ok_or_else(|| {
RustAgentsError::Validation(
"OPENAI_API_KEY is not set; export it or add it to a .env file".to_string(),
)
})?;
let mut model = Self::new(api_key);
if let Ok(name) = std::env::var("OPENAI_MODEL")
&& !name.trim().is_empty()
{
model = model.with_model(name);
}
if let Ok(url) = std::env::var("OPENAI_BASE_URL")
&& !url.trim().is_empty()
{
model = model.with_base_url(url);
}
Ok(model)
}
pub fn compatible(
api_key: impl Into<String>,
base_url: impl Into<String>,
model: impl Into<String>,
) -> Self {
Self::new(api_key).with_base_url(base_url).with_model(model)
}
pub fn deepseek(api_key: impl Into<String>) -> Self {
Self::compatible(api_key, "https://api.deepseek.com/v1", "deepseek-chat")
}
pub fn anthropic(api_key: impl Into<String>) -> Self {
Self::compatible(
api_key,
"https://api.anthropic.com/v1",
"claude-3-5-sonnet-latest",
)
}
pub fn groq(api_key: impl Into<String>) -> Self {
Self::compatible(
api_key,
"https://api.groq.com/openai/v1",
"llama-3.3-70b-versatile",
)
}
pub fn xai(api_key: impl Into<String>) -> Self {
Self::compatible(api_key, "https://api.x.ai/v1", "grok-2-latest")
}
pub fn openrouter(api_key: impl Into<String>) -> Self {
Self::compatible(
api_key,
"https://openrouter.ai/api/v1",
"openai/gpt-4o-mini",
)
}
pub fn together(api_key: impl Into<String>) -> Self {
Self::compatible(
api_key,
"https://api.together.xyz/v1",
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
)
}
pub fn mistral(api_key: impl Into<String>) -> Self {
Self::compatible(api_key, "https://api.mistral.ai/v1", "mistral-small-latest")
}
pub fn ollama() -> Self {
Self::compatible("ollama", "http://localhost:11434/v1", "llama3.2")
}
pub fn model(&self) -> &str {
&self.model
}
pub fn base_url(&self) -> &str {
&self.base_url
}
fn translate_request(&self, request: &ModelRequest) -> Result<ChatCompletionRequest> {
let messages = request
.messages
.iter()
.map(translate_message)
.collect::<Result<Vec<_>>>()?;
let tools: Vec<ToolWire> = request
.tools
.iter()
.map(|schema| ToolWire {
kind: "function".to_string(),
function: FunctionSchemaWire {
name: schema.name.clone(),
description: schema.description.clone(),
parameters: schema.parameters.clone(),
},
})
.collect();
let tool_choice = if tools.is_empty() {
None
} else {
Some(translate_tool_choice(&request.tool_choice))
};
let response_format = request
.response_format
.as_ref()
.and_then(translate_response_format);
Ok(ChatCompletionRequest {
model: request.model.clone().unwrap_or_else(|| self.model.clone()),
messages,
tools,
tool_choice,
response_format,
temperature: request.temperature,
max_tokens: request.max_tokens,
})
}
}
fn translate_message(message: &Message) -> Result<ChatMessageWire> {
let wire = match message {
Message::System(_) => ChatMessageWire {
role: "system".to_string(),
content: Some(message.text()),
tool_calls: Vec::new(),
tool_call_id: None,
},
Message::User(_) => ChatMessageWire {
role: "user".to_string(),
content: Some(message.text()),
tool_calls: Vec::new(),
tool_call_id: None,
},
Message::Assistant(assistant) => {
let text = message.text();
let content = if text.is_empty() && !assistant.tool_calls.is_empty() {
None
} else {
Some(text)
};
let tool_calls = assistant
.tool_calls
.iter()
.map(|call| {
Ok(ToolCallWire {
id: call.id.clone(),
kind: "function".to_string(),
function: FunctionCallWire {
name: call.name.clone(),
arguments: serde_json::to_string(&call.arguments)?,
},
})
})
.collect::<Result<Vec<_>>>()?;
ChatMessageWire {
role: "assistant".to_string(),
content,
tool_calls,
tool_call_id: None,
}
}
Message::Tool(tool) => ChatMessageWire {
role: "tool".to_string(),
content: Some(message.text()),
tool_calls: Vec::new(),
tool_call_id: Some(tool.tool_call_id.clone()),
},
};
Ok(wire)
}
fn translate_tool_choice(choice: &ToolChoice) -> Value {
match choice {
ToolChoice::Auto => json!("auto"),
ToolChoice::None => json!("none"),
ToolChoice::Required => json!("required"),
ToolChoice::Tool(name) => json!({
"type": "function",
"function": { "name": name }
}),
}
}
fn translate_response_format(format: &ResponseFormat) -> Option<Value> {
match format {
ResponseFormat::Text => None,
ResponseFormat::JsonObject => Some(json!({ "type": "json_object" })),
ResponseFormat::JsonSchema { name, schema } => Some(json!({
"type": "json_schema",
"json_schema": {
"name": name,
"schema": schema,
"strict": true,
}
})),
}
}
fn parse_response(value: Value) -> Result<ModelResponse> {
let parsed: ChatCompletionResponse = serde_json::from_value(value.clone())?;
let choice = parsed.choices.into_iter().next().ok_or_else(|| {
RustAgentsError::Model("openai response contained no choices".to_string())
})?;
let mut content = Vec::new();
if let Some(text) = choice.message.content.filter(|t| !t.is_empty()) {
content.push(ContentBlock::Text(text));
}
let tool_calls = choice
.message
.tool_calls
.into_iter()
.map(|call| ToolCall {
id: call.id,
name: call.function.name,
arguments: serde_json::from_str(&call.function.arguments).unwrap_or(Value::Null),
})
.collect();
let usage = parsed.usage.map(|u| Usage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
cache_read_tokens: u
.prompt_tokens_details
.map(|d| d.cached_tokens)
.unwrap_or(0),
..Usage::default()
});
let message = AssistantMessage {
id: parsed.id,
content,
tool_calls,
usage,
};
Ok(ModelResponse {
message,
usage,
finish_reason: choice.finish_reason,
raw: Some(value),
resolved_model: None,
})
}
#[async_trait]
impl<State: Send + Sync> ChatModel<State> for OpenAiModel {
async fn invoke(&self, _state: &State, request: ModelRequest) -> Result<ModelResponse> {
let body = self.translate_request(&request)?;
let url = format!("{}/chat/completions", self.base_url);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| RustAgentsError::Model(format!("openai request to {url} failed: {e}")))?;
let status = response.status();
let text = response.text().await.map_err(|e| {
RustAgentsError::Model(format!("openai response body read failed: {e}"))
})?;
if !status.is_success() {
return Err(RustAgentsError::Model(format!(
"openai returned HTTP {status}: {text}"
)));
}
let value: Value = serde_json::from_str(&text)?;
parse_response(value)
}
}
#[cfg(test)]
mod test;