openai_rust2/
lib.rs

1pub extern crate futures_util;
2use anyhow::{anyhow, Result};
3use lazy_static::lazy_static;
4
5lazy_static! {
6    static ref DEFAULT_BASE_URL: reqwest::Url =
7        reqwest::Url::parse("https://api.openai.com/v1/models").unwrap();
8}
9
10pub struct Client {
11    req_client: reqwest::Client,
12    key: String,
13    base_url: reqwest::Url,
14}
15
16pub mod chat;
17pub mod completions;
18pub mod edits;
19pub mod embeddings;
20pub mod images;
21pub mod models;
22
23impl Client {
24    pub fn new(api_key: &str) -> Client {
25        let req_client = reqwest::ClientBuilder::new().build().unwrap();
26        Client {
27            req_client,
28            key: api_key.to_owned(),
29            base_url: DEFAULT_BASE_URL.clone(),
30        }
31    }
32
33    pub fn new_with_client(api_key: &str, req_client: reqwest::Client) -> Client {
34        Client {
35            req_client,
36            key: api_key.to_owned(),
37            base_url: DEFAULT_BASE_URL.clone(),
38        }
39    }
40
41    pub fn new_with_base_url(api_key: &str, base_url: &str) -> Client {
42        let req_client = reqwest::ClientBuilder::new().build().unwrap();
43        let base_url = reqwest::Url::parse(base_url).unwrap();
44        Client {
45            req_client,
46            key: api_key.to_owned(),
47            base_url,
48        }
49    }
50
51    pub fn new_with_client_and_base_url(
52        api_key: &str,
53        req_client: reqwest::Client,
54        base_url: &str,
55    ) -> Client {
56        Client {
57            req_client,
58            key: api_key.to_owned(),
59            base_url: reqwest::Url::parse(base_url).unwrap(),
60        }
61    }
62
63    pub async fn list_models(
64        &self,
65        opt_url_path: Option<String>,
66    ) -> Result<Vec<models::Model>, anyhow::Error> {
67        let mut url = self.base_url.clone();
68        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/models")));
69
70        let res = self
71            .req_client
72            .get(url)
73            .bearer_auth(&self.key)
74            .send()
75            .await?;
76
77        if res.status() == 200 {
78            Ok(res.json::<models::ListModelsResponse>().await?.data)
79        } else {
80            Err(anyhow!(res.text().await?))
81        }
82    }
83
84    pub async fn create_chat(
85        &self,
86        args: chat::ChatArguments,
87        opt_url_path: Option<String>,
88    ) -> Result<chat::ChatCompletion, anyhow::Error> {
89        let mut url = self.base_url.clone();
90        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/chat/completions")));
91
92        let res = self
93            .req_client
94            .post(url)
95            .bearer_auth(&self.key)
96            .json(&args)
97            .send()
98            .await?;
99
100        if res.status() == 200 {
101            Ok(res.json().await?)
102        } else {
103            Err(anyhow!(res.text().await?))
104        }
105    }
106
107    pub async fn create_chat_stream(
108        &self,
109        args: chat::ChatArguments,
110        opt_url_path: Option<String>,
111    ) -> Result<chat::stream::ChatCompletionChunkStream> {
112        let mut url = self.base_url.clone();
113        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/chat/completions")));
114
115        let mut args = args;
116        args.stream = Some(true);
117
118        let res = self
119            .req_client
120            .post(url)
121            .bearer_auth(&self.key)
122            .json(&args)
123            .send()
124            .await?;
125
126        if res.status() == 200 {
127            Ok(chat::stream::ChatCompletionChunkStream::new(Box::pin(
128                res.bytes_stream(),
129            )))
130        } else {
131            Err(anyhow!(res.text().await?))
132        }
133    }
134
135    pub async fn create_completion(
136        &self,
137        args: completions::CompletionArguments,
138        opt_url_path: Option<String>,
139    ) -> Result<completions::CompletionResponse> {
140        let mut url = self.base_url.clone();
141        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/completions")));
142
143        let res = self
144            .req_client
145            .post(url)
146            .bearer_auth(&self.key)
147            .json(&args)
148            .send()
149            .await?;
150
151        if res.status() == 200 {
152            Ok(res.json().await?)
153        } else {
154            Err(anyhow!(res.text().await?))
155        }
156    }
157
158    pub async fn create_embeddings(
159        &self,
160        args: embeddings::EmbeddingsArguments,
161        opt_url_path: Option<String>,
162    ) -> Result<embeddings::EmbeddingsResponse> {
163        let mut url = self.base_url.clone();
164        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/embeddings")));
165
166        let res = self
167            .req_client
168            .post(url)
169            .bearer_auth(&self.key)
170            .json(&args)
171            .send()
172            .await?;
173
174        if res.status() == 200 {
175            Ok(res.json().await?)
176        } else {
177            Err(anyhow!(res.text().await?))
178        }
179    }
180
181    pub async fn create_image_old(
182        &self,
183        args: images::ImageArguments,
184        opt_url_path: Option<String>,
185    ) -> Result<Vec<String>> {
186        let mut url = self.base_url.clone();
187        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/images/generations")));
188
189        let res = self
190            .req_client
191            .post(url)
192            .bearer_auth(&self.key)
193            .json(&args)
194            .send()
195            .await?;
196
197        if res.status() == 200 {
198            Ok(res
199                .json::<images::ImageResponse>()
200                .await?
201                .data
202                .iter()
203                .map(|o| match o {
204                    images::ImageObject::Url(s) => s.to_string(),
205                    images::ImageObject::Base64JSON(s) => s.to_string(),
206                })
207                .collect())
208        } else {
209            Err(anyhow!(res.text().await?))
210        }
211    }
212
213    pub async fn create_image(
214        &self,
215        args: images::ImageArguments,
216        opt_url_path: Option<String>,
217    ) -> Result<Vec<String>> {
218        let mut url = self.base_url.clone();
219        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/images/generations")));
220
221        let image_args = images::ImageArguments {
222            prompt: args.prompt,
223            model: Some("gpt-image-1".to_string()),
224            n: Some(1),
225            size: Some("1024x1024".to_string()),
226            quality: Some("auto".to_string()), // valid quality values are 'low', 'medium', 'high' and 'auto'
227            //TODO: Make this an enum parameter to create_image
228            user: None,
229        };
230
231        let res = self
232            .req_client
233            .post(url)
234            .bearer_auth(&self.key)
235            .json(&image_args)
236            .send()
237            .await?;
238
239        if res.status() == 200 {
240            Ok(res
241                .json::<images::ImageResponse>()
242                .await?
243                .data
244                .iter()
245                .map(|o| match o {
246                    images::ImageObject::Url(s) => s.to_string(),
247                    images::ImageObject::Base64JSON(s) => s.to_string(),
248                })
249                .collect())
250        } else {
251            Err(anyhow!(res.text().await?))
252        }
253    }
254}