Skip to main content

lash_core/provider/
options.rs

1use super::support::*;
2
3pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 300_000;
4pub const DEFAULT_CHUNK_TIMEOUT_MS: u64 = 120_000;
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq)]
7pub struct LlmTimeouts {
8    pub request_timeout: Option<Duration>,
9    pub chunk_timeout: Duration,
10}
11
12impl Default for LlmTimeouts {
13    fn default() -> Self {
14        Self {
15            request_timeout: Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
16            chunk_timeout: Duration::from_millis(DEFAULT_CHUNK_TIMEOUT_MS),
17        }
18    }
19}
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
22pub enum RequestTimeout {
23    Disabled,
24    Millis(u64),
25}
26
27impl Serialize for RequestTimeout {
28    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
29    where
30        S: Serializer,
31    {
32        match self {
33            Self::Disabled => serializer.serialize_bool(false),
34            Self::Millis(value) => serializer.serialize_u64(*value),
35        }
36    }
37}
38
39impl<'de> Deserialize<'de> for RequestTimeout {
40    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41    where
42        D: Deserializer<'de>,
43    {
44        struct RequestTimeoutVisitor;
45
46        impl Visitor<'_> for RequestTimeoutVisitor {
47            type Value = RequestTimeout;
48
49            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50                formatter.write_str("a positive timeout in milliseconds or false")
51            }
52
53            fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
54            where
55                E: de::Error,
56            {
57                if value {
58                    return Err(E::custom("timeout must be a positive integer or false"));
59                }
60                Ok(RequestTimeout::Disabled)
61            }
62
63            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
64            where
65                E: de::Error,
66            {
67                if value == 0 {
68                    return Err(E::custom("timeout must be greater than 0"));
69                }
70                Ok(RequestTimeout::Millis(value))
71            }
72
73            fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
74            where
75                E: de::Error,
76            {
77                if value <= 0 {
78                    return Err(E::custom("timeout must be greater than 0"));
79                }
80                Ok(RequestTimeout::Millis(value as u64))
81            }
82        }
83
84        deserializer.deserialize_any(RequestTimeoutVisitor)
85    }
86}
87
88/// Prompt-cache lifetime hint. Providers translate this into their own
89/// wire dialect (Anthropic `cache_control` TTL, OpenRouter-Claude
90/// `cache_control` markers via Chat Completions, OpenAI Responses
91/// `prompt_cache_key` / `prompt_cache_retention`). Providers without a
92/// cache-control concept (Google, Codex) read the value but emit nothing
93/// for it.
94#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
95#[serde(rename_all = "kebab-case")]
96pub enum CacheRetention {
97    /// Do not emit any cache_control markers.
98    None,
99    /// Default Anthropic ephemeral window (5 minutes).
100    #[default]
101    Short,
102    /// Extend to a 1-hour TTL where the API supports it.
103    Long,
104}
105
106impl CacheRetention {
107    pub fn is_default(&self) -> bool {
108        matches!(self, CacheRetention::Short)
109    }
110}
111
112#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
113#[serde(deny_unknown_fields)]
114pub struct ProviderOptions {
115    #[serde(default)]
116    pub reliability: ProviderReliability,
117    /// Surface provider reasoning/thinking output in responses.
118    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
119    pub expose_thinking: bool,
120    /// Per-request output-token cap. `None` lets each provider apply its
121    /// own default. Providers translate to their wire-specific field
122    /// (`max_tokens`, `max_output_tokens`, `maxOutputTokens`, …).
123    #[serde(default, skip_serializing_if = "Option::is_none")]
124    pub max_output_tokens: Option<u64>,
125    /// Prompt-cache lifetime hint; see [`CacheRetention`].
126    #[serde(default, skip_serializing_if = "CacheRetention::is_default")]
127    pub cache_retention: CacheRetention,
128}
129
130impl ProviderOptions {
131    pub fn is_default(&self) -> bool {
132        self.reliability == ProviderReliability::default()
133            && !self.expose_thinking
134            && self.max_output_tokens.is_none()
135            && self.cache_retention.is_default()
136    }
137
138    pub fn llm_timeouts(&self) -> LlmTimeouts {
139        self.reliability.llm_timeouts()
140    }
141}
142
143#[derive(Clone, Debug, PartialEq, Eq)]
144pub struct ResolvedGenerationPolicy<TThinking> {
145    pub max_output_tokens: u64,
146    pub cache_retention: CacheRetention,
147    pub expose_thinking: bool,
148    pub thinking: TThinking,
149}
150
151pub fn resolve_generation_policy<TThinking>(
152    generation: &crate::GenerationOptions,
153    options: &ProviderOptions,
154    provider_default_max_output_tokens: u64,
155    thinking: TThinking,
156) -> ResolvedGenerationPolicy<TThinking> {
157    let max_output_tokens = generation
158        .output_token_cap_u64()
159        .or(options.max_output_tokens)
160        .unwrap_or(provider_default_max_output_tokens);
161    ResolvedGenerationPolicy {
162        max_output_tokens,
163        cache_retention: options.cache_retention,
164        expose_thinking: options.expose_thinking,
165        thinking,
166    }
167}
168
169#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
170pub struct ProviderReliability {
171    /// Whole-request timeout. `None` applies [`DEFAULT_REQUEST_TIMEOUT_MS`];
172    /// use [`RequestTimeout::Disabled`] to wait indefinitely.
173    #[serde(default, skip_serializing_if = "Option::is_none")]
174    pub request_timeout: Option<RequestTimeout>,
175    /// Inter-chunk stream timeout in milliseconds. `None` (or `0`) applies
176    /// [`DEFAULT_CHUNK_TIMEOUT_MS`].
177    #[serde(default, skip_serializing_if = "Option::is_none")]
178    pub chunk_timeout: Option<u64>,
179    #[serde(default)]
180    pub retry: ProviderRetryPolicy,
181    #[serde(default)]
182    pub rate_limits: ProviderRateLimitPolicy,
183}
184
185impl ProviderReliability {
186    pub fn codex() -> Self {
187        Self {
188            retry: ProviderRetryPolicy {
189                max_attempts: 4,
190                base_delay_ms: 1_000,
191                max_delay_ms: 4_000,
192                jitter_ms: 0,
193                retry_after_cap_ms: Some(60_000),
194                enabled: true,
195            },
196            ..Self::default()
197        }
198    }
199
200    pub fn disabled() -> Self {
201        Self {
202            retry: ProviderRetryPolicy::disabled(),
203            ..Self::default()
204        }
205    }
206
207    pub fn llm_timeouts(&self) -> LlmTimeouts {
208        let request_timeout = match self.request_timeout {
209            Some(RequestTimeout::Disabled) => None,
210            Some(RequestTimeout::Millis(ms)) => Some(Duration::from_millis(ms)),
211            None => Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
212        };
213        let chunk_timeout_ms = self
214            .chunk_timeout
215            .filter(|value| *value > 0)
216            .unwrap_or(DEFAULT_CHUNK_TIMEOUT_MS);
217        LlmTimeouts {
218            request_timeout,
219            chunk_timeout: Duration::from_millis(chunk_timeout_ms),
220        }
221    }
222
223    pub fn request_timeout(mut self, timeout: Option<RequestTimeout>) -> Self {
224        self.request_timeout = timeout;
225        self
226    }
227
228    pub fn stream_chunk_timeout_ms(mut self, timeout_ms: Option<u64>) -> Self {
229        self.chunk_timeout = timeout_ms;
230        self
231    }
232
233    pub fn max_attempts(mut self, attempts: u32) -> Self {
234        self.retry.max_attempts = attempts.max(1);
235        self
236    }
237
238    pub fn base_delay_ms(mut self, delay_ms: u64) -> Self {
239        self.retry.base_delay_ms = delay_ms;
240        self
241    }
242
243    pub fn max_delay_ms(mut self, delay_ms: u64) -> Self {
244        self.retry.max_delay_ms = delay_ms;
245        self
246    }
247
248    pub fn retry_after_cap_ms(mut self, cap_ms: Option<u64>) -> Self {
249        self.retry.retry_after_cap_ms = cap_ms;
250        self
251    }
252
253    pub fn max_concurrency(mut self, value: Option<usize>) -> Self {
254        self.rate_limits.max_concurrency = value;
255        self
256    }
257
258    pub fn requests_per_window(mut self, requests: Option<u32>, window_ms: Option<u64>) -> Self {
259        self.rate_limits.requests_per_window = requests;
260        self.rate_limits.request_window_ms = window_ms;
261        self
262    }
263
264    pub fn tokens_per_window(mut self, tokens: Option<u32>, window_ms: Option<u64>) -> Self {
265        self.rate_limits.tokens_per_window = tokens;
266        self.rate_limits.token_window_ms = window_ms;
267        self
268    }
269}
270
271#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
272pub struct ProviderRetryPolicy {
273    pub enabled: bool,
274    pub max_attempts: u32,
275    pub base_delay_ms: u64,
276    pub max_delay_ms: u64,
277    pub jitter_ms: u64,
278    #[serde(default, skip_serializing_if = "Option::is_none")]
279    pub retry_after_cap_ms: Option<u64>,
280}
281
282impl Default for ProviderRetryPolicy {
283    fn default() -> Self {
284        Self {
285            enabled: true,
286            max_attempts: 4,
287            base_delay_ms: 2_000,
288            max_delay_ms: 10_000,
289            jitter_ms: 0,
290            retry_after_cap_ms: Some(60_000),
291        }
292    }
293}
294
295impl ProviderRetryPolicy {
296    pub fn disabled() -> Self {
297        Self {
298            enabled: false,
299            max_attempts: 1,
300            base_delay_ms: 0,
301            max_delay_ms: 0,
302            jitter_ms: 0,
303            retry_after_cap_ms: None,
304        }
305    }
306
307    pub(crate) fn attempts(&self) -> u32 {
308        if self.enabled {
309            self.max_attempts.max(1)
310        } else {
311            1
312        }
313    }
314
315    pub(crate) fn delay_for_attempt(
316        &self,
317        retry_index: u32,
318        retry_after: Option<Duration>,
319    ) -> Duration {
320        if let Some(retry_after) = retry_after {
321            return self
322                .retry_after_cap_ms
323                .map(Duration::from_millis)
324                .map(|cap| retry_after.min(cap))
325                .unwrap_or(retry_after);
326        }
327        let multiplier = 1u64.checked_shl(retry_index).unwrap_or(u64::MAX);
328        let delay_ms = self
329            .base_delay_ms
330            .saturating_mul(multiplier)
331            .min(self.max_delay_ms);
332        Duration::from_millis(delay_ms.saturating_add(self.jitter_ms))
333    }
334}
335
336#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
337pub struct ProviderRateLimitPolicy {
338    #[serde(default, skip_serializing_if = "Option::is_none")]
339    pub max_concurrency: Option<usize>,
340    #[serde(default, skip_serializing_if = "Option::is_none")]
341    pub requests_per_window: Option<u32>,
342    #[serde(default, skip_serializing_if = "Option::is_none")]
343    pub request_window_ms: Option<u64>,
344    #[serde(default, skip_serializing_if = "Option::is_none")]
345    pub tokens_per_window: Option<u32>,
346    #[serde(default, skip_serializing_if = "Option::is_none")]
347    pub token_window_ms: Option<u64>,
348}