mod types;
pub use types::*;
use std::collections::VecDeque;
use std::pin::Pin;
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use serde_json::{Value, json};
use crate::error::{Result, TinyAgentsError};
use crate::harness::message::{AssistantMessage, ContentBlock, Message, MessageDelta};
use crate::harness::model::{
ChatModel, Modalities, ModelProfile, ModelRequest, ModelResponse, ModelStatus, ModelStream,
ModelStreamItem, ProviderError, ResponseFormat, ToolChoice,
};
use crate::harness::tool::{ToolCall, ToolDelta};
use crate::harness::usage::Usage;
use super::ProviderSpec;
const DEFAULT_MODEL: &str = "gpt-4.1-mini";
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
pub struct OpenAiModel {
client: reqwest::Client,
api_key: String,
model: String,
provider: String,
base_url: String,
profile: ModelProfile,
}
fn derive_profile(provider: &str, model: &str) -> ModelProfile {
let lower = model.to_ascii_lowercase();
let native_structured = lower.contains("gpt-4o")
|| lower.contains("gpt-4.1")
|| lower.starts_with("o1")
|| lower.starts_with("o3")
|| lower.starts_with("o4");
let reasoning = lower.starts_with("o1") || lower.starts_with("o3") || lower.starts_with("o4");
ModelProfile {
provider: Some(provider.to_string()),
model: Some(model.to_string()),
status: ModelStatus::Stable,
modalities: Modalities {
image_in: true,
..Modalities::default()
},
tool_calling: true,
parallel_tool_calls: true,
streaming: true,
streaming_tool_chunks: true,
native_structured_output: native_structured,
json_schema: true,
reasoning,
..ModelProfile::default()
}
}
impl OpenAiModel {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
model: DEFAULT_MODEL.to_string(),
provider: "openai".to_string(),
base_url: DEFAULT_BASE_URL.to_string(),
profile: derive_profile("openai", DEFAULT_MODEL),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self.profile = derive_profile(&self.provider, &self.model);
self
}
pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
self.provider = provider.into();
self.profile = derive_profile(&self.provider, &self.model);
self
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into().trim_end_matches('/').to_string();
self
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY")
.ok()
.filter(|k| !k.trim().is_empty())
.ok_or_else(|| {
TinyAgentsError::Validation(
"OPENAI_API_KEY is not set; export it or add it to a .env file".to_string(),
)
})?;
let mut model = Self::new(api_key);
if let Ok(name) = std::env::var("OPENAI_MODEL")
&& !name.trim().is_empty()
{
model = model.with_model(name);
}
if let Ok(url) = std::env::var("OPENAI_BASE_URL")
&& !url.trim().is_empty()
{
model = model.with_base_url(url);
}
Ok(model)
}
pub fn from_spec(spec: ProviderSpec, api_key: impl Into<String>) -> Result<Self> {
if spec.model.trim().is_empty() {
return Err(TinyAgentsError::Validation(
"provider spec model must not be empty".to_string(),
));
}
if spec.base_url.trim().is_empty() {
return Err(TinyAgentsError::Validation(
"provider spec base_url must not be empty".to_string(),
));
}
Ok(Self::compatible_provider(
spec.provider,
api_key,
spec.base_url,
spec.model,
))
}
pub fn from_spec_env(spec: ProviderSpec) -> Result<Self> {
let api_key = if spec.requires_api_key {
let env = spec.api_key_env.as_deref().ok_or_else(|| {
TinyAgentsError::Validation(format!(
"{} requires an api_key_env in ProviderSpec",
spec.provider
))
})?;
std::env::var(env)
.ok()
.filter(|k| !k.trim().is_empty())
.ok_or_else(|| {
TinyAgentsError::Validation(format!(
"{env} is not set; export it or provide an explicit API key"
))
})?
} else {
"local".to_string()
};
Self::from_spec(spec, api_key)
}
pub fn compatible(
api_key: impl Into<String>,
base_url: impl Into<String>,
model: impl Into<String>,
) -> Self {
Self::new(api_key).with_base_url(base_url).with_model(model)
}
pub fn compatible_provider(
provider: impl Into<String>,
api_key: impl Into<String>,
base_url: impl Into<String>,
model: impl Into<String>,
) -> Self {
Self::new(api_key)
.with_provider(provider)
.with_base_url(base_url)
.with_model(model)
}
pub fn deepseek(api_key: impl Into<String>) -> Self {
Self::compatible_provider(
"deepseek",
api_key,
"https://api.deepseek.com/v1",
"deepseek-chat",
)
}
pub fn anthropic(api_key: impl Into<String>) -> Self {
Self::compatible_provider(
"anthropic",
api_key,
"https://api.anthropic.com/v1",
"claude-3-5-sonnet-latest",
)
}
pub fn groq(api_key: impl Into<String>) -> Self {
Self::compatible_provider(
"groq",
api_key,
"https://api.groq.com/openai/v1",
"llama-3.3-70b-versatile",
)
}
pub fn xai(api_key: impl Into<String>) -> Self {
Self::compatible_provider("xai", api_key, "https://api.x.ai/v1", "grok-2-latest")
}
pub fn openrouter(api_key: impl Into<String>) -> Self {
Self::compatible_provider(
"openrouter",
api_key,
"https://openrouter.ai/api/v1",
"openai/gpt-4o-mini",
)
}
pub fn together(api_key: impl Into<String>) -> Self {
Self::compatible_provider(
"together",
api_key,
"https://api.together.xyz/v1",
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
)
}
pub fn mistral(api_key: impl Into<String>) -> Self {
Self::compatible_provider(
"mistral",
api_key,
"https://api.mistral.ai/v1",
"mistral-small-latest",
)
}
pub fn ollama() -> Self {
Self::compatible_provider("ollama", "ollama", "http://localhost:11434/v1", "llama3.2")
}
pub fn model(&self) -> &str {
&self.model
}
pub fn provider(&self) -> &str {
&self.provider
}
pub fn base_url(&self) -> &str {
&self.base_url
}
fn translate_request(&self, request: &ModelRequest) -> Result<ChatCompletionRequest> {
let messages = request
.messages
.iter()
.map(translate_message)
.collect::<Result<Vec<_>>>()?;
let tools: Vec<ToolWire> = request
.tools
.iter()
.map(|schema| ToolWire {
kind: "function".to_string(),
function: FunctionSchemaWire {
name: schema.name.clone(),
description: schema.description.clone(),
parameters: schema.parameters.clone(),
},
})
.collect();
let tool_choice = if tools.is_empty() {
None
} else {
Some(translate_tool_choice(&request.tool_choice))
};
let response_format = request
.response_format
.as_ref()
.and_then(translate_response_format);
Ok(ChatCompletionRequest {
model: request.model.clone().unwrap_or_else(|| self.model.clone()),
messages,
tools,
tool_choice,
response_format,
temperature: request.temperature,
max_tokens: request.max_tokens,
stream: false,
stream_options: None,
})
}
fn provider_error(
&self,
message: impl Into<String>,
status: Option<u16>,
code: Option<String>,
raw: Option<Value>,
) -> ProviderError {
let retryable = status.is_some_and(|s| s == 408 || s == 409 || s == 429 || s >= 500);
ProviderError {
provider: self.provider.clone(),
model: Some(self.model.clone()),
status,
code,
message: message.into(),
retryable,
raw,
}
}
fn provider_failure_message(&self, error: &ProviderError) -> String {
format!(
"{} returned{}{}: {}",
error.provider,
error
.status
.map(|status| format!(" HTTP {status}"))
.unwrap_or_default(),
error
.code
.as_deref()
.map(|code| format!(" ({code})"))
.unwrap_or_default(),
error.message
)
}
fn parse_error_body(&self, status: u16, text: &str) -> ProviderError {
let raw = serde_json::from_str::<Value>(text).ok();
let error_obj = raw.as_ref().and_then(|value| value.get("error"));
let message = error_obj
.and_then(|error| error.get("message"))
.and_then(Value::as_str)
.or_else(|| {
raw.as_ref()
.and_then(|value| value.get("message"))
.and_then(Value::as_str)
})
.filter(|message| !message.trim().is_empty())
.unwrap_or(text)
.to_string();
let code = error_obj
.and_then(|error| error.get("code").or_else(|| error.get("type")))
.and_then(Value::as_str)
.map(str::to_string);
self.provider_error(message, Some(status), code, raw)
}
}
fn translate_message(message: &Message) -> Result<ChatMessageWire> {
let wire = match message {
Message::System(_) => ChatMessageWire {
role: "system".to_string(),
content: Some(message.text()),
tool_calls: Vec::new(),
tool_call_id: None,
},
Message::User(_) => ChatMessageWire {
role: "user".to_string(),
content: Some(message.text()),
tool_calls: Vec::new(),
tool_call_id: None,
},
Message::Assistant(assistant) => {
let text = message.text();
let content = if text.is_empty() && !assistant.tool_calls.is_empty() {
None
} else {
Some(text)
};
let tool_calls = assistant
.tool_calls
.iter()
.map(|call| {
Ok(ToolCallWire {
id: call.id.clone(),
kind: "function".to_string(),
function: FunctionCallWire {
name: call.name.clone(),
arguments: serde_json::to_string(&call.arguments)?,
},
})
})
.collect::<Result<Vec<_>>>()?;
ChatMessageWire {
role: "assistant".to_string(),
content,
tool_calls,
tool_call_id: None,
}
}
Message::Tool(tool) => ChatMessageWire {
role: "tool".to_string(),
content: Some(message.text()),
tool_calls: Vec::new(),
tool_call_id: Some(tool.tool_call_id.clone()),
},
};
Ok(wire)
}
fn translate_tool_choice(choice: &ToolChoice) -> Value {
match choice {
ToolChoice::Auto => json!("auto"),
ToolChoice::None => json!("none"),
ToolChoice::Required => json!("required"),
ToolChoice::Tool(name) => json!({
"type": "function",
"function": { "name": name }
}),
}
}
fn translate_response_format(format: &ResponseFormat) -> Option<Value> {
match format {
ResponseFormat::Text => None,
ResponseFormat::JsonObject => Some(json!({ "type": "json_object" })),
ResponseFormat::JsonSchema { name, schema } | ResponseFormat::Auto { name, schema } => {
Some(json!({
"type": "json_schema",
"json_schema": {
"name": name,
"schema": schema,
"strict": true,
}
}))
}
}
}
fn parse_response(value: Value) -> Result<ModelResponse> {
let parsed: ChatCompletionResponse = serde_json::from_value(value.clone())?;
let choice = parsed.choices.into_iter().next().ok_or_else(|| {
TinyAgentsError::Model("openai response contained no choices".to_string())
})?;
let mut content = Vec::new();
if let Some(text) = choice.message.content.filter(|t| !t.is_empty()) {
content.push(ContentBlock::Text(text));
}
let tool_calls = choice
.message
.tool_calls
.into_iter()
.map(|call| ToolCall {
id: call.id,
name: call.function.name,
arguments: serde_json::from_str(&call.function.arguments).unwrap_or(Value::Null),
})
.collect();
let usage = parsed.usage.map(convert_usage);
let message = AssistantMessage {
id: parsed.id,
content,
tool_calls,
usage,
};
Ok(ModelResponse {
message,
usage,
finish_reason: choice.finish_reason,
raw: Some(value),
resolved_model: None,
})
}
fn convert_usage(wire: UsageWire) -> Usage {
Usage {
input_tokens: wire.prompt_tokens,
output_tokens: wire.completion_tokens,
total_tokens: wire.total_tokens,
cache_read_tokens: wire
.prompt_tokens_details
.map(|d| d.cached_tokens)
.unwrap_or(0),
..Usage::default()
}
}
#[derive(Clone, Debug, Default)]
struct ToolCallBuild {
id: String,
name: String,
args: String,
}
#[derive(Clone, Debug, Default)]
struct OpenAiStreamAcc {
id: Option<String>,
text: String,
tool_calls: Vec<ToolCallBuild>,
usage: Option<Usage>,
finish_reason: Option<String>,
}
impl OpenAiStreamAcc {
fn ingest(&mut self, chunk: ChatCompletionChunk, pending: &mut VecDeque<ModelStreamItem>) {
if let Some(id) = chunk.id
&& self.id.is_none()
{
self.id = Some(id);
}
if let Some(usage_wire) = chunk.usage {
let usage = convert_usage(usage_wire);
self.usage = Some(usage);
pending.push_back(ModelStreamItem::UsageDelta(usage));
}
for choice in chunk.choices {
if let Some(reason) = choice.finish_reason {
self.finish_reason = Some(reason);
}
if let Some(content) = choice.delta.content.filter(|c| !c.is_empty()) {
self.text.push_str(&content);
pending.push_back(ModelStreamItem::MessageDelta(MessageDelta {
text: content,
tool_call: None,
}));
}
for fragment in choice.delta.tool_calls {
let idx = fragment.index as usize;
while self.tool_calls.len() <= idx {
self.tool_calls.push(ToolCallBuild::default());
}
let slot = &mut self.tool_calls[idx];
if let Some(id) = fragment.id.filter(|id| !id.is_empty()) {
slot.id = id;
}
if let Some(function) = fragment.function {
if let Some(name) = function.name.filter(|n| !n.is_empty()) {
slot.name = name;
}
if let Some(args) = function.arguments.filter(|a| !a.is_empty()) {
slot.args.push_str(&args);
let call_id = if slot.id.is_empty() {
format!("tool-{idx}")
} else {
slot.id.clone()
};
pending.push_back(ModelStreamItem::ToolCallDelta(ToolDelta {
call_id,
content: args,
}));
}
}
}
}
}
fn into_response(self) -> ModelResponse {
let mut content = Vec::new();
if !self.text.is_empty() {
content.push(ContentBlock::Text(self.text));
}
let tool_calls = self
.tool_calls
.into_iter()
.filter(|b| !b.name.is_empty() || !b.args.is_empty())
.enumerate()
.map(|(idx, b)| ToolCall {
id: if b.id.is_empty() {
format!("tool-{idx}")
} else {
b.id
},
name: b.name,
arguments: serde_json::from_str(&b.args).unwrap_or(Value::Null),
})
.collect();
let message = AssistantMessage {
id: self.id,
content,
tool_calls,
usage: self.usage,
};
ModelResponse {
message,
usage: self.usage,
finish_reason: self.finish_reason,
raw: None,
resolved_model: None,
}
}
}
struct SseState {
bytes: Pin<Box<dyn Stream<Item = Result<Vec<u8>>> + Send>>,
buf: String,
pending: VecDeque<ModelStreamItem>,
acc: OpenAiStreamAcc,
provider: String,
model: String,
started: bool,
finished: bool,
terminal_emitted: bool,
}
impl SseState {
fn drain_lines(&mut self) {
while let Some(pos) = self.buf.find('\n') {
let line: String = self.buf.drain(..=pos).collect();
let line = line.trim();
if line.is_empty() {
continue;
}
let Some(rest) = line.strip_prefix("data:") else {
continue;
};
let payload = rest.trim();
if payload == "[DONE]" {
self.finished = true;
continue;
}
if let Ok(chunk) = serde_json::from_str::<ChatCompletionChunk>(payload) {
let mut pending = std::mem::take(&mut self.pending);
self.acc.ingest(chunk, &mut pending);
self.pending = pending;
}
}
}
}
async fn sse_next(mut state: SseState) -> Option<(ModelStreamItem, SseState)> {
loop {
if let Some(item) = state.pending.pop_front() {
return Some((item, state));
}
if !state.started {
state.started = true;
return Some((ModelStreamItem::Started, state));
}
if state.finished {
if state.terminal_emitted {
return None;
}
state.terminal_emitted = true;
let response = std::mem::take(&mut state.acc).into_response();
return Some((ModelStreamItem::Completed(response), state));
}
match state.bytes.next().await {
Some(Ok(chunk)) => {
state.buf.push_str(&String::from_utf8_lossy(&chunk));
state.drain_lines();
}
Some(Err(error)) => {
state.finished = true;
state.terminal_emitted = true;
let provider_error = ProviderError {
provider: state.provider.clone(),
model: Some(state.model.clone()),
message: error.to_string(),
retryable: true,
..ProviderError::default()
};
return Some((ModelStreamItem::ProviderFailed(provider_error), state));
}
None => {
state.finished = true;
}
}
}
}
#[async_trait]
impl<State: Send + Sync> ChatModel<State> for OpenAiModel {
fn profile(&self) -> Option<&ModelProfile> {
Some(&self.profile)
}
async fn invoke(&self, _state: &State, request: ModelRequest) -> Result<ModelResponse> {
let body = self.translate_request(&request)?;
let url = format!("{}/chat/completions", self.base_url);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| {
let error =
self.provider_error(format!("request to {url} failed: {e}"), None, None, None);
TinyAgentsError::Model(self.provider_failure_message(&error))
})?;
let status = response.status();
let text = response.text().await.map_err(|e| {
TinyAgentsError::Model(format!("openai response body read failed: {e}"))
})?;
if !status.is_success() {
let error = self.parse_error_body(status.as_u16(), &text);
return Err(TinyAgentsError::Model(
self.provider_failure_message(&error),
));
}
let value: Value = serde_json::from_str(&text)?;
parse_response(value)
}
async fn stream(&self, _state: &State, request: ModelRequest) -> Result<ModelStream> {
let mut body = self.translate_request(&request)?;
body.stream = true;
body.stream_options = Some(json!({ "include_usage": true }));
let url = format!("{}/chat/completions", self.base_url);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| {
let error = self.provider_error(
format!("stream request to {url} failed: {e}"),
None,
None,
None,
);
TinyAgentsError::Model(self.provider_failure_message(&error))
})?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
let error = self.parse_error_body(status.as_u16(), &text);
return Err(TinyAgentsError::Model(
self.provider_failure_message(&error),
));
}
let bytes = response.bytes_stream().map(|chunk| {
chunk
.map(|b| b.to_vec())
.map_err(|e| TinyAgentsError::Model(format!("stream chunk failed: {e}")))
});
let state = SseState {
bytes: Box::pin(bytes),
buf: String::new(),
pending: VecDeque::new(),
acc: OpenAiStreamAcc::default(),
provider: self.provider.clone(),
model: self.model.clone(),
started: false,
finished: false,
terminal_emitted: false,
};
Ok(Box::pin(futures::stream::unfold(state, sse_next)))
}
}
#[cfg(test)]
mod test;