Skip to main content

lash_core/provider/
options.rs

1use super::support::*;
2
3pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 300_000;
4pub const DEFAULT_CHUNK_TIMEOUT_MS: u64 = 120_000;
5pub const DEFAULT_THROTTLE_WAIT_BUDGET_MS: u64 = 90_000;
6
7/// Minimum amount a single deferred throttle wait charges against
8/// [`ProviderRetryPolicy::throttle_wait_budget_ms`]. Charging at least this
9/// much per deference bounds the courtesy in count as well as time: a zero or
10/// near-zero `Retry-After` (or a zero [`retry_after_cap_ms`]) cannot spin the
11/// attempt-free retry loop more than `budget / charge` times.
12///
13/// [`retry_after_cap_ms`]: ProviderRetryPolicy::retry_after_cap_ms
14pub(crate) const MIN_THROTTLE_BUDGET_CHARGE: Duration = Duration::from_secs(1);
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub struct LlmTimeouts {
18    pub request_timeout: Option<Duration>,
19    pub chunk_timeout: Duration,
20}
21
22impl Default for LlmTimeouts {
23    fn default() -> Self {
24        Self {
25            request_timeout: Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
26            chunk_timeout: Duration::from_millis(DEFAULT_CHUNK_TIMEOUT_MS),
27        }
28    }
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum RequestTimeout {
33    Disabled,
34    Millis(u64),
35}
36
37impl Serialize for RequestTimeout {
38    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
39    where
40        S: Serializer,
41    {
42        match self {
43            Self::Disabled => serializer.serialize_bool(false),
44            Self::Millis(value) => serializer.serialize_u64(*value),
45        }
46    }
47}
48
49impl<'de> Deserialize<'de> for RequestTimeout {
50    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51    where
52        D: Deserializer<'de>,
53    {
54        struct RequestTimeoutVisitor;
55
56        impl Visitor<'_> for RequestTimeoutVisitor {
57            type Value = RequestTimeout;
58
59            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60                formatter.write_str("a positive timeout in milliseconds or false")
61            }
62
63            fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
64            where
65                E: de::Error,
66            {
67                if value {
68                    return Err(E::custom("timeout must be a positive integer or false"));
69                }
70                Ok(RequestTimeout::Disabled)
71            }
72
73            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
74            where
75                E: de::Error,
76            {
77                if value == 0 {
78                    return Err(E::custom("timeout must be greater than 0"));
79                }
80                Ok(RequestTimeout::Millis(value))
81            }
82
83            fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
84            where
85                E: de::Error,
86            {
87                if value <= 0 {
88                    return Err(E::custom("timeout must be greater than 0"));
89                }
90                Ok(RequestTimeout::Millis(value as u64))
91            }
92        }
93
94        deserializer.deserialize_any(RequestTimeoutVisitor)
95    }
96}
97
98/// Prompt-cache lifetime hint. Providers translate this into their own
99/// wire dialect (Anthropic `cache_control` TTL, OpenRouter-Claude
100/// `cache_control` markers via Chat Completions, OpenAI Responses
101/// `prompt_cache_key` / `prompt_cache_retention`). Providers without a
102/// cache-control concept (Google, Codex) read the value but emit nothing
103/// for it.
104#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
105#[serde(rename_all = "kebab-case")]
106pub enum CacheRetention {
107    /// Do not emit any cache_control markers.
108    None,
109    /// Default Anthropic ephemeral window (5 minutes).
110    #[default]
111    Short,
112    /// Extend to a 1-hour TTL where the API supports it.
113    Long,
114}
115
116impl CacheRetention {
117    pub fn is_default(&self) -> bool {
118        matches!(self, CacheRetention::Short)
119    }
120}
121
122#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
123#[serde(deny_unknown_fields)]
124pub struct ProviderOptions {
125    #[serde(default)]
126    pub reliability: ProviderReliability,
127    /// Surface provider reasoning/thinking output in responses.
128    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
129    pub expose_thinking: bool,
130    /// Per-request output-token cap. `None` lets each provider apply its
131    /// own default. Providers translate to their wire-specific field
132    /// (`max_tokens`, `max_output_tokens`, `maxOutputTokens`, …).
133    #[serde(default, skip_serializing_if = "Option::is_none")]
134    pub max_output_tokens: Option<u64>,
135    /// Prompt-cache lifetime hint; see [`CacheRetention`].
136    #[serde(default, skip_serializing_if = "CacheRetention::is_default")]
137    pub cache_retention: CacheRetention,
138}
139
140impl ProviderOptions {
141    pub fn is_default(&self) -> bool {
142        self.reliability == ProviderReliability::default()
143            && !self.expose_thinking
144            && self.max_output_tokens.is_none()
145            && self.cache_retention.is_default()
146    }
147
148    pub fn llm_timeouts(&self) -> LlmTimeouts {
149        self.reliability.llm_timeouts()
150    }
151}
152
153#[derive(Clone, Debug, PartialEq, Eq)]
154pub struct ResolvedGenerationPolicy<TThinking> {
155    pub max_output_tokens: u64,
156    pub cache_retention: CacheRetention,
157    pub expose_thinking: bool,
158    pub thinking: TThinking,
159}
160
161pub fn resolve_generation_policy<TThinking>(
162    generation: &crate::GenerationOptions,
163    options: &ProviderOptions,
164    provider_default_max_output_tokens: u64,
165    thinking: TThinking,
166) -> ResolvedGenerationPolicy<TThinking> {
167    let max_output_tokens = generation
168        .output_token_cap_u64()
169        .or(options.max_output_tokens)
170        .unwrap_or(provider_default_max_output_tokens);
171    ResolvedGenerationPolicy {
172        max_output_tokens,
173        cache_retention: options.cache_retention,
174        expose_thinking: options.expose_thinking,
175        thinking,
176    }
177}
178
179#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
180pub struct ProviderReliability {
181    /// Whole-request timeout. `None` applies [`DEFAULT_REQUEST_TIMEOUT_MS`];
182    /// use [`RequestTimeout::Disabled`] to wait indefinitely.
183    #[serde(default, skip_serializing_if = "Option::is_none")]
184    pub request_timeout: Option<RequestTimeout>,
185    /// Inter-chunk stream timeout in milliseconds. `None` (or `0`) applies
186    /// [`DEFAULT_CHUNK_TIMEOUT_MS`].
187    #[serde(default, skip_serializing_if = "Option::is_none")]
188    pub chunk_timeout: Option<u64>,
189    #[serde(default)]
190    pub retry: ProviderRetryPolicy,
191    #[serde(default)]
192    pub rate_limits: ProviderRateLimitPolicy,
193}
194
195impl ProviderReliability {
196    pub fn codex() -> Self {
197        Self {
198            retry: ProviderRetryPolicy {
199                max_attempts: 4,
200                base_delay_ms: 1_000,
201                max_delay_ms: 4_000,
202                jitter_ms: 0,
203                retry_after_cap_ms: Some(60_000),
204                throttle_wait_budget_ms: DEFAULT_THROTTLE_WAIT_BUDGET_MS,
205                enabled: true,
206            },
207            ..Self::default()
208        }
209    }
210
211    pub fn disabled() -> Self {
212        Self {
213            retry: ProviderRetryPolicy::disabled(),
214            ..Self::default()
215        }
216    }
217
218    pub fn llm_timeouts(&self) -> LlmTimeouts {
219        let request_timeout = match self.request_timeout {
220            Some(RequestTimeout::Disabled) => None,
221            Some(RequestTimeout::Millis(ms)) => Some(Duration::from_millis(ms)),
222            None => Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
223        };
224        let chunk_timeout_ms = self
225            .chunk_timeout
226            .filter(|value| *value > 0)
227            .unwrap_or(DEFAULT_CHUNK_TIMEOUT_MS);
228        LlmTimeouts {
229            request_timeout,
230            chunk_timeout: Duration::from_millis(chunk_timeout_ms),
231        }
232    }
233
234    pub fn request_timeout(mut self, timeout: Option<RequestTimeout>) -> Self {
235        self.request_timeout = timeout;
236        self
237    }
238
239    pub fn stream_chunk_timeout_ms(mut self, timeout_ms: Option<u64>) -> Self {
240        self.chunk_timeout = timeout_ms;
241        self
242    }
243
244    pub fn max_attempts(mut self, attempts: u32) -> Self {
245        self.retry.max_attempts = attempts.max(1);
246        self
247    }
248
249    pub fn base_delay_ms(mut self, delay_ms: u64) -> Self {
250        self.retry.base_delay_ms = delay_ms;
251        self
252    }
253
254    pub fn max_delay_ms(mut self, delay_ms: u64) -> Self {
255        self.retry.max_delay_ms = delay_ms;
256        self
257    }
258
259    pub fn retry_after_cap_ms(mut self, cap_ms: Option<u64>) -> Self {
260        self.retry.retry_after_cap_ms = cap_ms;
261        self
262    }
263
264    pub fn throttle_wait_budget_ms(mut self, budget_ms: u64) -> Self {
265        self.retry.throttle_wait_budget_ms = budget_ms;
266        self
267    }
268
269    pub fn max_concurrency(mut self, value: Option<usize>) -> Self {
270        self.rate_limits.max_concurrency = value;
271        self
272    }
273
274    pub fn requests_per_window(mut self, requests: Option<u32>, window_ms: Option<u64>) -> Self {
275        self.rate_limits.requests_per_window = requests;
276        self.rate_limits.request_window_ms = window_ms;
277        self
278    }
279
280    pub fn tokens_per_window(mut self, tokens: Option<u32>, window_ms: Option<u64>) -> Self {
281        self.rate_limits.tokens_per_window = tokens;
282        self.rate_limits.token_window_ms = window_ms;
283        self
284    }
285}
286
287#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
288pub struct ProviderRetryPolicy {
289    pub enabled: bool,
290    pub max_attempts: u32,
291    pub base_delay_ms: u64,
292    pub max_delay_ms: u64,
293    pub jitter_ms: u64,
294    #[serde(default, skip_serializing_if = "Option::is_none")]
295    pub retry_after_cap_ms: Option<u64>,
296    /// Cumulative time [`ProviderHandle::complete`] may spend honoring
297    /// provider throttle waits — a retryable [`ProviderFailureKind::Quota`]
298    /// failure carrying `Retry-After` — without consuming retry attempts.
299    /// Each deferred wait charges what it actually waits, with a one-second
300    /// floor per deference, so the courtesy is bounded in count as well as
301    /// time. Once the budget is spent, throttled failures consume attempts
302    /// like any other retryable failure. `0` disables the deference entirely.
303    #[serde(
304        default = "default_throttle_wait_budget_ms",
305        skip_serializing_if = "is_default_throttle_wait_budget_ms"
306    )]
307    pub throttle_wait_budget_ms: u64,
308}
309
310fn default_throttle_wait_budget_ms() -> u64 {
311    DEFAULT_THROTTLE_WAIT_BUDGET_MS
312}
313
314fn is_default_throttle_wait_budget_ms(budget_ms: &u64) -> bool {
315    *budget_ms == DEFAULT_THROTTLE_WAIT_BUDGET_MS
316}
317
318impl Default for ProviderRetryPolicy {
319    fn default() -> Self {
320        Self {
321            enabled: true,
322            max_attempts: 4,
323            base_delay_ms: 2_000,
324            max_delay_ms: 10_000,
325            jitter_ms: 0,
326            retry_after_cap_ms: Some(60_000),
327            throttle_wait_budget_ms: DEFAULT_THROTTLE_WAIT_BUDGET_MS,
328        }
329    }
330}
331
332impl ProviderRetryPolicy {
333    pub fn disabled() -> Self {
334        Self {
335            enabled: false,
336            max_attempts: 1,
337            base_delay_ms: 0,
338            max_delay_ms: 0,
339            jitter_ms: 0,
340            retry_after_cap_ms: None,
341            throttle_wait_budget_ms: 0,
342        }
343    }
344
345    pub(crate) fn attempts(&self) -> u32 {
346        if self.enabled {
347            self.max_attempts.max(1)
348        } else {
349            1
350        }
351    }
352
353    /// A provider-stated `Retry-After`, bounded by
354    /// [`retry_after_cap_ms`](Self::retry_after_cap_ms) when one is set.
355    pub(crate) fn cap_retry_after(&self, retry_after: Duration) -> Duration {
356        self.retry_after_cap_ms
357            .map(Duration::from_millis)
358            .map(|cap| retry_after.min(cap))
359            .unwrap_or(retry_after)
360    }
361
362    pub(crate) fn delay_for_attempt(
363        &self,
364        retry_index: u32,
365        retry_after: Option<Duration>,
366    ) -> Duration {
367        if let Some(retry_after) = retry_after {
368            return self.cap_retry_after(retry_after);
369        }
370        let multiplier = 1u64.checked_shl(retry_index).unwrap_or(u64::MAX);
371        let delay_ms = self
372            .base_delay_ms
373            .saturating_mul(multiplier)
374            .min(self.max_delay_ms);
375        Duration::from_millis(delay_ms.saturating_add(self.jitter_ms))
376    }
377}
378
379#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
380pub struct ProviderRateLimitPolicy {
381    #[serde(default, skip_serializing_if = "Option::is_none")]
382    pub max_concurrency: Option<usize>,
383    #[serde(default, skip_serializing_if = "Option::is_none")]
384    pub requests_per_window: Option<u32>,
385    #[serde(default, skip_serializing_if = "Option::is_none")]
386    pub request_window_ms: Option<u64>,
387    #[serde(default, skip_serializing_if = "Option::is_none")]
388    pub tokens_per_window: Option<u32>,
389    #[serde(default, skip_serializing_if = "Option::is_none")]
390    pub token_window_ms: Option<u64>,
391}