openai_req/
lib.rs

1pub mod chat;
2pub mod completion;
3pub mod edit;
4pub mod image;
5pub mod files;
6pub mod embeddings;
7pub mod fine_tunes;
8pub mod moderations;
9pub mod audio;
10pub mod model;
11mod conversions;
12
13use anyhow::Result;
14use std::io;
15use std::path::PathBuf;
16use std::pin::Pin;
17use async_trait::async_trait;
18use bytes::Bytes;
19use futures_util::TryFutureExt;
20use reqwest::{Body, Client, multipart, RequestBuilder, Response};
21use reqwest::multipart::Part;
22use serde::de::DeserializeOwned;
23use tokio::fs::File;
24use tokio::io::AsyncWriteExt;
25use tokio::try_join;
26use tokio_stream::{Stream, StreamExt};
27use tokio_util::codec::{BytesCodec, FramedRead};
28use with_id::WithRefId;
29use std::fmt::{Debug, Display, Formatter};
30use serde::{Serialize, Deserialize};
31use crate::conversions::AsyncTryInto;
32
33
34/// This is main client structure required for all requests.
35/// It is passed as a reference parameter into all API operations.
36/// It is also holds actual `reqwest::Client` http client, that performs requests.
37/// # Usage example
38/// ```
39/// use openai_req::OpenAiClient;
40///
41/// let client = OpenAiClient::new("{YOUR_API_KEY}");
42/// ```
43#[derive(Debug, Clone)]
44pub struct OpenAiClient {
45    url:String,
46    key:String,
47    client:Client
48}
49
50impl OpenAiClient {
51
52    const URL: &'static str = "https://api.openai.com/v1";
53
54    ///simplest constructor, uses default https://api.openai.com/v1 url,
55    /// and creates new default client with connection pool for connections
56    pub fn new(key: &str)->Self{
57        let client = Client::new();
58        OpenAiClient::with_client(key,&client)
59    }
60
61    /// reqwest library recommends re-using single client,
62    /// so if you run access to multiple api-s, pass client into constructor.
63    /// Also use this constructor if you want to customize your client
64    /// (for example set different timeout, or use proxy)
65    pub fn with_client(key: &str, client: &Client)->Self{
66        OpenAiClient::with_url_and_client(key,OpenAiClient::URL,client)
67    }
68
69
70    ///if you want to change base url from https://api.openai.com/v1 to something else - you can
71    pub fn with_url(key: &str, url: &str) -> Self {
72        let client = Client::new();
73        OpenAiClient::with_url_and_client(key,url,&client)
74    }
75
76
77    /// this constructor allows you to customise everything:  client,
78    /// key and base url for all requests
79    pub fn with_url_and_client(key: &str, url: &str, client: &Client)->Self{
80        OpenAiClient {
81            url: url.to_string(),
82            key: key.to_string(),
83            client: client.clone()
84        }
85    }
86}
87
88///common error type used by api client traits, wraps underlying reqwest::Error,
89///but also tries to provide response body, so error is easier to debug
90#[derive(Debug)]
91pub struct Error{
92    pub(crate) response:ApiError,
93    pub(crate) inner:reqwest::Error
94}
95
96impl Display for Error {
97    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
98        write!(f,"{}",self.response)
99    }
100}
101
102impl std::error::Error for Error {
103    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
104        Some(&self.inner)
105    }
106}
107
108
109///structure returned by OpenAI for errors
110#[derive(Serialize, Deserialize, Debug, Clone)]
111pub struct ApiError {
112    pub error: ApiErrorDetails
113}
114
115#[derive(Serialize, Deserialize, Debug, Clone)]
116#[serde(rename(serialize = "error"))]
117#[serde(rename(deserialize = "error"))]
118pub struct ApiErrorDetails {
119    pub message: String,
120    #[serde(rename = "type")]
121    pub kind: String,
122    pub param: Option<String>,
123    pub code: Option<String>
124}
125
126impl Display for ApiError{
127    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
128        match &self.error.param {
129            None => match &self.error.code {
130                None => write!(f,"{}",self.error.message),
131                Some(code) => write!(f,"{}, code:{}",self.error.message,code)
132            }
133            Some(param) => match &self.error.code {
134                None => write!(f,"{}, param:{}",self.error.message,param),
135                Some(code) => write!(f,"{}, param:{}, code: {}",self.error.message,param,code)
136            }
137        }
138    }
139}
140
141///enum used by different requests,
142/// it is common for apis ot take either single string or array of tokens
143#[derive(Clone, Serialize, Deserialize, Debug)]
144#[serde(untagged)]
145pub enum Input {
146    String(String),
147    StringArray(Vec<String>)
148}
149
150impl From<String> for Input{
151    fn from(value:String) -> Self {
152        Input::String(value)
153    }
154}
155
156impl From<&str> for Input{
157    fn from(value:&str) -> Self {
158        Input::String(value.to_string())
159    }
160}
161
162impl From<Vec<String>> for Input{
163    fn from(value: Vec<String>) -> Self {
164        Input::StringArray(value)
165    }
166}
167
168///common response used by multiple delete API-s
169#[derive(Serialize, Deserialize, Debug, Clone)]
170pub struct DeleteResponse {
171    pub id: String,
172    pub object: String,
173    pub deleted: bool,
174}
175
176///common struct that comes up in responses
177#[derive(Serialize, Deserialize, Debug, Clone)]
178pub struct Usage{
179    pub prompt_tokens: u64,
180    pub completion_tokens: u64,
181    pub total_tokens: u64
182}
183
184#[async_trait]
185pub trait JsonRequest<TRes: DeserializeOwned>: Serialize + Sized + Sync{
186
187    const ENDPOINT: &'static str;
188
189    async fn run(&self, client:&OpenAiClient) -> Result<TRes>{
190        let final_url = client.url.to_owned()+Self::ENDPOINT;
191        let res = client.client.post(final_url)
192            .bearer_auth(client.key.clone())
193            .json(self)
194            .send()
195            .await?;
196        process_response::<TRes>(res).await
197    }
198}
199
200#[async_trait]
201pub trait ByUrlRequest<TRes: DeserializeOwned>:WithRefId<str>+Sync{
202
203    const ENDPOINT: &'static str;
204    const SUFFIX: &'static str;
205
206    fn builder(client:&OpenAiClient,final_url:String)->RequestBuilder{
207        client.client.get(final_url)
208    }
209
210    async fn run(&self, client:&OpenAiClient)-> Result<TRes>{
211        let final_url = client.url.to_owned()+Self::ENDPOINT+self.id()+Self::SUFFIX;
212        let res = Self::builder(client,final_url)
213            .bearer_auth(client.key.clone())
214            .send()
215            .await?;
216        process_response::<TRes>(res).await
217    }
218}
219
220
221#[async_trait]
222pub trait GetRequest:DeserializeOwned {
223
224    const ENDPOINT: &'static str;
225
226    async fn get(client:&OpenAiClient)-> Result<Self>{
227        let final_url = client.url.to_owned()+Self::ENDPOINT;
228        let res = client.client.get(final_url)
229            .bearer_auth(client.key.clone())
230            .send()
231            .await?;
232        process_response::<Self>(res).await
233    }
234}
235
236#[async_trait]
237pub trait FormRequest<TRes: DeserializeOwned> : AsyncTryInto<multipart::Form>+Clone+Sync+Send {
238
239    const ENDPOINT: &'static str;
240
241    async fn get_response(&self,
242                          client:&Client,
243                          final_url:String,
244                          key:&str
245    ) -> Result<Response> {
246        client.post(final_url)
247            .bearer_auth(key.clone())
248            .multipart(AsyncTryInto::try_into(self.clone()).await?)
249            .send()
250            .await.map_err(anyhow::Error::new)
251    }
252
253    async fn run(&self, client:&OpenAiClient)-> Result<TRes>{
254        let final_url =  client.url.to_owned()+Self::ENDPOINT;
255        let res = self.get_response(&client.client,final_url,&client.key).await?;
256        process_response::<TRes>(res).await
257    }
258}
259
260#[async_trait(?Send)]
261pub trait DownloadRequest: WithRefId<str>{
262
263    const ENDPOINT: &'static str;
264    const SUFFIX: &'static str = "";
265
266    async fn download(&self, client:&OpenAiClient) -> Result<Pin<Box<dyn Stream<Item=Result<Bytes, reqwest::Error>>>>>{
267        let final_url = client.url.to_owned()+Self::ENDPOINT+self.id()+Self::SUFFIX;
268        let res = client.client.get(final_url)
269            .bearer_auth(client.key.clone())
270            .send()
271            .await?;
272        let code = res.error_for_status_ref();
273        return match code {
274            Ok(_) => Ok(Box::pin(res.bytes_stream())),
275            Err(err) =>
276                Err(Error {
277                    response: res.json::<ApiError>().await?,
278                    inner: err
279                })?
280        }
281    }
282
283    async fn download_to_file(&self, client:&OpenAiClient, target_path:&str) -> Result<()>{
284        let file = File::create(target_path).map_err(anyhow::Error::new);
285        let stream = self.download(client);
286        let (mut file, mut stream) = try_join!(file, stream)?;
287        while let Some(chunk) = stream.next().await {
288            file.write_all(&chunk?).await?;
289        }
290        Ok(())
291    }
292
293}
294
295pub(crate) async fn process_response<T:DeserializeOwned>(response: Response) ->Result<T>{
296    let code = response.error_for_status_ref();
297    match code {
298        Ok(_) =>{
299            let full = response.text().await?;
300            dbg!(&full);
301            serde_json::from_str(&full)
302                .map_err(|err| anyhow::Error::new(err).context(full))
303        }
304        Err(err) =>
305            Err(Error {
306                response: response.json::<ApiError>().await?,
307                inner: err
308            })?
309    }
310}
311
312
313pub(crate) async fn process_text_response(response: Response) ->Result<String>{
314    let code = response.error_for_status_ref();
315    match code {
316        Ok(_) =>{
317            response.text().await.map_err(anyhow::Error::new)
318        }
319        Err(err) =>
320            Err(Error {
321                response: response.json::<ApiError>().await?,
322                inner: err
323            })?
324    }
325}
326
327
328
329
330pub(crate) async fn file_to_part(path: &PathBuf) -> io::Result<Part> {
331    let name = path.file_name()
332        .ok_or(io::Error::new(io::ErrorKind::InvalidInput,"filename is not full"))?
333        .to_str()
334        .ok_or(io::Error::new(io::ErrorKind::InvalidData,"non unicode filename"))?
335        .to_owned();
336    let file = File::open(path).await?;
337    let size = file.metadata().await?.len();
338    let stream = FramedRead::new(file, BytesCodec::new());
339    let body = Body::wrap_stream(stream);
340    Ok(Part::stream_with_length(body,size).file_name(name))
341}