1use super::support::*;
2
3pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 300_000;
4pub const DEFAULT_CHUNK_TIMEOUT_MS: u64 = 120_000;
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq)]
7pub struct LlmTimeouts {
8 pub request_timeout: Option<Duration>,
9 pub chunk_timeout: Duration,
10}
11
12impl Default for LlmTimeouts {
13 fn default() -> Self {
14 Self {
15 request_timeout: Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
16 chunk_timeout: Duration::from_millis(DEFAULT_CHUNK_TIMEOUT_MS),
17 }
18 }
19}
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
22pub enum RequestTimeout {
23 Disabled,
24 Millis(u64),
25}
26
27impl Serialize for RequestTimeout {
28 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
29 where
30 S: Serializer,
31 {
32 match self {
33 Self::Disabled => serializer.serialize_bool(false),
34 Self::Millis(value) => serializer.serialize_u64(*value),
35 }
36 }
37}
38
39impl<'de> Deserialize<'de> for RequestTimeout {
40 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41 where
42 D: Deserializer<'de>,
43 {
44 struct RequestTimeoutVisitor;
45
46 impl Visitor<'_> for RequestTimeoutVisitor {
47 type Value = RequestTimeout;
48
49 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 formatter.write_str("a positive timeout in milliseconds or false")
51 }
52
53 fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
54 where
55 E: de::Error,
56 {
57 if value {
58 return Err(E::custom("timeout must be a positive integer or false"));
59 }
60 Ok(RequestTimeout::Disabled)
61 }
62
63 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
64 where
65 E: de::Error,
66 {
67 if value == 0 {
68 return Err(E::custom("timeout must be greater than 0"));
69 }
70 Ok(RequestTimeout::Millis(value))
71 }
72
73 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
74 where
75 E: de::Error,
76 {
77 if value <= 0 {
78 return Err(E::custom("timeout must be greater than 0"));
79 }
80 Ok(RequestTimeout::Millis(value as u64))
81 }
82 }
83
84 deserializer.deserialize_any(RequestTimeoutVisitor)
85 }
86}
87
88#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
95#[serde(rename_all = "kebab-case")]
96pub enum CacheRetention {
97 None,
99 #[default]
101 Short,
102 Long,
104}
105
106impl CacheRetention {
107 pub fn is_default(&self) -> bool {
108 matches!(self, CacheRetention::Short)
109 }
110}
111
112#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
113#[serde(deny_unknown_fields)]
114pub struct ProviderOptions {
115 #[serde(default)]
116 pub reliability: ProviderReliability,
117 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
119 pub expose_thinking: bool,
120 #[serde(default, skip_serializing_if = "Option::is_none")]
124 pub max_output_tokens: Option<u64>,
125 #[serde(default, skip_serializing_if = "CacheRetention::is_default")]
127 pub cache_retention: CacheRetention,
128}
129
130impl ProviderOptions {
131 pub fn is_default(&self) -> bool {
132 self.reliability == ProviderReliability::default()
133 && !self.expose_thinking
134 && self.max_output_tokens.is_none()
135 && self.cache_retention.is_default()
136 }
137
138 pub fn llm_timeouts(&self) -> LlmTimeouts {
139 self.reliability.llm_timeouts()
140 }
141}
142
143#[derive(Clone, Debug, PartialEq, Eq)]
144pub struct ResolvedGenerationPolicy<TThinking> {
145 pub max_output_tokens: u64,
146 pub cache_retention: CacheRetention,
147 pub expose_thinking: bool,
148 pub thinking: TThinking,
149}
150
151pub fn resolve_generation_policy<TThinking>(
152 generation: &crate::GenerationOptions,
153 options: &ProviderOptions,
154 provider_default_max_output_tokens: u64,
155 thinking: TThinking,
156) -> ResolvedGenerationPolicy<TThinking> {
157 let max_output_tokens = generation
158 .output_token_cap_u64()
159 .or(options.max_output_tokens)
160 .unwrap_or(provider_default_max_output_tokens);
161 ResolvedGenerationPolicy {
162 max_output_tokens,
163 cache_retention: options.cache_retention,
164 expose_thinking: options.expose_thinking,
165 thinking,
166 }
167}
168
169#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
170pub struct ProviderReliability {
171 #[serde(default, skip_serializing_if = "Option::is_none")]
174 pub request_timeout: Option<RequestTimeout>,
175 #[serde(default, skip_serializing_if = "Option::is_none")]
178 pub chunk_timeout: Option<u64>,
179 #[serde(default)]
180 pub retry: ProviderRetryPolicy,
181 #[serde(default)]
182 pub rate_limits: ProviderRateLimitPolicy,
183}
184
185impl ProviderReliability {
186 pub fn codex() -> Self {
187 Self {
188 retry: ProviderRetryPolicy {
189 max_attempts: 4,
190 base_delay_ms: 1_000,
191 max_delay_ms: 4_000,
192 jitter_ms: 0,
193 retry_after_cap_ms: Some(60_000),
194 enabled: true,
195 },
196 ..Self::default()
197 }
198 }
199
200 pub fn disabled() -> Self {
201 Self {
202 retry: ProviderRetryPolicy::disabled(),
203 ..Self::default()
204 }
205 }
206
207 pub fn llm_timeouts(&self) -> LlmTimeouts {
208 let request_timeout = match self.request_timeout {
209 Some(RequestTimeout::Disabled) => None,
210 Some(RequestTimeout::Millis(ms)) => Some(Duration::from_millis(ms)),
211 None => Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
212 };
213 let chunk_timeout_ms = self
214 .chunk_timeout
215 .filter(|value| *value > 0)
216 .unwrap_or(DEFAULT_CHUNK_TIMEOUT_MS);
217 LlmTimeouts {
218 request_timeout,
219 chunk_timeout: Duration::from_millis(chunk_timeout_ms),
220 }
221 }
222
223 pub fn request_timeout(mut self, timeout: Option<RequestTimeout>) -> Self {
224 self.request_timeout = timeout;
225 self
226 }
227
228 pub fn stream_chunk_timeout_ms(mut self, timeout_ms: Option<u64>) -> Self {
229 self.chunk_timeout = timeout_ms;
230 self
231 }
232
233 pub fn max_attempts(mut self, attempts: u32) -> Self {
234 self.retry.max_attempts = attempts.max(1);
235 self
236 }
237
238 pub fn base_delay_ms(mut self, delay_ms: u64) -> Self {
239 self.retry.base_delay_ms = delay_ms;
240 self
241 }
242
243 pub fn max_delay_ms(mut self, delay_ms: u64) -> Self {
244 self.retry.max_delay_ms = delay_ms;
245 self
246 }
247
248 pub fn retry_after_cap_ms(mut self, cap_ms: Option<u64>) -> Self {
249 self.retry.retry_after_cap_ms = cap_ms;
250 self
251 }
252
253 pub fn max_concurrency(mut self, value: Option<usize>) -> Self {
254 self.rate_limits.max_concurrency = value;
255 self
256 }
257
258 pub fn requests_per_window(mut self, requests: Option<u32>, window_ms: Option<u64>) -> Self {
259 self.rate_limits.requests_per_window = requests;
260 self.rate_limits.request_window_ms = window_ms;
261 self
262 }
263
264 pub fn tokens_per_window(mut self, tokens: Option<u32>, window_ms: Option<u64>) -> Self {
265 self.rate_limits.tokens_per_window = tokens;
266 self.rate_limits.token_window_ms = window_ms;
267 self
268 }
269}
270
271#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
272pub struct ProviderRetryPolicy {
273 pub enabled: bool,
274 pub max_attempts: u32,
275 pub base_delay_ms: u64,
276 pub max_delay_ms: u64,
277 pub jitter_ms: u64,
278 #[serde(default, skip_serializing_if = "Option::is_none")]
279 pub retry_after_cap_ms: Option<u64>,
280}
281
282impl Default for ProviderRetryPolicy {
283 fn default() -> Self {
284 Self {
285 enabled: true,
286 max_attempts: 4,
287 base_delay_ms: 2_000,
288 max_delay_ms: 10_000,
289 jitter_ms: 0,
290 retry_after_cap_ms: Some(60_000),
291 }
292 }
293}
294
295impl ProviderRetryPolicy {
296 pub fn disabled() -> Self {
297 Self {
298 enabled: false,
299 max_attempts: 1,
300 base_delay_ms: 0,
301 max_delay_ms: 0,
302 jitter_ms: 0,
303 retry_after_cap_ms: None,
304 }
305 }
306
307 pub(crate) fn attempts(&self) -> u32 {
308 if self.enabled {
309 self.max_attempts.max(1)
310 } else {
311 1
312 }
313 }
314
315 pub(crate) fn delay_for_attempt(
316 &self,
317 retry_index: u32,
318 retry_after: Option<Duration>,
319 ) -> Duration {
320 if let Some(retry_after) = retry_after {
321 return self
322 .retry_after_cap_ms
323 .map(Duration::from_millis)
324 .map(|cap| retry_after.min(cap))
325 .unwrap_or(retry_after);
326 }
327 let multiplier = 1u64.checked_shl(retry_index).unwrap_or(u64::MAX);
328 let delay_ms = self
329 .base_delay_ms
330 .saturating_mul(multiplier)
331 .min(self.max_delay_ms);
332 Duration::from_millis(delay_ms.saturating_add(self.jitter_ms))
333 }
334}
335
336#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
337pub struct ProviderRateLimitPolicy {
338 #[serde(default, skip_serializing_if = "Option::is_none")]
339 pub max_concurrency: Option<usize>,
340 #[serde(default, skip_serializing_if = "Option::is_none")]
341 pub requests_per_window: Option<u32>,
342 #[serde(default, skip_serializing_if = "Option::is_none")]
343 pub request_window_ms: Option<u64>,
344 #[serde(default, skip_serializing_if = "Option::is_none")]
345 pub tokens_per_window: Option<u32>,
346 #[serde(default, skip_serializing_if = "Option::is_none")]
347 pub token_window_ms: Option<u64>,
348}