chronicle_proxy/
request.rs

1//! Utilities for retrying a provider request as needed.
2use std::time::Duration;
3
4use bytes::Bytes;
5use error_stack::{Report, ResultExt};
6use rand::Rng;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use serde_with::{serde_as, DurationMilliSeconds};
9use tracing::instrument;
10
11use crate::{
12    format::{ChatRequest, StreamingResponseSender},
13    provider_lookup::{ModelLookupChoice, ModelLookupResult},
14    providers::{ProviderError, ProviderErrorKind, SendRequestOptions},
15};
16
17#[serde_as]
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RetryOptions {
20    /// How long to wait after the first failure.
21    /// The default value is 200ms.
22    #[serde_as(as = "DurationMilliSeconds")]
23    #[serde(default = "default_initial_backoff")]
24    initial_backoff: Duration,
25
26    /// How to increase the backoff duration as additional retries occur. The default value
27    /// is an exponential backoff with a multiplier of `2.0`.
28    #[serde(default)]
29    increase: RepeatBackoffBehavior,
30
31    /// The number of times to try the request, including the first try.
32    /// Defaults to 4.
33    #[serde(default = "default_max_tries")]
34    max_tries: u32,
35
36    /// Maximum amount of jitter to add. The added jitter will be a random value between 0 and this
37    /// value.
38    /// Defaults to 100ms.
39    #[serde_as(as = "DurationMilliSeconds")]
40    #[serde(default = "default_jitter")]
41    jitter: Duration,
42
43    /// Never wait more than this amount of time. The behavior of this flag may be modified by the
44    /// `fail_if_rate_limit_exceeds_max_backoff` flag.
45    #[serde_as(as = "DurationMilliSeconds")]
46    #[serde(default = "default_max_backoff")]
47    max_backoff: Duration,
48
49    /// If a rate limit response asks to wait longer than the `max_backoff`, then stop retrying.
50    /// Otherwise it will wait for the requested time even if it is longer than max_backoff.
51    /// Defaults to true.
52    #[serde(default = "true_t")]
53    fail_if_rate_limit_exceeds_max_backoff: bool,
54}
55
56impl Default for RetryOptions {
57    fn default() -> Self {
58        Self {
59            initial_backoff: default_initial_backoff(),
60            increase: RepeatBackoffBehavior::default(),
61            max_backoff: default_max_backoff(),
62            max_tries: default_max_tries(),
63            jitter: default_jitter(),
64            fail_if_rate_limit_exceeds_max_backoff: true,
65        }
66    }
67}
68
69const fn default_max_tries() -> u32 {
70    4
71}
72
73const fn default_initial_backoff() -> Duration {
74    Duration::from_millis(200)
75}
76
77const fn default_jitter() -> Duration {
78    Duration::from_millis(100)
79}
80
81const fn default_max_backoff() -> Duration {
82    Duration::from_millis(5000)
83}
84
85const fn true_t() -> bool {
86    true
87}
88
89/// How to increase the backoff duration as additional retries occur.
90#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
91#[serde_as]
92#[serde(tag = "type", rename_all = "snake_case")]
93pub enum RepeatBackoffBehavior {
94    /// Use the initial backoff duration for additional retries as well.
95    Constant,
96    /// Add this duration to the backoff duration after each retry.
97    Additive {
98        #[serde_as(as = "DurationMilliSeconds")]
99        amount: Duration,
100    },
101    /// Multiply the backoff duration by this value after each retry.
102    Exponential { multiplier: f64 },
103}
104
105impl Default for RepeatBackoffBehavior {
106    fn default() -> Self {
107        Self::Exponential { multiplier: 2.0 }
108    }
109}
110
111impl RepeatBackoffBehavior {
112    fn next(&self, current: Duration) -> Duration {
113        match self {
114            RepeatBackoffBehavior::Constant => current,
115            RepeatBackoffBehavior::Additive { amount } => {
116                Duration::from_nanos(current.as_nanos() as u64 + amount.as_nanos() as u64)
117            }
118            RepeatBackoffBehavior::Exponential { multiplier } => {
119                Duration::from_nanos((current.as_nanos() as f64 * multiplier) as u64)
120            }
121        }
122    }
123}
124
125struct BackoffValue<'a> {
126    next_backoff: Duration,
127    options: &'a RetryOptions,
128}
129
130impl<'a> BackoffValue<'a> {
131    fn new(options: &'a RetryOptions) -> Self {
132        Self {
133            next_backoff: options.initial_backoff,
134            options,
135        }
136    }
137
138    /// Return the next duration to wait for
139    fn next(&mut self) -> Duration {
140        let mut backoff = self.next_backoff;
141        self.next_backoff = self.options.increase.next(backoff);
142
143        let max_jitter = self.options.jitter.as_secs_f64();
144        if max_jitter > 0.0 {
145            let jitter_value = rand::thread_rng().gen_range::<f64, _>(0.0..=1.0) * max_jitter;
146            backoff += Duration::from_secs_f64(jitter_value);
147        }
148
149        backoff.min(self.options.max_backoff)
150    }
151}
152
153#[derive(Debug, Clone)]
154pub struct TryModelChoicesResult {
155    /// The provider which was used for the successful request.
156    pub provider: String,
157    /// The model which was used for the successful request
158    pub model: String,
159    /// How many times we had to retry before we got a successful response.
160    pub num_retries: u32,
161    /// If we retried due to hitting a rate limit.
162    pub was_rate_limited: bool,
163    /// When the latest, successful request started
164    pub start_time: tokio::time::Instant,
165}
166
167#[derive(Debug)]
168pub struct TryModelChoicesError {
169    pub error: Report<ProviderError>,
170    pub num_retries: u32,
171    pub was_rate_limited: bool,
172}
173
174/// Run a provider request and retry on failure.
175#[instrument(level = "debug")]
176pub async fn try_model_choices(
177    ModelLookupResult {
178        alias,
179        random_order,
180        choices,
181    }: ModelLookupResult,
182    override_url: Option<String>,
183    options: RetryOptions,
184    timeout: Duration,
185    request: ChatRequest,
186    chunk_tx: StreamingResponseSender,
187) -> Result<TryModelChoicesResult, TryModelChoicesError> {
188    let single_choice = choices.len() == 1;
189    let start_choice = if random_order && !single_choice {
190        rand::thread_rng().gen_range(0..choices.len())
191    } else {
192        0
193    };
194
195    let mut current_choice = start_choice;
196
197    let mut on_final_model_choice = single_choice;
198    let mut backoff = BackoffValue::new(&options);
199
200    let mut was_rate_limited = false;
201    let mut current_try: u32 = 1;
202
203    loop {
204        let ModelLookupChoice {
205            model,
206            provider,
207            api_key,
208        } = &choices[current_choice];
209
210        let mut body = request.clone();
211        body.model = Some(model.to_string());
212        let start_time = tokio::time::Instant::now();
213        let result = provider
214            .send_request(
215                SendRequestOptions {
216                    override_url: override_url.clone(),
217                    timeout,
218                    api_key: api_key.clone(),
219                    body,
220                },
221                chunk_tx.clone(),
222            )
223            .await;
224
225        let provider_name = provider.name();
226        let error = match result {
227            Ok(_) => {
228                // The caller will stream the response from here.
229                return Ok(TryModelChoicesResult {
230                    was_rate_limited,
231                    num_retries: current_try - 1,
232                    provider: provider.name().to_string(),
233                    model: model.to_string(),
234                    start_time,
235                });
236            }
237            Err(e) => {
238                tracing::error!(err=?e, "llm.try"=current_try - 1, llm.vendor=provider_name, llm.request.model = model, llm.alias=alias);
239                e.attach_printable(format!(
240                    "Try {current_try}, Provider: {provider_name}, Model: {model}"
241                ))
242            }
243        };
244
245        let provider_error = error
246            .frames()
247            .find_map(|frame| frame.downcast_ref::<ProviderError>());
248
249        if let Some(ProviderErrorKind::RateLimit { .. }) = provider_error.map(|e| &e.kind) {
250            was_rate_limited = true;
251        }
252
253        // If we don't have any more fallback models, and this error is not retryable, then exit.
254        if current_try == options.max_tries
255            || (on_final_model_choice
256                && !provider_error.map(|e| e.kind.retryable()).unwrap_or(false))
257        {
258            return Err(TryModelChoicesError {
259                error,
260                num_retries: current_try - 1,
261                was_rate_limited,
262            });
263        }
264
265        if !on_final_model_choice {
266            if current_choice == choices.len() - 1 {
267                current_choice = 0;
268            } else {
269                current_choice = current_choice + 1;
270            }
271
272            if current_choice == start_choice {
273                // We looped around to the first choice again so enable backoff if it wasn't already
274                // on.
275                on_final_model_choice = true;
276            }
277        }
278
279        if on_final_model_choice {
280            // If we're on the final model choice then we need to backoff before the next retry.
281            let wait = backoff.next();
282            let wait = match provider_error.map(|e| &e.kind) {
283                // Rate limited, where the provider specified a time to wait
284                Some(ProviderErrorKind::RateLimit {
285                    retry_after: Some(retry_after),
286                }) => {
287                    if options.fail_if_rate_limit_exceeds_max_backoff
288                        && *retry_after > options.max_backoff
289                    {
290                        // Rate limited with a retry time that exceeds max backoff.
291                        return Err(TryModelChoicesError {
292                            error,
293                            num_retries: current_try - 1,
294                            was_rate_limited,
295                        });
296                    }
297
298                    // If the rate limit retry duration is more than the planned wait, then wait for
299                    // the rate limit duration instead.
300                    wait.max(*retry_after)
301                }
302                _ => wait,
303            };
304
305            tokio::time::sleep(wait).await;
306        }
307
308        current_try += 1;
309    }
310}
311
312/// Send an HTTP request with retries, and handle errors.
313/// Most providers can use this to handle sending their request and handling errors.
314#[instrument(level = "debug", skip(body, prepare, handle_rate_limit))]
315pub async fn send_standard_request(
316    timeout: Duration,
317    prepare: impl Fn() -> reqwest::RequestBuilder,
318    handle_rate_limit: impl Fn(&reqwest::Response) -> Option<Duration>,
319    body: Bytes,
320) -> Result<(reqwest::Response, Duration), Report<ProviderError>> {
321    let start = tokio::time::Instant::now();
322    let result = prepare()
323        .timeout(timeout)
324        .body(body)
325        .send()
326        .await
327        .change_context(ProviderError {
328            kind: ProviderErrorKind::Sending,
329            status_code: None,
330            body: None,
331            latency: start.elapsed(),
332        })?;
333
334    let status = result.status();
335    let error = ProviderErrorKind::from_status_code(status);
336
337    if let Some(mut e) = error {
338        match &mut e {
339            ProviderErrorKind::RateLimit { retry_after } => {
340                let value = handle_rate_limit(&result);
341                *retry_after = value;
342            }
343            _ => {}
344        };
345
346        let body_text = result.text().await.ok();
347
348        let body_json = body_text
349            .as_deref()
350            .and_then(|text| serde_json::from_str::<serde_json::Value>(&text).ok());
351        let latency = start.elapsed();
352
353        Err(Report::new(ProviderError {
354            kind: e,
355            status_code: Some(status),
356            body: body_json.or_else(|| body_text.map(serde_json::Value::String)),
357            latency,
358        }))
359    } else {
360        let latency = start.elapsed();
361        Ok::<_, Report<ProviderError>>((result, latency))
362    }
363}
364
365pub fn response_is_sse(response: &reqwest::Response) -> bool {
366    response
367        .headers()
368        .get(reqwest::header::CONTENT_TYPE)
369        .and_then(|ct| ct.to_str().ok())
370        .map(|ct| ct.starts_with("text/event-stream"))
371        .unwrap_or_default()
372}
373
374/// Parse a JSON response, with informative errors when the format does not match the expected
375/// structure.
376pub async fn parse_response_json<RESPONSE: DeserializeOwned>(
377    response: reqwest::Response,
378    latency: Duration,
379) -> Result<RESPONSE, Report<ProviderError>> {
380    let status = response.status();
381
382    // Get the result as text first so that we can save the entire response for better
383    // introspection if parsing fails.
384    let text = response.text().await.change_context(ProviderError {
385        kind: ProviderErrorKind::ParsingResponse,
386        status_code: Some(status),
387        body: None,
388        latency,
389    })?;
390
391    let jd = &mut serde_json::Deserializer::from_str(&text);
392    let body: RESPONSE = serde_path_to_error::deserialize(jd).change_context(ProviderError {
393        kind: ProviderErrorKind::ParsingResponse,
394        status_code: Some(status),
395        body: Some(serde_json::Value::String(text)),
396        latency,
397    })?;
398
399    Ok(body)
400}
401
402#[cfg(test)]
403mod test {
404    use std::time::Duration;
405
406    use super::TryModelChoicesError;
407    use crate::{
408        format::{ChatMessage, ChatRequest, StreamingResponse, StreamingResponseReceiver},
409        provider_lookup::{ModelLookupChoice, ModelLookupResult},
410        request::{try_model_choices, RetryOptions, TryModelChoicesResult},
411    };
412
413    async fn test_request(
414        choices: Vec<ModelLookupChoice>,
415    ) -> Result<(TryModelChoicesResult, StreamingResponseReceiver), TryModelChoicesError> {
416        let (chunk_tx, chunk_rx) = flume::bounded(5);
417        let res = try_model_choices(
418            ModelLookupResult {
419                alias: String::new(),
420                random_order: false,
421                choices,
422            },
423            None,
424            RetryOptions::default(),
425            Duration::from_secs(5),
426            ChatRequest {
427                messages: vec![ChatMessage {
428                    role: Some("user".to_string()),
429                    content: Some("Tell me a story".to_string()),
430                    tool_calls: Vec::new(),
431                    ..Default::default()
432                }],
433                ..Default::default()
434            },
435            chunk_tx,
436        )
437        .await?;
438        Ok((res, chunk_rx))
439    }
440
441    async fn test_response(chunk_rx: StreamingResponseReceiver) {
442        let chunk = chunk_rx.recv_async().await.unwrap().unwrap();
443        match chunk {
444            StreamingResponse::Single(res) => {
445                assert_eq!(
446                    res.choices[0].message.content.as_deref().unwrap(),
447                    "A response"
448                );
449            }
450            _ => panic!("Unexpected chunk {chunk:?}"),
451        }
452    }
453
454    mod single_choice {
455        use std::sync::Arc;
456
457        use super::test_request;
458        use crate::{provider_lookup::ModelLookupChoice, testing::TestProvider};
459
460        #[tokio::test(start_paused = true)]
461        async fn success() {
462            let (result, chunk_rx) = test_request(vec![ModelLookupChoice {
463                model: "test-model".to_string(),
464                provider: TestProvider::default().into(),
465                api_key: None,
466            }])
467            .await
468            .expect("Failed");
469
470            assert_eq!(result.num_retries, 0);
471            assert_eq!(result.was_rate_limited, false);
472            assert_eq!(result.provider, "test");
473            assert_eq!(result.model, "test-model");
474
475            super::test_response(chunk_rx).await;
476        }
477
478        #[tokio::test(start_paused = true)]
479        async fn nonretryable_failures() {
480            let provider = Arc::new(TestProvider {
481                fail: Some(crate::testing::TestFailure::BadRequest),
482                ..Default::default()
483            });
484            let result = test_request(vec![ModelLookupChoice {
485                model: "test-model".to_string(),
486                provider: provider.clone(),
487                api_key: None,
488            }])
489            .await
490            .expect_err("Should have failed");
491
492            assert_eq!(provider.calls.load(std::sync::atomic::Ordering::Relaxed), 1);
493            assert_eq!(result.num_retries, 0);
494            assert_eq!(result.was_rate_limited, false);
495        }
496
497        #[tokio::test(start_paused = true)]
498        async fn transient_failure() {
499            let provider = Arc::new(TestProvider {
500                fail: Some(crate::testing::TestFailure::Transient),
501                fail_times: 2,
502                ..Default::default()
503            });
504            let (result, chunk_rx) = test_request(vec![ModelLookupChoice {
505                model: "test-model".to_string(),
506                provider: provider.clone(),
507                api_key: None,
508            }])
509            .await
510            .expect("Should succeed");
511
512            assert_eq!(
513                provider.calls.load(std::sync::atomic::Ordering::Relaxed),
514                3,
515                "Should succeed on third try"
516            );
517            assert_eq!(result.num_retries, 2);
518            assert_eq!(result.was_rate_limited, false);
519            assert_eq!(result.provider, "test");
520            assert_eq!(result.model, "test-model");
521            super::test_response(chunk_rx).await;
522        }
523
524        #[tokio::test(start_paused = true)]
525        async fn rate_limit() {
526            let provider = Arc::new(TestProvider {
527                fail: Some(crate::testing::TestFailure::RateLimit),
528                fail_times: 2,
529                ..Default::default()
530            });
531            let (result, chunk_rx) = test_request(vec![ModelLookupChoice {
532                model: "test-model".to_string(),
533                provider: provider.clone(),
534                api_key: None,
535            }])
536            .await
537            .expect("Should succeed");
538
539            assert_eq!(
540                provider.calls.load(std::sync::atomic::Ordering::Relaxed),
541                3,
542                "Should succeed on third try"
543            );
544            assert_eq!(result.num_retries, 2);
545            assert_eq!(result.was_rate_limited, true);
546            assert_eq!(result.provider, "test");
547            assert_eq!(result.model, "test-model");
548            super::test_response(chunk_rx).await;
549        }
550
551        #[tokio::test(start_paused = true)]
552        async fn max_retries() {
553            let provider = Arc::new(TestProvider {
554                fail: Some(crate::testing::TestFailure::Transient),
555                ..Default::default()
556            });
557            let response = test_request(vec![ModelLookupChoice {
558                model: "test-model".to_string(),
559                provider: provider.clone(),
560                api_key: None,
561            }])
562            .await
563            .expect_err("Should have failed");
564
565            assert_eq!(
566                provider.calls.load(std::sync::atomic::Ordering::Relaxed),
567                4,
568                "Should have tried 4 times"
569            );
570            assert_eq!(response.num_retries, 3);
571            assert_eq!(response.was_rate_limited, false);
572        }
573    }
574
575    mod multiple_choices {
576        use std::sync::Arc;
577
578        use super::test_request;
579        use crate::{
580            provider_lookup::ModelLookupChoice,
581            testing::{TestFailure, TestProvider},
582        };
583
584        #[tokio::test(start_paused = true)]
585        async fn success() {
586            let (result, chunk_rx) = test_request(vec![
587                ModelLookupChoice {
588                    model: "test-model".to_string(),
589                    provider: TestProvider::default().into(),
590                    api_key: None,
591                },
592                ModelLookupChoice {
593                    model: "test-model-2".to_string(),
594                    provider: TestProvider::default().into(),
595                    api_key: None,
596                },
597            ])
598            .await
599            .expect("Failed");
600
601            assert_eq!(result.num_retries, 0);
602            assert_eq!(result.was_rate_limited, false);
603            assert_eq!(result.provider, "test");
604            assert_eq!(result.model, "test-model");
605            super::test_response(chunk_rx).await;
606        }
607
608        #[tokio::test(start_paused = true)]
609        async fn transient_failures() {
610            let (result, chunk_rx) = test_request(vec![
611                ModelLookupChoice {
612                    model: "test-model".to_string(),
613                    provider: TestProvider {
614                        fail: Some(TestFailure::Transient),
615                        ..Default::default()
616                    }
617                    .into(),
618                    api_key: None,
619                },
620                ModelLookupChoice {
621                    model: "test-model-2".to_string(),
622                    provider: TestProvider {
623                        fail: Some(TestFailure::Transient),
624                        ..Default::default()
625                    }
626                    .into(),
627                    api_key: None,
628                },
629                ModelLookupChoice {
630                    model: "test-model-3".to_string(),
631                    provider: TestProvider::default().into(),
632                    api_key: None,
633                },
634            ])
635            .await
636            .expect("Failed");
637
638            assert_eq!(result.num_retries, 2);
639            assert_eq!(result.was_rate_limited, false);
640            assert_eq!(result.provider, "test");
641            assert_eq!(result.model, "test-model-3");
642            super::test_response(chunk_rx).await;
643        }
644
645        #[tokio::test(start_paused = true)]
646        async fn rate_limit() {
647            let (result, chunk_rx) = test_request(vec![
648                ModelLookupChoice {
649                    model: "test-model".to_string(),
650                    provider: TestProvider {
651                        fail: Some(TestFailure::RateLimit),
652                        ..Default::default()
653                    }
654                    .into(),
655                    api_key: None,
656                },
657                ModelLookupChoice {
658                    model: "test-model-2".to_string(),
659                    provider: TestProvider::default().into(),
660                    api_key: None,
661                },
662            ])
663            .await
664            .expect("Failed");
665
666            assert_eq!(result.num_retries, 1);
667            assert_eq!(result.was_rate_limited, true);
668            assert_eq!(result.provider, "test");
669            assert_eq!(result.model, "test-model-2");
670            super::test_response(chunk_rx).await;
671        }
672
673        #[tokio::test(start_paused = true)]
674        async fn all_failed_every_time() {
675            let response = test_request(vec![
676                ModelLookupChoice {
677                    model: "test-model".to_string(),
678                    provider: TestProvider {
679                        fail: Some(TestFailure::BadRequest),
680                        ..Default::default()
681                    }
682                    .into(),
683                    api_key: None,
684                },
685                ModelLookupChoice {
686                    model: "test-model-2".to_string(),
687                    provider: TestProvider {
688                        fail: Some(TestFailure::RateLimit),
689                        ..Default::default()
690                    }
691                    .into(),
692                    api_key: None,
693                },
694                ModelLookupChoice {
695                    model: "test-model-3".to_string(),
696                    provider: TestProvider {
697                        fail: Some(TestFailure::Transient),
698                        ..Default::default()
699                    }
700                    .into(),
701                    api_key: None,
702                },
703            ])
704            .await
705            .expect_err("Should have failed");
706
707            assert_eq!(response.num_retries, 3);
708            assert_eq!(response.was_rate_limited, true);
709        }
710
711        #[tokio::test(start_paused = true)]
712        async fn all_failed_once() {
713            let p1 = Arc::new(TestProvider {
714                fail: Some(TestFailure::BadRequest),
715                fail_times: 1,
716                ..Default::default()
717            });
718            let p2 = Arc::new(TestProvider {
719                fail: Some(TestFailure::RateLimit),
720                ..Default::default()
721            });
722            let p3 = Arc::new(TestProvider {
723                fail: Some(TestFailure::Transient),
724                ..Default::default()
725            });
726
727            let (result, _) = test_request(vec![
728                ModelLookupChoice {
729                    model: "test-model".to_string(),
730                    provider: p1.clone(),
731                    api_key: None,
732                },
733                ModelLookupChoice {
734                    model: "test-model-2".to_string(),
735                    provider: p2.clone(),
736                    api_key: None,
737                },
738                ModelLookupChoice {
739                    model: "test-model-3".to_string(),
740                    provider: p3.clone(),
741                    api_key: None,
742                },
743            ])
744            .await
745            .expect("Should have succeeded");
746
747            assert_eq!(result.num_retries, 3);
748            assert_eq!(result.was_rate_limited, true);
749            assert_eq!(result.provider, "test");
750            // Should have wrapped around to the first one again.
751            assert_eq!(result.model, "test-model");
752            assert_eq!(p1.calls.load(std::sync::atomic::Ordering::Relaxed), 2);
753            assert_eq!(p2.calls.load(std::sync::atomic::Ordering::Relaxed), 1);
754            assert_eq!(p3.calls.load(std::sync::atomic::Ordering::Relaxed), 1);
755        }
756    }
757}