1use reqwest::header::{HeaderMap, AUTHORIZATION};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
8pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
10pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
12
13pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
15
16pub trait Config: Clone {
19 fn headers(&self) -> HeaderMap;
20 fn url(&self, path: &str) -> String;
21 fn query(&self) -> Vec<(&str, &str)>;
22
23 fn api_base(&self) -> &str;
24
25 fn api_key(&self) -> &SecretString;
26}
27
28#[derive(Clone, Debug, Deserialize)]
30#[serde(default)]
31pub struct OpenAIConfig {
32 api_base: String,
33 api_key: SecretString,
34 org_id: String,
35 project_id: String,
36}
37
38impl Default for OpenAIConfig {
39 fn default() -> Self {
40 Self {
41 api_base: OPENAI_API_BASE.to_string(),
42 api_key: std::env::var("OPENAI_API_KEY")
43 .unwrap_or_else(|_| "".to_string())
44 .into(),
45 org_id: Default::default(),
46 project_id: Default::default(),
47 }
48 }
49}
50
51impl OpenAIConfig {
52 pub fn new() -> Self {
54 Default::default()
55 }
56
57 pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
59 self.org_id = org_id.into();
60 self
61 }
62
63 pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
65 self.project_id = project_id.into();
66 self
67 }
68
69 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
71 self.api_key = SecretString::from(api_key.into());
72 self
73 }
74
75 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
77 self.api_base = api_base.into();
78 self
79 }
80
81 pub fn org_id(&self) -> &str {
82 &self.org_id
83 }
84}
85
86impl Config for OpenAIConfig {
87 fn headers(&self) -> HeaderMap {
88 let mut headers = HeaderMap::new();
89 if !self.org_id.is_empty() {
90 headers.insert(
91 OPENAI_ORGANIZATION_HEADER,
92 self.org_id.as_str().parse().unwrap(),
93 );
94 }
95
96 if !self.project_id.is_empty() {
97 headers.insert(
98 OPENAI_PROJECT_HEADER,
99 self.project_id.as_str().parse().unwrap(),
100 );
101 }
102
103 headers.insert(
104 AUTHORIZATION,
105 format!("Bearer {}", self.api_key.expose_secret())
106 .as_str()
107 .parse()
108 .unwrap(),
109 );
110
111 headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap());
114
115 headers
116 }
117
118 fn url(&self, path: &str) -> String {
119 format!("{}{}", self.api_base, path)
120 }
121
122 fn api_base(&self) -> &str {
123 &self.api_base
124 }
125
126 fn api_key(&self) -> &SecretString {
127 &self.api_key
128 }
129
130 fn query(&self) -> Vec<(&str, &str)> {
131 vec![]
132 }
133}
134
135#[derive(Clone, Debug, Deserialize)]
137#[serde(default)]
138pub struct AzureConfig {
139 api_version: String,
140 deployment_id: String,
141 api_base: String,
142 api_key: SecretString,
143}
144
145impl Default for AzureConfig {
146 fn default() -> Self {
147 Self {
148 api_base: Default::default(),
149 api_key: std::env::var("OPENAI_API_KEY")
150 .unwrap_or_else(|_| "".to_string())
151 .into(),
152 deployment_id: Default::default(),
153 api_version: Default::default(),
154 }
155 }
156}
157
158impl AzureConfig {
159 pub fn new() -> Self {
160 Default::default()
161 }
162
163 pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
164 self.api_version = api_version.into();
165 self
166 }
167
168 pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
169 self.deployment_id = deployment_id.into();
170 self
171 }
172
173 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
175 self.api_key = SecretString::from(api_key.into());
176 self
177 }
178
179 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
181 self.api_base = api_base.into();
182 self
183 }
184}
185
186impl Config for AzureConfig {
187 fn headers(&self) -> HeaderMap {
188 let mut headers = HeaderMap::new();
189
190 headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
191
192 headers
193 }
194
195 fn url(&self, path: &str) -> String {
196 format!(
197 "{}/openai/deployments/{}{}",
198 self.api_base, self.deployment_id, path
199 )
200 }
201
202 fn api_base(&self) -> &str {
203 &self.api_base
204 }
205
206 fn api_key(&self) -> &SecretString {
207 &self.api_key
208 }
209
210 fn query(&self) -> Vec<(&str, &str)> {
211 vec![("api-version", &self.api_version)]
212 }
213}