openai_api_wrapper/
client.rs1use eyre::WrapErr;
2use reqwest::{IntoUrl, RequestBuilder};
3use serde::Serialize;
4
5#[derive(Clone)]
6pub struct Client(reqwest::Client);
7
8#[derive(thiserror::Error, Debug)]
9#[error("Failed to build OpenAI client: {0}")]
10pub struct ClientBuildError(#[from] eyre::Report);
11
12#[derive(thiserror::Error, Debug)]
13#[error("Failed to make OpenAI request: {0}")]
14pub struct RequestError(#[from] eyre::Report);
15
16impl Client {
17 pub fn new(
18 access_token: impl AsRef<str>,
19 organization: impl AsRef<str>,
20 ) -> Result<Self, ClientBuildError> {
21 let header_map = reqwest::header::HeaderMap::from_iter([
22 (
23 reqwest::header::AUTHORIZATION,
24 format!("Bearer {}", access_token.as_ref())
25 .try_into()
26 .wrap_err("Unable to convert access token to header value")?,
27 ),
28 (
29 "OpenAI-Organization".try_into().unwrap(),
30 organization
31 .as_ref()
32 .try_into()
33 .wrap_err("Unable to parse organization id into header value")?,
34 ),
35 ]);
36
37 Ok(Self(
38 reqwest::Client::builder()
39 .default_headers(header_map)
40 .build()
41 .wrap_err("Unable to build OpenAI client")?,
42 ))
43 }
44
45 pub async fn request<T: OpenAIRequest>(&self, req: T) -> Result<T::Response, RequestError> {
46 let res = self
47 .0
48 .request(T::method(), T::url())
49 .json(&req)
50 .send()
51 .await
52 .wrap_err("Unable to build request")?;
53
54 let output = res.text().await.wrap_err("Unable to get response text")?;
55
56 Ok(serde_json::from_str(&output)
57 .wrap_err_with(|| format!("Unable to parse response as JSON: {}", output))?)
58 }
59}
60
61pub trait OpenAIRequest: Serialize {
62 type Response: serde::de::DeserializeOwned;
63
64 fn method() -> reqwest::Method;
65 fn url() -> &'static str;
66}
67
68#[cfg(test)]
69mod test {
70 #[tokio::test]
71 async fn client_builder() {
72 let client = super::Client::new("test", "org name").unwrap();
73
74 let mut server = mockito::Server::new();
75
76 let mock = server
77 .mock("GET", "/")
78 .match_header("Authorization", "Bearer test")
79 .match_header("OpenAI-Organization", "org name")
80 .create();
81
82 let request = client.0.get(server.url()).send().await;
83 mock.assert();
84 }
85}