1use super::support::*;
2
3pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 300_000;
4pub const DEFAULT_CHUNK_TIMEOUT_MS: u64 = 120_000;
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq)]
7pub struct LlmTimeouts {
8 pub request_timeout: Option<Duration>,
9 pub chunk_timeout: Duration,
10}
11
12impl Default for LlmTimeouts {
13 fn default() -> Self {
14 Self {
15 request_timeout: Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
16 chunk_timeout: Duration::from_millis(DEFAULT_CHUNK_TIMEOUT_MS),
17 }
18 }
19}
20
21#[derive(Clone, Debug, PartialEq, Eq)]
25pub enum VariantRequestConfig {
26 ReasoningEffort(String),
27 GoogleThinkingLevel { level: String },
28 GoogleThinkingBudget { budget_tokens: i32 },
29 AnthropicAdaptiveThinking { effort: String },
30 AnthropicThinkingBudget { budget_tokens: i32 },
31}
32
33#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct AgentModelSelection {
36 pub model: String,
37 pub variant: Option<String>,
38}
39
40#[derive(Clone, Copy, Debug, PartialEq, Eq)]
41pub enum RequestTimeout {
42 Disabled,
43 Millis(u64),
44}
45
46impl Serialize for RequestTimeout {
47 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
48 where
49 S: Serializer,
50 {
51 match self {
52 Self::Disabled => serializer.serialize_bool(false),
53 Self::Millis(value) => serializer.serialize_u64(*value),
54 }
55 }
56}
57
58impl<'de> Deserialize<'de> for RequestTimeout {
59 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
60 where
61 D: Deserializer<'de>,
62 {
63 struct RequestTimeoutVisitor;
64
65 impl Visitor<'_> for RequestTimeoutVisitor {
66 type Value = RequestTimeout;
67
68 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 formatter.write_str("a positive timeout in milliseconds or false")
70 }
71
72 fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
73 where
74 E: de::Error,
75 {
76 if value {
77 return Err(E::custom("timeout must be a positive integer or false"));
78 }
79 Ok(RequestTimeout::Disabled)
80 }
81
82 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
83 where
84 E: de::Error,
85 {
86 if value == 0 {
87 return Err(E::custom("timeout must be greater than 0"));
88 }
89 Ok(RequestTimeout::Millis(value))
90 }
91
92 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
93 where
94 E: de::Error,
95 {
96 if value <= 0 {
97 return Err(E::custom("timeout must be greater than 0"));
98 }
99 Ok(RequestTimeout::Millis(value as u64))
100 }
101 }
102
103 deserializer.deserialize_any(RequestTimeoutVisitor)
104 }
105}
106
107#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
114#[serde(rename_all = "kebab-case")]
115pub enum CacheRetention {
116 None,
118 #[default]
120 Short,
121 Long,
123}
124
125impl CacheRetention {
126 pub fn is_default(&self) -> bool {
127 matches!(self, CacheRetention::Short)
128 }
129}
130
131#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(deny_unknown_fields)]
133pub struct ProviderOptions {
134 #[serde(default)]
135 pub reliability: ProviderReliability,
136 #[serde(default, skip_serializing_if = "ProviderThinkingPolicy::is_default")]
137 pub thinking: ProviderThinkingPolicy,
138 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub max_output_tokens: Option<u64>,
143 #[serde(default, skip_serializing_if = "CacheRetention::is_default")]
145 pub cache_retention: CacheRetention,
146}
147
148impl ProviderOptions {
149 pub fn is_default(&self) -> bool {
150 self.reliability == ProviderReliability::default_llm()
151 && self.thinking == ProviderThinkingPolicy::default()
152 && self.max_output_tokens.is_none()
153 && self.cache_retention.is_default()
154 }
155
156 pub fn llm_timeouts(&self) -> LlmTimeouts {
157 self.reliability.timeouts.llm_timeouts()
158 }
159}
160
161#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
162#[serde(deny_unknown_fields)]
163pub struct ProviderThinkingPolicy {
164 #[serde(default)]
165 pub expose: bool,
166}
167
168impl ProviderThinkingPolicy {
169 pub fn is_default(&self) -> bool {
170 !self.expose
171 }
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
175pub struct ProviderReliability {
176 #[serde(default)]
177 pub timeouts: ProviderTimeoutPolicy,
178 #[serde(default)]
179 pub retry: ProviderRetryPolicy,
180 #[serde(default)]
181 pub rate_limits: ProviderRateLimitPolicy,
182}
183
184impl ProviderReliability {
185 pub fn default_llm() -> Self {
186 Self {
187 timeouts: ProviderTimeoutPolicy::default(),
188 retry: ProviderRetryPolicy::default(),
189 rate_limits: ProviderRateLimitPolicy::default(),
190 }
191 }
192
193 pub fn codex() -> Self {
194 Self {
195 retry: ProviderRetryPolicy {
196 max_attempts: 4,
197 base_delay_ms: 1_000,
198 max_delay_ms: 4_000,
199 jitter_ms: 0,
200 retry_after_cap_ms: Some(60_000),
201 enabled: true,
202 },
203 ..Self::default_llm()
204 }
205 }
206
207 pub fn disabled() -> Self {
208 Self {
209 retry: ProviderRetryPolicy::disabled(),
210 rate_limits: ProviderRateLimitPolicy::default(),
211 timeouts: ProviderTimeoutPolicy::default(),
212 }
213 }
214
215 pub fn builder() -> ProviderReliabilityBuilder {
216 ProviderReliabilityBuilder {
217 reliability: Self::default_llm(),
218 }
219 }
220}
221
222impl Default for ProviderReliability {
223 fn default() -> Self {
224 Self::default_llm()
225 }
226}
227
228#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
229pub struct ProviderTimeoutPolicy {
230 #[serde(default, skip_serializing_if = "Option::is_none")]
231 pub request_timeout: Option<RequestTimeout>,
232 #[serde(default, skip_serializing_if = "Option::is_none")]
233 pub chunk_timeout: Option<u64>,
234}
235
236impl ProviderTimeoutPolicy {
237 pub fn llm_timeouts(&self) -> LlmTimeouts {
238 let request_timeout = match self.request_timeout {
239 Some(RequestTimeout::Disabled) => None,
240 Some(RequestTimeout::Millis(ms)) => Some(Duration::from_millis(ms)),
241 None => Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
242 };
243 let chunk_timeout_ms = self
244 .chunk_timeout
245 .filter(|value| *value > 0)
246 .unwrap_or(DEFAULT_CHUNK_TIMEOUT_MS);
247 LlmTimeouts {
248 request_timeout,
249 chunk_timeout: Duration::from_millis(chunk_timeout_ms),
250 }
251 }
252}
253
254#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
255pub struct ProviderRetryPolicy {
256 pub enabled: bool,
257 pub max_attempts: u32,
258 pub base_delay_ms: u64,
259 pub max_delay_ms: u64,
260 pub jitter_ms: u64,
261 #[serde(default, skip_serializing_if = "Option::is_none")]
262 pub retry_after_cap_ms: Option<u64>,
263}
264
265impl Default for ProviderRetryPolicy {
266 fn default() -> Self {
267 Self {
268 enabled: true,
269 max_attempts: 4,
270 base_delay_ms: 2_000,
271 max_delay_ms: 10_000,
272 jitter_ms: 0,
273 retry_after_cap_ms: Some(60_000),
274 }
275 }
276}
277
278impl ProviderRetryPolicy {
279 pub fn disabled() -> Self {
280 Self {
281 enabled: false,
282 max_attempts: 1,
283 base_delay_ms: 0,
284 max_delay_ms: 0,
285 jitter_ms: 0,
286 retry_after_cap_ms: None,
287 }
288 }
289
290 pub(crate) fn attempts(&self) -> u32 {
291 if self.enabled {
292 self.max_attempts.max(1)
293 } else {
294 1
295 }
296 }
297
298 pub(crate) fn delay_for_attempt(
299 &self,
300 retry_index: u32,
301 retry_after: Option<Duration>,
302 ) -> Duration {
303 if let Some(retry_after) = retry_after {
304 return self
305 .retry_after_cap_ms
306 .map(Duration::from_millis)
307 .map(|cap| retry_after.min(cap))
308 .unwrap_or(retry_after);
309 }
310 let multiplier = 1u64.checked_shl(retry_index).unwrap_or(u64::MAX);
311 let delay_ms = self
312 .base_delay_ms
313 .saturating_mul(multiplier)
314 .min(self.max_delay_ms);
315 Duration::from_millis(delay_ms.saturating_add(self.jitter_ms))
316 }
317}
318
319#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
320pub struct ProviderRateLimitPolicy {
321 #[serde(default, skip_serializing_if = "Option::is_none")]
322 pub max_concurrency: Option<usize>,
323 #[serde(default, skip_serializing_if = "Option::is_none")]
324 pub requests_per_window: Option<u32>,
325 #[serde(default, skip_serializing_if = "Option::is_none")]
326 pub request_window_ms: Option<u64>,
327 #[serde(default, skip_serializing_if = "Option::is_none")]
328 pub tokens_per_window: Option<u32>,
329 #[serde(default, skip_serializing_if = "Option::is_none")]
330 pub token_window_ms: Option<u64>,
331}
332
333pub struct ProviderReliabilityBuilder {
334 reliability: ProviderReliability,
335}
336
337impl ProviderReliabilityBuilder {
338 pub fn request_timeout(mut self, timeout: Option<RequestTimeout>) -> Self {
339 self.reliability.timeouts.request_timeout = timeout;
340 self
341 }
342
343 pub fn stream_chunk_timeout_ms(mut self, timeout_ms: Option<u64>) -> Self {
344 self.reliability.timeouts.chunk_timeout = timeout_ms;
345 self
346 }
347
348 pub fn max_attempts(mut self, attempts: u32) -> Self {
349 self.reliability.retry.max_attempts = attempts.max(1);
350 self
351 }
352
353 pub fn base_delay_ms(mut self, delay_ms: u64) -> Self {
354 self.reliability.retry.base_delay_ms = delay_ms;
355 self
356 }
357
358 pub fn max_delay_ms(mut self, delay_ms: u64) -> Self {
359 self.reliability.retry.max_delay_ms = delay_ms;
360 self
361 }
362
363 pub fn retry_after_cap_ms(mut self, cap_ms: Option<u64>) -> Self {
364 self.reliability.retry.retry_after_cap_ms = cap_ms;
365 self
366 }
367
368 pub fn max_concurrency(mut self, value: Option<usize>) -> Self {
369 self.reliability.rate_limits.max_concurrency = value;
370 self
371 }
372
373 pub fn requests_per_window(mut self, requests: Option<u32>, window_ms: Option<u64>) -> Self {
374 self.reliability.rate_limits.requests_per_window = requests;
375 self.reliability.rate_limits.request_window_ms = window_ms;
376 self
377 }
378
379 pub fn tokens_per_window(mut self, tokens: Option<u32>, window_ms: Option<u64>) -> Self {
380 self.reliability.rate_limits.tokens_per_window = tokens;
381 self.reliability.rate_limits.token_window_ms = window_ms;
382 self
383 }
384
385 pub fn build(self) -> ProviderReliability {
386 self.reliability
387 }
388}