1use std::time::Duration;
3
4use bytes::Bytes;
5use error_stack::{Report, ResultExt};
6use rand::Rng;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use serde_with::{serde_as, DurationMilliSeconds};
9use tracing::instrument;
10
11use crate::{
12 format::{ChatRequest, StreamingResponseSender},
13 provider_lookup::{ModelLookupChoice, ModelLookupResult},
14 providers::{ProviderError, ProviderErrorKind, SendRequestOptions},
15};
16
17#[serde_as]
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RetryOptions {
20 #[serde_as(as = "DurationMilliSeconds")]
23 #[serde(default = "default_initial_backoff")]
24 initial_backoff: Duration,
25
26 #[serde(default)]
29 increase: RepeatBackoffBehavior,
30
31 #[serde(default = "default_max_tries")]
34 max_tries: u32,
35
36 #[serde_as(as = "DurationMilliSeconds")]
40 #[serde(default = "default_jitter")]
41 jitter: Duration,
42
43 #[serde_as(as = "DurationMilliSeconds")]
46 #[serde(default = "default_max_backoff")]
47 max_backoff: Duration,
48
49 #[serde(default = "true_t")]
53 fail_if_rate_limit_exceeds_max_backoff: bool,
54}
55
56impl Default for RetryOptions {
57 fn default() -> Self {
58 Self {
59 initial_backoff: default_initial_backoff(),
60 increase: RepeatBackoffBehavior::default(),
61 max_backoff: default_max_backoff(),
62 max_tries: default_max_tries(),
63 jitter: default_jitter(),
64 fail_if_rate_limit_exceeds_max_backoff: true,
65 }
66 }
67}
68
69const fn default_max_tries() -> u32 {
70 4
71}
72
73const fn default_initial_backoff() -> Duration {
74 Duration::from_millis(200)
75}
76
77const fn default_jitter() -> Duration {
78 Duration::from_millis(100)
79}
80
81const fn default_max_backoff() -> Duration {
82 Duration::from_millis(5000)
83}
84
85const fn true_t() -> bool {
86 true
87}
88
89#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
91#[serde_as]
92#[serde(tag = "type", rename_all = "snake_case")]
93pub enum RepeatBackoffBehavior {
94 Constant,
96 Additive {
98 #[serde_as(as = "DurationMilliSeconds")]
99 amount: Duration,
100 },
101 Exponential { multiplier: f64 },
103}
104
105impl Default for RepeatBackoffBehavior {
106 fn default() -> Self {
107 Self::Exponential { multiplier: 2.0 }
108 }
109}
110
111impl RepeatBackoffBehavior {
112 fn next(&self, current: Duration) -> Duration {
113 match self {
114 RepeatBackoffBehavior::Constant => current,
115 RepeatBackoffBehavior::Additive { amount } => {
116 Duration::from_nanos(current.as_nanos() as u64 + amount.as_nanos() as u64)
117 }
118 RepeatBackoffBehavior::Exponential { multiplier } => {
119 Duration::from_nanos((current.as_nanos() as f64 * multiplier) as u64)
120 }
121 }
122 }
123}
124
125struct BackoffValue<'a> {
126 next_backoff: Duration,
127 options: &'a RetryOptions,
128}
129
130impl<'a> BackoffValue<'a> {
131 fn new(options: &'a RetryOptions) -> Self {
132 Self {
133 next_backoff: options.initial_backoff,
134 options,
135 }
136 }
137
138 fn next(&mut self) -> Duration {
140 let mut backoff = self.next_backoff;
141 self.next_backoff = self.options.increase.next(backoff);
142
143 let max_jitter = self.options.jitter.as_secs_f64();
144 if max_jitter > 0.0 {
145 let jitter_value = rand::thread_rng().gen_range::<f64, _>(0.0..=1.0) * max_jitter;
146 backoff += Duration::from_secs_f64(jitter_value);
147 }
148
149 backoff.min(self.options.max_backoff)
150 }
151}
152
153#[derive(Debug, Clone)]
154pub struct TryModelChoicesResult {
155 pub provider: String,
157 pub model: String,
159 pub num_retries: u32,
161 pub was_rate_limited: bool,
163 pub start_time: tokio::time::Instant,
165}
166
167#[derive(Debug)]
168pub struct TryModelChoicesError {
169 pub error: Report<ProviderError>,
170 pub num_retries: u32,
171 pub was_rate_limited: bool,
172}
173
174#[instrument(level = "debug")]
176pub async fn try_model_choices(
177 ModelLookupResult {
178 alias,
179 random_order,
180 choices,
181 }: ModelLookupResult,
182 override_url: Option<String>,
183 options: RetryOptions,
184 timeout: Duration,
185 request: ChatRequest,
186 chunk_tx: StreamingResponseSender,
187) -> Result<TryModelChoicesResult, TryModelChoicesError> {
188 let single_choice = choices.len() == 1;
189 let start_choice = if random_order && !single_choice {
190 rand::thread_rng().gen_range(0..choices.len())
191 } else {
192 0
193 };
194
195 let mut current_choice = start_choice;
196
197 let mut on_final_model_choice = single_choice;
198 let mut backoff = BackoffValue::new(&options);
199
200 let mut was_rate_limited = false;
201 let mut current_try: u32 = 1;
202
203 loop {
204 let ModelLookupChoice {
205 model,
206 provider,
207 api_key,
208 } = &choices[current_choice];
209
210 let mut body = request.clone();
211 body.model = Some(model.to_string());
212 let start_time = tokio::time::Instant::now();
213 let result = provider
214 .send_request(
215 SendRequestOptions {
216 override_url: override_url.clone(),
217 timeout,
218 api_key: api_key.clone(),
219 body,
220 },
221 chunk_tx.clone(),
222 )
223 .await;
224
225 let provider_name = provider.name();
226 let error = match result {
227 Ok(_) => {
228 return Ok(TryModelChoicesResult {
230 was_rate_limited,
231 num_retries: current_try - 1,
232 provider: provider.name().to_string(),
233 model: model.to_string(),
234 start_time,
235 });
236 }
237 Err(e) => {
238 tracing::error!(err=?e, "llm.try"=current_try - 1, llm.vendor=provider_name, llm.request.model = model, llm.alias=alias);
239 e.attach_printable(format!(
240 "Try {current_try}, Provider: {provider_name}, Model: {model}"
241 ))
242 }
243 };
244
245 let provider_error = error
246 .frames()
247 .find_map(|frame| frame.downcast_ref::<ProviderError>());
248
249 if let Some(ProviderErrorKind::RateLimit { .. }) = provider_error.map(|e| &e.kind) {
250 was_rate_limited = true;
251 }
252
253 if current_try == options.max_tries
255 || (on_final_model_choice
256 && !provider_error.map(|e| e.kind.retryable()).unwrap_or(false))
257 {
258 return Err(TryModelChoicesError {
259 error,
260 num_retries: current_try - 1,
261 was_rate_limited,
262 });
263 }
264
265 if !on_final_model_choice {
266 if current_choice == choices.len() - 1 {
267 current_choice = 0;
268 } else {
269 current_choice = current_choice + 1;
270 }
271
272 if current_choice == start_choice {
273 on_final_model_choice = true;
276 }
277 }
278
279 if on_final_model_choice {
280 let wait = backoff.next();
282 let wait = match provider_error.map(|e| &e.kind) {
283 Some(ProviderErrorKind::RateLimit {
285 retry_after: Some(retry_after),
286 }) => {
287 if options.fail_if_rate_limit_exceeds_max_backoff
288 && *retry_after > options.max_backoff
289 {
290 return Err(TryModelChoicesError {
292 error,
293 num_retries: current_try - 1,
294 was_rate_limited,
295 });
296 }
297
298 wait.max(*retry_after)
301 }
302 _ => wait,
303 };
304
305 tokio::time::sleep(wait).await;
306 }
307
308 current_try += 1;
309 }
310}
311
312#[instrument(level = "debug", skip(body, prepare, handle_rate_limit))]
315pub async fn send_standard_request(
316 timeout: Duration,
317 prepare: impl Fn() -> reqwest::RequestBuilder,
318 handle_rate_limit: impl Fn(&reqwest::Response) -> Option<Duration>,
319 body: Bytes,
320) -> Result<(reqwest::Response, Duration), Report<ProviderError>> {
321 let start = tokio::time::Instant::now();
322 let result = prepare()
323 .timeout(timeout)
324 .body(body)
325 .send()
326 .await
327 .change_context(ProviderError {
328 kind: ProviderErrorKind::Sending,
329 status_code: None,
330 body: None,
331 latency: start.elapsed(),
332 })?;
333
334 let status = result.status();
335 let error = ProviderErrorKind::from_status_code(status);
336
337 if let Some(mut e) = error {
338 match &mut e {
339 ProviderErrorKind::RateLimit { retry_after } => {
340 let value = handle_rate_limit(&result);
341 *retry_after = value;
342 }
343 _ => {}
344 };
345
346 let body_text = result.text().await.ok();
347
348 let body_json = body_text
349 .as_deref()
350 .and_then(|text| serde_json::from_str::<serde_json::Value>(&text).ok());
351 let latency = start.elapsed();
352
353 Err(Report::new(ProviderError {
354 kind: e,
355 status_code: Some(status),
356 body: body_json.or_else(|| body_text.map(serde_json::Value::String)),
357 latency,
358 }))
359 } else {
360 let latency = start.elapsed();
361 Ok::<_, Report<ProviderError>>((result, latency))
362 }
363}
364
365pub fn response_is_sse(response: &reqwest::Response) -> bool {
366 response
367 .headers()
368 .get(reqwest::header::CONTENT_TYPE)
369 .and_then(|ct| ct.to_str().ok())
370 .map(|ct| ct.starts_with("text/event-stream"))
371 .unwrap_or_default()
372}
373
374pub async fn parse_response_json<RESPONSE: DeserializeOwned>(
377 response: reqwest::Response,
378 latency: Duration,
379) -> Result<RESPONSE, Report<ProviderError>> {
380 let status = response.status();
381
382 let text = response.text().await.change_context(ProviderError {
385 kind: ProviderErrorKind::ParsingResponse,
386 status_code: Some(status),
387 body: None,
388 latency,
389 })?;
390
391 let jd = &mut serde_json::Deserializer::from_str(&text);
392 let body: RESPONSE = serde_path_to_error::deserialize(jd).change_context(ProviderError {
393 kind: ProviderErrorKind::ParsingResponse,
394 status_code: Some(status),
395 body: Some(serde_json::Value::String(text)),
396 latency,
397 })?;
398
399 Ok(body)
400}
401
402#[cfg(test)]
403mod test {
404 use std::time::Duration;
405
406 use super::TryModelChoicesError;
407 use crate::{
408 format::{ChatMessage, ChatRequest, StreamingResponse, StreamingResponseReceiver},
409 provider_lookup::{ModelLookupChoice, ModelLookupResult},
410 request::{try_model_choices, RetryOptions, TryModelChoicesResult},
411 };
412
413 async fn test_request(
414 choices: Vec<ModelLookupChoice>,
415 ) -> Result<(TryModelChoicesResult, StreamingResponseReceiver), TryModelChoicesError> {
416 let (chunk_tx, chunk_rx) = flume::bounded(5);
417 let res = try_model_choices(
418 ModelLookupResult {
419 alias: String::new(),
420 random_order: false,
421 choices,
422 },
423 None,
424 RetryOptions::default(),
425 Duration::from_secs(5),
426 ChatRequest {
427 messages: vec![ChatMessage {
428 role: Some("user".to_string()),
429 content: Some("Tell me a story".to_string()),
430 tool_calls: Vec::new(),
431 ..Default::default()
432 }],
433 ..Default::default()
434 },
435 chunk_tx,
436 )
437 .await?;
438 Ok((res, chunk_rx))
439 }
440
441 async fn test_response(chunk_rx: StreamingResponseReceiver) {
442 let chunk = chunk_rx.recv_async().await.unwrap().unwrap();
443 match chunk {
444 StreamingResponse::Single(res) => {
445 assert_eq!(
446 res.choices[0].message.content.as_deref().unwrap(),
447 "A response"
448 );
449 }
450 _ => panic!("Unexpected chunk {chunk:?}"),
451 }
452 }
453
454 mod single_choice {
455 use std::sync::Arc;
456
457 use super::test_request;
458 use crate::{provider_lookup::ModelLookupChoice, testing::TestProvider};
459
460 #[tokio::test(start_paused = true)]
461 async fn success() {
462 let (result, chunk_rx) = test_request(vec![ModelLookupChoice {
463 model: "test-model".to_string(),
464 provider: TestProvider::default().into(),
465 api_key: None,
466 }])
467 .await
468 .expect("Failed");
469
470 assert_eq!(result.num_retries, 0);
471 assert_eq!(result.was_rate_limited, false);
472 assert_eq!(result.provider, "test");
473 assert_eq!(result.model, "test-model");
474
475 super::test_response(chunk_rx).await;
476 }
477
478 #[tokio::test(start_paused = true)]
479 async fn nonretryable_failures() {
480 let provider = Arc::new(TestProvider {
481 fail: Some(crate::testing::TestFailure::BadRequest),
482 ..Default::default()
483 });
484 let result = test_request(vec![ModelLookupChoice {
485 model: "test-model".to_string(),
486 provider: provider.clone(),
487 api_key: None,
488 }])
489 .await
490 .expect_err("Should have failed");
491
492 assert_eq!(provider.calls.load(std::sync::atomic::Ordering::Relaxed), 1);
493 assert_eq!(result.num_retries, 0);
494 assert_eq!(result.was_rate_limited, false);
495 }
496
497 #[tokio::test(start_paused = true)]
498 async fn transient_failure() {
499 let provider = Arc::new(TestProvider {
500 fail: Some(crate::testing::TestFailure::Transient),
501 fail_times: 2,
502 ..Default::default()
503 });
504 let (result, chunk_rx) = test_request(vec![ModelLookupChoice {
505 model: "test-model".to_string(),
506 provider: provider.clone(),
507 api_key: None,
508 }])
509 .await
510 .expect("Should succeed");
511
512 assert_eq!(
513 provider.calls.load(std::sync::atomic::Ordering::Relaxed),
514 3,
515 "Should succeed on third try"
516 );
517 assert_eq!(result.num_retries, 2);
518 assert_eq!(result.was_rate_limited, false);
519 assert_eq!(result.provider, "test");
520 assert_eq!(result.model, "test-model");
521 super::test_response(chunk_rx).await;
522 }
523
524 #[tokio::test(start_paused = true)]
525 async fn rate_limit() {
526 let provider = Arc::new(TestProvider {
527 fail: Some(crate::testing::TestFailure::RateLimit),
528 fail_times: 2,
529 ..Default::default()
530 });
531 let (result, chunk_rx) = test_request(vec![ModelLookupChoice {
532 model: "test-model".to_string(),
533 provider: provider.clone(),
534 api_key: None,
535 }])
536 .await
537 .expect("Should succeed");
538
539 assert_eq!(
540 provider.calls.load(std::sync::atomic::Ordering::Relaxed),
541 3,
542 "Should succeed on third try"
543 );
544 assert_eq!(result.num_retries, 2);
545 assert_eq!(result.was_rate_limited, true);
546 assert_eq!(result.provider, "test");
547 assert_eq!(result.model, "test-model");
548 super::test_response(chunk_rx).await;
549 }
550
551 #[tokio::test(start_paused = true)]
552 async fn max_retries() {
553 let provider = Arc::new(TestProvider {
554 fail: Some(crate::testing::TestFailure::Transient),
555 ..Default::default()
556 });
557 let response = test_request(vec![ModelLookupChoice {
558 model: "test-model".to_string(),
559 provider: provider.clone(),
560 api_key: None,
561 }])
562 .await
563 .expect_err("Should have failed");
564
565 assert_eq!(
566 provider.calls.load(std::sync::atomic::Ordering::Relaxed),
567 4,
568 "Should have tried 4 times"
569 );
570 assert_eq!(response.num_retries, 3);
571 assert_eq!(response.was_rate_limited, false);
572 }
573 }
574
575 mod multiple_choices {
576 use std::sync::Arc;
577
578 use super::test_request;
579 use crate::{
580 provider_lookup::ModelLookupChoice,
581 testing::{TestFailure, TestProvider},
582 };
583
584 #[tokio::test(start_paused = true)]
585 async fn success() {
586 let (result, chunk_rx) = test_request(vec![
587 ModelLookupChoice {
588 model: "test-model".to_string(),
589 provider: TestProvider::default().into(),
590 api_key: None,
591 },
592 ModelLookupChoice {
593 model: "test-model-2".to_string(),
594 provider: TestProvider::default().into(),
595 api_key: None,
596 },
597 ])
598 .await
599 .expect("Failed");
600
601 assert_eq!(result.num_retries, 0);
602 assert_eq!(result.was_rate_limited, false);
603 assert_eq!(result.provider, "test");
604 assert_eq!(result.model, "test-model");
605 super::test_response(chunk_rx).await;
606 }
607
608 #[tokio::test(start_paused = true)]
609 async fn transient_failures() {
610 let (result, chunk_rx) = test_request(vec![
611 ModelLookupChoice {
612 model: "test-model".to_string(),
613 provider: TestProvider {
614 fail: Some(TestFailure::Transient),
615 ..Default::default()
616 }
617 .into(),
618 api_key: None,
619 },
620 ModelLookupChoice {
621 model: "test-model-2".to_string(),
622 provider: TestProvider {
623 fail: Some(TestFailure::Transient),
624 ..Default::default()
625 }
626 .into(),
627 api_key: None,
628 },
629 ModelLookupChoice {
630 model: "test-model-3".to_string(),
631 provider: TestProvider::default().into(),
632 api_key: None,
633 },
634 ])
635 .await
636 .expect("Failed");
637
638 assert_eq!(result.num_retries, 2);
639 assert_eq!(result.was_rate_limited, false);
640 assert_eq!(result.provider, "test");
641 assert_eq!(result.model, "test-model-3");
642 super::test_response(chunk_rx).await;
643 }
644
645 #[tokio::test(start_paused = true)]
646 async fn rate_limit() {
647 let (result, chunk_rx) = test_request(vec![
648 ModelLookupChoice {
649 model: "test-model".to_string(),
650 provider: TestProvider {
651 fail: Some(TestFailure::RateLimit),
652 ..Default::default()
653 }
654 .into(),
655 api_key: None,
656 },
657 ModelLookupChoice {
658 model: "test-model-2".to_string(),
659 provider: TestProvider::default().into(),
660 api_key: None,
661 },
662 ])
663 .await
664 .expect("Failed");
665
666 assert_eq!(result.num_retries, 1);
667 assert_eq!(result.was_rate_limited, true);
668 assert_eq!(result.provider, "test");
669 assert_eq!(result.model, "test-model-2");
670 super::test_response(chunk_rx).await;
671 }
672
673 #[tokio::test(start_paused = true)]
674 async fn all_failed_every_time() {
675 let response = test_request(vec![
676 ModelLookupChoice {
677 model: "test-model".to_string(),
678 provider: TestProvider {
679 fail: Some(TestFailure::BadRequest),
680 ..Default::default()
681 }
682 .into(),
683 api_key: None,
684 },
685 ModelLookupChoice {
686 model: "test-model-2".to_string(),
687 provider: TestProvider {
688 fail: Some(TestFailure::RateLimit),
689 ..Default::default()
690 }
691 .into(),
692 api_key: None,
693 },
694 ModelLookupChoice {
695 model: "test-model-3".to_string(),
696 provider: TestProvider {
697 fail: Some(TestFailure::Transient),
698 ..Default::default()
699 }
700 .into(),
701 api_key: None,
702 },
703 ])
704 .await
705 .expect_err("Should have failed");
706
707 assert_eq!(response.num_retries, 3);
708 assert_eq!(response.was_rate_limited, true);
709 }
710
711 #[tokio::test(start_paused = true)]
712 async fn all_failed_once() {
713 let p1 = Arc::new(TestProvider {
714 fail: Some(TestFailure::BadRequest),
715 fail_times: 1,
716 ..Default::default()
717 });
718 let p2 = Arc::new(TestProvider {
719 fail: Some(TestFailure::RateLimit),
720 ..Default::default()
721 });
722 let p3 = Arc::new(TestProvider {
723 fail: Some(TestFailure::Transient),
724 ..Default::default()
725 });
726
727 let (result, _) = test_request(vec![
728 ModelLookupChoice {
729 model: "test-model".to_string(),
730 provider: p1.clone(),
731 api_key: None,
732 },
733 ModelLookupChoice {
734 model: "test-model-2".to_string(),
735 provider: p2.clone(),
736 api_key: None,
737 },
738 ModelLookupChoice {
739 model: "test-model-3".to_string(),
740 provider: p3.clone(),
741 api_key: None,
742 },
743 ])
744 .await
745 .expect("Should have succeeded");
746
747 assert_eq!(result.num_retries, 3);
748 assert_eq!(result.was_rate_limited, true);
749 assert_eq!(result.provider, "test");
750 assert_eq!(result.model, "test-model");
752 assert_eq!(p1.calls.load(std::sync::atomic::Ordering::Relaxed), 2);
753 assert_eq!(p2.calls.load(std::sync::atomic::Ordering::Relaxed), 1);
754 assert_eq!(p3.calls.load(std::sync::atomic::Ordering::Relaxed), 1);
755 }
756 }
757}