pub mod messages;
use super::{
messages::AbstractMessage, LLMProvider, LLMToolUsage, MultiModelLLMProvider,
StructuredLLMProvider, Tool, ToolChoice, Toolkit,
};
use anyhow::Result;
use log::{debug, info, warn};
use messages::OpenAIMessage;
use reqwest::blocking::Client;
use schemars::{
schema::{ObjectValidation, RootSchema, Schema},
schema_for, JsonSchema,
};
use serde::{Deserialize, Serialize};
pub struct OpenAIClient {
api_key: String,
client: Client,
model: OpenAIModel,
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
pub enum OpenAIModel {
#[serde(rename = "gpt-4o")]
Gpt4o,
#[serde(rename = "o1-preview")]
O1Preview,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CompletionRequest {
model: OpenAIModel,
messages: Vec<OpenAIMessage>,
}
impl CompletionRequest {
fn body(model: OpenAIModel, messages: Vec<OpenAIMessage>) -> Self {
Self { model, messages }
}
}
#[derive(Debug, Deserialize)]
pub struct CompletionChoice {
finish_reason: String,
index: u64,
message: OpenAIMessage,
}
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
id: String,
object: String,
created: u64, choices: Vec<CompletionChoice>,
}
fn set_additional_properties_false(root_schema: &mut RootSchema) {
if root_schema.schema.object.is_none() {
root_schema.schema.object = Some(Box::new(ObjectValidation::default()));
}
root_schema
.schema
.object
.as_mut()
.unwrap()
.additional_properties = Some(Box::new(Schema::Bool(false)));
if let Some(props) = &mut root_schema.schema.object {
for schema in props.properties.values_mut() {
if let Schema::Object(obj) = schema {
if obj.object.is_none() {
obj.object = Some(Box::new(ObjectValidation::default()));
}
obj.object.as_mut().unwrap().additional_properties =
Some(Box::new(Schema::Bool(false)));
}
}
}
for schema in root_schema.definitions.values_mut() {
if let Schema::Object(obj) = schema {
if obj.object.is_none() {
obj.object = Some(Box::new(ObjectValidation::default()));
}
obj.object.as_mut().unwrap().additional_properties =
Some(Box::new(Schema::Bool(false)));
}
}
}
impl LLMProvider<OpenAIMessage> for OpenAIClient {
fn get_completion(&self, messages: Vec<OpenAIMessage>) -> Result<Vec<OpenAIMessage>> {
debug!(
"Getting completion from OpenAI with {} messages",
messages.len()
);
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", self.api_key)
.parse()
.expect("Invalid API key"),
);
headers.insert(
"Content-Type",
"application/json".parse().expect("Invalid content type"),
);
let request_body = CompletionRequest::body(OpenAIModel::Gpt4o, messages.clone());
debug!("Sending request to OpenAI API");
let result = self
.client
.post("https://api.openai.com/v1/chat/completions")
.headers(headers)
.json(&request_body)
.send()?;
if !result.status().is_success() {
let status = result.status();
let error_text = result.text()?;
warn!("OpenAI API error: {} - {}", status, error_text);
return Err(anyhow::anyhow!(
"Failed to get completion: {:?} {:?}",
status,
error_text
));
}
let completion_response: CompletionResponse = result.json()?;
let last_message = completion_response.choices.first().ok_or(anyhow::anyhow!(
"No choices returned in the OpenAI response"
))?;
debug!("Last message: {:?}", last_message.message);
Ok(messages
.into_iter()
.chain(vec![last_message.message.clone()])
.collect())
}
fn stream_completion(
&self,
messages: Vec<OpenAIMessage>,
) -> Result<Box<dyn Iterator<Item = OpenAIMessage>>> {
todo!("Implement streaming for the OpenAI client")
}
}
impl MultiModelLLMProvider<OpenAIModel> for OpenAIClient {
fn with_model(&self, model: OpenAIModel) -> Self {
Self {
api_key: self.api_key.clone(),
client: self.client.clone(),
model,
}
}
fn get_model(&self) -> OpenAIModel {
self.model
}
}
impl LLMToolUsage<OpenAIMessage> for OpenAIClient {
fn do_work_with_tool(
&self,
messages: Vec<OpenAIMessage>,
tool: &dyn Tool,
) -> Result<Vec<OpenAIMessage>> {
debug!("Executing tool '{}' with OpenAI", tool.name());
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", self.api_key).parse().unwrap(),
);
headers.insert(
"Content-Type",
"application/json".parse().expect("Invalid content type"),
);
let request_body = serde_json::json!({
"model": self.model,
"messages": messages,
"tools": [{
"type": "function",
"function": {
"name": tool.name(),
"description": tool.description(),
"parameters": tool.schema()
}
}],
"tool_choice": {
"type": "function",
"function": { "name": tool.name() }
}
});
debug!("Sending tool execution request to OpenAI API");
let result = self
.client
.post("https://api.openai.com/v1/chat/completions")
.headers(headers)
.json(&request_body)
.send()?;
if !result.status().is_success() {
let status = result.status();
let error_text = result.text()?;
warn!(
"OpenAI API error during tool execution: {} - {}",
status, error_text
);
return Err(anyhow::anyhow!("Failed to use tool: {}", error_text));
}
let response: CompletionResponse = result.json()?;
println!("Raw response from tool use ask: {:#?}", response);
let message = response
.choices
.first()
.ok_or_else(|| anyhow::anyhow!("No choices returned in the OpenAI response"))?;
debug!("Last message: {:?}", message.message);
match &message.message {
OpenAIMessage::Assistant {
tool_calls: Some(tool_calls),
..
} => {
let tool_call = tool_calls
.first()
.ok_or_else(|| anyhow::anyhow!("No tool calls in assistant message"))?;
let args = serde_json::from_str(&tool_call.function.arguments)?;
let result = tool.execute(args)?;
Ok(vec![OpenAIMessage::Tool {
content: serde_json::to_string(&result)?,
tool_call_id: tool_call.id.clone(),
}])
}
_ => Err(anyhow::anyhow!(
"Expected assistant message with tool calls"
)),
}
}
fn get_chat_with_tools(
&self,
messages: Vec<OpenAIMessage>,
tool_kit: &Toolkit,
force_tool_use: &ToolChoice,
) -> Result<Vec<OpenAIMessage>> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", self.api_key).parse().unwrap(),
);
headers.insert(
"Content-Type",
"application/json".parse().expect("Invalid content type"),
);
debug!("Messages: {:?}", messages);
let tool_defs: Vec<serde_json::Value> = tool_kit
.tools()
.iter()
.map(|tool| {
serde_json::json!({
"type": "function",
"function": {
"name": tool.name(),
"description": tool.description(),
"parameters": tool.schema()
}
})
})
.collect();
debug!("Tool definitions: {:?}", tool_defs);
let tool_choice = match force_tool_use {
ToolChoice::Specific(name) => serde_json::json!({
"type": "function",
"function": {
"name": name
}
}),
ToolChoice::Any => serde_json::json!("required"),
ToolChoice::SelfSelect => serde_json::json!("auto"),
};
let request_body = serde_json::json!({
"model": self.model,
"messages": messages,
"tools": tool_defs,
"tool_choice": tool_choice
});
let result = self
.client
.post("https://api.openai.com/v1/chat/completions")
.headers(headers)
.json(&request_body)
.send()?;
if !result.status().is_success() {
let status = result.status();
let error_text = result.text()?;
warn!(
"OpenAI API error during chat with tools: {} - {}",
status, error_text
);
return Err(anyhow::anyhow!("Failed to chat with tools: {}", error_text));
}
let response: CompletionResponse = result.json()?;
let message = response
.choices
.first()
.ok_or_else(|| anyhow::anyhow!("No choices returned in the OpenAI response"))?;
debug!("Last message: {:?}", message.message);
Ok(messages
.into_iter()
.chain(vec![message.message.clone()])
.collect())
}
fn get_work_result(
&self,
messages: Vec<OpenAIMessage>,
tool_kit: &Toolkit,
tool_choice: &ToolChoice,
) -> Result<Vec<OpenAIMessage>> {
info!("Getting work result with tool choice: {:?}", tool_choice);
match tool_choice {
ToolChoice::Specific(name) => {
debug!("Using specific tool: {}", name);
self.do_work_with_tool(
messages,
tool_kit
.get(name)
.ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?,
)
}
ToolChoice::Any => {
debug!("Getting chat with any tool allowed");
let response = self.get_chat_with_tools(messages, tool_kit, tool_choice)?;
debug!("Response from chat with tools: {:?}", response);
if let Some(OpenAIMessage::Assistant {
tool_calls: Some(tool_calls),
..
}) = response.clone().last()
{
let mut result_messages = response;
for tool_call in tool_calls {
debug!("Processing tool call: {:?}", tool_call);
let tool = tool_kit.get(&tool_call.function.name).ok_or_else(|| {
anyhow::anyhow!("Tool not found: {}", tool_call.function.name)
})?;
let args = serde_json::from_str(&tool_call.function.arguments)?;
let result = tool.execute(args)?;
result_messages.push(OpenAIMessage::Tool {
content: serde_json::to_string(&result)?,
tool_call_id: tool_call.id.clone(),
});
}
debug!("Result messages: {:?}", result_messages);
let messages = self.get_work_result(result_messages, tool_kit, tool_choice)?;
Ok(messages)
} else {
Err(anyhow::anyhow!("No tool calls in assistant message"))
}
}
ToolChoice::SelfSelect => {
debug!("Letting model select tool usage");
let response = self.get_chat_with_tools(messages, tool_kit, tool_choice)?;
debug!("Response from chat with tools: {:?}", response);
if let Some(OpenAIMessage::Assistant {
tool_calls: Some(tool_calls),
..
}) = response.clone().last()
{
let mut result_messages = response;
for tool_call in tool_calls {
debug!("Processing tool call: {:?}", tool_call);
let tool = tool_kit.get(&tool_call.function.name).ok_or_else(|| {
anyhow::anyhow!("Tool not found: {}", tool_call.function.name)
})?;
let args = serde_json::from_str(&tool_call.function.arguments)?;
let result = tool.execute(args)?;
result_messages.push(OpenAIMessage::Tool {
content: serde_json::to_string(&result)?,
tool_call_id: tool_call.id.clone(),
});
}
debug!("Result messages: {:?}", result_messages);
let messages = self.get_work_result(result_messages, tool_kit, tool_choice)?;
Ok(messages)
} else {
Ok(response) }
}
}
}
}
impl StructuredLLMProvider<OpenAIMessage> for OpenAIClient {
fn get_structured_response<
DesiredSchema: Serialize + serde::de::DeserializeOwned + JsonSchema,
>(
&self,
messages: Vec<OpenAIMessage>,
) -> Result<DesiredSchema> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", self.api_key)
.parse()
.expect("Invalid API key"),
);
headers.insert(
"Content-Type",
"application/json".parse().expect("Invalid content type"),
);
let mut schema = schema_for!(DesiredSchema);
set_additional_properties_false(&mut schema);
println!("{}", serde_json::to_string(&schema).unwrap());
let request_body = serde_json::json!({
"model": OpenAIModel::Gpt4o,
"messages": messages,
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "desired_schema",
"strict": true,
"schema": schema
}
}
});
let result = self
.client
.post("https://api.openai.com/v1/chat/completions")
.headers(headers)
.json(&request_body)
.send()?;
if !result.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to get structured response: {:?} {:?}",
result.status(),
result.text()
));
}
let response: CompletionResponse = result.json()?;
let content = response.choices[0]
.message
.get_content()
.map_err(|_| anyhow::anyhow!("Failed to get message content"))?;
Ok(serde_json::from_str(&content)?)
}
}
impl Default for OpenAIClient {
fn default() -> Self {
Self::new()
}
}
impl OpenAIClient {
pub fn new() -> Self {
Self {
api_key: std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"),
client: Client::new(),
model: OpenAIModel::Gpt4o,
}
}
}