openai_api_wrapper/
client.rs

1use eyre::WrapErr;
2use reqwest::{IntoUrl, RequestBuilder};
3use serde::Serialize;
4
5#[derive(Clone)]
6pub struct Client(reqwest::Client);
7
8#[derive(thiserror::Error, Debug)]
9#[error("Failed to build OpenAI client: {0}")]
10pub struct ClientBuildError(#[from] eyre::Report);
11
12#[derive(thiserror::Error, Debug)]
13#[error("Failed to make OpenAI request: {0}")]
14pub struct RequestError(#[from] eyre::Report);
15
16impl Client {
17    pub fn new(
18        access_token: impl AsRef<str>,
19        organization: impl AsRef<str>,
20    ) -> Result<Self, ClientBuildError> {
21        let header_map = reqwest::header::HeaderMap::from_iter([
22            (
23                reqwest::header::AUTHORIZATION,
24                format!("Bearer {}", access_token.as_ref())
25                    .try_into()
26                    .wrap_err("Unable to convert access token to header value")?,
27            ),
28            (
29                "OpenAI-Organization".try_into().unwrap(),
30                organization
31                    .as_ref()
32                    .try_into()
33                    .wrap_err("Unable to parse organization id into header value")?,
34            ),
35        ]);
36
37        Ok(Self(
38            reqwest::Client::builder()
39                .default_headers(header_map)
40                .build()
41                .wrap_err("Unable to build OpenAI client")?,
42        ))
43    }
44
45    pub async fn request<T: OpenAIRequest>(&self, req: T) -> Result<T::Response, RequestError> {
46        let res = self
47            .0
48            .request(T::method(), T::url())
49            .json(&req)
50            .send()
51            .await
52            .wrap_err("Unable to build request")?;
53
54        let output = res.text().await.wrap_err("Unable to get response text")?;
55
56        Ok(serde_json::from_str(&output)
57            .wrap_err_with(|| format!("Unable to parse response as JSON: {}", output))?)
58    }
59}
60
61pub trait OpenAIRequest: Serialize {
62    type Response: serde::de::DeserializeOwned;
63
64    fn method() -> reqwest::Method;
65    fn url() -> &'static str;
66}
67
68#[cfg(test)]
69mod test {
70    #[tokio::test]
71    async fn client_builder() {
72        let client = super::Client::new("test", "org name").unwrap();
73
74        let mut server = mockito::Server::new();
75
76        let mock = server
77            .mock("GET", "/")
78            .match_header("Authorization", "Bearer test")
79            .match_header("OpenAI-Organization", "org name")
80            .create();
81
82        let request = client.0.get(server.url()).send().await;
83        mock.assert();
84    }
85}