Skip to main content

liter_llm/client/
config.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use secrecy::SecretString;
5
6use crate::auth::CredentialProvider;
7#[cfg(any(feature = "native-http", feature = "wasm-http"))]
8use crate::error::{LiterLlmError, Result};
9#[cfg(feature = "tower")]
10use crate::tower::{BudgetConfig, CacheConfig, CacheStore, LlmHook, RateLimitConfig};
11
12/// Configuration for an LLM client.
13///
14/// `api_key` is stored as a [`SecretString`] so it is zeroed on drop and never
15/// printed accidentally.  Access it via [`secrecy::ExposeSecret`].
16#[derive(Clone)]
17pub struct ClientConfig {
18    /// API key for authentication (stored as a secret).
19    pub api_key: SecretString,
20    /// Override base URL.  When set, all requests go here regardless of model
21    /// name, and provider auto-detection is skipped.
22    pub base_url: Option<String>,
23    /// Request timeout.
24    pub timeout: Duration,
25    /// Maximum number of retries on 429 / 5xx responses.
26    pub max_retries: u32,
27    /// Extra headers sent on every request.
28    ///
29    /// Use `Vec<(String, String)>` rather than `HashMap` to preserve insertion
30    /// order and avoid non-deterministic iteration when building the reqwest
31    /// `HeaderMap`.  Access via [`ClientConfig::headers`]; do not mutate
32    /// directly from outside this crate.
33    pub(crate) extra_headers: Vec<(String, String)>,
34    /// Optional dynamic credential provider for token-based auth
35    /// (Azure AD, Vertex OAuth2) or refreshable credentials (AWS STS).
36    ///
37    /// When set, the client calls `resolve()` before each request to obtain
38    /// a fresh credential.  When `None`, the static `api_key` is used.
39    pub credential_provider: Option<Arc<dyn CredentialProvider>>,
40
41    /// Configuration for the response cache Tower middleware layer.
42    ///
43    /// When set, bindings and advanced Rust users can use this to construct
44    /// a [`CacheLayer`](crate::tower::CacheLayer) in their Tower stack.
45    #[cfg(feature = "tower")]
46    pub cache_config: Option<CacheConfig>,
47
48    /// Custom cache store backend for the cache Tower middleware layer.
49    ///
50    /// When set alongside `cache_config`, the cache layer will use this
51    /// store instead of the default in-memory LRU.
52    #[cfg(feature = "tower")]
53    pub cache_store: Option<Arc<dyn CacheStore>>,
54
55    /// Configuration for the budget enforcement Tower middleware layer.
56    ///
57    /// When set, bindings and advanced Rust users can use this to construct
58    /// a [`BudgetLayer`](crate::tower::BudgetLayer) in their Tower stack.
59    #[cfg(feature = "tower")]
60    pub budget_config: Option<BudgetConfig>,
61
62    /// User-defined hooks for the hooks Tower middleware layer.
63    ///
64    /// These hooks are invoked at request lifecycle points (pre-request,
65    /// post-response, on-error) when a
66    /// [`HooksLayer`](crate::tower::HooksLayer) is constructed from this
67    /// config.
68    #[cfg(feature = "tower")]
69    pub hooks: Vec<Arc<dyn LlmHook>>,
70
71    /// Cooldown duration after transient errors (rate limit, timeout, server error).
72    /// When set, the client rejects requests with `ServiceUnavailable` during cooldown.
73    #[cfg(feature = "tower")]
74    pub cooldown_duration: Option<Duration>,
75
76    /// Per-model rate limiting configuration (RPM/TPM).
77    #[cfg(feature = "tower")]
78    pub rate_limit_config: Option<RateLimitConfig>,
79
80    /// Background health check interval. When set, periodically probes the provider
81    /// and rejects requests when the provider is unhealthy.
82    #[cfg(feature = "tower")]
83    pub health_check_interval: Option<Duration>,
84
85    /// Enable per-request cost tracking. Costs are accumulated atomically and
86    /// logged via `tracing::info`.
87    #[cfg(feature = "tower")]
88    pub enable_cost_tracking: bool,
89
90    /// Enable OpenTelemetry-compatible tracing spans for every request.
91    #[cfg(feature = "tower")]
92    pub enable_tracing: bool,
93
94    /// Automatically load the API key from the provider's environment variable
95    /// when no explicit key is provided.
96    ///
97    /// When `true` (the default) and `api_key` is empty, [`DefaultClient::new`]
98    /// reads the provider's designated environment variable (e.g.
99    /// `OPENAI_API_KEY` for OpenAI).  Set to `false` to suppress this behaviour
100    /// and require the caller to supply the key explicitly.
101    ///
102    /// Has no effect on WASM targets, where `std::env::var` is unavailable.
103    pub load_env: bool,
104}
105
106impl ClientConfig {
107    /// Create a config with the given API key and sensible defaults.
108    pub fn new(api_key: impl Into<String>) -> Self {
109        Self {
110            api_key: SecretString::from(api_key.into()),
111            base_url: None,
112            timeout: Duration::from_secs(60),
113            max_retries: 3,
114            extra_headers: Vec::new(),
115            credential_provider: None,
116            load_env: true,
117            #[cfg(feature = "tower")]
118            cache_config: None,
119            #[cfg(feature = "tower")]
120            cache_store: None,
121            #[cfg(feature = "tower")]
122            budget_config: None,
123            #[cfg(feature = "tower")]
124            hooks: Vec::new(),
125            #[cfg(feature = "tower")]
126            cooldown_duration: None,
127            #[cfg(feature = "tower")]
128            rate_limit_config: None,
129            #[cfg(feature = "tower")]
130            health_check_interval: None,
131            #[cfg(feature = "tower")]
132            enable_cost_tracking: false,
133            #[cfg(feature = "tower")]
134            enable_tracing: false,
135        }
136    }
137
138    /// Return the extra headers as an ordered slice of `(name, value)` pairs.
139    pub fn headers(&self) -> &[(String, String)] {
140        &self.extra_headers
141    }
142}
143
144/// Note: intentionally does *not* implement `Debug` so the secret key is never
145/// accidentally logged via `{:?}`.
146impl std::fmt::Debug for ClientConfig {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        // Redact all header values — they may contain API keys or secrets.
149        let redacted_headers: Vec<(&str, &str)> = self
150            .extra_headers
151            .iter()
152            .map(|(k, _v)| (k.as_str(), "[redacted]"))
153            .collect();
154        let mut dbg = f.debug_struct("ClientConfig");
155        dbg.field("api_key", &"[redacted]")
156            .field("base_url", &self.base_url)
157            .field("timeout", &self.timeout)
158            .field("max_retries", &self.max_retries)
159            .field("extra_headers", &redacted_headers)
160            .field("load_env", &self.load_env)
161            .field(
162                "credential_provider",
163                &self.credential_provider.as_ref().map(|_| "[configured]"),
164            );
165
166        #[cfg(feature = "tower")]
167        {
168            dbg.field("cache_config", &self.cache_config)
169                .field("cache_store", &self.cache_store.as_ref().map(|_| "[configured]"))
170                .field("budget_config", &self.budget_config)
171                .field("hooks_count", &self.hooks.len())
172                .field("cooldown_duration", &self.cooldown_duration)
173                .field("rate_limit_config", &self.rate_limit_config)
174                .field("health_check_interval", &self.health_check_interval)
175                .field("enable_cost_tracking", &self.enable_cost_tracking)
176                .field("enable_tracing", &self.enable_tracing);
177        }
178
179        dbg.finish()
180    }
181}
182
183/// Builder for [`ClientConfig`].
184///
185/// Construct with [`ClientConfigBuilder::new`] and call builder methods to
186/// customise the configuration, then call [`ClientConfigBuilder::build`] to
187/// obtain a [`ClientConfig`].
188#[must_use]
189pub struct ClientConfigBuilder {
190    pub(crate) config: ClientConfig,
191}
192
193impl ClientConfigBuilder {
194    /// Create a new builder with the given API key and sensible defaults.
195    pub fn new(api_key: impl Into<String>) -> Self {
196        Self {
197            config: ClientConfig::new(api_key),
198        }
199    }
200
201    /// Create a builder with no explicit API key.
202    ///
203    /// `load_env` is `true` by default, so the key will be read from the
204    /// provider's environment variable (e.g. `OPENAI_API_KEY`) at client
205    /// construction time.  Call `.load_env(false)` to opt out.
206    pub fn from_env() -> Self {
207        Self {
208            config: ClientConfig::new(""),
209        }
210    }
211
212    /// Enable or disable automatic API key loading from environment variables.
213    ///
214    /// When `true` (the default) and no explicit `api_key` was provided,
215    /// [`DefaultClient::new`] reads the provider's designated environment
216    /// variable.  Set to `false` to require an explicit key.
217    ///
218    /// Has no effect on WASM targets.
219    pub fn load_env(mut self, enabled: bool) -> Self {
220        self.config.load_env = enabled;
221        self
222    }
223
224    /// Override the provider base URL for all requests.
225    pub fn base_url(mut self, url: impl Into<String>) -> Self {
226        self.config.base_url = Some(url.into());
227        self
228    }
229
230    /// Set the per-request timeout (default: 60 s).
231    pub fn timeout(mut self, timeout: Duration) -> Self {
232        self.config.timeout = timeout;
233        self
234    }
235
236    /// Set the maximum number of retries on 429 / 5xx responses (default: 3).
237    pub fn max_retries(mut self, retries: u32) -> Self {
238        self.config.max_retries = retries;
239        self
240    }
241
242    /// Set a dynamic credential provider for token-based or refreshable auth.
243    ///
244    /// When configured, the client calls `resolve()` before each request
245    /// instead of using the static `api_key` for authentication.
246    pub fn credential_provider(mut self, provider: Arc<dyn CredentialProvider>) -> Self {
247        self.config.credential_provider = Some(provider);
248        self
249    }
250
251    /// Add a custom header sent on every request.
252    ///
253    /// Returns an error if either `key` or `value` is not a valid HTTP header
254    /// name / value.
255    ///
256    /// This method is only available when the `native-http` feature is enabled
257    /// because header validation relies on `reqwest`'s header types.
258    #[cfg(any(feature = "native-http", feature = "wasm-http"))]
259    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Result<Self> {
260        let key = key.into();
261        let value = value.into();
262
263        // Validate header name.
264        reqwest::header::HeaderName::from_bytes(key.as_bytes()).map_err(|e| LiterLlmError::InvalidHeader {
265            name: key.clone(),
266            reason: e.to_string(),
267        })?;
268
269        // Validate header value.
270        reqwest::header::HeaderValue::from_str(&value).map_err(|e| LiterLlmError::InvalidHeader {
271            name: key.clone(),
272            reason: e.to_string(),
273        })?;
274
275        self.config.extra_headers.push((key, value));
276        Ok(self)
277    }
278
279    /// Set the response cache configuration for the Tower middleware stack.
280    ///
281    /// When set, bindings and advanced Rust users can read this from the
282    /// built [`ClientConfig`] to construct a
283    /// [`CacheLayer`](crate::tower::CacheLayer).
284    #[cfg(feature = "tower")]
285    pub fn cache(mut self, config: CacheConfig) -> Self {
286        self.config.cache_config = Some(config);
287        self
288    }
289
290    /// Set a custom cache store backend for the Tower cache middleware.
291    ///
292    /// When set alongside [`cache`](Self::cache), the cache layer will use
293    /// this store instead of the default in-memory LRU.
294    #[cfg(feature = "tower")]
295    pub fn cache_store(mut self, store: Arc<dyn CacheStore>) -> Self {
296        self.config.cache_store = Some(store);
297        self
298    }
299
300    /// Set the budget enforcement configuration for the Tower middleware stack.
301    ///
302    /// When set, bindings and advanced Rust users can read this from the
303    /// built [`ClientConfig`] to construct a
304    /// [`BudgetLayer`](crate::tower::BudgetLayer).
305    #[cfg(feature = "tower")]
306    pub fn budget(mut self, config: BudgetConfig) -> Self {
307        self.config.budget_config = Some(config);
308        self
309    }
310
311    /// Add a single hook to the Tower hooks middleware stack.
312    ///
313    /// Hooks are invoked sequentially in registration order at request
314    /// lifecycle points (pre-request, post-response, on-error).
315    #[cfg(feature = "tower")]
316    pub fn hook(mut self, hook: Arc<dyn LlmHook>) -> Self {
317        self.config.hooks.push(hook);
318        self
319    }
320
321    /// Set the full list of hooks for the Tower hooks middleware stack,
322    /// replacing any previously registered hooks.
323    ///
324    /// Hooks are invoked sequentially in registration order.
325    #[cfg(feature = "tower")]
326    pub fn hooks(mut self, hooks: Vec<Arc<dyn LlmHook>>) -> Self {
327        self.config.hooks = hooks;
328        self
329    }
330
331    /// Set the cooldown duration after transient errors.
332    ///
333    /// When set, the client rejects requests with `ServiceUnavailable` for
334    /// the given duration after a transient error (rate limit, timeout,
335    /// server error).
336    #[cfg(feature = "tower")]
337    pub fn cooldown(mut self, duration: Duration) -> Self {
338        self.config.cooldown_duration = Some(duration);
339        self
340    }
341
342    /// Set per-model rate limiting configuration.
343    ///
344    /// When set, requests exceeding the configured RPM or TPM limits are
345    /// rejected with [`LiterLlmError::RateLimited`](crate::error::LiterLlmError::RateLimited).
346    #[cfg(feature = "tower")]
347    pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
348        self.config.rate_limit_config = Some(config);
349        self
350    }
351
352    /// Set the background health check interval.
353    ///
354    /// When set, the client periodically probes the provider and rejects
355    /// requests when the provider is unhealthy.
356    #[cfg(feature = "tower")]
357    pub fn health_check(mut self, interval: Duration) -> Self {
358        self.config.health_check_interval = Some(interval);
359        self
360    }
361
362    /// Enable or disable per-request cost tracking.
363    ///
364    /// When enabled, estimated USD cost is recorded on the current tracing
365    /// span as `gen_ai.usage.cost`.
366    #[cfg(feature = "tower")]
367    pub fn cost_tracking(mut self, enabled: bool) -> Self {
368        self.config.enable_cost_tracking = enabled;
369        self
370    }
371
372    /// Enable or disable OpenTelemetry-compatible tracing spans.
373    ///
374    /// When enabled, every request is wrapped in a `gen_ai` tracing span
375    /// with semantic convention attributes.
376    #[cfg(feature = "tower")]
377    pub fn tracing(mut self, enabled: bool) -> Self {
378        self.config.enable_tracing = enabled;
379        self
380    }
381
382    /// Consume the builder and return the completed [`ClientConfig`].
383    #[must_use]
384    pub fn build(self) -> ClientConfig {
385        self.config
386    }
387}