1use std::env;
2
3#[cfg(all(feature = "reqwest", feature = "ureq"))]
4compile_error!("Features 'reqwest' and 'ureq' are mutually exclusive.");
5
6#[cfg(not(any(feature = "reqwest", feature = "ureq")))]
7compile_error!("One of the features 'reqwest' and 'ureq' must be enabled.");
8
9use serde::ser::{SerializeMap, SerializeSeq};
10#[cfg(feature = "ureq")]
11use ureq;
12
13#[cfg(feature = "reqwest")]
14use reqwest;
15
16const OPENAI_API_KEY: &str = "OPENAI_API_KEY";
17const OPENAI_API_BASE: &str = "OPENAI_API_BASE";
18const DEFAULT_API_BASE: &str = "https://api.openai.com/v1";
19
20#[derive(thiserror::Error, Debug)]
21pub enum Error {
22 #[error("The configuration contains errors: {0}")]
23 BadConfigurationError(String),
24
25 #[error("Failed to serialize response: {0}")]
26 SerializationError(serde_json::Error),
27
28 #[error("Failed to deserialize response: {0}")]
29 DeserializationError(String),
30
31 #[error("Network error: {0}")]
32 NetworkError(String),
33
34 #[error("API error: {0}")]
35 ApiError(String),
36}
37
38pub const DEFAULT_CHAT_MODEL: &str = "gpt-4o-mini";
39pub const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-ada-002";
40
41pub const ROLE_SYSTEM: &str = "system";
42pub const ROLE_USER: &str = "user";
43pub const ROLE_ASSISTANT: &str = "assistant";
44
45#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
46pub struct Message {
47 pub content: String,
48 pub role: String,
49}
50
51#[derive(Clone, Debug)]
52pub enum ResponseFormat {
53 JsonObject,
54 JsonSchema(serde_json::Value),
55}
56
57impl serde::Serialize for ResponseFormat {
58 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
59 where
60 S: serde::Serializer,
61 {
62 match self {
63 ResponseFormat::JsonObject => {
64 let mut map = serializer.serialize_map(Some(1))?;
65 map.serialize_entry("type", "json_object")?;
66 map.end()
67 }
68 ResponseFormat::JsonSchema(schema) => {
69 let mut map = serializer.serialize_map(Some(2))?;
70 map.serialize_entry("type", "json_schema")?;
71 map.serialize_entry("json_schema", schema)?;
72 map.end()
73 }
74 }
75 }
76}
77
78#[derive(Clone, Debug)]
79pub enum Stop {
80 String(String),
81 Array(Vec<String>),
82}
83
84impl serde::Serialize for Stop {
85 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
86 where
87 S: serde::Serializer,
88 {
89 match self {
90 Stop::String(string) => serializer.serialize_str(&string),
91 Stop::Array(strings) => {
92 let mut array = serializer.serialize_seq(Some(strings.len()))?;
93
94 for string in strings {
95 array.serialize_element(string)?;
96 }
97
98 array.end()
99 }
100 }
101 }
102}
103
104fn is_false(value: &bool) -> bool {
105 *value == false
106}
107
108#[derive(Clone, Debug, serde::Serialize)]
130pub struct ChatCompletions {
131 pub messages: Vec<Message>,
132 pub model: String,
133 #[serde(skip_serializing_if = "is_false")]
134 pub store: bool,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 pub metadata: Option<serde_json::Value>,
137 #[serde(skip_serializing_if = "Option::is_none")]
138 pub logit_bias: Option<serde_json::Value>,
139 #[serde(skip_serializing_if = "is_false")]
140 pub logprobs: bool,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub top_logprobs: Option<usize>,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub max_tokens: Option<usize>,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 pub max_completion_tokens: Option<usize>,
147 #[serde(skip_serializing_if = "Option::is_none")]
148 pub n: Option<usize>,
149 #[serde(skip_serializing_if = "Option::is_none")]
150 pub presence_penalty: Option<f32>,
151 #[serde(skip_serializing_if = "Option::is_none")]
152 pub response_format: Option<ResponseFormat>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 pub seed: Option<u32>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 pub service_tier: Option<String>,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 pub stop: Option<Stop>,
159 pub stream: bool,
161 #[serde(skip_serializing_if = "Option::is_none")]
166 pub user: Option<String>,
167}
168
169impl Default for ChatCompletions {
170 fn default() -> Self {
171 Self {
172 messages: Default::default(),
173 model: DEFAULT_CHAT_MODEL.into(),
174 store: false,
175 metadata: None,
176 logit_bias: None,
177 logprobs: false,
178 top_logprobs: None,
179 max_tokens: None,
180 max_completion_tokens: None,
181 n: None,
182 presence_penalty: None,
183 response_format: None,
184 seed: None,
185 service_tier: None,
186 stop: None,
187 stream: false,
188 user: None,
189 }
190 }
191}
192
193#[derive(Clone, Debug, serde::Deserialize)]
194pub struct Choice {
195 pub index: usize,
196 pub message: Message,
197 pub finish_reason: String,
199}
200
201#[derive(Clone, Debug, serde::Deserialize)]
202pub struct ChatCompletionsResponse {
203 pub id: String,
204 pub object: String,
205 pub created: usize,
206 pub model: String,
207 pub choices: Vec<Choice>,
208 }
210
211#[derive(Clone, Debug)]
212pub enum Input {
213 String(String),
214 Array(Vec<String>),
215}
216
217impl From<String> for Input {
218 fn from(value: String) -> Self {
219 Self::String(value)
220 }
221}
222
223impl From<&str> for Input {
224 fn from(value: &str) -> Self {
225 Self::String(value.to_string())
226 }
227}
228
229impl From<Vec<String>> for Input {
230 fn from(values: Vec<String>) -> Self {
231 Self::Array(values)
232 }
233}
234
235impl From<&[String]> for Input {
236 fn from(values: &[String]) -> Self {
237 Self::Array(values.to_vec())
238 }
239}
240
241impl serde::Serialize for Input {
242 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
243 where
244 S: serde::Serializer,
245 {
246 match self {
247 Input::String(string) => serializer.serialize_str(string),
248 Input::Array(array) => {
249 let mut seq = serializer.serialize_seq(Some(array.len()))?;
250 for s in array {
251 seq.serialize_element(s)?;
252 }
253 seq.end()
254 }
255 }
256 }
257}
258
259#[derive(Clone, Debug, serde::Serialize)]
271pub struct Embeddings {
272 pub input: Input,
273 pub model: String,
274 #[serde(skip_serializing_if = "Option::is_none")]
276 pub dimensions: Option<usize>,
277 #[serde(skip_serializing_if = "Option::is_none")]
278 pub user: Option<String>,
279}
280
281impl Default for Embeddings {
282 fn default() -> Self {
283 Self {
284 input: Input::String("".into()),
285 model: DEFAULT_EMBEDDING_MODEL.into(),
286 dimensions: None,
287 user: None,
288 }
289 }
290}
291
292#[derive(Clone, Debug, serde::Deserialize)]
293pub struct EmbeddingsResponse {
294 pub data: Vec<Embedding>,
295 pub model: String,
296 pub usage: Option<Usage>, }
298
299#[derive(Clone, Debug, serde::Deserialize)]
300pub struct Embedding {
301 pub index: u64,
302 pub embedding: Vec<f32>,
303}
304
305#[derive(Clone, Debug, serde::Deserialize)]
306pub struct Usage {
307 pub prompt_tokens: u32,
308 pub total_tokens: u32,
309}
310
311#[cfg(feature = "ureq")]
312struct ClientImpl {
313 client: ureq::Agent,
314 token: Option<String>,
315}
316
317#[cfg(feature = "ureq")]
318impl ClientImpl {
319 fn new(token: Option<String>) -> Result<ClientImpl, Error> {
320 Ok(Self {
321 client: ureq::Agent::new(),
322 token,
323 })
324 }
325
326 fn do_request(&self, url: String, body: String) -> Result<String, Error> {
327 let mut request = self
328 .client
329 .post(&url)
330 .set("Content-Type", "application/json");
331
332 if let Some(token) = self.token.as_ref() {
333 request = request.set("Authorization", &format!("Bearer {}", token));
334 }
335
336 let response = request
337 .send_string(&body)
338 .map_err(|e| Error::NetworkError(e.to_string()))?;
339
340 if response.status() != 200 {
341 let text = format!("{} {}", response.status(), response.status_text());
342 Err(Error::ApiError(text))?;
343 }
344
345 let body = response
346 .into_string()
347 .map_err(|e| Error::NetworkError(e.to_string()))?;
348 Ok(body)
349 }
350}
351
352#[cfg(feature = "reqwest")]
353struct ClientImpl {
354 client: reqwest::Client,
355}
356
357#[cfg(feature = "reqwest")]
358impl ClientImpl {
359 fn new(token: Option<String>) -> Result<ClientImpl, Error> {
360 let mut headers = reqwest::header::HeaderMap::new();
361
362 if let Some(token) = token {
363 let mut value = reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
364 .map_err(|e| Error::BadConfigurationError(e.to_string()))?;
365 value.set_sensitive(true);
366 headers.insert(reqwest::header::AUTHORIZATION, value);
367 }
368
369 let client = reqwest::ClientBuilder::new()
370 .default_headers(headers)
371 .build()
372 .map_err(|e| Error::BadConfigurationError(e.to_string()))?;
373
374 Ok(Self { client })
375 }
376
377 async fn do_request(&self, url: String, body: String) -> Result<String, Error> {
378 let response = self
379 .client
380 .post(url)
381 .header(reqwest::header::CONTENT_TYPE, "application/json")
382 .body(body)
383 .send()
384 .await
385 .map_err(|e| Error::NetworkError(e.to_string()))?
386 .error_for_status()
387 .map_err(|e| Error::ApiError(e.to_string()))?
388 .text()
389 .await
390 .map_err(|e| Error::NetworkError(e.to_string()))?;
391
392 Ok(response)
393 }
394}
395
396pub struct Client {
397 inner: ClientImpl,
398 base_uri: String,
399}
400
401impl Client {
402 pub fn new(base_uri: Option<String>, token: Option<String>) -> Result<Client, Error> {
419 let env_base_uri = env::var(OPENAI_API_BASE).unwrap_or_default();
420 let env_token = env::var(OPENAI_API_KEY).unwrap_or_default();
421
422 let base_uri = if env_base_uri.is_empty() {
423 if let Some(uri) = base_uri {
424 uri
425 } else {
426 DEFAULT_API_BASE.to_string()
427 }
428 } else {
429 env_base_uri
430 };
431
432 let token = if env_token.is_empty() {
433 token
434 } else {
435 Some(env_token)
436 };
437
438 Self::new_without_environment(base_uri, token)
439 }
440
441 pub fn new_without_environment(
456 base_uri: String,
457 token: Option<String>,
458 ) -> Result<Client, Error> {
459 if base_uri.is_empty() {
460 return Err(Error::BadConfigurationError("No base URI given".into()));
461 }
462
463 if base_uri == DEFAULT_API_BASE && token.is_none() {
466 return Err(Error::BadConfigurationError("Missing api token".into()));
467 }
468
469 let inner = ClientImpl::new(token)?;
470 Ok(Self { inner, base_uri })
471 }
472
473 pub fn new_from_environment() -> Result<Client, Error> {
481 let env_base_uri =
482 env::var(OPENAI_API_BASE).map_err(|e| Error::BadConfigurationError(e.to_string()))?;
483 let env_token = env::var(OPENAI_API_KEY).unwrap_or_default();
484
485 let token = if env_token.is_empty() {
486 None
487 } else {
488 Some(env_token)
489 };
490
491 Self::new_without_environment(env_base_uri, token)
492 }
493
494 #[cfg(feature = "reqwest")]
532 pub async fn chat_completions(
533 &self,
534 request: &ChatCompletions,
535 ) -> Result<ChatCompletionsResponse, Error> {
536 let url = format!("{}/chat/completions", self.base_uri);
537 let body = serde_json::to_string(request).map_err(Error::SerializationError)?;
538 let response = self.inner.do_request(url, body).await?;
539
540 serde_json::from_str(&response).map_err(|e| Error::DeserializationError(e.to_string()))
541 }
542
543 #[cfg(feature = "ureq")]
581 pub fn chat_completions(
582 &self,
583 request: &ChatCompletions,
584 ) -> Result<ChatCompletionsResponse, Error> {
585 let url = format!("{}/chat/completions", self.base_uri);
586 let body = serde_json::to_string(request).map_err(Error::SerializationError)?;
587 let response = self.inner.do_request(url, body)?;
588
589 serde_json::from_str(&response).map_err(|e| Error::DeserializationError(e.to_string()))
590 }
591
592 #[cfg(feature = "reqwest")]
642 pub async fn chat_completions_into<F, T, E>(
643 &self,
644 request: &ChatCompletions,
645 max_tries: usize,
646 converter: F,
647 ) -> Result<T, Error>
648 where
649 F: Fn(String) -> Result<T, E>,
650 E: ToString,
651 {
652 let mut error: Option<Error> = None;
653
654 for _ in 1..=max_tries {
655 match self.chat_completions(request).await {
656 Ok(mut response) => {
657 let choice = response.choices.swap_remove(0);
658 match converter(choice.message.content) {
659 Ok(result) => return Ok(result),
660 Err(e) => error = Some(Error::DeserializationError(e.to_string())),
661 }
662 }
663 Err(e) => {
664 error = Some(e);
665 }
666 }
667 }
668
669 Err(error.unwrap())
670 }
671
672 #[cfg(feature = "ureq")]
722 pub fn chat_completions_into<F, T, E>(
723 &self,
724 request: &ChatCompletions,
725 max_tries: usize,
726 converter: F,
727 ) -> Result<T, Error>
728 where
729 F: Fn(String) -> Result<T, E>,
730 E: ToString,
731 {
732 let mut error: Option<Error> = None;
733
734 for _ in 1..=max_tries {
735 match self.chat_completions(request) {
736 Ok(mut response) => {
737 let choice = response.choices.swap_remove(0);
738 match converter(choice.message.content) {
739 Ok(result) => return Ok(result),
740 Err(e) => error = Some(Error::DeserializationError(e.to_string())),
741 }
742 }
743 Err(e) => {
744 error = Some(e);
745 }
746 }
747 }
748
749 Err(error.unwrap())
750 }
751
752 #[cfg(feature = "reqwest")]
783 pub async fn embeddings(&self, request: &Embeddings) -> Result<EmbeddingsResponse, Error> {
784 let url = format!("{}/embeddings", self.base_uri);
785 let body = serde_json::to_string(request).map_err(Error::SerializationError)?;
786 let response = self.inner.do_request(url, body).await?;
787
788 serde_json::from_str(&response).map_err(|e| Error::DeserializationError(e.to_string()))
789 }
790
791 #[cfg(feature = "ureq")]
822 pub fn embeddings(&self, request: &Embeddings) -> Result<EmbeddingsResponse, Error> {
823 let url = format!("{}/embeddings", self.base_uri);
824 let body = serde_json::to_string(request).map_err(Error::SerializationError)?;
825 let response = self.inner.do_request(url, body)?;
826
827 serde_json::from_str(&response).map_err(|e| Error::DeserializationError(e.to_string()))
828 }
829}
830
831pub fn parse_json_lenient<T>(text: String) -> Result<T, String>
837where
838 T: serde::de::DeserializeOwned,
839{
840 let found = (text.find('{'), text.rfind('}'));
841 if let (Some(begin), Some(end)) = found {
842 let json = &text[begin..=end];
843 serde_json::from_str(json).map_err(|e| e.to_string())
844 } else {
845 Err("The text doesn't contain a JSON object".into())
846 }
847}
848
849#[cfg(test)]
850mod tests {
851 use super::*;
852
853 #[cfg(feature = "ureq")]
854 #[test]
855 fn test_chat_completions() -> Result<(), Error> {
856 let client = Client::new(None, None)?;
857 let request = ChatCompletions {
858 messages: vec![Message {
859 role: ROLE_SYSTEM.into(),
860 content: "Just say OK.".into(),
861 }],
862 ..Default::default()
863 };
864
865 let response: ChatCompletionsResponse = client.chat_completions(&request)?;
866
867 assert_eq!(response.choices.len(), 1);
868 assert_eq!(response.choices[0].message.content.contains("OK"), true);
869
870 Ok(())
871 }
872
873 #[cfg(feature = "reqwest")]
874 #[tokio::test]
875 async fn test_chat_completions() -> Result<(), Error> {
876 let client = Client::new(None, None)?;
877 let request = ChatCompletions {
878 messages: vec![Message {
879 role: ROLE_SYSTEM.into(),
880 content: "Just say OK.".into(),
881 }],
882 ..Default::default()
883 };
884
885 let response: ChatCompletionsResponse = client.chat_completions(&request).await?;
886
887 assert_eq!(response.choices.len(), 1);
888 assert_eq!(response.choices[0].message.content.contains("OK"), true);
889
890 Ok(())
891 }
892
893 #[cfg(feature = "ureq")]
894 #[test]
895 fn test_chat_completions_into() -> Result<(), Error> {
896 #[derive(serde::Deserialize)]
897 struct Test {
898 hello: String,
899 }
900
901 let client = Client::new(None, None)?;
902 let request = ChatCompletions {
903 messages: vec![Message {
904 role: ROLE_SYSTEM.into(),
905 content: r#"Respond with this JSON: {"hello": "a word of your choosing"}."#.into(),
906 }],
907 ..Default::default()
908 };
909
910 let response: Test = client.chat_completions_into(&request, 3, parse_json_lenient)?;
911 assert_eq!(response.hello.is_empty(), false);
912
913 Ok(())
914 }
915
916 #[cfg(feature = "reqwest")]
917 #[tokio::test]
918 async fn test_chat_completions_into() -> Result<(), Error> {
919 #[derive(serde::Deserialize)]
920 struct Test {
921 hello: String,
922 }
923
924 let client = Client::new(None, None)?;
925 let request = ChatCompletions {
926 messages: vec![Message {
927 role: ROLE_SYSTEM.into(),
928 content: r#"Respond with this JSON: {"hello": "a word of your choosing"}."#.into(),
929 }],
930 ..Default::default()
931 };
932
933 let response: Test = client
934 .chat_completions_into(&request, 3, parse_json_lenient)
935 .await?;
936 assert_eq!(response.hello.is_empty(), false);
937
938 Ok(())
939 }
940
941 #[cfg(feature = "ureq")]
942 #[test]
943 fn test_embeddings() -> Result<(), Error> {
944 let client = Client::new(None, None)?;
945 let request = Embeddings {
946 input: "Hello".into(),
947 ..Default::default()
948 };
949
950 let response: EmbeddingsResponse = client.embeddings(&request)?;
951
952 assert_eq!(response.data.len(), 1);
953 assert_eq!(response.data[0].embedding.is_empty(), false);
954
955 Ok(())
956 }
957
958 #[cfg(feature = "reqwest")]
959 #[tokio::test]
960 async fn test_embeddings() -> Result<(), Error> {
961 let client = Client::new(None, None)?;
962 let request = Embeddings {
963 input: "Hello".into(),
964 ..Default::default()
965 };
966
967 let response: EmbeddingsResponse = client.embeddings(&request).await?;
968
969 assert_eq!(response.data.len(), 1);
970 assert_eq!(response.data[0].embedding.is_empty(), false);
971
972 Ok(())
973 }
974
975 #[test]
976 fn test_parse_json_lenient() -> Result<(), String> {
977 #[derive(serde::Deserialize)]
978 struct Test {
979 hello: String,
980 }
981
982 let test: Test = parse_json_lenient(r#"Here's your JSON: {"hello": "world"}"#.into())?;
983 assert_eq!(test.hello, "world");
984
985 let test: Result<Test, String> =
986 parse_json_lenient(r#"JSON is a great choice for your request!"#.into());
987 assert_eq!(test.is_err(), true);
988
989 Ok(())
990 }
991}