1pub mod chat;
2pub mod completion;
3pub mod edit;
4pub mod image;
5pub mod files;
6pub mod embeddings;
7pub mod fine_tunes;
8pub mod moderations;
9pub mod audio;
10pub mod model;
11mod conversions;
12
13use anyhow::Result;
14use std::io;
15use std::path::PathBuf;
16use std::pin::Pin;
17use async_trait::async_trait;
18use bytes::Bytes;
19use futures_util::TryFutureExt;
20use reqwest::{Body, Client, multipart, RequestBuilder, Response};
21use reqwest::multipart::Part;
22use serde::de::DeserializeOwned;
23use tokio::fs::File;
24use tokio::io::AsyncWriteExt;
25use tokio::try_join;
26use tokio_stream::{Stream, StreamExt};
27use tokio_util::codec::{BytesCodec, FramedRead};
28use with_id::WithRefId;
29use std::fmt::{Debug, Display, Formatter};
30use serde::{Serialize, Deserialize};
31use crate::conversions::AsyncTryInto;
32
33
34#[derive(Debug, Clone)]
44pub struct OpenAiClient {
45 url:String,
46 key:String,
47 client:Client
48}
49
50impl OpenAiClient {
51
52 const URL: &'static str = "https://api.openai.com/v1";
53
54 pub fn new(key: &str)->Self{
57 let client = Client::new();
58 OpenAiClient::with_client(key,&client)
59 }
60
61 pub fn with_client(key: &str, client: &Client)->Self{
66 OpenAiClient::with_url_and_client(key,OpenAiClient::URL,client)
67 }
68
69
70 pub fn with_url(key: &str, url: &str) -> Self {
72 let client = Client::new();
73 OpenAiClient::with_url_and_client(key,url,&client)
74 }
75
76
77 pub fn with_url_and_client(key: &str, url: &str, client: &Client)->Self{
80 OpenAiClient {
81 url: url.to_string(),
82 key: key.to_string(),
83 client: client.clone()
84 }
85 }
86}
87
88#[derive(Debug)]
91pub struct Error{
92 pub(crate) response:ApiError,
93 pub(crate) inner:reqwest::Error
94}
95
96impl Display for Error {
97 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
98 write!(f,"{}",self.response)
99 }
100}
101
102impl std::error::Error for Error {
103 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
104 Some(&self.inner)
105 }
106}
107
108
109#[derive(Serialize, Deserialize, Debug, Clone)]
111pub struct ApiError {
112 pub error: ApiErrorDetails
113}
114
115#[derive(Serialize, Deserialize, Debug, Clone)]
116#[serde(rename(serialize = "error"))]
117#[serde(rename(deserialize = "error"))]
118pub struct ApiErrorDetails {
119 pub message: String,
120 #[serde(rename = "type")]
121 pub kind: String,
122 pub param: Option<String>,
123 pub code: Option<String>
124}
125
126impl Display for ApiError{
127 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
128 match &self.error.param {
129 None => match &self.error.code {
130 None => write!(f,"{}",self.error.message),
131 Some(code) => write!(f,"{}, code:{}",self.error.message,code)
132 }
133 Some(param) => match &self.error.code {
134 None => write!(f,"{}, param:{}",self.error.message,param),
135 Some(code) => write!(f,"{}, param:{}, code: {}",self.error.message,param,code)
136 }
137 }
138 }
139}
140
141#[derive(Clone, Serialize, Deserialize, Debug)]
144#[serde(untagged)]
145pub enum Input {
146 String(String),
147 StringArray(Vec<String>)
148}
149
150impl From<String> for Input{
151 fn from(value:String) -> Self {
152 Input::String(value)
153 }
154}
155
156impl From<&str> for Input{
157 fn from(value:&str) -> Self {
158 Input::String(value.to_string())
159 }
160}
161
162impl From<Vec<String>> for Input{
163 fn from(value: Vec<String>) -> Self {
164 Input::StringArray(value)
165 }
166}
167
168#[derive(Serialize, Deserialize, Debug, Clone)]
170pub struct DeleteResponse {
171 pub id: String,
172 pub object: String,
173 pub deleted: bool,
174}
175
176#[derive(Serialize, Deserialize, Debug, Clone)]
178pub struct Usage{
179 pub prompt_tokens: u64,
180 pub completion_tokens: u64,
181 pub total_tokens: u64
182}
183
184#[async_trait]
185pub trait JsonRequest<TRes: DeserializeOwned>: Serialize + Sized + Sync{
186
187 const ENDPOINT: &'static str;
188
189 async fn run(&self, client:&OpenAiClient) -> Result<TRes>{
190 let final_url = client.url.to_owned()+Self::ENDPOINT;
191 let res = client.client.post(final_url)
192 .bearer_auth(client.key.clone())
193 .json(self)
194 .send()
195 .await?;
196 process_response::<TRes>(res).await
197 }
198}
199
200#[async_trait]
201pub trait ByUrlRequest<TRes: DeserializeOwned>:WithRefId<str>+Sync{
202
203 const ENDPOINT: &'static str;
204 const SUFFIX: &'static str;
205
206 fn builder(client:&OpenAiClient,final_url:String)->RequestBuilder{
207 client.client.get(final_url)
208 }
209
210 async fn run(&self, client:&OpenAiClient)-> Result<TRes>{
211 let final_url = client.url.to_owned()+Self::ENDPOINT+self.id()+Self::SUFFIX;
212 let res = Self::builder(client,final_url)
213 .bearer_auth(client.key.clone())
214 .send()
215 .await?;
216 process_response::<TRes>(res).await
217 }
218}
219
220
221#[async_trait]
222pub trait GetRequest:DeserializeOwned {
223
224 const ENDPOINT: &'static str;
225
226 async fn get(client:&OpenAiClient)-> Result<Self>{
227 let final_url = client.url.to_owned()+Self::ENDPOINT;
228 let res = client.client.get(final_url)
229 .bearer_auth(client.key.clone())
230 .send()
231 .await?;
232 process_response::<Self>(res).await
233 }
234}
235
236#[async_trait]
237pub trait FormRequest<TRes: DeserializeOwned> : AsyncTryInto<multipart::Form>+Clone+Sync+Send {
238
239 const ENDPOINT: &'static str;
240
241 async fn get_response(&self,
242 client:&Client,
243 final_url:String,
244 key:&str
245 ) -> Result<Response> {
246 client.post(final_url)
247 .bearer_auth(key.clone())
248 .multipart(AsyncTryInto::try_into(self.clone()).await?)
249 .send()
250 .await.map_err(anyhow::Error::new)
251 }
252
253 async fn run(&self, client:&OpenAiClient)-> Result<TRes>{
254 let final_url = client.url.to_owned()+Self::ENDPOINT;
255 let res = self.get_response(&client.client,final_url,&client.key).await?;
256 process_response::<TRes>(res).await
257 }
258}
259
260#[async_trait(?Send)]
261pub trait DownloadRequest: WithRefId<str>{
262
263 const ENDPOINT: &'static str;
264 const SUFFIX: &'static str = "";
265
266 async fn download(&self, client:&OpenAiClient) -> Result<Pin<Box<dyn Stream<Item=Result<Bytes, reqwest::Error>>>>>{
267 let final_url = client.url.to_owned()+Self::ENDPOINT+self.id()+Self::SUFFIX;
268 let res = client.client.get(final_url)
269 .bearer_auth(client.key.clone())
270 .send()
271 .await?;
272 let code = res.error_for_status_ref();
273 return match code {
274 Ok(_) => Ok(Box::pin(res.bytes_stream())),
275 Err(err) =>
276 Err(Error {
277 response: res.json::<ApiError>().await?,
278 inner: err
279 })?
280 }
281 }
282
283 async fn download_to_file(&self, client:&OpenAiClient, target_path:&str) -> Result<()>{
284 let file = File::create(target_path).map_err(anyhow::Error::new);
285 let stream = self.download(client);
286 let (mut file, mut stream) = try_join!(file, stream)?;
287 while let Some(chunk) = stream.next().await {
288 file.write_all(&chunk?).await?;
289 }
290 Ok(())
291 }
292
293}
294
295pub(crate) async fn process_response<T:DeserializeOwned>(response: Response) ->Result<T>{
296 let code = response.error_for_status_ref();
297 match code {
298 Ok(_) =>{
299 let full = response.text().await?;
300 dbg!(&full);
301 serde_json::from_str(&full)
302 .map_err(|err| anyhow::Error::new(err).context(full))
303 }
304 Err(err) =>
305 Err(Error {
306 response: response.json::<ApiError>().await?,
307 inner: err
308 })?
309 }
310}
311
312
313pub(crate) async fn process_text_response(response: Response) ->Result<String>{
314 let code = response.error_for_status_ref();
315 match code {
316 Ok(_) =>{
317 response.text().await.map_err(anyhow::Error::new)
318 }
319 Err(err) =>
320 Err(Error {
321 response: response.json::<ApiError>().await?,
322 inner: err
323 })?
324 }
325}
326
327
328
329
330pub(crate) async fn file_to_part(path: &PathBuf) -> io::Result<Part> {
331 let name = path.file_name()
332 .ok_or(io::Error::new(io::ErrorKind::InvalidInput,"filename is not full"))?
333 .to_str()
334 .ok_or(io::Error::new(io::ErrorKind::InvalidData,"non unicode filename"))?
335 .to_owned();
336 let file = File::open(path).await?;
337 let size = file.metadata().await?.len();
338 let stream = FramedRead::new(file, BytesCodec::new());
339 let body = Body::wrap_stream(stream);
340 Ok(Part::stream_with_length(body,size).file_name(name))
341}