mini_openai/
lib.rs

1use std::env;
2
3#[cfg(all(feature = "reqwest", feature = "ureq"))]
4compile_error!("Features 'reqwest' and 'ureq' are mutually exclusive.");
5
6#[cfg(not(any(feature = "reqwest", feature = "ureq")))]
7compile_error!("One of the features 'reqwest' and 'ureq' must be enabled.");
8
9use serde::ser::{SerializeMap, SerializeSeq};
10#[cfg(feature = "ureq")]
11use ureq;
12
13#[cfg(feature = "reqwest")]
14use reqwest;
15
16const OPENAI_API_KEY: &str = "OPENAI_API_KEY";
17const OPENAI_API_BASE: &str = "OPENAI_API_BASE";
18const DEFAULT_API_BASE: &str = "https://api.openai.com/v1";
19
20#[derive(thiserror::Error, Debug)]
21pub enum Error {
22    #[error("The configuration contains errors: {0}")]
23    BadConfigurationError(String),
24
25    #[error("Failed to serialize response: {0}")]
26    SerializationError(serde_json::Error),
27
28    #[error("Failed to deserialize response: {0}")]
29    DeserializationError(String),
30
31    #[error("Network error: {0}")]
32    NetworkError(String),
33
34    #[error("API error: {0}")]
35    ApiError(String),
36}
37
38pub const DEFAULT_CHAT_MODEL: &str = "gpt-4o-mini";
39pub const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-ada-002";
40
41pub const ROLE_SYSTEM: &str = "system";
42pub const ROLE_USER: &str = "user";
43pub const ROLE_ASSISTANT: &str = "assistant";
44
45#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
46pub struct Message {
47    pub content: String,
48    pub role: String,
49}
50
51#[derive(Clone, Debug)]
52pub enum ResponseFormat {
53    JsonObject,
54    JsonSchema(serde_json::Value),
55}
56
57impl serde::Serialize for ResponseFormat {
58    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
59    where
60        S: serde::Serializer,
61    {
62        match self {
63            ResponseFormat::JsonObject => {
64                let mut map = serializer.serialize_map(Some(1))?;
65                map.serialize_entry("type", "json_object")?;
66                map.end()
67            }
68            ResponseFormat::JsonSchema(schema) => {
69                let mut map = serializer.serialize_map(Some(2))?;
70                map.serialize_entry("type", "json_schema")?;
71                map.serialize_entry("json_schema", schema)?;
72                map.end()
73            }
74        }
75    }
76}
77
78#[derive(Clone, Debug)]
79pub enum Stop {
80    String(String),
81    Array(Vec<String>),
82}
83
84impl serde::Serialize for Stop {
85    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
86    where
87        S: serde::Serializer,
88    {
89        match self {
90            Stop::String(string) => serializer.serialize_str(&string),
91            Stop::Array(strings) => {
92                let mut array = serializer.serialize_seq(Some(strings.len()))?;
93
94                for string in strings {
95                    array.serialize_element(string)?;
96                }
97
98                array.end()
99            }
100        }
101    }
102}
103
104fn is_false(value: &bool) -> bool {
105    *value == false
106}
107
108// NOTE: As we're supporting non-OpenAI API implementations, we should only
109// send options in requests that are set. Some implementations don't like if
110// they see options they don't know, even if they're "null".
111
112/// Chat completions structure.
113///
114/// For reference, see: https://platform.openai.com/docs/api-reference/chat
115///
116/// To construct this structure easily use the default trait:
117///
118/// ```rust
119/// let request = mini_openai::ChatCompletions {
120///   messages: vec![
121///     mini_openai::Message{
122///         role: mini_openai::ROLE_SYSTEM.into(),
123///         content: "Who are you?".into()
124///     }
125///   ],
126///   ..Default::default()
127/// };
128/// ```
129#[derive(Clone, Debug, serde::Serialize)]
130pub struct ChatCompletions {
131    pub messages: Vec<Message>,
132    pub model: String,
133    #[serde(skip_serializing_if = "is_false")]
134    pub store: bool,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub metadata: Option<serde_json::Value>,
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub logit_bias: Option<serde_json::Value>,
139    #[serde(skip_serializing_if = "is_false")]
140    pub logprobs: bool,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub top_logprobs: Option<usize>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub max_tokens: Option<usize>,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub max_completion_tokens: Option<usize>,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub n: Option<usize>,
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub presence_penalty: Option<f32>,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub response_format: Option<ResponseFormat>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub seed: Option<u32>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub service_tier: Option<String>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub stop: Option<Stop>,
159    /// Must be 'false': Only non-streaming is supported.
160    pub stream: bool,
161    // pub stream_options: Option<>,
162    // pub tools: Option<>,
163    // pub tool_choice: Option<>,
164    // pub parallel_tool_calls: bool,
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub user: Option<String>,
167}
168
169impl Default for ChatCompletions {
170    fn default() -> Self {
171        Self {
172            messages: Default::default(),
173            model: DEFAULT_CHAT_MODEL.into(),
174            store: false,
175            metadata: None,
176            logit_bias: None,
177            logprobs: false,
178            top_logprobs: None,
179            max_tokens: None,
180            max_completion_tokens: None,
181            n: None,
182            presence_penalty: None,
183            response_format: None,
184            seed: None,
185            service_tier: None,
186            stop: None,
187            stream: false,
188            user: None,
189        }
190    }
191}
192
193#[derive(Clone, Debug, serde::Deserialize)]
194pub struct Choice {
195    pub index: usize,
196    pub message: Message,
197    //pub logprobs: Option<Logprobs>,
198    pub finish_reason: String,
199}
200
201#[derive(Clone, Debug, serde::Deserialize)]
202pub struct ChatCompletionsResponse {
203    pub id: String,
204    pub object: String,
205    pub created: usize,
206    pub model: String,
207    pub choices: Vec<Choice>,
208    //pub usage: Usage,
209}
210
211#[derive(Clone, Debug)]
212pub enum Input {
213    String(String),
214    Array(Vec<String>),
215}
216
217impl From<String> for Input {
218    fn from(value: String) -> Self {
219        Self::String(value)
220    }
221}
222
223impl From<&str> for Input {
224    fn from(value: &str) -> Self {
225        Self::String(value.to_string())
226    }
227}
228
229impl From<Vec<String>> for Input {
230    fn from(values: Vec<String>) -> Self {
231        Self::Array(values)
232    }
233}
234
235impl From<&[String]> for Input {
236    fn from(values: &[String]) -> Self {
237        Self::Array(values.to_vec())
238    }
239}
240
241impl serde::Serialize for Input {
242    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
243    where
244        S: serde::Serializer,
245    {
246        match self {
247            Input::String(string) => serializer.serialize_str(string),
248            Input::Array(array) => {
249                let mut seq = serializer.serialize_seq(Some(array.len()))?;
250                for s in array {
251                    seq.serialize_element(s)?;
252                }
253                seq.end()
254            }
255        }
256    }
257}
258
259/// Embeddings request structure.
260///
261/// You can easily construct the input using .into():
262///
263/// ```rust
264/// let embeddings = mini_openai::Embeddings {
265///     input: "Hello".into(),
266///     ..Default::default()
267/// };
268/// ```
269///
270#[derive(Clone, Debug, serde::Serialize)]
271pub struct Embeddings {
272    pub input: Input,
273    pub model: String,
274    //pub encoding_format: EncodingFormat,
275    #[serde(skip_serializing_if = "Option::is_none")]
276    pub dimensions: Option<usize>,
277    #[serde(skip_serializing_if = "Option::is_none")]
278    pub user: Option<String>,
279}
280
281impl Default for Embeddings {
282    fn default() -> Self {
283        Self {
284            input: Input::String("".into()),
285            model: DEFAULT_EMBEDDING_MODEL.into(),
286            dimensions: None,
287            user: None,
288        }
289    }
290}
291
292#[derive(Clone, Debug, serde::Deserialize)]
293pub struct EmbeddingsResponse {
294    pub data: Vec<Embedding>,
295    pub model: String,
296    pub usage: Option<Usage>, // Not all implementations may return this
297}
298
299#[derive(Clone, Debug, serde::Deserialize)]
300pub struct Embedding {
301    pub index: u64,
302    pub embedding: Vec<f32>,
303}
304
305#[derive(Clone, Debug, serde::Deserialize)]
306pub struct Usage {
307    pub prompt_tokens: u32,
308    pub total_tokens: u32,
309}
310
311#[cfg(feature = "ureq")]
312struct ClientImpl {
313    client: ureq::Agent,
314    token: Option<String>,
315}
316
317#[cfg(feature = "ureq")]
318impl ClientImpl {
319    fn new(token: Option<String>) -> Result<ClientImpl, Error> {
320        Ok(Self {
321            client: ureq::Agent::new(),
322            token,
323        })
324    }
325
326    fn do_request(&self, url: String, body: String) -> Result<String, Error> {
327        let mut request = self
328            .client
329            .post(&url)
330            .set("Content-Type", "application/json");
331
332        if let Some(token) = self.token.as_ref() {
333            request = request.set("Authorization", &format!("Bearer {}", token));
334        }
335
336        let response = request
337            .send_string(&body)
338            .map_err(|e| Error::NetworkError(e.to_string()))?;
339
340        if response.status() != 200 {
341            let text = format!("{} {}", response.status(), response.status_text());
342            Err(Error::ApiError(text))?;
343        }
344
345        let body = response
346            .into_string()
347            .map_err(|e| Error::NetworkError(e.to_string()))?;
348        Ok(body)
349    }
350}
351
352#[cfg(feature = "reqwest")]
353struct ClientImpl {
354    client: reqwest::Client,
355}
356
357#[cfg(feature = "reqwest")]
358impl ClientImpl {
359    fn new(token: Option<String>) -> Result<ClientImpl, Error> {
360        let mut headers = reqwest::header::HeaderMap::new();
361
362        if let Some(token) = token {
363            let mut value = reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
364                .map_err(|e| Error::BadConfigurationError(e.to_string()))?;
365            value.set_sensitive(true);
366            headers.insert(reqwest::header::AUTHORIZATION, value);
367        }
368
369        let client = reqwest::ClientBuilder::new()
370            .default_headers(headers)
371            .build()
372            .map_err(|e| Error::BadConfigurationError(e.to_string()))?;
373
374        Ok(Self { client })
375    }
376
377    async fn do_request(&self, url: String, body: String) -> Result<String, Error> {
378        let response = self
379            .client
380            .post(url)
381            .header(reqwest::header::CONTENT_TYPE, "application/json")
382            .body(body)
383            .send()
384            .await
385            .map_err(|e| Error::NetworkError(e.to_string()))?
386            .error_for_status()
387            .map_err(|e| Error::ApiError(e.to_string()))?
388            .text()
389            .await
390            .map_err(|e| Error::NetworkError(e.to_string()))?;
391
392        Ok(response)
393    }
394}
395
396pub struct Client {
397    inner: ClientImpl,
398    base_uri: String,
399}
400
401impl Client {
402    /// Creates a new `Client` instance.
403    ///
404    /// This function will first check for environment variables `OPENAI_API_BASE` and `OPENAI_API_KEY`.
405    /// If they are not set, it will use the provided `base_uri` and `token` parameters. If neither are set,
406    /// it will use the default API base URI.
407    ///
408    /// If a `token` is not provided and `base_uri` is set to the OpenAI API base URI, an error will be returned.
409    ///
410    /// # Arguments
411    ///
412    /// * `base_uri`: The base URI of the API, or `None` to use the environment variable or default.
413    /// * `token`: The API token, or `None` to use the environment variable.
414    ///
415    /// # Returns
416    ///
417    /// A `Result` containing the new `Client` instance, or an `Error` if the configuration is invalid.
418    pub fn new(base_uri: Option<String>, token: Option<String>) -> Result<Client, Error> {
419        let env_base_uri = env::var(OPENAI_API_BASE).unwrap_or_default();
420        let env_token = env::var(OPENAI_API_KEY).unwrap_or_default();
421
422        let base_uri = if env_base_uri.is_empty() {
423            if let Some(uri) = base_uri {
424                uri
425            } else {
426                DEFAULT_API_BASE.to_string()
427            }
428        } else {
429            env_base_uri
430        };
431
432        let token = if env_token.is_empty() {
433            token
434        } else {
435            Some(env_token)
436        };
437
438        Self::new_without_environment(base_uri, token)
439    }
440
441    /// Creates a new `Client` instance without checking environment variables.
442    ///
443    /// This function is used internally by `new` to create a client without checking for environment variables.
444    ///
445    /// # Arguments
446    ///
447    /// * `base_uri`: The base URI of the API.
448    /// * `token`: The API token, or `None` if not required.
449    ///
450    /// # Returns
451    ///
452    /// If `base_uri` is empty, an error will be returned.
453    /// If `base_uri` is set to the OpenAI API base URI and `token` is `None`, an error will be returned.
454    /// A `Result` containing the new `Client` instance, or an `Error` if the configuration is invalid.
455    pub fn new_without_environment(
456        base_uri: String,
457        token: Option<String>,
458    ) -> Result<Client, Error> {
459        if base_uri.is_empty() {
460            return Err(Error::BadConfigurationError("No base URI given".into()));
461        }
462
463        // Only check if there's a token if we're connecting to OpenAI.
464        // Custom endpoints may not require it, so don't enforce it for them.
465        if base_uri == DEFAULT_API_BASE && token.is_none() {
466            return Err(Error::BadConfigurationError("Missing api token".into()));
467        }
468
469        let inner = ClientImpl::new(token)?;
470        Ok(Self { inner, base_uri })
471    }
472
473    /// Creates a new `Client` instance from environment variables.
474    ///
475    /// This function will read the `OPENAI_API_BASE` and `OPENAI_API_KEY` environment variables and use them to create a client.
476    ///
477    /// # Returns
478    ///
479    /// A `Result` containing the new `Client` instance, or an `Error` if the environment variables are not set.
480    pub fn new_from_environment() -> Result<Client, Error> {
481        let env_base_uri =
482            env::var(OPENAI_API_BASE).map_err(|e| Error::BadConfigurationError(e.to_string()))?;
483        let env_token = env::var(OPENAI_API_KEY).unwrap_or_default();
484
485        let token = if env_token.is_empty() {
486            None
487        } else {
488            Some(env_token)
489        };
490
491        Self::new_without_environment(env_base_uri, token)
492    }
493
494    /// Sends a request to the OpenAI API to generate a completion for a chat conversation.
495    ///
496    /// This function takes a `ChatCompletions` struct as input, which defines the parameters of the completion request,
497    /// including the chat history, model to use, and desired response format.
498    ///
499    /// The function returns a `ChatCompletionsResponse` struct, which contains the generated completion.
500    ///
501    /// # Arguments
502    ///
503    /// * `request`: The `ChatCompletions` struct containing the request parameters.
504    ///
505    /// # Returns
506    ///
507    /// A `Result` containing the `ChatCompletionsResponse` struct, or an `Error` if the request fails.
508    ///
509    /// # Example
510    ///
511    /// ```rust,ignore
512    /// use mini_openai::{Client, ChatCompletions, Message, ROLE_USER};
513    ///
514    /// let client = Client::new(None, None).unwrap();
515    ///
516    /// // Create a new chat completion request
517    /// let mut request = ChatCompletions::default();
518    ///
519    /// // Add a message to the chat history
520    /// request.messages.push(Message {
521    ///     content: "Hello!".to_string(),
522    ///     role: ROLE_USER.to_string(),
523    /// });
524    ///
525    /// // Send the request to the OpenAI API
526    /// let response = client.chat_completions(&request).await.unwrap();
527    ///
528    /// // Print the generated completion
529    /// println!("{}", response.choices[0].message.content);
530    /// ```
531    #[cfg(feature = "reqwest")]
532    pub async fn chat_completions(
533        &self,
534        request: &ChatCompletions,
535    ) -> Result<ChatCompletionsResponse, Error> {
536        let url = format!("{}/chat/completions", self.base_uri);
537        let body = serde_json::to_string(request).map_err(Error::SerializationError)?;
538        let response = self.inner.do_request(url, body).await?;
539
540        serde_json::from_str(&response).map_err(|e| Error::DeserializationError(e.to_string()))
541    }
542
543    /// Sends a request to the OpenAI API to generate a completion for a chat conversation.
544    ///
545    /// This function takes a `ChatCompletions` struct as input, which defines the parameters of the completion request,
546    /// including the chat history, model to use, and desired response format.
547    ///
548    /// The function returns a `ChatCompletionsResponse` struct, which contains the generated completion.
549    ///
550    /// # Arguments
551    ///
552    /// * `request`: The `ChatCompletions` struct containing the request parameters.
553    ///
554    /// # Returns
555    ///
556    /// A `Result` containing the `ChatCompletionsResponse` struct, or an `Error` if the request fails.
557    ///
558    /// # Example
559    ///
560    /// ```rust
561    /// use mini_openai::{Client, ChatCompletions, Message, ROLE_USER};
562    ///
563    /// let client = Client::new(None, None).unwrap();
564    ///
565    /// // Create a new chat completion request
566    /// let mut request = ChatCompletions::default();
567    ///
568    /// // Add a message to the chat history
569    /// request.messages.push(Message {
570    ///     content: "Hello!".to_string(),
571    ///     role: ROLE_USER.to_string(),
572    /// });
573    ///
574    /// // Send the request to the OpenAI API
575    /// let response = client.chat_completions(&request).unwrap();
576    ///
577    /// // Print the generated completion
578    /// println!("{}", response.choices[0].message.content);
579    /// ```
580    #[cfg(feature = "ureq")]
581    pub fn chat_completions(
582        &self,
583        request: &ChatCompletions,
584    ) -> Result<ChatCompletionsResponse, Error> {
585        let url = format!("{}/chat/completions", self.base_uri);
586        let body = serde_json::to_string(request).map_err(Error::SerializationError)?;
587        let response = self.inner.do_request(url, body)?;
588
589        serde_json::from_str(&response).map_err(|e| Error::DeserializationError(e.to_string()))
590    }
591
592    /// Attempts to retrieve a chat completion and deserializes the response into a custom type.
593    ///
594    /// This function makes multiple attempts to retrieve a chat completion, up to a specified maximum number of tries.
595    /// If a successful response is received, it will attempt to deserialize the response into the desired type using
596    /// a provided converter function. If an error is caused anywhere the whole chain is retried up to *max_tries* times.
597    ///
598    /// If all attempts fail, the error that was last received will be returned.
599    ///
600    /// # Arguments
601    ///
602    /// * `request`: The chat completion request to send.
603    /// * `max_tries`: The maximum number of attempts to make.
604    /// * `converter`: A function that takes the content of the chat completion response and attempts to deserialize it
605    ///               into the desired type.
606    ///
607    /// # Returns
608    ///
609    /// * `Result<T, Error>`: The deserialized result if successful, or the final error if all attempts fail.
610    ///
611    /// # Errors
612    ///
613    /// * Any errors that occur during the network request itself.
614    /// * Any errors that occur during deserialization.
615    ///
616    /// # Example
617    ///
618    /// For the likely case that you want to parse JSON, you can use the `parse_json_lenient` helper function.
619    /// Here's how to use it:
620    ///
621    /// ```rust,ignore
622    /// #[derive(Debug, serde::Deserialize)]
623    /// struct Hello {
624    ///     hello: String,
625    /// }
626    ///
627    /// let client = mini_openai::Client::new(None, None).unwrap();
628    /// let request = mini_openai::ChatCompletions {
629    ///     messages: vec![
630    ///         mini_openai::Message {
631    ///             content: r#"Respond with {"hello": "world"}"#.into(),
632    ///             role: mini_openai::ROLE_SYSTEM.into(),
633    ///         }
634    ///     ],
635    ///     ..Default::default()
636    /// };
637    ///
638    /// let hello: Hello = client.chat_completions_into(&request, 3, mini_openai::parse_json_lenient).await.unwrap();
639    /// println!("Result: {:?}", hello);
640    /// ```
641    #[cfg(feature = "reqwest")]
642    pub async fn chat_completions_into<F, T, E>(
643        &self,
644        request: &ChatCompletions,
645        max_tries: usize,
646        converter: F,
647    ) -> Result<T, Error>
648    where
649        F: Fn(String) -> Result<T, E>,
650        E: ToString,
651    {
652        let mut error: Option<Error> = None;
653
654        for _ in 1..=max_tries {
655            match self.chat_completions(request).await {
656                Ok(mut response) => {
657                    let choice = response.choices.swap_remove(0);
658                    match converter(choice.message.content) {
659                        Ok(result) => return Ok(result),
660                        Err(e) => error = Some(Error::DeserializationError(e.to_string())),
661                    }
662                }
663                Err(e) => {
664                    error = Some(e);
665                }
666            }
667        }
668
669        Err(error.unwrap())
670    }
671
672    /// Attempts to retrieve a chat completion and deserializes the response into a custom type.
673    ///
674    /// This function makes multiple attempts to retrieve a chat completion, up to a specified maximum number of tries.
675    /// If a successful response is received, it will attempt to deserialize the response into the desired type using
676    /// a provided converter function. If an error is caused anywhere the whole chain is retried up to *max_tries* times.
677    ///
678    /// If all attempts fail, the error that was last received will be returned.
679    ///
680    /// # Arguments
681    ///
682    /// * `request`: The chat completion request to send.
683    /// * `max_tries`: The maximum number of attempts to make.
684    /// * `converter`: A function that takes the content of the chat completion response and attempts to deserialize it
685    ///               into the desired type.
686    ///
687    /// # Returns
688    ///
689    /// * `Result<T, Error>`: The deserialized result if successful, or the final error if all attempts fail.
690    ///
691    /// # Errors
692    ///
693    /// * Any errors that occur during the network request itself.
694    /// * Any errors that occur during deserialization.
695    ///
696    /// # Example
697    ///
698    /// For the likely case that you want to parse JSON, you can use the `parse_json_lenient` helper function.
699    /// Here's how to use it:
700    ///
701    /// ```rust
702    /// #[derive(Debug, serde::Deserialize)]
703    /// struct Hello {
704    ///     hello: String,
705    /// }
706    ///
707    /// let client = mini_openai::Client::new(None, None).unwrap();
708    /// let request = mini_openai::ChatCompletions {
709    ///     messages: vec![
710    ///         mini_openai::Message {
711    ///             content: r#"Respond with {"hello": "world"}"#.into(),
712    ///             role: mini_openai::ROLE_SYSTEM.into(),
713    ///         }
714    ///     ],
715    ///     ..Default::default()
716    /// };
717    ///
718    /// let hello: Hello = client.chat_completions_into(&request, 3, mini_openai::parse_json_lenient).unwrap();
719    /// println!("Result: {:?}", hello);
720    /// ```
721    #[cfg(feature = "ureq")]
722    pub fn chat_completions_into<F, T, E>(
723        &self,
724        request: &ChatCompletions,
725        max_tries: usize,
726        converter: F,
727    ) -> Result<T, Error>
728    where
729        F: Fn(String) -> Result<T, E>,
730        E: ToString,
731    {
732        let mut error: Option<Error> = None;
733
734        for _ in 1..=max_tries {
735            match self.chat_completions(request) {
736                Ok(mut response) => {
737                    let choice = response.choices.swap_remove(0);
738                    match converter(choice.message.content) {
739                        Ok(result) => return Ok(result),
740                        Err(e) => error = Some(Error::DeserializationError(e.to_string())),
741                    }
742                }
743                Err(e) => {
744                    error = Some(e);
745                }
746            }
747        }
748
749        Err(error.unwrap())
750    }
751
752    /// Sends a request to the OpenAI API to generate embeddings of text.
753    ///
754    /// This function takes a `Embeddings` struct as input..
755    ///
756    /// The function returns a `EmbeddingsResponse` struct, which contains the generated embeddings.
757    ///
758    /// # Arguments
759    ///
760    /// * `request`: The `Embeddings` struct containing the request parameters.
761    ///
762    /// # Returns
763    ///
764    /// A `Result` containing the `EmbeddingsResponse` struct, or an `Error` if the request fails.
765    ///
766    /// # Example
767    ///
768    /// ```rust,ignore
769    /// use mini_openai::{Client, Embeddings, Message, ROLE_USER};
770    ///
771    /// let client = Client::new(None, None).unwrap();
772    ///
773    /// // Create a new chat completion request
774    /// let request = Embeddings { input: "Hello".into(), ..Default::default() };
775    ///
776    /// // Send the request to the OpenAI API
777    /// let response = client.embeddings(&request).await.unwrap();
778    ///
779    /// // Print the generated completion
780    /// println!("{}", response.data[0].embedding);
781    /// ```
782    #[cfg(feature = "reqwest")]
783    pub async fn embeddings(&self, request: &Embeddings) -> Result<EmbeddingsResponse, Error> {
784        let url = format!("{}/embeddings", self.base_uri);
785        let body = serde_json::to_string(request).map_err(Error::SerializationError)?;
786        let response = self.inner.do_request(url, body).await?;
787
788        serde_json::from_str(&response).map_err(|e| Error::DeserializationError(e.to_string()))
789    }
790
791    /// Sends a request to the OpenAI API to generate embeddings of text.
792    ///
793    /// This function takes a `Embeddings` struct as input..
794    ///
795    /// The function returns a `EmbeddingsResponse` struct, which contains the generated embeddings.
796    ///
797    /// # Arguments
798    ///
799    /// * `request`: The `Embeddings` struct containing the request parameters.
800    ///
801    /// # Returns
802    ///
803    /// A `Result` containing the `EmbeddingsResponse` struct, or an `Error` if the request fails.
804    ///
805    /// # Example
806    ///
807    /// ```rust
808    /// use mini_openai::{Client, Embeddings, Message, ROLE_USER};
809    ///
810    /// let client = Client::new(None, None).unwrap();
811    ///
812    /// // Create a new chat completion request
813    /// let request = Embeddings { input: "Hello".into(), ..Default::default() };
814    ///
815    /// // Send the request to the OpenAI API
816    /// let response = client.embeddings(&request).unwrap();
817    ///
818    /// // Print the generated completion
819    /// println!("{:?}", response.data[0].embedding);
820    /// ```
821    #[cfg(feature = "ureq")]
822    pub fn embeddings(&self, request: &Embeddings) -> Result<EmbeddingsResponse, Error> {
823        let url = format!("{}/embeddings", self.base_uri);
824        let body = serde_json::to_string(request).map_err(Error::SerializationError)?;
825        let response = self.inner.do_request(url, body)?;
826
827        serde_json::from_str(&response).map_err(|e| Error::DeserializationError(e.to_string()))
828    }
829}
830
831/// Helper function to be used with Client::chat_completions_into().
832///
833/// Pass this function to chat_completions_into() to let it parse a JSON
834/// document. This function allows for some blabber emitted by the LLM,
835/// making things like explanations or markdown-style fences a non-issue.
836pub fn parse_json_lenient<T>(text: String) -> Result<T, String>
837where
838    T: serde::de::DeserializeOwned,
839{
840    let found = (text.find('{'), text.rfind('}'));
841    if let (Some(begin), Some(end)) = found {
842        let json = &text[begin..=end];
843        serde_json::from_str(json).map_err(|e| e.to_string())
844    } else {
845        Err("The text doesn't contain a JSON object".into())
846    }
847}
848
849#[cfg(test)]
850mod tests {
851    use super::*;
852
853    #[cfg(feature = "ureq")]
854    #[test]
855    fn test_chat_completions() -> Result<(), Error> {
856        let client = Client::new(None, None)?;
857        let request = ChatCompletions {
858            messages: vec![Message {
859                role: ROLE_SYSTEM.into(),
860                content: "Just say OK.".into(),
861            }],
862            ..Default::default()
863        };
864
865        let response: ChatCompletionsResponse = client.chat_completions(&request)?;
866
867        assert_eq!(response.choices.len(), 1);
868        assert_eq!(response.choices[0].message.content.contains("OK"), true);
869
870        Ok(())
871    }
872
873    #[cfg(feature = "reqwest")]
874    #[tokio::test]
875    async fn test_chat_completions() -> Result<(), Error> {
876        let client = Client::new(None, None)?;
877        let request = ChatCompletions {
878            messages: vec![Message {
879                role: ROLE_SYSTEM.into(),
880                content: "Just say OK.".into(),
881            }],
882            ..Default::default()
883        };
884
885        let response: ChatCompletionsResponse = client.chat_completions(&request).await?;
886
887        assert_eq!(response.choices.len(), 1);
888        assert_eq!(response.choices[0].message.content.contains("OK"), true);
889
890        Ok(())
891    }
892
893    #[cfg(feature = "ureq")]
894    #[test]
895    fn test_chat_completions_into() -> Result<(), Error> {
896        #[derive(serde::Deserialize)]
897        struct Test {
898            hello: String,
899        }
900
901        let client = Client::new(None, None)?;
902        let request = ChatCompletions {
903            messages: vec![Message {
904                role: ROLE_SYSTEM.into(),
905                content: r#"Respond with this JSON: {"hello": "a word of your choosing"}."#.into(),
906            }],
907            ..Default::default()
908        };
909
910        let response: Test = client.chat_completions_into(&request, 3, parse_json_lenient)?;
911        assert_eq!(response.hello.is_empty(), false);
912
913        Ok(())
914    }
915
916    #[cfg(feature = "reqwest")]
917    #[tokio::test]
918    async fn test_chat_completions_into() -> Result<(), Error> {
919        #[derive(serde::Deserialize)]
920        struct Test {
921            hello: String,
922        }
923
924        let client = Client::new(None, None)?;
925        let request = ChatCompletions {
926            messages: vec![Message {
927                role: ROLE_SYSTEM.into(),
928                content: r#"Respond with this JSON: {"hello": "a word of your choosing"}."#.into(),
929            }],
930            ..Default::default()
931        };
932
933        let response: Test = client
934            .chat_completions_into(&request, 3, parse_json_lenient)
935            .await?;
936        assert_eq!(response.hello.is_empty(), false);
937
938        Ok(())
939    }
940
941    #[cfg(feature = "ureq")]
942    #[test]
943    fn test_embeddings() -> Result<(), Error> {
944        let client = Client::new(None, None)?;
945        let request = Embeddings {
946            input: "Hello".into(),
947            ..Default::default()
948        };
949
950        let response: EmbeddingsResponse = client.embeddings(&request)?;
951
952        assert_eq!(response.data.len(), 1);
953        assert_eq!(response.data[0].embedding.is_empty(), false);
954
955        Ok(())
956    }
957
958    #[cfg(feature = "reqwest")]
959    #[tokio::test]
960    async fn test_embeddings() -> Result<(), Error> {
961        let client = Client::new(None, None)?;
962        let request = Embeddings {
963            input: "Hello".into(),
964            ..Default::default()
965        };
966
967        let response: EmbeddingsResponse = client.embeddings(&request).await?;
968
969        assert_eq!(response.data.len(), 1);
970        assert_eq!(response.data[0].embedding.is_empty(), false);
971
972        Ok(())
973    }
974
975    #[test]
976    fn test_parse_json_lenient() -> Result<(), String> {
977        #[derive(serde::Deserialize)]
978        struct Test {
979            hello: String,
980        }
981
982        let test: Test = parse_json_lenient(r#"Here's your JSON: {"hello": "world"}"#.into())?;
983        assert_eq!(test.hello, "world");
984
985        let test: Result<Test, String> =
986            parse_json_lenient(r#"JSON is a great choice for your request!"#.into());
987        assert_eq!(test.is_err(), true);
988
989        Ok(())
990    }
991}