async_llm/providers/
config.rs

1use derive_builder::Builder;
2use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
3use secrecy::{ExposeSecret, SecretString};
4use std::fmt::Debug;
5
6use crate::error::Error;
7
8use super::openai::OPENAI_BASE_URL;
9
10pub const OPENAI_ORGANIZATION: &str = "OpenAI-Organization";
11pub const OPENAI_PROJECT: &str = "OpenAI-Project";
12pub const OPENAI_BETA: &str = "OpenAI-Beta";
13
14pub trait Config: Debug + Clone + Send + Sync {
15    fn headers(&self) -> Result<HeaderMap, Error>;
16    fn url(&self, path: &str) -> String;
17    fn query(&self) -> Vec<(&str, &str)>;
18
19    fn base_url(&self) -> &str;
20
21    fn api_key(&self) -> Option<&SecretString>;
22
23    fn stream_done_message(&self) -> &'static str {
24        "[DONE]"
25    }
26}
27
28#[derive(Debug, Clone, Builder)]
29#[builder(derive(Debug))]
30#[builder(build_fn(error = Error))]
31pub struct OpenAIConfig {
32    pub(crate) base_url: String,
33    pub(crate) api_key: Option<SecretString>,
34    pub(crate) org_id: Option<String>,
35    pub(crate) project_id: Option<String>,
36    pub(crate) beta: Option<String>,
37}
38
39fn sanitize_base_url(input: impl Into<String>) -> String {
40    let input: String = input.into();
41    input.trim_end_matches(|c| c == '/' || c == ' ').to_string()
42}
43
44impl OpenAIConfig {
45    pub fn new(base_url: impl Into<String>, api_key: Option<SecretString>) -> Self {
46        Self {
47            base_url: sanitize_base_url(base_url),
48            api_key: api_key.into(),
49            beta: Some("assistants=v2".into()),
50            ..Default::default()
51        }
52    }
53}
54
55impl Default for OpenAIConfig {
56    fn default() -> Self {
57        Self {
58            base_url: sanitize_base_url(
59                std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| OPENAI_BASE_URL.to_string()),
60            ),
61            api_key: std::env::var("OPENAI_API_KEY").map(|v| v.into()).ok(),
62            org_id: Default::default(),
63            project_id: Default::default(),
64            beta: Some("assistants=v2".into()),
65        }
66    }
67}
68
69impl Config for OpenAIConfig {
70    fn headers(&self) -> Result<reqwest::header::HeaderMap, Error> {
71        let mut headers = HeaderMap::new();
72
73        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
74
75        if let Some(api_key) = &self.api_key {
76            let bearer = format!("Bearer {}", api_key.expose_secret());
77            headers.insert(
78                AUTHORIZATION,
79                bearer.parse().map_err(|e| {
80                    Error::InvalidConfig(format!(
81                        "Failed to convert api key id to header value. {:?}",
82                        e
83                    ))
84                })?,
85            );
86        }
87
88        if let Some(org_id) = &self.org_id {
89            headers.insert(
90                OPENAI_ORGANIZATION,
91                org_id.parse().map_err(|e| {
92                    Error::InvalidConfig(format!(
93                        "Failed to convert organization id to header value. {:?}",
94                        e
95                    ))
96                })?,
97            );
98        }
99        if let Some(project_id) = &self.project_id {
100            headers.insert(
101                OPENAI_PROJECT,
102                project_id.parse().map_err(|e| {
103                    Error::InvalidConfig(format!(
104                        "Failed to convert project id to header value. {:?}",
105                        e
106                    ))
107                })?,
108            );
109        }
110
111        // See: https://github.com/64bit/async-openai/blob/bd7a87e335630d5d2f3e6cef30d15633048937b3/async-openai/src/config.rs#L111
112        if let Some(beta) = &self.beta {
113            headers.insert(
114                OPENAI_BETA,
115                beta.parse().map_err(|e| {
116                    Error::InvalidConfig(format!("Failed to convert beta to header. {:?}", e))
117                })?,
118            );
119        }
120        Ok(headers)
121    }
122
123    fn url(&self, path: &str) -> String {
124        format!("{}{}", self.base_url, path)
125    }
126
127    fn query(&self) -> Vec<(&str, &str)> {
128        vec![]
129    }
130
131    fn base_url(&self) -> &str {
132        &self.base_url
133    }
134
135    fn api_key(&self) -> Option<&SecretString> {
136        self.api_key.as_ref()
137    }
138}