Skip to main content

lash_core/provider/
options.rs

1use super::support::*;
2
3pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 300_000;
4pub const DEFAULT_CHUNK_TIMEOUT_MS: u64 = 120_000;
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq)]
7pub struct LlmTimeouts {
8    pub request_timeout: Option<Duration>,
9    pub chunk_timeout: Duration,
10}
11
12impl Default for LlmTimeouts {
13    fn default() -> Self {
14        Self {
15            request_timeout: Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
16            chunk_timeout: Duration::from_millis(DEFAULT_CHUNK_TIMEOUT_MS),
17        }
18    }
19}
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
22pub enum RequestTimeout {
23    Disabled,
24    Millis(u64),
25}
26
27impl Serialize for RequestTimeout {
28    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
29    where
30        S: Serializer,
31    {
32        match self {
33            Self::Disabled => serializer.serialize_bool(false),
34            Self::Millis(value) => serializer.serialize_u64(*value),
35        }
36    }
37}
38
39impl<'de> Deserialize<'de> for RequestTimeout {
40    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41    where
42        D: Deserializer<'de>,
43    {
44        struct RequestTimeoutVisitor;
45
46        impl Visitor<'_> for RequestTimeoutVisitor {
47            type Value = RequestTimeout;
48
49            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50                formatter.write_str("a positive timeout in milliseconds or false")
51            }
52
53            fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
54            where
55                E: de::Error,
56            {
57                if value {
58                    return Err(E::custom("timeout must be a positive integer or false"));
59                }
60                Ok(RequestTimeout::Disabled)
61            }
62
63            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
64            where
65                E: de::Error,
66            {
67                if value == 0 {
68                    return Err(E::custom("timeout must be greater than 0"));
69                }
70                Ok(RequestTimeout::Millis(value))
71            }
72
73            fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
74            where
75                E: de::Error,
76            {
77                if value <= 0 {
78                    return Err(E::custom("timeout must be greater than 0"));
79                }
80                Ok(RequestTimeout::Millis(value as u64))
81            }
82        }
83
84        deserializer.deserialize_any(RequestTimeoutVisitor)
85    }
86}
87
88/// Prompt-cache lifetime hint. Providers translate this into their own
89/// wire dialect (Anthropic `cache_control` TTL, OpenRouter-Claude
90/// `cache_control` markers via Chat Completions, OpenAI Responses
91/// `prompt_cache_key` / `prompt_cache_retention`). Providers without a
92/// cache-control concept (Google, Codex) read the value but emit nothing
93/// for it.
94#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
95#[serde(rename_all = "kebab-case")]
96pub enum CacheRetention {
97    /// Do not emit any cache_control markers.
98    None,
99    /// Default Anthropic ephemeral window (5 minutes).
100    #[default]
101    Short,
102    /// Extend to a 1-hour TTL where the API supports it.
103    Long,
104}
105
106impl CacheRetention {
107    pub fn is_default(&self) -> bool {
108        matches!(self, CacheRetention::Short)
109    }
110}
111
112#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
113#[serde(deny_unknown_fields)]
114pub struct ProviderOptions {
115    #[serde(default)]
116    pub reliability: ProviderReliability,
117    #[serde(default, skip_serializing_if = "ProviderThinkingPolicy::is_default")]
118    pub thinking: ProviderThinkingPolicy,
119    /// Per-request output-token cap. `None` lets each provider apply its
120    /// own default. Providers translate to their wire-specific field
121    /// (`max_tokens`, `max_output_tokens`, `maxOutputTokens`, …).
122    #[serde(default, skip_serializing_if = "Option::is_none")]
123    pub max_output_tokens: Option<u64>,
124    /// Prompt-cache lifetime hint; see [`CacheRetention`].
125    #[serde(default, skip_serializing_if = "CacheRetention::is_default")]
126    pub cache_retention: CacheRetention,
127}
128
129impl ProviderOptions {
130    pub fn is_default(&self) -> bool {
131        self.reliability == ProviderReliability::default_llm()
132            && self.thinking == ProviderThinkingPolicy::default()
133            && self.max_output_tokens.is_none()
134            && self.cache_retention.is_default()
135    }
136
137    pub fn llm_timeouts(&self) -> LlmTimeouts {
138        self.reliability.timeouts.llm_timeouts()
139    }
140}
141
142#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
143#[serde(deny_unknown_fields)]
144pub struct ProviderThinkingPolicy {
145    #[serde(default)]
146    pub expose: bool,
147}
148
149impl ProviderThinkingPolicy {
150    pub fn is_default(&self) -> bool {
151        !self.expose
152    }
153}
154
155#[derive(Clone, Debug, PartialEq, Eq)]
156pub struct ResolvedGenerationPolicy<TThinking> {
157    pub max_output_tokens: u64,
158    pub cache_retention: CacheRetention,
159    pub expose_thinking: bool,
160    pub thinking: TThinking,
161}
162
163pub fn resolve_generation_policy<TThinking>(
164    generation: &crate::GenerationOptions,
165    options: &ProviderOptions,
166    provider_default_max_output_tokens: u64,
167    thinking: TThinking,
168) -> ResolvedGenerationPolicy<TThinking> {
169    let max_output_tokens = generation
170        .output_token_cap_u64()
171        .or(options.max_output_tokens)
172        .unwrap_or(provider_default_max_output_tokens);
173    ResolvedGenerationPolicy {
174        max_output_tokens,
175        cache_retention: options.cache_retention,
176        expose_thinking: options.thinking.expose,
177        thinking,
178    }
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
182pub struct ProviderReliability {
183    #[serde(default)]
184    pub timeouts: ProviderTimeoutPolicy,
185    #[serde(default)]
186    pub retry: ProviderRetryPolicy,
187    #[serde(default)]
188    pub rate_limits: ProviderRateLimitPolicy,
189}
190
191impl ProviderReliability {
192    pub fn default_llm() -> Self {
193        Self {
194            timeouts: ProviderTimeoutPolicy::default(),
195            retry: ProviderRetryPolicy::default(),
196            rate_limits: ProviderRateLimitPolicy::default(),
197        }
198    }
199
200    pub fn codex() -> Self {
201        Self {
202            retry: ProviderRetryPolicy {
203                max_attempts: 4,
204                base_delay_ms: 1_000,
205                max_delay_ms: 4_000,
206                jitter_ms: 0,
207                retry_after_cap_ms: Some(60_000),
208                enabled: true,
209            },
210            ..Self::default_llm()
211        }
212    }
213
214    pub fn disabled() -> Self {
215        Self {
216            retry: ProviderRetryPolicy::disabled(),
217            rate_limits: ProviderRateLimitPolicy::default(),
218            timeouts: ProviderTimeoutPolicy::default(),
219        }
220    }
221
222    pub fn builder() -> ProviderReliabilityBuilder {
223        ProviderReliabilityBuilder {
224            reliability: Self::default_llm(),
225        }
226    }
227}
228
229impl Default for ProviderReliability {
230    fn default() -> Self {
231        Self::default_llm()
232    }
233}
234
235#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
236pub struct ProviderTimeoutPolicy {
237    #[serde(default, skip_serializing_if = "Option::is_none")]
238    pub request_timeout: Option<RequestTimeout>,
239    #[serde(default, skip_serializing_if = "Option::is_none")]
240    pub chunk_timeout: Option<u64>,
241}
242
243impl ProviderTimeoutPolicy {
244    pub fn llm_timeouts(&self) -> LlmTimeouts {
245        let request_timeout = match self.request_timeout {
246            Some(RequestTimeout::Disabled) => None,
247            Some(RequestTimeout::Millis(ms)) => Some(Duration::from_millis(ms)),
248            None => Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
249        };
250        let chunk_timeout_ms = self
251            .chunk_timeout
252            .filter(|value| *value > 0)
253            .unwrap_or(DEFAULT_CHUNK_TIMEOUT_MS);
254        LlmTimeouts {
255            request_timeout,
256            chunk_timeout: Duration::from_millis(chunk_timeout_ms),
257        }
258    }
259}
260
261#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
262pub struct ProviderRetryPolicy {
263    pub enabled: bool,
264    pub max_attempts: u32,
265    pub base_delay_ms: u64,
266    pub max_delay_ms: u64,
267    pub jitter_ms: u64,
268    #[serde(default, skip_serializing_if = "Option::is_none")]
269    pub retry_after_cap_ms: Option<u64>,
270}
271
272impl Default for ProviderRetryPolicy {
273    fn default() -> Self {
274        Self {
275            enabled: true,
276            max_attempts: 4,
277            base_delay_ms: 2_000,
278            max_delay_ms: 10_000,
279            jitter_ms: 0,
280            retry_after_cap_ms: Some(60_000),
281        }
282    }
283}
284
285impl ProviderRetryPolicy {
286    pub fn disabled() -> Self {
287        Self {
288            enabled: false,
289            max_attempts: 1,
290            base_delay_ms: 0,
291            max_delay_ms: 0,
292            jitter_ms: 0,
293            retry_after_cap_ms: None,
294        }
295    }
296
297    pub(crate) fn attempts(&self) -> u32 {
298        if self.enabled {
299            self.max_attempts.max(1)
300        } else {
301            1
302        }
303    }
304
305    pub(crate) fn delay_for_attempt(
306        &self,
307        retry_index: u32,
308        retry_after: Option<Duration>,
309    ) -> Duration {
310        if let Some(retry_after) = retry_after {
311            return self
312                .retry_after_cap_ms
313                .map(Duration::from_millis)
314                .map(|cap| retry_after.min(cap))
315                .unwrap_or(retry_after);
316        }
317        let multiplier = 1u64.checked_shl(retry_index).unwrap_or(u64::MAX);
318        let delay_ms = self
319            .base_delay_ms
320            .saturating_mul(multiplier)
321            .min(self.max_delay_ms);
322        Duration::from_millis(delay_ms.saturating_add(self.jitter_ms))
323    }
324}
325
326#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
327pub struct ProviderRateLimitPolicy {
328    #[serde(default, skip_serializing_if = "Option::is_none")]
329    pub max_concurrency: Option<usize>,
330    #[serde(default, skip_serializing_if = "Option::is_none")]
331    pub requests_per_window: Option<u32>,
332    #[serde(default, skip_serializing_if = "Option::is_none")]
333    pub request_window_ms: Option<u64>,
334    #[serde(default, skip_serializing_if = "Option::is_none")]
335    pub tokens_per_window: Option<u32>,
336    #[serde(default, skip_serializing_if = "Option::is_none")]
337    pub token_window_ms: Option<u64>,
338}
339
340pub struct ProviderReliabilityBuilder {
341    reliability: ProviderReliability,
342}
343
344impl ProviderReliabilityBuilder {
345    pub fn request_timeout(mut self, timeout: Option<RequestTimeout>) -> Self {
346        self.reliability.timeouts.request_timeout = timeout;
347        self
348    }
349
350    pub fn stream_chunk_timeout_ms(mut self, timeout_ms: Option<u64>) -> Self {
351        self.reliability.timeouts.chunk_timeout = timeout_ms;
352        self
353    }
354
355    pub fn max_attempts(mut self, attempts: u32) -> Self {
356        self.reliability.retry.max_attempts = attempts.max(1);
357        self
358    }
359
360    pub fn base_delay_ms(mut self, delay_ms: u64) -> Self {
361        self.reliability.retry.base_delay_ms = delay_ms;
362        self
363    }
364
365    pub fn max_delay_ms(mut self, delay_ms: u64) -> Self {
366        self.reliability.retry.max_delay_ms = delay_ms;
367        self
368    }
369
370    pub fn retry_after_cap_ms(mut self, cap_ms: Option<u64>) -> Self {
371        self.reliability.retry.retry_after_cap_ms = cap_ms;
372        self
373    }
374
375    pub fn max_concurrency(mut self, value: Option<usize>) -> Self {
376        self.reliability.rate_limits.max_concurrency = value;
377        self
378    }
379
380    pub fn requests_per_window(mut self, requests: Option<u32>, window_ms: Option<u64>) -> Self {
381        self.reliability.rate_limits.requests_per_window = requests;
382        self.reliability.rate_limits.request_window_ms = window_ms;
383        self
384    }
385
386    pub fn tokens_per_window(mut self, tokens: Option<u32>, window_ms: Option<u64>) -> Self {
387        self.reliability.rate_limits.tokens_per_window = tokens;
388        self.reliability.rate_limits.token_window_ms = window_ms;
389        self
390    }
391
392    pub fn build(self) -> ProviderReliability {
393        self.reliability
394    }
395}