async_llm/providers/
config.rs1use derive_builder::Builder;
2use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
3use secrecy::{ExposeSecret, SecretString};
4use std::fmt::Debug;
5
6use crate::error::Error;
7
8use super::openai::OPENAI_BASE_URL;
9
10pub const OPENAI_ORGANIZATION: &str = "OpenAI-Organization";
11pub const OPENAI_PROJECT: &str = "OpenAI-Project";
12pub const OPENAI_BETA: &str = "OpenAI-Beta";
13
14pub trait Config: Debug + Clone + Send + Sync {
15 fn headers(&self) -> Result<HeaderMap, Error>;
16 fn url(&self, path: &str) -> String;
17 fn query(&self) -> Vec<(&str, &str)>;
18
19 fn base_url(&self) -> &str;
20
21 fn api_key(&self) -> Option<&SecretString>;
22
23 fn stream_done_message(&self) -> &'static str {
24 "[DONE]"
25 }
26}
27
28#[derive(Debug, Clone, Builder)]
29#[builder(derive(Debug))]
30#[builder(build_fn(error = Error))]
31pub struct OpenAIConfig {
32 pub(crate) base_url: String,
33 pub(crate) api_key: Option<SecretString>,
34 pub(crate) org_id: Option<String>,
35 pub(crate) project_id: Option<String>,
36 pub(crate) beta: Option<String>,
37}
38
39fn sanitize_base_url(input: impl Into<String>) -> String {
40 let input: String = input.into();
41 input.trim_end_matches(|c| c == '/' || c == ' ').to_string()
42}
43
44impl OpenAIConfig {
45 pub fn new(base_url: impl Into<String>, api_key: Option<SecretString>) -> Self {
46 Self {
47 base_url: sanitize_base_url(base_url),
48 api_key: api_key.into(),
49 beta: Some("assistants=v2".into()),
50 ..Default::default()
51 }
52 }
53}
54
55impl Default for OpenAIConfig {
56 fn default() -> Self {
57 Self {
58 base_url: sanitize_base_url(
59 std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| OPENAI_BASE_URL.to_string()),
60 ),
61 api_key: std::env::var("OPENAI_API_KEY").map(|v| v.into()).ok(),
62 org_id: Default::default(),
63 project_id: Default::default(),
64 beta: Some("assistants=v2".into()),
65 }
66 }
67}
68
69impl Config for OpenAIConfig {
70 fn headers(&self) -> Result<reqwest::header::HeaderMap, Error> {
71 let mut headers = HeaderMap::new();
72
73 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
74
75 if let Some(api_key) = &self.api_key {
76 let bearer = format!("Bearer {}", api_key.expose_secret());
77 headers.insert(
78 AUTHORIZATION,
79 bearer.parse().map_err(|e| {
80 Error::InvalidConfig(format!(
81 "Failed to convert api key id to header value. {:?}",
82 e
83 ))
84 })?,
85 );
86 }
87
88 if let Some(org_id) = &self.org_id {
89 headers.insert(
90 OPENAI_ORGANIZATION,
91 org_id.parse().map_err(|e| {
92 Error::InvalidConfig(format!(
93 "Failed to convert organization id to header value. {:?}",
94 e
95 ))
96 })?,
97 );
98 }
99 if let Some(project_id) = &self.project_id {
100 headers.insert(
101 OPENAI_PROJECT,
102 project_id.parse().map_err(|e| {
103 Error::InvalidConfig(format!(
104 "Failed to convert project id to header value. {:?}",
105 e
106 ))
107 })?,
108 );
109 }
110
111 if let Some(beta) = &self.beta {
113 headers.insert(
114 OPENAI_BETA,
115 beta.parse().map_err(|e| {
116 Error::InvalidConfig(format!("Failed to convert beta to header. {:?}", e))
117 })?,
118 );
119 }
120 Ok(headers)
121 }
122
123 fn url(&self, path: &str) -> String {
124 format!("{}{}", self.base_url, path)
125 }
126
127 fn query(&self) -> Vec<(&str, &str)> {
128 vec![]
129 }
130
131 fn base_url(&self) -> &str {
132 &self.base_url
133 }
134
135 fn api_key(&self) -> Option<&SecretString> {
136 self.api_key.as_ref()
137 }
138}