1use super::support::*;
2
3pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 300_000;
4pub const DEFAULT_CHUNK_TIMEOUT_MS: u64 = 120_000;
5pub const DEFAULT_THROTTLE_WAIT_BUDGET_MS: u64 = 90_000;
6
7pub(crate) const MIN_THROTTLE_BUDGET_CHARGE: Duration = Duration::from_secs(1);
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub struct LlmTimeouts {
18 pub request_timeout: Option<Duration>,
19 pub chunk_timeout: Duration,
20}
21
22impl Default for LlmTimeouts {
23 fn default() -> Self {
24 Self {
25 request_timeout: Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
26 chunk_timeout: Duration::from_millis(DEFAULT_CHUNK_TIMEOUT_MS),
27 }
28 }
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum RequestTimeout {
33 Disabled,
34 Millis(u64),
35}
36
37impl Serialize for RequestTimeout {
38 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
39 where
40 S: Serializer,
41 {
42 match self {
43 Self::Disabled => serializer.serialize_bool(false),
44 Self::Millis(value) => serializer.serialize_u64(*value),
45 }
46 }
47}
48
49impl<'de> Deserialize<'de> for RequestTimeout {
50 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51 where
52 D: Deserializer<'de>,
53 {
54 struct RequestTimeoutVisitor;
55
56 impl Visitor<'_> for RequestTimeoutVisitor {
57 type Value = RequestTimeout;
58
59 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 formatter.write_str("a positive timeout in milliseconds or false")
61 }
62
63 fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
64 where
65 E: de::Error,
66 {
67 if value {
68 return Err(E::custom("timeout must be a positive integer or false"));
69 }
70 Ok(RequestTimeout::Disabled)
71 }
72
73 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
74 where
75 E: de::Error,
76 {
77 if value == 0 {
78 return Err(E::custom("timeout must be greater than 0"));
79 }
80 Ok(RequestTimeout::Millis(value))
81 }
82
83 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
84 where
85 E: de::Error,
86 {
87 if value <= 0 {
88 return Err(E::custom("timeout must be greater than 0"));
89 }
90 Ok(RequestTimeout::Millis(value as u64))
91 }
92 }
93
94 deserializer.deserialize_any(RequestTimeoutVisitor)
95 }
96}
97
98#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
105#[serde(rename_all = "kebab-case")]
106pub enum CacheRetention {
107 None,
109 #[default]
111 Short,
112 Long,
114}
115
116impl CacheRetention {
117 pub fn is_default(&self) -> bool {
118 matches!(self, CacheRetention::Short)
119 }
120}
121
122#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
123#[serde(deny_unknown_fields)]
124pub struct ProviderOptions {
125 #[serde(default)]
126 pub reliability: ProviderReliability,
127 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
129 pub expose_thinking: bool,
130 #[serde(default, skip_serializing_if = "Option::is_none")]
134 pub max_output_tokens: Option<u64>,
135 #[serde(default, skip_serializing_if = "CacheRetention::is_default")]
137 pub cache_retention: CacheRetention,
138}
139
140impl ProviderOptions {
141 pub fn is_default(&self) -> bool {
142 self.reliability == ProviderReliability::default()
143 && !self.expose_thinking
144 && self.max_output_tokens.is_none()
145 && self.cache_retention.is_default()
146 }
147
148 pub fn llm_timeouts(&self) -> LlmTimeouts {
149 self.reliability.llm_timeouts()
150 }
151}
152
153#[derive(Clone, Debug, PartialEq, Eq)]
154pub struct ResolvedGenerationPolicy<TThinking> {
155 pub max_output_tokens: u64,
156 pub cache_retention: CacheRetention,
157 pub expose_thinking: bool,
158 pub thinking: TThinking,
159}
160
161pub fn resolve_generation_policy<TThinking>(
162 generation: &crate::GenerationOptions,
163 options: &ProviderOptions,
164 provider_default_max_output_tokens: u64,
165 thinking: TThinking,
166) -> ResolvedGenerationPolicy<TThinking> {
167 let max_output_tokens = generation
168 .output_token_cap_u64()
169 .or(options.max_output_tokens)
170 .unwrap_or(provider_default_max_output_tokens);
171 ResolvedGenerationPolicy {
172 max_output_tokens,
173 cache_retention: options.cache_retention,
174 expose_thinking: options.expose_thinking,
175 thinking,
176 }
177}
178
179#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
180pub struct ProviderReliability {
181 #[serde(default, skip_serializing_if = "Option::is_none")]
184 pub request_timeout: Option<RequestTimeout>,
185 #[serde(default, skip_serializing_if = "Option::is_none")]
188 pub chunk_timeout: Option<u64>,
189 #[serde(default)]
190 pub retry: ProviderRetryPolicy,
191 #[serde(default)]
192 pub rate_limits: ProviderRateLimitPolicy,
193}
194
195impl ProviderReliability {
196 pub fn codex() -> Self {
197 Self {
198 retry: ProviderRetryPolicy {
199 max_attempts: 4,
200 base_delay_ms: 1_000,
201 max_delay_ms: 4_000,
202 jitter_ms: 0,
203 retry_after_cap_ms: Some(60_000),
204 throttle_wait_budget_ms: DEFAULT_THROTTLE_WAIT_BUDGET_MS,
205 enabled: true,
206 },
207 ..Self::default()
208 }
209 }
210
211 pub fn disabled() -> Self {
212 Self {
213 retry: ProviderRetryPolicy::disabled(),
214 ..Self::default()
215 }
216 }
217
218 pub fn llm_timeouts(&self) -> LlmTimeouts {
219 let request_timeout = match self.request_timeout {
220 Some(RequestTimeout::Disabled) => None,
221 Some(RequestTimeout::Millis(ms)) => Some(Duration::from_millis(ms)),
222 None => Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
223 };
224 let chunk_timeout_ms = self
225 .chunk_timeout
226 .filter(|value| *value > 0)
227 .unwrap_or(DEFAULT_CHUNK_TIMEOUT_MS);
228 LlmTimeouts {
229 request_timeout,
230 chunk_timeout: Duration::from_millis(chunk_timeout_ms),
231 }
232 }
233
234 pub fn request_timeout(mut self, timeout: Option<RequestTimeout>) -> Self {
235 self.request_timeout = timeout;
236 self
237 }
238
239 pub fn stream_chunk_timeout_ms(mut self, timeout_ms: Option<u64>) -> Self {
240 self.chunk_timeout = timeout_ms;
241 self
242 }
243
244 pub fn max_attempts(mut self, attempts: u32) -> Self {
245 self.retry.max_attempts = attempts.max(1);
246 self
247 }
248
249 pub fn base_delay_ms(mut self, delay_ms: u64) -> Self {
250 self.retry.base_delay_ms = delay_ms;
251 self
252 }
253
254 pub fn max_delay_ms(mut self, delay_ms: u64) -> Self {
255 self.retry.max_delay_ms = delay_ms;
256 self
257 }
258
259 pub fn retry_after_cap_ms(mut self, cap_ms: Option<u64>) -> Self {
260 self.retry.retry_after_cap_ms = cap_ms;
261 self
262 }
263
264 pub fn throttle_wait_budget_ms(mut self, budget_ms: u64) -> Self {
265 self.retry.throttle_wait_budget_ms = budget_ms;
266 self
267 }
268
269 pub fn max_concurrency(mut self, value: Option<usize>) -> Self {
270 self.rate_limits.max_concurrency = value;
271 self
272 }
273
274 pub fn requests_per_window(mut self, requests: Option<u32>, window_ms: Option<u64>) -> Self {
275 self.rate_limits.requests_per_window = requests;
276 self.rate_limits.request_window_ms = window_ms;
277 self
278 }
279
280 pub fn tokens_per_window(mut self, tokens: Option<u32>, window_ms: Option<u64>) -> Self {
281 self.rate_limits.tokens_per_window = tokens;
282 self.rate_limits.token_window_ms = window_ms;
283 self
284 }
285}
286
287#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
288pub struct ProviderRetryPolicy {
289 pub enabled: bool,
290 pub max_attempts: u32,
291 pub base_delay_ms: u64,
292 pub max_delay_ms: u64,
293 pub jitter_ms: u64,
294 #[serde(default, skip_serializing_if = "Option::is_none")]
295 pub retry_after_cap_ms: Option<u64>,
296 #[serde(
304 default = "default_throttle_wait_budget_ms",
305 skip_serializing_if = "is_default_throttle_wait_budget_ms"
306 )]
307 pub throttle_wait_budget_ms: u64,
308}
309
310fn default_throttle_wait_budget_ms() -> u64 {
311 DEFAULT_THROTTLE_WAIT_BUDGET_MS
312}
313
314fn is_default_throttle_wait_budget_ms(budget_ms: &u64) -> bool {
315 *budget_ms == DEFAULT_THROTTLE_WAIT_BUDGET_MS
316}
317
318impl Default for ProviderRetryPolicy {
319 fn default() -> Self {
320 Self {
321 enabled: true,
322 max_attempts: 4,
323 base_delay_ms: 2_000,
324 max_delay_ms: 10_000,
325 jitter_ms: 0,
326 retry_after_cap_ms: Some(60_000),
327 throttle_wait_budget_ms: DEFAULT_THROTTLE_WAIT_BUDGET_MS,
328 }
329 }
330}
331
332impl ProviderRetryPolicy {
333 pub fn disabled() -> Self {
334 Self {
335 enabled: false,
336 max_attempts: 1,
337 base_delay_ms: 0,
338 max_delay_ms: 0,
339 jitter_ms: 0,
340 retry_after_cap_ms: None,
341 throttle_wait_budget_ms: 0,
342 }
343 }
344
345 pub(crate) fn attempts(&self) -> u32 {
346 if self.enabled {
347 self.max_attempts.max(1)
348 } else {
349 1
350 }
351 }
352
353 pub(crate) fn cap_retry_after(&self, retry_after: Duration) -> Duration {
356 self.retry_after_cap_ms
357 .map(Duration::from_millis)
358 .map(|cap| retry_after.min(cap))
359 .unwrap_or(retry_after)
360 }
361
362 pub(crate) fn delay_for_attempt(
363 &self,
364 retry_index: u32,
365 retry_after: Option<Duration>,
366 ) -> Duration {
367 if let Some(retry_after) = retry_after {
368 return self.cap_retry_after(retry_after);
369 }
370 let multiplier = 1u64.checked_shl(retry_index).unwrap_or(u64::MAX);
371 let delay_ms = self
372 .base_delay_ms
373 .saturating_mul(multiplier)
374 .min(self.max_delay_ms);
375 Duration::from_millis(delay_ms.saturating_add(self.jitter_ms))
376 }
377}
378
379#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
380pub struct ProviderRateLimitPolicy {
381 #[serde(default, skip_serializing_if = "Option::is_none")]
382 pub max_concurrency: Option<usize>,
383 #[serde(default, skip_serializing_if = "Option::is_none")]
384 pub requests_per_window: Option<u32>,
385 #[serde(default, skip_serializing_if = "Option::is_none")]
386 pub request_window_ms: Option<u64>,
387 #[serde(default, skip_serializing_if = "Option::is_none")]
388 pub tokens_per_window: Option<u32>,
389 #[serde(default, skip_serializing_if = "Option::is_none")]
390 pub token_window_ms: Option<u64>,
391}