use std::env;
use openai_api_rs::v1::{api::OpenAIClient, chat_completion};
use serde_json::{json, Value};
use crate::{datamodel::DataModel, json::export::to_json_schema};
pub async fn query_openai(
prompt: &str,
pre_prompt: &str,
data_model: &DataModel,
root: &str,
model: &str,
multiple: bool,
api_key: Option<String>,
) -> Result<Value, Box<dyn std::error::Error>> {
let response_format = prepare_response_format(data_model, root, multiple)?;
let mut client = prepare_client(api_key)?;
let messages = vec![create_chat_message(pre_prompt), create_chat_message(prompt)];
let model_type = ModelType::from_str(model)?;
let req = match model_type {
ModelType::Reasoning => {
chat_completion::ChatCompletionRequest::new(model.to_string(), messages)
.response_format(response_format)
}
ModelType::Generation => {
chat_completion::ChatCompletionRequest::new(model.to_string(), messages)
.response_format(response_format)
.temperature(0.0)
}
};
let result = client.chat_completion(req).await?;
let content = result
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or_else(|| format!("No content in response from {model}"))?;
Ok(serde_json::from_str(content)?)
}
fn prepare_response_format(
model: &DataModel,
root: &str,
multiple: bool,
) -> Result<Value, Box<dyn std::error::Error>> {
let mut schema = to_json_schema(model, root, true)?;
if multiple {
let definitions = schema.definitions.clone();
schema.definitions.clear();
Ok(json!(
{
"type": "json_schema",
"json_schema": {
"strict": true,
"name": root,
"schema": {
"type": "object",
"additionalProperties": false,
"required": ["items"],
"properties": {
"items": {
"type": "array",
"items": schema
}
},
"$defs": definitions
}
}
}
))
} else {
Ok(json!(
{
"type": "json_schema",
"json_schema": {
"name": root,
"strict": true,
"schema": schema
}
}
))
}
}
fn prepare_client(api_key: Option<String>) -> Result<OpenAIClient, Box<dyn std::error::Error>> {
let api_key = match api_key {
Some(api_key) => api_key,
None => env::var("OPENAI_API_KEY")?,
};
OpenAIClient::builder().with_api_key(api_key).build()
}
fn create_chat_message(content: &str) -> chat_completion::ChatCompletionMessage {
chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: chat_completion::Content::Text(content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
enum ModelType {
Reasoning,
Generation,
}
impl ModelType {
fn from_str(s: &str) -> Result<Self, String> {
if regex::Regex::new(r"o\d*").unwrap().is_match(s) {
Ok(ModelType::Reasoning)
} else {
Ok(ModelType::Generation)
}
}
}