use reqwest::blocking::Client;
use serde_json::{json, Value};
use crate::{
EmbeddingModel, EmbeddingResult, Error, FinishReason, LanguageModel, ModelMessage,
ModelRequest, ModelResponse, Part, ProviderRegistration, Result, Role, ToolChoice,
ToolDefinition, ToolSchema, Usage,
};
#[derive(Clone, Debug)]
pub struct OpenAiLanguageModel {
model_id: String,
}
#[derive(Clone, Debug)]
pub struct OpenAiEmbeddingModel {
model_id: String,
}
impl LanguageModel for OpenAiLanguageModel {
fn model_id(&self) -> &str {
&self.model_id
}
fn generate(&self, request: &ModelRequest) -> Result<ModelResponse> {
let (status, body) = openai_post_json(
"https://api.openai.com/v1/chat/completions",
openai_chat_request(&self.model_id, request),
)?;
if !(200..300).contains(&status) {
return Err(Error::Api(openai_error_message(&body)));
}
openai_chat_response_to_model_response(&self.model_id, &body)
}
}
impl EmbeddingModel for OpenAiEmbeddingModel {
fn model_id(&self) -> &str {
&self.model_id
}
fn embed(&self, value: &str) -> Result<EmbeddingResult> {
let (status, body) = openai_post_json(
"https://api.openai.com/v1/embeddings",
json!({
"model": self.model_id,
"input": value,
}),
)?;
if !(200..300).contains(&status) {
return Err(Error::Api(openai_error_message(&body)));
}
let embedding = body
.get("data")
.and_then(Value::as_array)
.and_then(|items| items.first())
.and_then(|item| item.get("embedding"))
.and_then(Value::as_array)
.ok_or_else(|| Error::Parse("missing embedding data".to_string()))?
.iter()
.map(|value| {
value
.as_f64()
.map(|value| value as f32)
.ok_or_else(|| Error::Parse("invalid embedding value".to_string()))
})
.collect::<Result<Vec<_>>>()?;
Ok(EmbeddingResult {
embedding,
usage: openai_usage(&body),
})
}
}
fn openai_language_model(model_id: &str) -> Result<Box<dyn LanguageModel>> {
if model_id.is_empty() {
return Err(Error::UnsupportedModel("openai/".to_string()));
}
Ok(Box::new(OpenAiLanguageModel {
model_id: model_id.to_string(),
}))
}
fn openai_embedding_model(model_id: &str) -> Result<Box<dyn EmbeddingModel>> {
if model_id.is_empty() {
return Err(Error::UnsupportedModel("openai/".to_string()));
}
Ok(Box::new(OpenAiEmbeddingModel {
model_id: model_id.to_string(),
}) as Box<dyn EmbeddingModel>)
}
inventory::submit! {
ProviderRegistration {
id: "openai",
language_model: openai_language_model,
embedding_model: openai_embedding_model,
}
}
fn openai_api_key() -> Result<String> {
std::env::var("OPENAI_API_KEY").map_err(|_| Error::MissingEnvironmentVariable("OPENAI_API_KEY"))
}
fn openai_post_json(url: &'static str, body: Value) -> Result<(u16, Value)> {
let api_key = openai_api_key()?;
std::thread::spawn(move || {
let response = Client::builder()
.build()
.map_err(|error| Error::Http(error.to_string()))?
.post(url)
.bearer_auth(api_key)
.json(&body)
.send()
.map_err(|error| Error::Http(error.to_string()))?;
let status = response.status().as_u16();
let body = response
.json()
.map_err(|error| Error::Json(error.to_string()))?;
Ok((status, body))
})
.join()
.map_err(|_| Error::Http("openai request thread panicked".to_string()))?
}
fn openai_chat_request(model_id: &str, request: &ModelRequest) -> Value {
let mut body = json!({
"model": model_id,
"messages": openai_messages(&request.messages),
});
if let Some(temperature) = request.settings.temperature {
body["temperature"] = json!(temperature);
}
if let Some(max_tokens) = request.settings.max_output_tokens {
body["max_tokens"] = json!(max_tokens);
}
if !request.tools.is_empty() {
body["tools"] = Value::Array(request.tools.iter().map(openai_tool_definition).collect());
body["tool_choice"] = openai_tool_choice(&request.tool_choice);
}
body
}
fn openai_messages(messages: &[ModelMessage]) -> Vec<Value> {
messages
.iter()
.map(|message| match message.role {
Role::System => json!({ "role": "system", "content": message.text() }),
Role::User => json!({ "role": "user", "content": message.text() }),
Role::Assistant => {
let content = message.text();
let tool_calls = message
.parts
.iter()
.filter_map(|part| match part {
Part::ToolCall(call) => Some(json!({
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": call.input,
}
})),
_ => None,
})
.collect::<Vec<_>>();
if tool_calls.is_empty() {
json!({ "role": "assistant", "content": content })
} else {
json!({
"role": "assistant",
"content": if content.is_empty() { Value::Null } else { Value::String(content) },
"tool_calls": tool_calls,
})
}
}
Role::Tool => {
let result = message.parts.iter().find_map(|part| match part {
Part::ToolResult(result) => Some(result),
_ => None,
});
match result {
Some(result) => json!({
"role": "tool",
"tool_call_id": result.call_id,
"content": result.output,
}),
None => json!({ "role": "tool", "content": message.text() }),
}
}
})
.collect()
}
fn openai_tool_definition(tool: &ToolDefinition) -> Value {
json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool_schema_json(&tool.input_schema),
}
})
}
fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
match tool_choice {
ToolChoice::Auto => json!("auto"),
ToolChoice::None => json!("none"),
ToolChoice::Required(name) => json!({
"type": "function",
"function": { "name": name }
}),
}
}
pub(crate) fn tool_schema_json(schema: &ToolSchema) -> Value {
match schema {
ToolSchema::String { description } => {
json_with_description(json!({ "type": "string" }), description)
}
ToolSchema::Integer { description } => {
json_with_description(json!({ "type": "integer" }), description)
}
ToolSchema::Number { description } => {
json_with_description(json!({ "type": "number" }), description)
}
ToolSchema::Boolean { description } => {
json_with_description(json!({ "type": "boolean" }), description)
}
ToolSchema::Array { description, items } => json_with_description(
json!({ "type": "array", "items": tool_schema_json(items) }),
description,
),
ToolSchema::Object(object) => {
let properties = object
.fields
.iter()
.map(|field| {
let mut schema = tool_schema_json(&field.schema);
if let Some(description) = &field.description {
schema["description"] = json!(description);
}
(field.name.clone(), schema)
})
.collect::<serde_json::Map<String, Value>>();
let required = object
.fields
.iter()
.filter(|field| field.required)
.map(|field| Value::String(field.name.clone()))
.collect::<Vec<_>>();
json_with_description(
json!({
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": false,
}),
&object.description,
)
}
}
}
fn json_with_description(mut value: Value, description: &Option<String>) -> Value {
if let Some(description) = description {
value["description"] = json!(description);
}
value
}
fn openai_chat_response_to_model_response(model_id: &str, body: &Value) -> Result<ModelResponse> {
let choice = body
.get("choices")
.and_then(Value::as_array)
.and_then(|choices| choices.first())
.ok_or_else(|| Error::Parse("missing choice".to_string()))?;
let message = choice
.get("message")
.ok_or_else(|| Error::Parse("missing message".to_string()))?;
let mut parts = Vec::new();
if let Some(content) = message.get("content") {
let text = openai_text_content(content);
if !text.is_empty() {
parts.push(Part::Text(text));
}
}
if let Some(tool_calls) = message.get("tool_calls").and_then(Value::as_array) {
for tool_call in tool_calls {
let id = tool_call
.get("id")
.and_then(Value::as_str)
.ok_or_else(|| Error::Parse("missing tool call id".to_string()))?;
let function = tool_call
.get("function")
.ok_or_else(|| Error::Parse("missing tool call function".to_string()))?;
let name = function
.get("name")
.and_then(Value::as_str)
.ok_or_else(|| Error::Parse("missing tool call name".to_string()))?;
let input = function
.get("arguments")
.and_then(Value::as_str)
.ok_or_else(|| Error::Parse("missing tool call arguments".to_string()))?;
parts.push(Part::ToolCall(crate::ToolCall {
id: id.to_string(),
name: name.to_string(),
input: input.to_string(),
}));
}
}
Ok(ModelResponse {
parts,
finish_reason: openai_finish_reason(choice.get("finish_reason").and_then(Value::as_str)),
usage: openai_usage(body),
response_metadata: crate::metadata_with_provider("openai", model_id),
})
}
fn openai_text_content(content: &Value) -> String {
match content {
Value::String(text) => text.clone(),
Value::Array(parts) => parts
.iter()
.filter_map(|part| part.get("text").and_then(Value::as_str))
.collect::<Vec<_>>()
.join(""),
_ => String::new(),
}
}
fn openai_finish_reason(reason: Option<&str>) -> FinishReason {
match reason {
Some("tool_calls") => FinishReason::ToolCalls,
Some("length") => FinishReason::Length,
Some("stop") | None => FinishReason::Stop,
_ => FinishReason::Error,
}
}
pub(crate) fn openai_usage(body: &Value) -> Usage {
let usage = body.get("usage");
Usage {
input_tokens: usage
.and_then(|usage| usage.get("prompt_tokens"))
.and_then(Value::as_u64)
.unwrap_or_default() as usize,
output_tokens: usage
.and_then(|usage| usage.get("completion_tokens"))
.and_then(Value::as_u64)
.unwrap_or_default() as usize,
}
}
fn openai_error_message(body: &Value) -> String {
body.get("error")
.and_then(|error| error.get("message"))
.and_then(Value::as_str)
.unwrap_or("unknown OpenAI error")
.to_string()
}