openai_rust2/
lib.rs

1#![doc = include_str!("../README.md")]
2//#![feature(str_split_remainder)]
3use anyhow::{anyhow, Result};
4use lazy_static::lazy_static;
5use reqwest;
6
7pub extern crate futures_util;
8
9lazy_static! {
10    static ref DEFAULT_BASE_URL: reqwest::Url =
11        reqwest::Url::parse("https://api.openai.com/v1/models").unwrap();
12}
13
14/// This is the main interface to interact with the api.
15pub struct Client {
16    req_client: reqwest::Client,
17    key: String,
18    base_url: reqwest::Url,
19}
20
21pub mod chat;
22pub mod completions;
23pub mod edits;
24pub mod embeddings;
25pub mod images;
26pub mod models;
27
28impl Client {
29    /// Create a new client.
30    /// This will automatically build a [reqwest::Client] used internally.
31    pub fn new(api_key: &str) -> Client {
32        let req_client = reqwest::ClientBuilder::new().build().unwrap();
33        Client {
34            req_client,
35            key: api_key.to_owned(),
36            base_url: DEFAULT_BASE_URL.clone(),
37        }
38    }
39
40    /// Build a client using your own [reqwest::Client].
41    pub fn new_with_client(api_key: &str, req_client: reqwest::Client) -> Client {
42        Client {
43            req_client,
44            key: api_key.to_owned(),
45            base_url: DEFAULT_BASE_URL.clone(),
46        }
47    }
48
49    // Build a client with a custom base url. The default is `https://api.openai.com/v1/models`
50    pub fn new_with_base_url(api_key: &str, base_url: &str) -> Client {
51        let req_client = reqwest::ClientBuilder::new().build().unwrap();
52        let base_url = reqwest::Url::parse(base_url).unwrap();
53        Client {
54            req_client,
55            key: api_key.to_owned(),
56            base_url,
57        }
58    }
59
60    pub fn new_with_client_and_base_url(
61        api_key: &str,
62        req_client: reqwest::Client,
63        base_url: &str,
64    ) -> Client {
65        Client {
66            req_client,
67            key: api_key.to_owned(),
68            base_url: reqwest::Url::parse(base_url).unwrap(),
69        }
70    }
71
72    /// List and describe the various models available in the API. You can refer to the [Models](https://platform.openai.com/docs/models) documentation to understand what models are available and the differences between them.
73    ///
74    /// ```
75    /// # use openai_rust2 as openai_rust;
76    /// # let api_key = "";
77    /// # tokio_test::block_on(async {
78    /// let client = openai_rust::Client::new(api_key);
79    /// let models = client.list_models(None).await.unwrap();
80    /// # })
81    /// ```
82    ///
83    /// See <https://platform.openai.com/docs/api-reference/models/list>.
84    pub async fn list_models(
85        &self,
86        opt_url_path: Option<String>,
87    ) -> Result<Vec<models::Model>, anyhow::Error> {
88        let mut url = self.base_url.clone();
89        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/models")));
90
91        let res = self
92            .req_client
93            .get(url)
94            .bearer_auth(&self.key)
95            .send()
96            .await?;
97
98        if res.status() == 200 {
99            Ok(res.json::<models::ListModelsResponse>().await?.data)
100        } else {
101            Err(anyhow!(res.text().await?))
102        }
103    }
104
105    /// Given a list of messages comprising a conversation, the model will return a response.
106    ///
107    /// See <https://platform.openai.com/docs/api-reference/chat>.
108    /// ```
109    /// # use tokio_test;
110    /// # tokio_test::block_on(async {
111    /// # use openai_rust2 as openai_rust;
112    /// # let api_key = "";
113    /// let client = openai_rust::Client::new(api_key);
114    /// let args = openai_rust::chat::ChatArguments::new("gpt-3.5-turbo", vec![
115    ///    openai_rust::chat::Message {
116    ///        role: "user".to_owned(),
117    ///        content: "Hello GPT!".to_owned(),
118    ///    }
119    /// ]);
120    /// let res = client.create_chat(args, None).await.unwrap();
121    /// println!("{}", res.choices[0].message.content);
122    /// # })
123    /// ```
124    pub async fn create_chat(
125        &self,
126        args: chat::ChatArguments,
127        opt_url_path: Option<String>,
128    ) -> Result<chat::ChatCompletion, anyhow::Error> {
129        let mut url = self.base_url.clone();
130        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/chat/completions")));
131
132        let res = self
133            .req_client
134            .post(url)
135            .bearer_auth(&self.key)
136            .json(&args)
137            .send()
138            .await?;
139
140        if res.status() == 200 {
141            Ok(res.json().await?)
142        } else {
143            Err(anyhow!(res.text().await?))
144        }
145    }
146
147    /// Like [Client::create_chat] but with streaming.
148    ///
149    /// See <https://platform.openai.com/docs/api-reference/chat>.
150    ///
151    /// This method will return a stream of [chat::stream::ChatCompletionChunk]s. Use with [futures_util::StreamExt::next].
152    ///
153    /// ```
154    /// # use tokio_test;
155    /// # tokio_test::block_on(async {
156    /// # use openai_rust2 as openai_rust;
157    /// # use std::io::Write;
158    /// # let client = openai_rust::Client::new("");
159    /// # let args = openai_rust::chat::ChatArguments::new("gpt-3.5-turbo", vec![
160    /// #    openai_rust::chat::Message {
161    /// #        role: "user".to_owned(),
162    /// #        content: "Hello GPT!".to_owned(),
163    /// #    }
164    /// # ]);
165    /// use openai_rust::futures_util::StreamExt;
166    /// let mut res = client.create_chat_stream(args, None).await.unwrap();
167    /// while let Some(chunk) = res.next().await {
168    ///     print!("{}", chunk.unwrap());
169    ///     std::io::stdout().flush().unwrap();
170    /// }
171    /// # })
172    /// ```
173    ///
174    pub async fn create_chat_stream(
175        &self,
176        args: chat::ChatArguments,
177        opt_url_path: Option<String>,
178    ) -> Result<chat::stream::ChatCompletionChunkStream> {
179        let mut url = self.base_url.clone();
180        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/chat/completions")));
181
182        // Enable streaming
183        let mut args = args;
184        args.stream = Some(true);
185
186        let res = self
187            .req_client
188            .post(url)
189            .bearer_auth(&self.key)
190            .json(&args)
191            .send()
192            .await?;
193
194        if res.status() == 200 {
195            Ok(chat::stream::ChatCompletionChunkStream::new(Box::pin(
196                res.bytes_stream(),
197            )))
198        } else {
199            Err(anyhow!(res.text().await?))
200        }
201    }
202
203    /// Given a prompt, the model will return one or more predicted completions, and can also return the probabilities of alternative tokens at each position.
204    ///
205    /// See <https://platform.openai.com/docs/api-reference/completions>
206    ///
207    /// ```
208    /// # use openai_rust2 as openai_rust;
209    /// # use openai_rust::*;
210    /// # use tokio_test;
211    /// # tokio_test::block_on(async {
212    /// # let api_key = "";
213    /// let c = Client::new(api_key);
214    /// let args = completions::CompletionArguments::new("text-davinci-003", "The quick brown fox".to_owned());
215    /// println!("{}", c.create_completion(args, None).await.unwrap().choices[0].text);
216    /// # })
217    /// ```
218    pub async fn create_completion(
219        &self,
220        args: completions::CompletionArguments,
221        opt_url_path: Option<String>,
222    ) -> Result<completions::CompletionResponse> {
223        let mut url = self.base_url.clone();
224        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/completions")));
225
226        let res = self
227            .req_client
228            .post(url)
229            .bearer_auth(&self.key)
230            .json(&args)
231            .send()
232            .await?;
233
234        if res.status() == 200 {
235            Ok(res.json().await?)
236        } else {
237            Err(anyhow!(res.text().await?))
238        }
239    }
240
241    /// Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms.
242    ///
243    /// See <https://platform.openai.com/docs/api-reference/embeddings>
244    ///
245    /// ```
246    /// # use openai_rust2 as openai_rust;
247    /// # use tokio_test;
248    /// # tokio_test::block_on(async {
249    /// # let api_key = "";
250    /// let c = openai_rust::Client::new(api_key);
251    /// let args = openai_rust::embeddings::EmbeddingsArguments::new("text-embedding-ada-002", "The food was delicious and the waiter...".to_owned());
252    /// println!("{:?}", c.create_embeddings(args, None).await.unwrap().data);
253    /// # })
254    /// ```
255    ///
256    pub async fn create_embeddings(
257        &self,
258        args: embeddings::EmbeddingsArguments,
259        opt_url_path: Option<String>,
260    ) -> Result<embeddings::EmbeddingsResponse> {
261        let mut url = self.base_url.clone();
262        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/embeddings")));
263
264        let res = self
265            .req_client
266            .post(url)
267            .bearer_auth(&self.key)
268            .json(&args)
269            .send()
270            .await?;
271
272        if res.status() == 200 {
273            Ok(res.json().await?)
274        } else {
275            Err(anyhow!(res.text().await?))
276        }
277    }
278
279    /// Creates an image given a prompt.
280    pub async fn create_image(
281        &self,
282        args: images::ImageArguments,
283        opt_url_path: Option<String>,
284    ) -> Result<Vec<String>> {
285        let mut url = self.base_url.clone();
286        url.set_path(&opt_url_path.unwrap_or_else(|| String::from("/v1/images/generations")));
287
288        let res = self
289            .req_client
290            .post(url)
291            .bearer_auth(&self.key)
292            .json(&args)
293            .send()
294            .await?;
295
296        if res.status() == 200 {
297            Ok(res
298                .json::<images::ImageResponse>()
299                .await?
300                .data
301                .iter()
302                .map(|o| match o {
303                    images::ImageObject::Url(s) => s.to_string(),
304                    images::ImageObject::Base64JSON(s) => s.to_string(),
305                })
306                .collect())
307        } else {
308            Err(anyhow!(res.text().await?))
309        }
310    }
311}