openai_rust/
lib.rs

1#![doc = include_str!("../README.md")]
2//#![feature(str_split_remainder)]
3use anyhow::{anyhow, Result};
4use lazy_static::lazy_static;
5use reqwest;
6
7pub extern crate futures_util;
8
9lazy_static! {
10    static ref BASE_URL: reqwest::Url =
11        reqwest::Url::parse("https://api.openai.com/v1/models").unwrap();
12}
13
14/// This is the main interface to interact with the api.
15pub struct Client {
16    req_client: reqwest::Client,
17    key: String,
18}
19
20pub mod models;
21pub mod chat;
22pub mod completions;
23pub mod edits;
24pub mod embeddings;
25pub mod images;
26
27impl Client {
28    /// Create a new client.
29    /// This will automatically build a [reqwest::Client] used internally.
30    pub fn new(api_key: &str) -> Client {
31        let req_client = reqwest::ClientBuilder::new().build().unwrap();
32        Client {
33            req_client,
34            key: api_key.to_owned(),
35        }
36    }
37
38    /// Build a client using your own [reqwest::Client].
39    pub fn new_with_client(api_key: &str, req_client: reqwest::Client) -> Client {
40        Client {
41            req_client,
42            key: api_key.to_owned(),
43        }
44    }
45
46    /// List and describe the various models available in the API. You can refer to the [Models](https://platform.openai.com/docs/models) documentation to understand what models are available and the differences between them.
47    ///
48    /// ```no_run
49    /// # let api_key = "";
50    /// # tokio_test::block_on(async {
51    /// let client = openai_rust::Client::new(api_key);
52    /// let models = client.list_models().await.unwrap();
53    /// # })
54    /// ```
55    ///
56    /// See <https://platform.openai.com/docs/api-reference/models/list>.
57    pub async fn list_models(&self) -> Result<Vec<models::Model>, anyhow::Error> {
58        let mut url = BASE_URL.clone();
59        url.set_path("/v1/models");
60
61        let res = self
62            .req_client
63            .get(url)
64            .bearer_auth(&self.key)
65            .send()
66            .await?;
67
68        if res.status() == 200 {
69            Ok(res.json::<models::ListModelsResponse>().await?.data)
70        } else {
71            Err(anyhow!(res.text().await?))
72        }
73    }
74
75    /// Given a list of messages comprising a conversation, the model will return a response.
76    ///
77    /// See <https://platform.openai.com/docs/api-reference/chat>.
78    /// ```no_run
79    /// # use tokio_test;
80    /// # tokio_test::block_on(async {
81    /// # use openai_rust;
82    /// # let api_key = "";
83    /// let client = openai_rust::Client::new(api_key);
84    /// let args = openai_rust::chat::ChatArguments::new("gpt-3.5-turbo", vec![
85    ///    openai_rust::chat::Message {
86    ///        role: "user".to_owned(),
87    ///        content: "Hello GPT!".to_owned(),
88    ///    }
89    /// ]);
90    /// let res = client.create_chat(args).await.unwrap();
91    /// println!("{}", res.choices[0].message.content);
92    /// # })
93    /// ```
94    pub async fn create_chat(
95        &self,
96        args: chat::ChatArguments,
97    ) -> Result<chat::ChatCompletion, anyhow::Error> {
98        let mut url = BASE_URL.clone();
99        url.set_path("/v1/chat/completions");
100
101        let res = self
102            .req_client
103            .post(url)
104            .bearer_auth(&self.key)
105            .json(&args)
106            .send()
107            .await?;
108
109        if res.status() == 200 {
110            Ok(res.json().await?)
111        } else {
112            Err(anyhow!(res.text().await?))
113        }
114    }
115
116    /// Like [Client::create_chat] but with streaming.
117    ///
118    /// See <https://platform.openai.com/docs/api-reference/chat>.
119    ///
120    /// This method will return a stream of [chat::stream::ChatCompletionChunk]s. Use with [futures_util::StreamExt::next].
121    ///
122    /// ```no_run
123    /// # use tokio_test;
124    /// # tokio_test::block_on(async {
125    /// # use openai_rust;
126    /// # use std::io::Write;
127    /// # let client = openai_rust::Client::new("");
128    /// # let args = openai_rust::chat::ChatArguments::new("gpt-3.5-turbo", vec![
129    /// #    openai_rust::chat::Message {
130    /// #        role: "user".to_owned(),
131    /// #        content: "Hello GPT!".to_owned(),
132    /// #    }
133    /// # ]);
134    /// use openai_rust::futures_util::StreamExt;
135    /// let mut res = client.create_chat_stream(args).await.unwrap();
136    /// while let Some(chunk) = res.next().await {
137    ///     print!("{}", chunk.unwrap());
138    ///     std::io::stdout().flush().unwrap();
139    /// }
140    /// # })
141    /// ```
142    ///
143    pub async fn create_chat_stream(
144        &self,
145        args: chat::ChatArguments,
146    ) -> Result<chat::stream::ChatCompletionChunkStream> {
147        let mut url = BASE_URL.clone();
148        url.set_path("/v1/chat/completions");
149
150        // Enable streaming
151        let mut args = args;
152        args.stream = Some(true);
153
154        let res = self
155            .req_client
156            .post(url)
157            .bearer_auth(&self.key)
158            .json(&args)
159            .send()
160            .await?;
161
162        if res.status() == 200 {
163            Ok(chat::stream::ChatCompletionChunkStream::new(Box::pin(res.bytes_stream())))
164        } else {
165            Err(anyhow!(res.text().await?))
166        }
167    }
168
169    /// Given a prompt, the model will return one or more predicted completions, and can also return the probabilities of alternative tokens at each position.
170    ///
171    /// See <https://platform.openai.com/docs/api-reference/completions>
172    ///
173    /// ```no_run
174    /// # use openai_rust::*;
175    /// # use tokio_test;
176    /// # tokio_test::block_on(async {
177    /// # let api_key = "";
178    /// let c = openai_rust::Client::new(api_key);
179    /// let args = openai_rust::completions::CompletionArguments::new("text-davinci-003", "The quick brown fox".to_owned());
180    /// println!("{}", c.create_completion(args).await.unwrap().choices[0].text);
181    /// # })
182    /// ```
183    pub async fn create_completion(
184        &self,
185        args: completions::CompletionArguments,
186    ) -> Result<completions::CompletionResponse> {
187        let mut url = BASE_URL.clone();
188        url.set_path("/v1/completions");
189
190        let res = self
191            .req_client
192            .post(url)
193            .bearer_auth(&self.key)
194            .json(&args)
195            .send()
196            .await?;
197
198        if res.status() == 200 {
199            Ok(res.json().await?)
200        } else {
201            Err(anyhow!(res.text().await?))
202        }
203    }
204
205    /// Given a prompt and an instruction, the model will return an edited version of the prompt.
206    ///
207    /// See <https://platform.openai.com/docs/api-reference/edits>
208    ///
209    /// ```no_run
210    /// # use openai_rust;
211    /// # use tokio_test;
212    /// # tokio_test::block_on(async {
213    /// # let api_key = "";
214    /// let c = openai_rust::Client::new(api_key);
215    /// let args = openai_rust::edits::EditArguments::new("text-davinci-edit-001", "The quick brown fox".to_owned(), "Complete this sentence.".to_owned());
216    /// println!("{}", c.create_edit(args).await.unwrap().to_string());
217    /// # })
218    /// ```
219    ///
220    #[deprecated = "Use the chat api instead"]
221    pub async fn create_edit(&self, args: edits::EditArguments) -> Result<edits::EditResponse> {
222        let mut url = BASE_URL.clone();
223        url.set_path("/v1/edits");
224
225        let res = self
226            .req_client
227            .post(url)
228            .bearer_auth(&self.key)
229            .json(&args)
230            .send()
231            .await?;
232
233        if res.status() == 200 {
234            Ok(res.json().await?)
235        } else {
236            Err(anyhow!(res.text().await?))
237        }
238    }
239
240    /// Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms.
241    ///
242    /// See <https://platform.openai.com/docs/api-reference/embeddings>
243    ///
244    /// ```no_run
245    /// # use openai_rust;
246    /// # use tokio_test;
247    /// # tokio_test::block_on(async {
248    /// # let api_key = "";
249    /// let c = openai_rust::Client::new(api_key);
250    /// let args = openai_rust::embeddings::EmbeddingsArguments::new("text-embedding-ada-002", "The food was delicious and the waiter...".to_owned());
251    /// println!("{:?}", c.create_embeddings(args).await.unwrap().data);
252    /// # })
253    /// ```
254    ///
255    pub async fn create_embeddings(
256        &self,
257        args: embeddings::EmbeddingsArguments,
258    ) -> Result<embeddings::EmbeddingsResponse> {
259        let mut url = BASE_URL.clone();
260        url.set_path("/v1/embeddings");
261
262        let res = self
263            .req_client
264            .post(url)
265            .bearer_auth(&self.key)
266            .json(&args)
267            .send()
268            .await?;
269
270        if res.status() == 200 {
271            Ok(res.json().await?)
272        } else {
273            Err(anyhow!(res.text().await?))
274        }
275    }
276
277    /// Creates an image given a prompt.
278    pub async fn create_image(
279        &self,
280        args: images::ImageArguments,
281    ) -> Result<Vec<String>> {
282        let mut url = BASE_URL.clone();
283        url.set_path("/v1/images/generations");
284
285        let res = self
286            .req_client
287            .post(url)
288            .bearer_auth(&self.key)
289            .json(&args)
290            .send()
291            .await?;
292
293        if res.status() == 200 {
294            Ok(res.json::<images::ImageResponse>().await?.data.iter().map(|o|
295                match o {
296                    images::ImageObject::Url(s) => s.to_string(),
297                    images::ImageObject::Base64JSON(s) => s.to_string(),
298                }
299            ).collect())
300        } else {
301            Err(anyhow!(res.text().await?))
302        }
303    }
304}