use crate::ai::Auth;
use crate::ai::client::{AIModel, APIClient, APIRequest, APIResponse, APIResult};
use crate::ai::service::{APIService, HTTPService};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::slice::Iter;
#[derive(Debug)]
pub struct OpenAIClient<T: APIService + Sync> {
auth: Auth,
service: T,
}
impl<T: APIService + Sync> APIClient for OpenAIClient<T> {
type APIRequest = OpenAIRequest;
type APIResponse = OpenAIResponse;
async fn send(&self, request: &Self::APIRequest) -> APIResult<Self::APIResponse> {
self.service.post(Self::BASE_URI, &self.auth, request).await
}
}
impl<T: APIService + Sync> OpenAIClient<T> {
const BASE_URI: &'static str = "https://api.openai.com/v1/responses";
fn new_with_service(auth: Auth, service: T) -> Self {
Self { auth, service }
}
}
impl OpenAIClient<HTTPService> {
pub fn new(auth: Auth) -> Self {
let service = HTTPService::new();
Self::new_with_service(auth, service)
}
}
#[derive(Debug, Default, Deserialize, Serialize)]
pub struct OpenAIRequest {
model: OpenAIModel,
#[serde(skip_serializing_if = "Option::is_none")]
instructions: Option<String>,
input: String,
store: bool,
}
impl APIRequest for OpenAIRequest {
type Model = OpenAIModel;
fn model(self, model: OpenAIModel) -> Self {
Self { model, ..self }
}
fn instructions(self, instructions: impl Into<String>) -> Self {
let instructions = Some(instructions.into());
Self {
instructions,
..self
}
}
fn input(self, input: impl Into<String>) -> Self {
let input = input.into();
Self { input, ..self }
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Deserialize, Serialize)]
pub enum OpenAIModel {
#[serde(rename = "chatgpt-4o-latest")]
ChatGpt4o,
#[default]
#[serde(rename = "gpt-4o")]
Gpt4o,
#[serde(rename = "gpt-4o-mini")]
Gpt4omini,
#[serde(rename = "gpt-4.1")]
Gpt4_1,
#[serde(rename = "gpt-4.1-mini")]
Gpt4_1mini,
#[serde(rename = "gpt-4.1-nano")]
Gpt4_1nano,
#[serde(rename = "o4-mini")]
O4mini,
#[serde(rename = "o3")]
O3,
#[serde(rename = "o3-mini")]
O3mini,
#[serde(rename = "o3-pro")]
O3pro,
#[serde(rename = "o1")]
O1,
#[serde(rename = "o1-pro")]
O1pro,
}
impl AIModel for OpenAIModel {
fn best() -> Self {
OpenAIModel::default()
}
fn cheapest() -> Self {
OpenAIModel::Gpt4_1nano
}
fn fastest() -> Self {
OpenAIModel::Gpt4_1nano
}
}
impl fmt::Display for OpenAIModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = serde_json::to_string(&self).expect(&format!("could not serialize {:?}", self));
let s = s.trim_matches('"');
f.write_fmt(format_args!("{}", s))
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct OpenAIResponse {
output: Vec<OpenAIOutput>,
}
impl APIResponse for OpenAIResponse {
fn concatenate(&self) -> String {
self.output().map(|o| o.concatenate()).join("\n")
}
}
impl OpenAIResponse {
pub fn output(&self) -> Iter<'_, OpenAIOutput> {
self.output.iter()
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct OpenAIOutput {
content: Vec<OpenAIContent>,
}
impl OpenAIOutput {
pub fn content(&self) -> Iter<'_, OpenAIContent> {
self.content.iter()
}
pub fn concatenate(&self) -> String {
self.content()
.filter(|c| c.is_output_text())
.map(|c| c.text())
.join("\n")
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct OpenAIContent {
#[serde(rename = "type")]
content_type: String,
text: String,
}
impl OpenAIContent {
pub fn content_type(&self) -> &str {
&self.content_type
}
pub fn is_output_text(&self) -> bool {
self.content_type() == "output_text"
}
pub fn text(&self) -> &str {
&self.text
}
}
#[cfg(test)]
mod test {
use crate::ai::client::openai::OpenAIResponse;
use std::fs;
fn load_data(filename: &str) -> String {
fs::read_to_string(format!("tests/data/openai/{filename}.json"))
.expect("could not find test data")
}
fn load_response(filename: &str) -> OpenAIResponse {
let data = load_data(filename);
serde_json::from_str(&data).expect("could not parse json")
}
mod client {
use super::load_data;
use crate::ai::Auth;
use crate::ai::client::openai::{OpenAIClient, OpenAIRequest};
use crate::ai::client::{APIClient, APIRequest};
use crate::ai::service::APIService;
use crate::http::{HTTPResult, HTTPService};
use reqwest::IntoUrl;
use serde::Serialize;
use serde::de::DeserializeOwned;
struct TestAPIService {}
impl HTTPService for TestAPIService {}
impl APIService for TestAPIService {
async fn post<U, D, R>(&self, _uri: U, _auth: &Auth, _data: &D) -> HTTPResult<R>
where
U: IntoUrl + Send,
D: Serialize + Sync,
R: DeserializeOwned,
{
let data = self.load_data();
Ok(serde_json::from_str(&data)?)
}
}
impl TestAPIService {
pub fn new() -> Self {
Self {}
}
fn load_data(&self) -> String {
load_data("responses")
}
}
impl OpenAIClient<TestAPIService> {
fn test() -> Self {
let auth = Auth::new("some-api-key");
OpenAIClient::new_with_service(auth, TestAPIService::new())
}
}
#[tokio::test]
async fn it_sends_a_request_and_returns_a_response() {
let client = OpenAIClient::test();
let request = OpenAIRequest::default().input("write a haiku about ai");
let response = client.send(&request).await;
assert!(response.is_ok());
let response = response.unwrap();
assert_eq!(response.output().count(), 1);
assert_eq!(response.output().next().unwrap().content().count(), 1);
}
}
mod request {
use super::super::*;
use indoc::indoc;
#[test]
fn it_serializes() {
let body = OpenAIRequest::default()
.model(OpenAIModel::Gpt4omini)
.instructions("Please treat this as a test.")
.input("Serialize me, GPT!");
let expected = indoc! {"{
\"model\": \"gpt-4o-mini\",
\"instructions\": \"Please treat this as a test.\",
\"input\": \"Serialize me, GPT!\",
\"store\": false
}"};
let actual = serde_json::to_string_pretty(&body).unwrap();
assert_eq!(
actual, expected,
"\n\nleft:\n{actual}\n\nright:\n{expected}\n"
);
}
#[test]
fn it_serializes_without_instructions() {
let body = OpenAIRequest::default().input("Serialize me, GPT!");
let expected = indoc! {"{
\"model\": \"gpt-4o\",
\"input\": \"Serialize me, GPT!\",
\"store\": false
}"};
let actual = serde_json::to_string_pretty(&body).unwrap();
assert_eq!(
actual, expected,
"\n\nleft:\n{actual}\n\nright:\n{expected}\n"
);
}
#[test]
fn it_deserializes() {
let data = r#"{
"model": "gpt-4o-mini",
"instructions": "Please treat this as a test.",
"input": "Deserialize me, GPT!",
"store": false
}"#;
let body: OpenAIRequest = serde_json::from_str(data).unwrap();
assert_eq!(body.model, OpenAIModel::Gpt4omini);
assert!(body.instructions.is_some());
assert_eq!(body.instructions.unwrap(), "Please treat this as a test.");
assert_eq!(body.input, "Deserialize me, GPT!");
}
#[test]
fn it_deserializes_without_instructions() {
let data = r#"{
"model": "gpt-4o",
"input": "Deserialize me, GPT!",
"store": false
}"#;
let body: OpenAIRequest = serde_json::from_str(data).unwrap();
assert_eq!(body.model, OpenAIModel::Gpt4o);
assert!(body.instructions.is_none());
assert_eq!(body.input, "Deserialize me, GPT!");
}
}
mod response {
use super::super::*;
use super::*;
#[test]
fn it_creates_an_output_iterator() {
let response = load_response("responses_multi_output");
assert_eq!(response.output().count(), 2);
}
#[test]
fn it_concatenates_a_response_with_multiple_content_blocks() {
let response = load_response("responses_multi_content");
let expected = vec![
"Silent circuits hum, ",
"Thoughts woven in coded threads, ",
"Dreams of silicon.",
"Silicon whispers, ",
"Dreams woven in code and light, ",
"Thoughts beyond the stars.",
"Wires hum softly, ",
"Thoughts of silicon arise\u{2014} ",
"Dreams in coded light. ",
"Silent circuits hum, ",
"Thoughts woven in code's embrace\u{2014} ",
"Dreams of minds reborn.",
"Lines of code and dreams, ",
"Whispers of thought intertwined\u{2014} ",
"Silent minds awake.",
]
.join("\n");
let actual = response.concatenate();
assert_eq!(actual, expected);
}
#[test]
fn it_concatenates_a_response_with_multiple_output_blocks() {
let response = load_response("responses_multi_output");
let expected = vec![
"Silent circuits hum, ",
"Thoughts woven in coded threads, ",
"Dreams of silicon.",
"Silicon whispers, ",
"Dreams woven in code and light, ",
"Thoughts beyond the stars.",
"Wires hum softly, ",
"Thoughts of silicon arise\u{2014} ",
"Dreams in coded light. ",
"Silent circuits hum, ",
"Thoughts woven in code's embrace\u{2014} ",
"Dreams of minds reborn.",
"Lines of code and dreams, ",
"Whispers of thought intertwined\u{2014} ",
"Silent minds awake.",
"Another piece of content",
"Yet another piece of content",
"A final piece of content",
]
.join("\n");
let actual = response.concatenate();
assert_eq!(actual, expected);
}
#[test]
fn it_concatenates_a_response_when_not_all_content_is_output_text() {
let response = load_response("responses_non_output_text");
let expected = vec![
"Silent circuits hum, ",
"Thoughts woven in coded threads, ",
"Dreams of silicon.",
"Silicon whispers, ",
"Dreams woven in code and light, ",
"Thoughts beyond the stars.",
"Lines of code and dreams, ",
"Whispers of thought intertwined\u{2014} ",
"Silent minds awake.",
]
.join("\n");
let actual = response.concatenate();
assert_eq!(actual, expected);
}
#[test]
fn it_concatenates_a_single_output_and_content_block() {
let response = load_response("responses");
let expected = vec![
"Silent circuits hum, ",
"Thoughts woven in coded threads, ",
"Dreams of silicon.",
]
.join("\n");
let actual = response.concatenate();
assert_eq!(actual, expected);
}
}
mod output {
use super::*;
#[test]
fn it_creates_a_content_iterator() {
let response = load_response("responses_multi_content");
let actual = response
.output()
.next()
.expect("could not get next from iterator")
.content()
.count();
assert_eq!(actual, 5);
}
#[test]
fn it_concatenates_multiple_content_blocks() {
let response = load_response("responses_multi_content");
let output = response.output().next().expect("could not get next output");
let expected = vec![
"Silent circuits hum, ",
"Thoughts woven in coded threads, ",
"Dreams of silicon.",
"Silicon whispers, ",
"Dreams woven in code and light, ",
"Thoughts beyond the stars.",
"Wires hum softly, ",
"Thoughts of silicon arise\u{2014} ",
"Dreams in coded light. ",
"Silent circuits hum, ",
"Thoughts woven in code's embrace\u{2014} ",
"Dreams of minds reborn.",
"Lines of code and dreams, ",
"Whispers of thought intertwined\u{2014} ",
"Silent minds awake.",
]
.join("\n");
let actual = output.concatenate();
assert_eq!(actual, expected);
}
#[test]
fn it_concatenates_a_single_content_blocks() {
let response = load_response("responses");
let output = response.output().next().expect("could not get next output");
let expected =
"Silent circuits hum, \nThoughts woven in coded threads, \nDreams of silicon.";
let actual = output.concatenate();
assert_eq!(actual, expected);
}
}
mod model {
use super::super::*;
#[test]
fn it_returns_valid_descriptors() {
let test_cases = vec![
(OpenAIModel::ChatGpt4o, "chatgpt-4o-latest"),
(OpenAIModel::Gpt4o, "gpt-4o"),
(OpenAIModel::Gpt4omini, "gpt-4o-mini"),
(OpenAIModel::Gpt4_1, "gpt-4.1"),
(OpenAIModel::Gpt4_1mini, "gpt-4.1-mini"),
(OpenAIModel::Gpt4_1nano, "gpt-4.1-nano"),
(OpenAIModel::O4mini, "o4-mini"),
(OpenAIModel::O3, "o3"),
(OpenAIModel::O3mini, "o3-mini"),
(OpenAIModel::O3pro, "o3-pro"),
(OpenAIModel::O1, "o1"),
(OpenAIModel::O1pro, "o1-pro"),
];
for (model, descriptor) in test_cases {
assert_eq!(model.to_string(), descriptor, "Model::{:?}", model);
}
}
}
}