1use super::error::ErrorResponse;
2use anyhow::Result;
3use reqwest::{
4 header::{HeaderMap, HeaderValue, AUTHORIZATION},
5 Client as HttpClient, ClientBuilder, Response,
6};
7use serde::Serialize;
8use std::env;
9
10#[derive(Debug, PartialEq, Eq, Clone)]
11pub struct Config {
12 pub api_key: String,
13 pub organization: Option<String>,
14}
15impl Config {
16 pub fn from_env() -> Result<Self> {
17 dotenv::dotenv().ok();
18 let api_key = match env::var("API_KEY") {
19 Ok(val) => val,
20 _ => return Err(anyhow::anyhow!("API_KEY must be set.")),
21 };
22 let organization = env::var("ORGANIZATION").ok();
23 Ok(Self {
24 api_key,
25 organization,
26 })
27 }
28}
29
30#[derive(Debug, Clone)]
31pub struct Client {
32 client: HttpClient,
33}
34
35impl Client {
36 pub fn new(config: Config) -> Result<Self> {
37 let mut headers = HeaderMap::new();
38 headers.insert(
39 AUTHORIZATION,
40 HeaderValue::from_str(format!("Bearer {}", config.api_key.as_str()).as_str())?,
41 );
42 if let Some(org) = config.organization {
43 headers.append(
44 "OpenAI-Organization",
45 HeaderValue::from_str(format!("{}", org.as_str()).as_str())?,
46 );
47 };
48 let client = ClientBuilder::new().default_headers(headers).build()?;
49 Ok(Self { client })
50 }
51
52 pub async fn get(&self, url: &str) -> Result<Response> {
53 let res = self.client.get(url).send().await?;
54 if !res.status().is_success() {
55 let error: ErrorResponse = serde_json::from_str(res.text().await?.as_str())?;
56 return Err(anyhow::anyhow!(format!(
57 "{}: {}",
58 error.error.code, error.error.message
59 )));
60 }
61 Ok(res)
62 }
63
64 pub async fn post<T: Serialize>(&self, url: &str, body: T) -> Result<Response> {
65 let res = self.client.post(url).json(&body).send().await?;
66 if !res.status().is_success() {
67 let error: ErrorResponse = serde_json::from_str(res.text().await?.as_str())?;
68 return Err(anyhow::anyhow!(format!(
69 "{}: {}",
70 error.error.code, error.error.message
71 )));
72 }
73 Ok(res)
74 }
75}