openai_rust2/
lib.rs

1pub extern crate futures_util;
2use anyhow::{anyhow, Result};
3use base64::Engine as _;
4use lazy_static::lazy_static;
5use reqwest;
6use serde::{Deserialize, Serialize};
7
8lazy_static! {
9    static ref DEFAULT_BASE_URL: reqwest::Url =
10        reqwest::Url::parse("https://api.openai.com/v1/models").unwrap();
11}
12
13pub struct Client {
14    req_client: reqwest::Client,
15    key: String,
16    base_url: reqwest::Url,
17}
18
19pub mod chat;
20pub mod completions;
21pub mod edits;
22pub mod embeddings;
23pub mod images;
24pub mod models;
25
26impl Client {
27    pub fn new(api_key: &str) -> Client {
28        let req_client = reqwest::ClientBuilder::new().build().unwrap();
29        Client {
30            req_client,
31            key: api_key.to_owned(),
32            base_url: DEFAULT_BASE_URL.clone(),
33        }
34    }
35
36    pub fn new_with_client(api_key: &str, req_client: reqwest::Client) -> Client {
37        Client {
38            req_client,
39            key: api_key.to_owned(),
40            base_url: DEFAULT_BASE_URL.clone(),
41        }
42    }
43
44    pub fn new_with_base_url(api_key: &str, base_url: &str) -> Client {
45        let req_client = reqwest::ClientBuilder::new().build().unwrap();
46        let base_url = reqwest::Url::parse(base_url).unwrap();
47        Client {
48            req_client,
49            key: api_key.to_owned(),
50            base_url,
51        }
52    }
53
54    pub fn new_with_client_and_base_url(
55        api_key: &str,
56        req_client: reqwest::Client,
57        base_url: &str,
58    ) -> Client {
59        Client {
60            req_client,
61            key: api_key.to_owned(),
62            base_url: reqwest::Url::parse(base_url).unwrap(),
63        }
64    }
65
66    pub async fn list_models(
67        &self,
68        opt_url_path: Option<String>,
69    ) -> Result<Vec<models::Model>, anyhow::Error> {
70        let mut url = self.base_url.clone();
71        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/models")));
72
73        let res = self
74            .req_client
75            .get(url)
76            .bearer_auth(&self.key)
77            .send()
78            .await?;
79
80        if res.status() == 200 {
81            Ok(res.json::<models::ListModelsResponse>().await?.data)
82        } else {
83            Err(anyhow!(res.text().await?))
84        }
85    }
86
87    pub async fn create_chat(
88        &self,
89        args: chat::ChatArguments,
90        opt_url_path: Option<String>,
91    ) -> Result<chat::ChatCompletion, anyhow::Error> {
92        let mut url = self.base_url.clone();
93        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/chat/completions")));
94
95        let res = self
96            .req_client
97            .post(url)
98            .bearer_auth(&self.key)
99            .json(&args)
100            .send()
101            .await?;
102
103        if res.status() == 200 {
104            Ok(res.json().await?)
105        } else {
106            Err(anyhow!(res.text().await?))
107        }
108    }
109
110    pub async fn create_chat_stream(
111        &self,
112        args: chat::ChatArguments,
113        opt_url_path: Option<String>,
114    ) -> Result<chat::stream::ChatCompletionChunkStream> {
115        let mut url = self.base_url.clone();
116        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/chat/completions")));
117
118        let mut args = args;
119        args.stream = Some(true);
120
121        let res = self
122            .req_client
123            .post(url)
124            .bearer_auth(&self.key)
125            .json(&args)
126            .send()
127            .await?;
128
129        if res.status() == 200 {
130            Ok(chat::stream::ChatCompletionChunkStream::new(Box::pin(
131                res.bytes_stream(),
132            )))
133        } else {
134            Err(anyhow!(res.text().await?))
135        }
136    }
137
138    pub async fn create_completion(
139        &self,
140        args: completions::CompletionArguments,
141        opt_url_path: Option<String>,
142    ) -> Result<completions::CompletionResponse> {
143        let mut url = self.base_url.clone();
144        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/completions")));
145
146        let res = self
147            .req_client
148            .post(url)
149            .bearer_auth(&self.key)
150            .json(&args)
151            .send()
152            .await?;
153
154        if res.status() == 200 {
155            Ok(res.json().await?)
156        } else {
157            Err(anyhow!(res.text().await?))
158        }
159    }
160
161    pub async fn create_embeddings(
162        &self,
163        args: embeddings::EmbeddingsArguments,
164        opt_url_path: Option<String>,
165    ) -> Result<embeddings::EmbeddingsResponse> {
166        let mut url = self.base_url.clone();
167        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/embeddings")));
168
169        let res = self
170            .req_client
171            .post(url)
172            .bearer_auth(&self.key)
173            .json(&args)
174            .send()
175            .await?;
176
177        if res.status() == 200 {
178            Ok(res.json().await?)
179        } else {
180            Err(anyhow!(res.text().await?))
181        }
182    }
183
184    pub async fn create_image_old(
185        &self,
186        args: images::ImageArguments,
187        opt_url_path: Option<String>,
188    ) -> Result<Vec<String>> {
189        let mut url = self.base_url.clone();
190        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/images/generations")));
191
192        let res = self
193            .req_client
194            .post(url)
195            .bearer_auth(&self.key)
196            .json(&args)
197            .send()
198            .await?;
199
200        if res.status() == 200 {
201            Ok(res
202                .json::<images::ImageResponse>()
203                .await?
204                .data
205                .iter()
206                .map(|o| match o {
207                    images::ImageObject::Url(s) => s.to_string(),
208                    images::ImageObject::Base64JSON(s) => s.to_string(),
209                })
210                .collect())
211        } else {
212            Err(anyhow!(res.text().await?))
213        }
214    }
215
216    pub async fn create_image(
217        &self,
218        args: images::ImageArguments,
219        opt_url_path: Option<String>,
220    ) -> Result<Vec<String>> {
221        let mut url = self.base_url.clone();
222        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/images/generations")));
223
224        let image_args = images::ImageArguments {
225            prompt: args.prompt,
226            model: Some("gpt-image-1".to_string()),
227            n: Some(1),
228            size: Some("1024x1024".to_string()),
229            quality: Some("auto".to_string()), // valid quality values are 'low', 'medium', 'high' and 'auto'
230            //TODO: Make this an enum parameter to create_image
231            user: None,
232        };
233
234        let res = self
235            .req_client
236            .post(url)
237            .bearer_auth(&self.key)
238            .json(&image_args)
239            .send()
240            .await?;
241
242        if res.status() == 200 {
243            Ok(res
244                .json::<images::ImageResponse>()
245                .await?
246                .data
247                .iter()
248                .map(|o| match o {
249                    images::ImageObject::Url(s) => s.to_string(),
250                    images::ImageObject::Base64JSON(s) => s.to_string(),
251                })
252                .collect())
253        } else {
254            Err(anyhow!(res.text().await?))
255        }
256    }
257}