Skip to main content

synaptic_models/
retry.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, SynapseError};
6
7#[derive(Debug, Clone)]
8pub struct RetryPolicy {
9    pub max_attempts: usize,
10    pub base_delay: Duration,
11}
12
13impl Default for RetryPolicy {
14    fn default() -> Self {
15        Self {
16            max_attempts: 3,
17            base_delay: Duration::from_millis(500),
18        }
19    }
20}
21
22pub struct RetryChatModel {
23    inner: Arc<dyn ChatModel>,
24    policy: RetryPolicy,
25}
26
27impl RetryChatModel {
28    pub fn new(inner: Arc<dyn ChatModel>, policy: RetryPolicy) -> Self {
29        Self { inner, policy }
30    }
31}
32
33fn is_retryable(err: &SynapseError) -> bool {
34    matches!(err, SynapseError::RateLimit(_) | SynapseError::Timeout(_))
35}
36
37#[async_trait]
38impl ChatModel for RetryChatModel {
39    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapseError> {
40        let mut last_error = None;
41        for attempt in 0..self.policy.max_attempts {
42            match self.inner.chat(request.clone()).await {
43                Ok(resp) => return Ok(resp),
44                Err(e) if is_retryable(&e) && attempt + 1 < self.policy.max_attempts => {
45                    let delay = self.policy.base_delay * 2u32.saturating_pow(attempt as u32);
46                    tokio::time::sleep(delay).await;
47                    last_error = Some(e);
48                }
49                Err(e) => return Err(e),
50            }
51        }
52        Err(last_error.unwrap_or_else(|| SynapseError::Model("retry exhausted".to_string())))
53    }
54
55    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
56        let inner = self.inner.clone();
57        let policy = self.policy.clone();
58
59        Box::pin(async_stream::stream! {
60            let mut last_error = None;
61            for attempt in 0..policy.max_attempts {
62                let mut stream = inner.stream_chat(request.clone());
63
64                use futures::StreamExt;
65                let mut chunks = Vec::new();
66                let mut had_error = false;
67
68                while let Some(result) = stream.next().await {
69                    match result {
70                        Ok(chunk) => chunks.push(chunk),
71                        Err(e) if is_retryable(&e) && attempt + 1 < policy.max_attempts => {
72                            let delay = policy.base_delay * 2u32.saturating_pow(attempt as u32);
73                            tokio::time::sleep(delay).await;
74                            last_error = Some(e);
75                            had_error = true;
76                            break;
77                        }
78                        Err(e) => {
79                            yield Err(e);
80                            return;
81                        }
82                    }
83                }
84
85                if !had_error {
86                    for chunk in chunks {
87                        yield Ok(chunk);
88                    }
89                    return;
90                }
91            }
92            if let Some(e) = last_error {
93                yield Err(e);
94            }
95        })
96    }
97}