openai_rust2/lib.rs
1#![doc = include_str!("../README.md")]
2//#![feature(str_split_remainder)]
3use anyhow::{anyhow, Result};
4use lazy_static::lazy_static;
5use reqwest;
6
7pub extern crate futures_util;
8
9lazy_static! {
10 static ref DEFAULT_BASE_URL: reqwest::Url =
11 reqwest::Url::parse("https://api.openai.com/v1/models").unwrap();
12}
13
14/// This is the main interface to interact with the api.
15pub struct Client {
16 req_client: reqwest::Client,
17 key: String,
18 base_url: reqwest::Url,
19}
20
21pub mod chat;
22pub mod completions;
23pub mod edits;
24pub mod embeddings;
25pub mod images;
26pub mod models;
27
28impl Client {
29 /// Create a new client.
30 /// This will automatically build a [reqwest::Client] used internally.
31 pub fn new(api_key: &str) -> Client {
32 let req_client = reqwest::ClientBuilder::new().build().unwrap();
33 Client {
34 req_client,
35 key: api_key.to_owned(),
36 base_url: DEFAULT_BASE_URL.clone(),
37 }
38 }
39
40 /// Build a client using your own [reqwest::Client].
41 pub fn new_with_client(api_key: &str, req_client: reqwest::Client) -> Client {
42 Client {
43 req_client,
44 key: api_key.to_owned(),
45 base_url: DEFAULT_BASE_URL.clone(),
46 }
47 }
48
49 // Build a client with a custom base url. The default is `https://api.openai.com/v1/models`
50 pub fn new_with_base_url(api_key: &str, base_url: &str) -> Client {
51 let req_client = reqwest::ClientBuilder::new().build().unwrap();
52 let base_url = reqwest::Url::parse(base_url).unwrap();
53 Client {
54 req_client,
55 key: api_key.to_owned(),
56 base_url,
57 }
58 }
59
60 pub fn new_with_client_and_base_url(
61 api_key: &str,
62 req_client: reqwest::Client,
63 base_url: &str,
64 ) -> Client {
65 Client {
66 req_client,
67 key: api_key.to_owned(),
68 base_url: reqwest::Url::parse(base_url).unwrap(),
69 }
70 }
71
72 /// List and describe the various models available in the API. You can refer to the [Models](https://platform.openai.com/docs/models) documentation to understand what models are available and the differences between them.
73 ///
74 /// ```
75 /// # use openai_rust2 as openai_rust;
76 /// # let api_key = "";
77 /// # tokio_test::block_on(async {
78 /// let client = openai_rust::Client::new(api_key);
79 /// let models = client.list_models(None).await.unwrap();
80 /// # })
81 /// ```
82 ///
83 /// See <https://platform.openai.com/docs/api-reference/models/list>.
84 pub async fn list_models(
85 &self,
86 opt_url_path: Option<String>,
87 ) -> Result<Vec<models::Model>, anyhow::Error> {
88 let mut url = self.base_url.clone();
89 url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/models")));
90
91 let res = self
92 .req_client
93 .get(url)
94 .bearer_auth(&self.key)
95 .send()
96 .await?;
97
98 if res.status() == 200 {
99 Ok(res.json::<models::ListModelsResponse>().await?.data)
100 } else {
101 Err(anyhow!(res.text().await?))
102 }
103 }
104
105 /// Given a list of messages comprising a conversation, the model will return a response.
106 ///
107 /// See <https://platform.openai.com/docs/api-reference/chat>.
108 /// ```
109 /// # use tokio_test;
110 /// # tokio_test::block_on(async {
111 /// # use openai_rust2 as openai_rust;
112 /// # let api_key = "";
113 /// let client = openai_rust::Client::new(api_key);
114 /// let args = openai_rust::chat::ChatArguments::new("gpt-3.5-turbo", vec![
115 /// openai_rust::chat::Message {
116 /// role: "user".to_owned(),
117 /// content: "Hello GPT!".to_owned(),
118 /// }
119 /// ]);
120 /// let res = client.create_chat(args, None).await.unwrap();
121 /// println!("{}", res.choices[0].message.content);
122 /// # })
123 /// ```
124 pub async fn create_chat(
125 &self,
126 args: chat::ChatArguments,
127 opt_url_path: Option<String>,
128 ) -> Result<chat::ChatCompletion, anyhow::Error> {
129 let mut url = self.base_url.clone();
130 url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/chat/completions")));
131
132 let res = self
133 .req_client
134 .post(url)
135 .bearer_auth(&self.key)
136 .json(&args)
137 .send()
138 .await?;
139
140 if res.status() == 200 {
141 Ok(res.json().await?)
142 } else {
143 Err(anyhow!(res.text().await?))
144 }
145 }
146
147 /// Like [Client::create_chat] but with streaming.
148 ///
149 /// See <https://platform.openai.com/docs/api-reference/chat>.
150 ///
151 /// This method will return a stream of [chat::stream::ChatCompletionChunk]s. Use with [futures_util::StreamExt::next].
152 ///
153 /// ```
154 /// # use tokio_test;
155 /// # tokio_test::block_on(async {
156 /// # use openai_rust2 as openai_rust;
157 /// # use std::io::Write;
158 /// # let client = openai_rust::Client::new("");
159 /// # let args = openai_rust::chat::ChatArguments::new("gpt-3.5-turbo", vec![
160 /// # openai_rust::chat::Message {
161 /// # role: "user".to_owned(),
162 /// # content: "Hello GPT!".to_owned(),
163 /// # }
164 /// # ]);
165 /// use openai_rust::futures_util::StreamExt;
166 /// let mut res = client.create_chat_stream(args, None).await.unwrap();
167 /// while let Some(chunk) = res.next().await {
168 /// print!("{}", chunk.unwrap());
169 /// std::io::stdout().flush().unwrap();
170 /// }
171 /// # })
172 /// ```
173 ///
174 pub async fn create_chat_stream(
175 &self,
176 args: chat::ChatArguments,
177 opt_url_path: Option<String>,
178 ) -> Result<chat::stream::ChatCompletionChunkStream> {
179 let mut url = self.base_url.clone();
180 url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/chat/completions")));
181
182 // Enable streaming
183 let mut args = args;
184 args.stream = Some(true);
185
186 let res = self
187 .req_client
188 .post(url)
189 .bearer_auth(&self.key)
190 .json(&args)
191 .send()
192 .await?;
193
194 if res.status() == 200 {
195 Ok(chat::stream::ChatCompletionChunkStream::new(Box::pin(
196 res.bytes_stream(),
197 )))
198 } else {
199 Err(anyhow!(res.text().await?))
200 }
201 }
202
203 /// Given a prompt, the model will return one or more predicted completions, and can also return the probabilities of alternative tokens at each position.
204 ///
205 /// See <https://platform.openai.com/docs/api-reference/completions>
206 ///
207 /// ```
208 /// # use openai_rust2 as openai_rust;
209 /// # use openai_rust::*;
210 /// # use tokio_test;
211 /// # tokio_test::block_on(async {
212 /// # let api_key = "";
213 /// let c = Client::new(api_key);
214 /// let args = completions::CompletionArguments::new("text-davinci-003", "The quick brown fox".to_owned());
215 /// println!("{}", c.create_completion(args, None).await.unwrap().choices[0].text);
216 /// # })
217 /// ```
218 pub async fn create_completion(
219 &self,
220 args: completions::CompletionArguments,
221 opt_url_path: Option<String>,
222 ) -> Result<completions::CompletionResponse> {
223 let mut url = self.base_url.clone();
224 url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/completions")));
225
226 let res = self
227 .req_client
228 .post(url)
229 .bearer_auth(&self.key)
230 .json(&args)
231 .send()
232 .await?;
233
234 if res.status() == 200 {
235 Ok(res.json().await?)
236 } else {
237 Err(anyhow!(res.text().await?))
238 }
239 }
240
241 /// Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms.
242 ///
243 /// See <https://platform.openai.com/docs/api-reference/embeddings>
244 ///
245 /// ```
246 /// # use openai_rust2 as openai_rust;
247 /// # use tokio_test;
248 /// # tokio_test::block_on(async {
249 /// # let api_key = "";
250 /// let c = openai_rust::Client::new(api_key);
251 /// let args = openai_rust::embeddings::EmbeddingsArguments::new("text-embedding-ada-002", "The food was delicious and the waiter...".to_owned());
252 /// println!("{:?}", c.create_embeddings(args, None).await.unwrap().data);
253 /// # })
254 /// ```
255 ///
256 pub async fn create_embeddings(
257 &self,
258 args: embeddings::EmbeddingsArguments,
259 opt_url_path: Option<String>,
260 ) -> Result<embeddings::EmbeddingsResponse> {
261 let mut url = self.base_url.clone();
262 url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/embeddings")));
263
264 let res = self
265 .req_client
266 .post(url)
267 .bearer_auth(&self.key)
268 .json(&args)
269 .send()
270 .await?;
271
272 if res.status() == 200 {
273 Ok(res.json().await?)
274 } else {
275 Err(anyhow!(res.text().await?))
276 }
277 }
278
279 /// Creates an image given a prompt.
280 pub async fn create_image(
281 &self,
282 args: images::ImageArguments,
283 opt_url_path: Option<String>,
284 ) -> Result<Vec<String>> {
285 let mut url = self.base_url.clone();
286 url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/images/generations")));
287
288 let res = self
289 .req_client
290 .post(url)
291 .bearer_auth(&self.key)
292 .json(&args)
293 .send()
294 .await?;
295
296 if res.status() == 200 {
297 Ok(res
298 .json::<images::ImageResponse>()
299 .await?
300 .data
301 .iter()
302 .map(|o| match o {
303 images::ImageObject::Url(s) => s.to_string(),
304 images::ImageObject::Base64JSON(s) => s.to_string(),
305 })
306 .collect())
307 } else {
308 Err(anyhow!(res.text().await?))
309 }
310 }
311}