1use super::support::*;
2
3pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 300_000;
4pub const DEFAULT_CHUNK_TIMEOUT_MS: u64 = 120_000;
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq)]
7pub struct LlmTimeouts {
8 pub request_timeout: Option<Duration>,
9 pub chunk_timeout: Duration,
10}
11
12impl Default for LlmTimeouts {
13 fn default() -> Self {
14 Self {
15 request_timeout: Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
16 chunk_timeout: Duration::from_millis(DEFAULT_CHUNK_TIMEOUT_MS),
17 }
18 }
19}
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
22pub enum RequestTimeout {
23 Disabled,
24 Millis(u64),
25}
26
27impl Serialize for RequestTimeout {
28 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
29 where
30 S: Serializer,
31 {
32 match self {
33 Self::Disabled => serializer.serialize_bool(false),
34 Self::Millis(value) => serializer.serialize_u64(*value),
35 }
36 }
37}
38
39impl<'de> Deserialize<'de> for RequestTimeout {
40 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41 where
42 D: Deserializer<'de>,
43 {
44 struct RequestTimeoutVisitor;
45
46 impl Visitor<'_> for RequestTimeoutVisitor {
47 type Value = RequestTimeout;
48
49 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 formatter.write_str("a positive timeout in milliseconds or false")
51 }
52
53 fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
54 where
55 E: de::Error,
56 {
57 if value {
58 return Err(E::custom("timeout must be a positive integer or false"));
59 }
60 Ok(RequestTimeout::Disabled)
61 }
62
63 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
64 where
65 E: de::Error,
66 {
67 if value == 0 {
68 return Err(E::custom("timeout must be greater than 0"));
69 }
70 Ok(RequestTimeout::Millis(value))
71 }
72
73 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
74 where
75 E: de::Error,
76 {
77 if value <= 0 {
78 return Err(E::custom("timeout must be greater than 0"));
79 }
80 Ok(RequestTimeout::Millis(value as u64))
81 }
82 }
83
84 deserializer.deserialize_any(RequestTimeoutVisitor)
85 }
86}
87
88#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
95#[serde(rename_all = "kebab-case")]
96pub enum CacheRetention {
97 None,
99 #[default]
101 Short,
102 Long,
104}
105
106impl CacheRetention {
107 pub fn is_default(&self) -> bool {
108 matches!(self, CacheRetention::Short)
109 }
110}
111
112#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
113#[serde(deny_unknown_fields)]
114pub struct ProviderOptions {
115 #[serde(default)]
116 pub reliability: ProviderReliability,
117 #[serde(default, skip_serializing_if = "ProviderThinkingPolicy::is_default")]
118 pub thinking: ProviderThinkingPolicy,
119 #[serde(default, skip_serializing_if = "Option::is_none")]
123 pub max_output_tokens: Option<u64>,
124 #[serde(default, skip_serializing_if = "CacheRetention::is_default")]
126 pub cache_retention: CacheRetention,
127}
128
129impl ProviderOptions {
130 pub fn is_default(&self) -> bool {
131 self.reliability == ProviderReliability::default_llm()
132 && self.thinking == ProviderThinkingPolicy::default()
133 && self.max_output_tokens.is_none()
134 && self.cache_retention.is_default()
135 }
136
137 pub fn llm_timeouts(&self) -> LlmTimeouts {
138 self.reliability.timeouts.llm_timeouts()
139 }
140}
141
142#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
143#[serde(deny_unknown_fields)]
144pub struct ProviderThinkingPolicy {
145 #[serde(default)]
146 pub expose: bool,
147}
148
149impl ProviderThinkingPolicy {
150 pub fn is_default(&self) -> bool {
151 !self.expose
152 }
153}
154
155#[derive(Clone, Debug, PartialEq, Eq)]
156pub struct ResolvedGenerationPolicy<TThinking> {
157 pub max_output_tokens: u64,
158 pub cache_retention: CacheRetention,
159 pub expose_thinking: bool,
160 pub thinking: TThinking,
161}
162
163pub fn resolve_generation_policy<TThinking>(
164 generation: &crate::GenerationOptions,
165 options: &ProviderOptions,
166 provider_default_max_output_tokens: u64,
167 thinking: TThinking,
168) -> ResolvedGenerationPolicy<TThinking> {
169 let max_output_tokens = generation
170 .output_token_cap_u64()
171 .or(options.max_output_tokens)
172 .unwrap_or(provider_default_max_output_tokens);
173 ResolvedGenerationPolicy {
174 max_output_tokens,
175 cache_retention: options.cache_retention,
176 expose_thinking: options.thinking.expose,
177 thinking,
178 }
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
182pub struct ProviderReliability {
183 #[serde(default)]
184 pub timeouts: ProviderTimeoutPolicy,
185 #[serde(default)]
186 pub retry: ProviderRetryPolicy,
187 #[serde(default)]
188 pub rate_limits: ProviderRateLimitPolicy,
189}
190
191impl ProviderReliability {
192 pub fn default_llm() -> Self {
193 Self {
194 timeouts: ProviderTimeoutPolicy::default(),
195 retry: ProviderRetryPolicy::default(),
196 rate_limits: ProviderRateLimitPolicy::default(),
197 }
198 }
199
200 pub fn codex() -> Self {
201 Self {
202 retry: ProviderRetryPolicy {
203 max_attempts: 4,
204 base_delay_ms: 1_000,
205 max_delay_ms: 4_000,
206 jitter_ms: 0,
207 retry_after_cap_ms: Some(60_000),
208 enabled: true,
209 },
210 ..Self::default_llm()
211 }
212 }
213
214 pub fn disabled() -> Self {
215 Self {
216 retry: ProviderRetryPolicy::disabled(),
217 rate_limits: ProviderRateLimitPolicy::default(),
218 timeouts: ProviderTimeoutPolicy::default(),
219 }
220 }
221
222 pub fn builder() -> ProviderReliabilityBuilder {
223 ProviderReliabilityBuilder {
224 reliability: Self::default_llm(),
225 }
226 }
227}
228
229impl Default for ProviderReliability {
230 fn default() -> Self {
231 Self::default_llm()
232 }
233}
234
235#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
236pub struct ProviderTimeoutPolicy {
237 #[serde(default, skip_serializing_if = "Option::is_none")]
238 pub request_timeout: Option<RequestTimeout>,
239 #[serde(default, skip_serializing_if = "Option::is_none")]
240 pub chunk_timeout: Option<u64>,
241}
242
243impl ProviderTimeoutPolicy {
244 pub fn llm_timeouts(&self) -> LlmTimeouts {
245 let request_timeout = match self.request_timeout {
246 Some(RequestTimeout::Disabled) => None,
247 Some(RequestTimeout::Millis(ms)) => Some(Duration::from_millis(ms)),
248 None => Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
249 };
250 let chunk_timeout_ms = self
251 .chunk_timeout
252 .filter(|value| *value > 0)
253 .unwrap_or(DEFAULT_CHUNK_TIMEOUT_MS);
254 LlmTimeouts {
255 request_timeout,
256 chunk_timeout: Duration::from_millis(chunk_timeout_ms),
257 }
258 }
259}
260
261#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
262pub struct ProviderRetryPolicy {
263 pub enabled: bool,
264 pub max_attempts: u32,
265 pub base_delay_ms: u64,
266 pub max_delay_ms: u64,
267 pub jitter_ms: u64,
268 #[serde(default, skip_serializing_if = "Option::is_none")]
269 pub retry_after_cap_ms: Option<u64>,
270}
271
272impl Default for ProviderRetryPolicy {
273 fn default() -> Self {
274 Self {
275 enabled: true,
276 max_attempts: 4,
277 base_delay_ms: 2_000,
278 max_delay_ms: 10_000,
279 jitter_ms: 0,
280 retry_after_cap_ms: Some(60_000),
281 }
282 }
283}
284
285impl ProviderRetryPolicy {
286 pub fn disabled() -> Self {
287 Self {
288 enabled: false,
289 max_attempts: 1,
290 base_delay_ms: 0,
291 max_delay_ms: 0,
292 jitter_ms: 0,
293 retry_after_cap_ms: None,
294 }
295 }
296
297 pub(crate) fn attempts(&self) -> u32 {
298 if self.enabled {
299 self.max_attempts.max(1)
300 } else {
301 1
302 }
303 }
304
305 pub(crate) fn delay_for_attempt(
306 &self,
307 retry_index: u32,
308 retry_after: Option<Duration>,
309 ) -> Duration {
310 if let Some(retry_after) = retry_after {
311 return self
312 .retry_after_cap_ms
313 .map(Duration::from_millis)
314 .map(|cap| retry_after.min(cap))
315 .unwrap_or(retry_after);
316 }
317 let multiplier = 1u64.checked_shl(retry_index).unwrap_or(u64::MAX);
318 let delay_ms = self
319 .base_delay_ms
320 .saturating_mul(multiplier)
321 .min(self.max_delay_ms);
322 Duration::from_millis(delay_ms.saturating_add(self.jitter_ms))
323 }
324}
325
326#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
327pub struct ProviderRateLimitPolicy {
328 #[serde(default, skip_serializing_if = "Option::is_none")]
329 pub max_concurrency: Option<usize>,
330 #[serde(default, skip_serializing_if = "Option::is_none")]
331 pub requests_per_window: Option<u32>,
332 #[serde(default, skip_serializing_if = "Option::is_none")]
333 pub request_window_ms: Option<u64>,
334 #[serde(default, skip_serializing_if = "Option::is_none")]
335 pub tokens_per_window: Option<u32>,
336 #[serde(default, skip_serializing_if = "Option::is_none")]
337 pub token_window_ms: Option<u64>,
338}
339
340pub struct ProviderReliabilityBuilder {
341 reliability: ProviderReliability,
342}
343
344impl ProviderReliabilityBuilder {
345 pub fn request_timeout(mut self, timeout: Option<RequestTimeout>) -> Self {
346 self.reliability.timeouts.request_timeout = timeout;
347 self
348 }
349
350 pub fn stream_chunk_timeout_ms(mut self, timeout_ms: Option<u64>) -> Self {
351 self.reliability.timeouts.chunk_timeout = timeout_ms;
352 self
353 }
354
355 pub fn max_attempts(mut self, attempts: u32) -> Self {
356 self.reliability.retry.max_attempts = attempts.max(1);
357 self
358 }
359
360 pub fn base_delay_ms(mut self, delay_ms: u64) -> Self {
361 self.reliability.retry.base_delay_ms = delay_ms;
362 self
363 }
364
365 pub fn max_delay_ms(mut self, delay_ms: u64) -> Self {
366 self.reliability.retry.max_delay_ms = delay_ms;
367 self
368 }
369
370 pub fn retry_after_cap_ms(mut self, cap_ms: Option<u64>) -> Self {
371 self.reliability.retry.retry_after_cap_ms = cap_ms;
372 self
373 }
374
375 pub fn max_concurrency(mut self, value: Option<usize>) -> Self {
376 self.reliability.rate_limits.max_concurrency = value;
377 self
378 }
379
380 pub fn requests_per_window(mut self, requests: Option<u32>, window_ms: Option<u64>) -> Self {
381 self.reliability.rate_limits.requests_per_window = requests;
382 self.reliability.rate_limits.request_window_ms = window_ms;
383 self
384 }
385
386 pub fn tokens_per_window(mut self, tokens: Option<u32>, window_ms: Option<u64>) -> Self {
387 self.reliability.rate_limits.tokens_per_window = tokens;
388 self.reliability.rate_limits.token_window_ms = window_ms;
389 self
390 }
391
392 pub fn build(self) -> ProviderReliability {
393 self.reliability
394 }
395}