openai_rs/
client.rs

1use hyper::Client as HyperClient;
2use hyper::client::HttpConnector;
3use hyper_openssl::HttpsConnector;
4use crate::endpoints::{Response, ResponseError};
5use crate::endpoints::request::Endpoint;
6
7pub(crate) type HttpsHyperClient = HyperClient<HttpsConnector<HttpConnector>>;
8
9#[derive(Debug)]
10pub struct Client {
11    pub(crate) api_key: String,
12    pub(crate) https: HttpsHyperClient,
13}
14
15impl Client {
16    /// Returns a new response from the OpenAI API.
17    ///
18    /// # Arguments
19    ///
20    /// * `engine_id` - The engine id to use. Due to few endpoints this can be optional.
21    /// * `model` - The model to use. Each Model in the endpoints module is a corresponding model.
22    ///
23    /// # Example
24    ///
25    /// ```
26    /// use std::borrow::Cow;
27    /// use openai_rs::client::Client;
28    /// use openai_rs::endpoints::edits::Edit;
29    /// use openai_rs::endpoints::{Response, ResponseError};
30    /// use openai_rs::openai;
31    ///
32    /// // Create the Client with your API key.
33    /// let client: Client = openai::new("api_key");
34    ///
35    /// // Create the Edit struct with the input and instruction.
36    /// let edit = Edit {
37    ///      input: Cow::Borrowed("What day of the wek is it?"),
38    ///      instruction: Cow::Borrowed("Fix the spelling mistakes"),
39    ///      ..Default::default()
40    ///  };
41    ///
42    /// // Send the request to the OpenAI API.
43    /// let response: Result<Response, ResponseError> = client.create(
44    ///     Some("text-davinci-edit-001"), &edit
45    /// ).await;
46    /// ```
47    pub async fn create<T>(
48        &self,
49        engine_id: Option<&str>,
50        model: &T
51    ) -> Result<Response, ResponseError>
52        where T: Endpoint {
53        match self.https.request(model.request(&*self.api_key, engine_id)).await {
54            Ok(response) => {
55                if response.status().is_success() {
56                    let body = hyper::body::to_bytes(response.into_body()).await?;
57                    let deserialized = serde_json::from_slice(&body)
58                        .map_err(ResponseError::from)?;
59                    trace!("Requesting: {:#?}", deserialized);
60
61                    Ok(deserialized)
62                } else {
63                    Err(ResponseError::ErrorCode(response.status()))
64                }
65            },
66            Err(error) => Err(error.into())
67        }
68    }
69}