Skip to main content

openai_oxide/
config.rs

1// Client configuration
2
3use std::env;
4
5use reqwest::header::HeaderMap;
6
7use crate::request_options::RequestOptions;
8
9const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
10const DEFAULT_TIMEOUT_SECS: u64 = 600;
11const DEFAULT_MAX_RETRIES: u32 = 2;
12
13/// Configuration trait for API clients.
14///
15/// Allows implementing custom configurations for different providers
16/// (e.g., standard OpenAI, Azure OpenAI, OpenRouter).
17pub trait Config: Send + Sync + std::fmt::Debug {
18    /// Returns the base URL for the API.
19    fn base_url(&self) -> &str;
20
21    /// Returns the API key.
22    fn api_key(&self) -> &str;
23
24    /// Hook to modify or add provider-specific headers and auth to a request.
25    fn build_request(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder;
26
27    /// Returns the organization ID, if any.
28    fn organization(&self) -> Option<&str> {
29        None
30    }
31
32    /// Returns the project ID, if any.
33    fn project(&self) -> Option<&str> {
34        None
35    }
36
37    /// Returns the timeout in seconds.
38    fn timeout_secs(&self) -> u64 {
39        DEFAULT_TIMEOUT_SECS
40    }
41
42    /// Returns the maximum number of retries.
43    fn max_retries(&self) -> u32 {
44        DEFAULT_MAX_RETRIES
45    }
46
47    /// Returns default headers to append to all requests.
48    fn default_headers(&self) -> Option<&HeaderMap> {
49        None
50    }
51
52    /// Returns default query parameters to append to all requests.
53    fn default_query(&self) -> Option<&[(String, String)]> {
54        None
55    }
56
57    /// Build the initial `RequestOptions` from config-level defaults.
58    fn initial_options(&self) -> RequestOptions {
59        let mut opts = RequestOptions::new();
60        if let Some(h) = self.default_headers() {
61            opts.headers = Some(h.clone());
62        }
63        if let Some(q) = self.default_query() {
64            opts.query = Some(q.to_vec());
65        }
66        opts
67    }
68}
69
70/// Configuration for the OpenAI client.
71#[derive(Debug, Clone)]
72pub struct ClientConfig {
73    pub api_key: String,
74    pub base_url: String,
75    pub organization: Option<String>,
76    pub project: Option<String>,
77    pub timeout_secs: u64,
78    pub max_retries: u32,
79    /// Default headers sent with every request.
80    pub default_headers: Option<HeaderMap>,
81    /// Default query parameters appended to every request URL.
82    pub default_query: Option<Vec<(String, String)>>,
83    /// When true, use `api-key` header instead of `Authorization: Bearer` for auth.
84    /// This is used by Azure OpenAI deployments.
85    pub(crate) use_azure_api_key_header: bool,
86}
87
88impl ClientConfig {
89    /// Create a new config with the given API key.
90    pub fn new(api_key: impl Into<String>) -> Self {
91        Self {
92            api_key: api_key.into(),
93            base_url: DEFAULT_BASE_URL.to_string(),
94            organization: None,
95            project: None,
96            timeout_secs: DEFAULT_TIMEOUT_SECS,
97            max_retries: DEFAULT_MAX_RETRIES,
98            default_headers: None,
99            default_query: None,
100            use_azure_api_key_header: false,
101        }
102    }
103
104    /// Create config from the `OPENAI_API_KEY` environment variable.
105    pub fn from_env() -> Result<Self, crate::error::OpenAIError> {
106        let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
107            crate::error::OpenAIError::InvalidArgument(
108                "OPENAI_API_KEY environment variable not set".to_string(),
109            )
110        })?;
111        Ok(Self::new(api_key))
112    }
113
114    pub fn base_url(mut self, url: impl Into<String>) -> Self {
115        self.base_url = url.into();
116        self
117    }
118
119    pub fn organization(mut self, org: impl Into<String>) -> Self {
120        self.organization = Some(org.into());
121        self
122    }
123
124    pub fn project(mut self, project: impl Into<String>) -> Self {
125        self.project = Some(project.into());
126        self
127    }
128
129    pub fn timeout_secs(mut self, secs: u64) -> Self {
130        self.timeout_secs = secs;
131        self
132    }
133
134    pub fn max_retries(mut self, retries: u32) -> Self {
135        self.max_retries = retries;
136        self
137    }
138
139    /// Set default headers sent with every request.
140    pub fn default_headers(mut self, headers: HeaderMap) -> Self {
141        self.default_headers = Some(headers);
142        self
143    }
144
145    /// Set default query parameters appended to every request URL.
146    pub fn default_query(mut self, query: Vec<(String, String)>) -> Self {
147        self.default_query = Some(query);
148        self
149    }
150
151    /// Use Azure `api-key` header instead of `Authorization: Bearer` for auth.
152    pub(crate) fn use_azure_api_key_header(mut self, enabled: bool) -> Self {
153        self.use_azure_api_key_header = enabled;
154        self
155    }
156}
157
158impl Config for ClientConfig {
159    fn base_url(&self) -> &str {
160        &self.base_url
161    }
162
163    fn api_key(&self) -> &str {
164        &self.api_key
165    }
166
167    fn build_request(&self, mut req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
168        if self.use_azure_api_key_header {
169            req = req.header("api-key", &self.api_key);
170        } else {
171            req = req.bearer_auth(&self.api_key);
172        }
173
174        if let Some(ref org) = self.organization {
175            req = req.header("OpenAI-Organization", org);
176        }
177        if let Some(ref project) = self.project {
178            req = req.header("OpenAI-Project", project);
179        }
180
181        req
182    }
183
184    fn organization(&self) -> Option<&str> {
185        self.organization.as_deref()
186    }
187
188    fn project(&self) -> Option<&str> {
189        self.project.as_deref()
190    }
191
192    fn timeout_secs(&self) -> u64 {
193        self.timeout_secs
194    }
195
196    fn max_retries(&self) -> u32 {
197        self.max_retries
198    }
199
200    fn default_headers(&self) -> Option<&HeaderMap> {
201        self.default_headers.as_ref()
202    }
203
204    fn default_query(&self) -> Option<&[(String, String)]> {
205        self.default_query.as_deref()
206    }
207}