gpt_rs/
client.rs

1use super::error::ErrorResponse;
2use anyhow::Result;
3use reqwest::{
4    header::{HeaderMap, HeaderValue, AUTHORIZATION},
5    Client as HttpClient, ClientBuilder, Response,
6};
7use serde::Serialize;
8use std::env;
9
10#[derive(Debug, PartialEq, Eq, Clone)]
11pub struct Config {
12    pub api_key: String,
13    pub organization: Option<String>,
14}
15impl Config {
16    pub fn from_env() -> Result<Self> {
17        dotenv::dotenv().ok();
18        let api_key = match env::var("API_KEY") {
19            Ok(val) => val,
20            _ => return Err(anyhow::anyhow!("API_KEY must be set.")),
21        };
22        let organization = env::var("ORGANIZATION").ok();
23        Ok(Self {
24            api_key,
25            organization,
26        })
27    }
28}
29
30#[derive(Debug, Clone)]
31pub struct Client {
32    client: HttpClient,
33}
34
35impl Client {
36    pub fn new(config: Config) -> Result<Self> {
37        let mut headers = HeaderMap::new();
38        headers.insert(
39            AUTHORIZATION,
40            HeaderValue::from_str(format!("Bearer {}", config.api_key.as_str()).as_str())?,
41        );
42        if let Some(org) = config.organization {
43            headers.append(
44                "OpenAI-Organization",
45                HeaderValue::from_str(format!("{}", org.as_str()).as_str())?,
46            );
47        };
48        let client = ClientBuilder::new().default_headers(headers).build()?;
49        Ok(Self { client })
50    }
51
52    pub async fn get(&self, url: &str) -> Result<Response> {
53        let res = self.client.get(url).send().await?;
54        if !res.status().is_success() {
55            let error: ErrorResponse = serde_json::from_str(res.text().await?.as_str())?;
56            return Err(anyhow::anyhow!(format!(
57                "{}: {}",
58                error.error.code, error.error.message
59            )));
60        }
61        Ok(res)
62    }
63
64    pub async fn post<T: Serialize>(&self, url: &str, body: T) -> Result<Response> {
65        let res = self.client.post(url).json(&body).send().await?;
66        if !res.status().is_success() {
67            let error: ErrorResponse = serde_json::from_str(res.text().await?.as_str())?;
68            return Err(anyhow::anyhow!(format!(
69                "{}: {}",
70                error.error.code, error.error.message
71            )));
72        }
73        Ok(res)
74    }
75}