Skip to main content

lash_core/provider/
options.rs

1use super::support::*;
2
3pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 300_000;
4pub const DEFAULT_CHUNK_TIMEOUT_MS: u64 = 120_000;
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq)]
7pub struct LlmTimeouts {
8    pub request_timeout: Option<Duration>,
9    pub chunk_timeout: Duration,
10}
11
12impl Default for LlmTimeouts {
13    fn default() -> Self {
14        Self {
15            request_timeout: Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
16            chunk_timeout: Duration::from_millis(DEFAULT_CHUNK_TIMEOUT_MS),
17        }
18    }
19}
20
21/// Per-request tuning a provider produces for a model + variant. Each
22/// concrete provider crate interprets its own variant strings and emits
23/// the request-shaping parameters its wire protocol needs.
24#[derive(Clone, Debug, PartialEq, Eq)]
25pub enum VariantRequestConfig {
26    ReasoningEffort(String),
27    GoogleThinkingLevel { level: String },
28    GoogleThinkingBudget { budget_tokens: i32 },
29    AnthropicAdaptiveThinking { effort: String },
30    AnthropicThinkingBudget { budget_tokens: i32 },
31}
32
33/// Model + optional variant returned by provider model policies for agent tiers.
34#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct AgentModelSelection {
36    pub model: String,
37    pub variant: Option<String>,
38}
39
40#[derive(Clone, Copy, Debug, PartialEq, Eq)]
41pub enum RequestTimeout {
42    Disabled,
43    Millis(u64),
44}
45
46impl Serialize for RequestTimeout {
47    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
48    where
49        S: Serializer,
50    {
51        match self {
52            Self::Disabled => serializer.serialize_bool(false),
53            Self::Millis(value) => serializer.serialize_u64(*value),
54        }
55    }
56}
57
58impl<'de> Deserialize<'de> for RequestTimeout {
59    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
60    where
61        D: Deserializer<'de>,
62    {
63        struct RequestTimeoutVisitor;
64
65        impl Visitor<'_> for RequestTimeoutVisitor {
66            type Value = RequestTimeout;
67
68            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69                formatter.write_str("a positive timeout in milliseconds or false")
70            }
71
72            fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
73            where
74                E: de::Error,
75            {
76                if value {
77                    return Err(E::custom("timeout must be a positive integer or false"));
78                }
79                Ok(RequestTimeout::Disabled)
80            }
81
82            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
83            where
84                E: de::Error,
85            {
86                if value == 0 {
87                    return Err(E::custom("timeout must be greater than 0"));
88                }
89                Ok(RequestTimeout::Millis(value))
90            }
91
92            fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
93            where
94                E: de::Error,
95            {
96                if value <= 0 {
97                    return Err(E::custom("timeout must be greater than 0"));
98                }
99                Ok(RequestTimeout::Millis(value as u64))
100            }
101        }
102
103        deserializer.deserialize_any(RequestTimeoutVisitor)
104    }
105}
106
107/// Prompt-cache lifetime hint. Providers translate this into their own
108/// wire dialect (Anthropic `cache_control` TTL, OpenRouter-Claude
109/// `cache_control` markers via Chat Completions, OpenAI Responses
110/// `prompt_cache_key` / `prompt_cache_retention`). Providers without a
111/// cache-control concept (Google, Codex) read the value but emit nothing
112/// for it.
113#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
114#[serde(rename_all = "kebab-case")]
115pub enum CacheRetention {
116    /// Do not emit any cache_control markers.
117    None,
118    /// Default Anthropic ephemeral window (5 minutes).
119    #[default]
120    Short,
121    /// Extend to a 1-hour TTL where the API supports it.
122    Long,
123}
124
125impl CacheRetention {
126    pub fn is_default(&self) -> bool {
127        matches!(self, CacheRetention::Short)
128    }
129}
130
131#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(deny_unknown_fields)]
133pub struct ProviderOptions {
134    #[serde(default)]
135    pub reliability: ProviderReliability,
136    #[serde(default, skip_serializing_if = "ProviderThinkingPolicy::is_default")]
137    pub thinking: ProviderThinkingPolicy,
138    /// Per-request output-token cap. `None` lets each provider apply its
139    /// own default. Providers translate to their wire-specific field
140    /// (`max_tokens`, `max_output_tokens`, `maxOutputTokens`, …).
141    #[serde(default, skip_serializing_if = "Option::is_none")]
142    pub max_output_tokens: Option<u64>,
143    /// Prompt-cache lifetime hint; see [`CacheRetention`].
144    #[serde(default, skip_serializing_if = "CacheRetention::is_default")]
145    pub cache_retention: CacheRetention,
146}
147
148impl ProviderOptions {
149    pub fn is_default(&self) -> bool {
150        self.reliability == ProviderReliability::default_llm()
151            && self.thinking == ProviderThinkingPolicy::default()
152            && self.max_output_tokens.is_none()
153            && self.cache_retention.is_default()
154    }
155
156    pub fn llm_timeouts(&self) -> LlmTimeouts {
157        self.reliability.timeouts.llm_timeouts()
158    }
159}
160
161#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
162#[serde(deny_unknown_fields)]
163pub struct ProviderThinkingPolicy {
164    #[serde(default)]
165    pub expose: bool,
166}
167
168impl ProviderThinkingPolicy {
169    pub fn is_default(&self) -> bool {
170        !self.expose
171    }
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
175pub struct ProviderReliability {
176    #[serde(default)]
177    pub timeouts: ProviderTimeoutPolicy,
178    #[serde(default)]
179    pub retry: ProviderRetryPolicy,
180    #[serde(default)]
181    pub rate_limits: ProviderRateLimitPolicy,
182}
183
184impl ProviderReliability {
185    pub fn default_llm() -> Self {
186        Self {
187            timeouts: ProviderTimeoutPolicy::default(),
188            retry: ProviderRetryPolicy::default(),
189            rate_limits: ProviderRateLimitPolicy::default(),
190        }
191    }
192
193    pub fn codex() -> Self {
194        Self {
195            retry: ProviderRetryPolicy {
196                max_attempts: 4,
197                base_delay_ms: 1_000,
198                max_delay_ms: 4_000,
199                jitter_ms: 0,
200                retry_after_cap_ms: Some(60_000),
201                enabled: true,
202            },
203            ..Self::default_llm()
204        }
205    }
206
207    pub fn disabled() -> Self {
208        Self {
209            retry: ProviderRetryPolicy::disabled(),
210            rate_limits: ProviderRateLimitPolicy::default(),
211            timeouts: ProviderTimeoutPolicy::default(),
212        }
213    }
214
215    pub fn builder() -> ProviderReliabilityBuilder {
216        ProviderReliabilityBuilder {
217            reliability: Self::default_llm(),
218        }
219    }
220}
221
222impl Default for ProviderReliability {
223    fn default() -> Self {
224        Self::default_llm()
225    }
226}
227
228#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
229pub struct ProviderTimeoutPolicy {
230    #[serde(default, skip_serializing_if = "Option::is_none")]
231    pub request_timeout: Option<RequestTimeout>,
232    #[serde(default, skip_serializing_if = "Option::is_none")]
233    pub chunk_timeout: Option<u64>,
234}
235
236impl ProviderTimeoutPolicy {
237    pub fn llm_timeouts(&self) -> LlmTimeouts {
238        let request_timeout = match self.request_timeout {
239            Some(RequestTimeout::Disabled) => None,
240            Some(RequestTimeout::Millis(ms)) => Some(Duration::from_millis(ms)),
241            None => Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
242        };
243        let chunk_timeout_ms = self
244            .chunk_timeout
245            .filter(|value| *value > 0)
246            .unwrap_or(DEFAULT_CHUNK_TIMEOUT_MS);
247        LlmTimeouts {
248            request_timeout,
249            chunk_timeout: Duration::from_millis(chunk_timeout_ms),
250        }
251    }
252}
253
254#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
255pub struct ProviderRetryPolicy {
256    pub enabled: bool,
257    pub max_attempts: u32,
258    pub base_delay_ms: u64,
259    pub max_delay_ms: u64,
260    pub jitter_ms: u64,
261    #[serde(default, skip_serializing_if = "Option::is_none")]
262    pub retry_after_cap_ms: Option<u64>,
263}
264
265impl Default for ProviderRetryPolicy {
266    fn default() -> Self {
267        Self {
268            enabled: true,
269            max_attempts: 4,
270            base_delay_ms: 2_000,
271            max_delay_ms: 10_000,
272            jitter_ms: 0,
273            retry_after_cap_ms: Some(60_000),
274        }
275    }
276}
277
278impl ProviderRetryPolicy {
279    pub fn disabled() -> Self {
280        Self {
281            enabled: false,
282            max_attempts: 1,
283            base_delay_ms: 0,
284            max_delay_ms: 0,
285            jitter_ms: 0,
286            retry_after_cap_ms: None,
287        }
288    }
289
290    pub(crate) fn attempts(&self) -> u32 {
291        if self.enabled {
292            self.max_attempts.max(1)
293        } else {
294            1
295        }
296    }
297
298    pub(crate) fn delay_for_attempt(
299        &self,
300        retry_index: u32,
301        retry_after: Option<Duration>,
302    ) -> Duration {
303        if let Some(retry_after) = retry_after {
304            return self
305                .retry_after_cap_ms
306                .map(Duration::from_millis)
307                .map(|cap| retry_after.min(cap))
308                .unwrap_or(retry_after);
309        }
310        let multiplier = 1u64.checked_shl(retry_index).unwrap_or(u64::MAX);
311        let delay_ms = self
312            .base_delay_ms
313            .saturating_mul(multiplier)
314            .min(self.max_delay_ms);
315        Duration::from_millis(delay_ms.saturating_add(self.jitter_ms))
316    }
317}
318
319#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
320pub struct ProviderRateLimitPolicy {
321    #[serde(default, skip_serializing_if = "Option::is_none")]
322    pub max_concurrency: Option<usize>,
323    #[serde(default, skip_serializing_if = "Option::is_none")]
324    pub requests_per_window: Option<u32>,
325    #[serde(default, skip_serializing_if = "Option::is_none")]
326    pub request_window_ms: Option<u64>,
327    #[serde(default, skip_serializing_if = "Option::is_none")]
328    pub tokens_per_window: Option<u32>,
329    #[serde(default, skip_serializing_if = "Option::is_none")]
330    pub token_window_ms: Option<u64>,
331}
332
333pub struct ProviderReliabilityBuilder {
334    reliability: ProviderReliability,
335}
336
337impl ProviderReliabilityBuilder {
338    pub fn request_timeout(mut self, timeout: Option<RequestTimeout>) -> Self {
339        self.reliability.timeouts.request_timeout = timeout;
340        self
341    }
342
343    pub fn stream_chunk_timeout_ms(mut self, timeout_ms: Option<u64>) -> Self {
344        self.reliability.timeouts.chunk_timeout = timeout_ms;
345        self
346    }
347
348    pub fn max_attempts(mut self, attempts: u32) -> Self {
349        self.reliability.retry.max_attempts = attempts.max(1);
350        self
351    }
352
353    pub fn base_delay_ms(mut self, delay_ms: u64) -> Self {
354        self.reliability.retry.base_delay_ms = delay_ms;
355        self
356    }
357
358    pub fn max_delay_ms(mut self, delay_ms: u64) -> Self {
359        self.reliability.retry.max_delay_ms = delay_ms;
360        self
361    }
362
363    pub fn retry_after_cap_ms(mut self, cap_ms: Option<u64>) -> Self {
364        self.reliability.retry.retry_after_cap_ms = cap_ms;
365        self
366    }
367
368    pub fn max_concurrency(mut self, value: Option<usize>) -> Self {
369        self.reliability.rate_limits.max_concurrency = value;
370        self
371    }
372
373    pub fn requests_per_window(mut self, requests: Option<u32>, window_ms: Option<u64>) -> Self {
374        self.reliability.rate_limits.requests_per_window = requests;
375        self.reliability.rate_limits.request_window_ms = window_ms;
376        self
377    }
378
379    pub fn tokens_per_window(mut self, tokens: Option<u32>, window_ms: Option<u64>) -> Self {
380        self.reliability.rate_limits.tokens_per_window = tokens;
381        self.reliability.rate_limits.token_window_ms = window_ms;
382        self
383    }
384
385    pub fn build(self) -> ProviderReliability {
386        self.reliability
387    }
388}