Skip to main content

llmix_rs/
resilience.rs

1use crate::error::{CircuitOpenError, KillSwitchActiveError};
2use crate::{LlmixError, LlmixResult};
3use fs2::FileExt;
4use sha2::{Digest, Sha256};
5use std::collections::HashMap;
6use std::fs::{self, File};
7use std::future::Future;
8use std::io;
9use std::path::{Path, PathBuf};
10use std::sync::{Arc, Condvar, Mutex};
11use std::time::{Duration, Instant, SystemTime};
12use tokio::sync::Notify;
13
14const DEFAULT_FAILURE_THRESHOLD: u32 = 3;
15const DEFAULT_COOLDOWN: Duration = Duration::from_secs(30);
16const DEFAULT_PERMITTED_HALF_OPEN_CALLS: u32 = 10;
17const DEFAULT_BASE_DELAY_MS: u64 = 1_000;
18const DEFAULT_MAX_DELAY_MS: u64 = 30_000;
19const DEFAULT_JITTER_MS: u64 = 1_000;
20const DEFAULT_MAX_RETRY_AFTER_MS: u64 = 60_000;
21const MAX_COOLDOWN: Duration = Duration::from_secs(300);
22const KILLSWITCH_FILENAME: &str = "killswitch";
23const STATE_SUBDIR: &str = "llmix";
24
25pub fn is_retryable(status_code: u16) -> bool {
26    status_code == 408 || status_code == 429 || (500..=599).contains(&status_code)
27}
28
29pub fn resolve_state_dir() -> PathBuf {
30    if let Ok(value) = std::env::var("LLMIX_STATE_DIR") {
31        return PathBuf::from(value);
32    }
33    if let Ok(xdg) = std::env::var("XDG_STATE_HOME") {
34        return PathBuf::from(xdg).join(STATE_SUBDIR);
35    }
36    let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_owned());
37    PathBuf::from(home).join(".local/state").join(STATE_SUBDIR)
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum CircuitState {
42    Closed,
43    Open,
44    HalfOpen,
45}
46
47#[derive(Debug)]
48struct CircuitInner {
49    state: CircuitState,
50    consecutive_failures: u32,
51    opened_at: Option<Instant>,
52    cooldown: Duration,
53    half_open_active: u32,
54    half_open_successes: u32,
55    half_open_failures: u32,
56}
57
58#[derive(Debug)]
59pub struct CircuitBreaker {
60    provider: String,
61    base_url: String,
62    failure_threshold: u32,
63    permitted_half_open_calls: u32,
64    base_cooldown: Duration,
65    inner: Mutex<CircuitInner>,
66}
67
68impl CircuitBreaker {
69    pub fn new(provider: impl Into<String>, base_url: impl Into<String>) -> Self {
70        Self::with_options(
71            provider,
72            base_url,
73            DEFAULT_FAILURE_THRESHOLD,
74            DEFAULT_COOLDOWN,
75            DEFAULT_PERMITTED_HALF_OPEN_CALLS,
76        )
77    }
78
79    pub fn with_options(
80        provider: impl Into<String>,
81        base_url: impl Into<String>,
82        failure_threshold: u32,
83        cooldown: Duration,
84        permitted_half_open_calls: u32,
85    ) -> Self {
86        Self {
87            provider: provider.into(),
88            base_url: base_url.into(),
89            failure_threshold,
90            permitted_half_open_calls: permitted_half_open_calls.max(1),
91            base_cooldown: cooldown,
92            inner: Mutex::new(CircuitInner {
93                state: CircuitState::Closed,
94                consecutive_failures: 0,
95                opened_at: None,
96                cooldown,
97                half_open_active: 0,
98                half_open_successes: 0,
99                half_open_failures: 0,
100            }),
101        }
102    }
103
104    pub fn state(&self) -> CircuitState {
105        let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
106        transition_open_to_half_open(&mut inner);
107        inner.state
108    }
109
110    pub fn cooldown(&self) -> Duration {
111        self.inner
112            .lock()
113            .unwrap_or_else(|e| e.into_inner())
114            .cooldown
115    }
116
117    pub fn check(&self) -> Result<(), CircuitOpenError> {
118        let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
119        transition_open_to_half_open(&mut inner);
120
121        match inner.state {
122            CircuitState::Closed => Ok(()),
123            CircuitState::HalfOpen => {
124                if inner.half_open_active >= self.permitted_half_open_calls {
125                    Err(CircuitOpenError {
126                        provider: self.provider.clone(),
127                        base_url: self.base_url.clone(),
128                    })
129                } else {
130                    inner.half_open_active += 1;
131                    Ok(())
132                }
133            }
134            CircuitState::Open => Err(CircuitOpenError {
135                provider: self.provider.clone(),
136                base_url: self.base_url.clone(),
137            }),
138        }
139    }
140
141    pub fn on_success(&self) {
142        let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
143        match inner.state {
144            CircuitState::HalfOpen => {
145                inner.half_open_successes += 1;
146                evaluate_half_open(
147                    &mut inner,
148                    self.base_cooldown,
149                    self.permitted_half_open_calls,
150                );
151            }
152            CircuitState::Open => {}
153            CircuitState::Closed => {
154                inner.consecutive_failures = 0;
155                inner.opened_at = None;
156            }
157        }
158    }
159
160    pub fn on_failure(&self, status_code: Option<u16>, network_error: bool) {
161        let retryable = network_error || status_code.is_some_and(is_retryable);
162        let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
163
164        if inner.state == CircuitState::HalfOpen {
165            if retryable {
166                inner.half_open_failures += 1;
167            } else {
168                inner.half_open_successes += 1;
169            }
170            evaluate_half_open(
171                &mut inner,
172                self.base_cooldown,
173                self.permitted_half_open_calls,
174            );
175            return;
176        }
177
178        if matches!(status_code, Some(401 | 403)) {
179            inner.consecutive_failures = 0;
180            return;
181        }
182
183        if !retryable {
184            inner.consecutive_failures = 0;
185            return;
186        }
187
188        inner.consecutive_failures += 1;
189        if inner.consecutive_failures >= self.failure_threshold {
190            inner.state = CircuitState::Open;
191            inner.opened_at = Some(Instant::now());
192        }
193    }
194
195    pub fn cancel_probe(&self) {
196        let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
197        if inner.state != CircuitState::HalfOpen {
198            return;
199        }
200        let total_finalized = inner.half_open_successes + inner.half_open_failures;
201        if total_finalized >= inner.half_open_active {
202            return;
203        }
204        inner.half_open_failures += 1;
205        evaluate_half_open(
206            &mut inner,
207            self.base_cooldown,
208            self.permitted_half_open_calls,
209        );
210    }
211
212    pub fn reset(&self) {
213        let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
214        inner.state = CircuitState::Closed;
215        inner.consecutive_failures = 0;
216        inner.opened_at = None;
217        inner.cooldown = self.base_cooldown;
218        inner.half_open_active = 0;
219        inner.half_open_successes = 0;
220        inner.half_open_failures = 0;
221    }
222}
223
224fn transition_open_to_half_open(inner: &mut CircuitInner) {
225    if inner.state != CircuitState::Open {
226        return;
227    }
228    let Some(opened_at) = inner.opened_at else {
229        return;
230    };
231    if opened_at.elapsed() >= inner.cooldown {
232        inner.state = CircuitState::HalfOpen;
233        inner.opened_at = None;
234        inner.half_open_active = 0;
235        inner.half_open_successes = 0;
236        inner.half_open_failures = 0;
237    }
238}
239
240fn evaluate_half_open(
241    inner: &mut CircuitInner,
242    base_cooldown: Duration,
243    permitted_half_open_calls: u32,
244) {
245    let total_completed = inner.half_open_successes + inner.half_open_failures;
246    if total_completed < permitted_half_open_calls {
247        return;
248    }
249
250    if inner.half_open_successes > inner.half_open_failures {
251        inner.state = CircuitState::Closed;
252        inner.consecutive_failures = 0;
253        inner.opened_at = None;
254        inner.cooldown = base_cooldown;
255    } else {
256        inner.state = CircuitState::Open;
257        inner.opened_at = Some(Instant::now());
258        inner.cooldown = (inner.cooldown * 2).min(MAX_COOLDOWN);
259    }
260}
261
262#[derive(Debug)]
263pub struct KillSwitch {
264    path: PathBuf,
265}
266
267impl KillSwitch {
268    pub fn new() -> io::Result<Self> {
269        Self::with_state_dir(resolve_state_dir())
270    }
271
272    pub fn with_state_dir(path: impl AsRef<Path>) -> io::Result<Self> {
273        Ok(Self {
274            path: path.as_ref().join(KILLSWITCH_FILENAME),
275        })
276    }
277
278    pub fn path(&self) -> &Path {
279        &self.path
280    }
281
282    pub fn check(&self) -> LlmixResult<()> {
283        match fs::metadata(&self.path) {
284            Ok(_) => Err(KillSwitchActiveError {
285                path: self.path.display().to_string(),
286            }
287            .into()),
288            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(()),
289            Err(err) => Err(err.into()),
290        }
291    }
292
293    pub fn is_active(&self) -> LlmixResult<bool> {
294        match fs::metadata(&self.path) {
295            Ok(_) => Ok(true),
296            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(false),
297            Err(err) => Err(err.into()),
298        }
299    }
300
301    pub async fn check_async(&self) -> LlmixResult<()> {
302        match tokio::fs::metadata(&self.path).await {
303            Ok(_) => Err(KillSwitchActiveError {
304                path: self.path.display().to_string(),
305            }
306            .into()),
307            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(()),
308            Err(err) => Err(err.into()),
309        }
310    }
311
312    pub async fn is_active_async(&self) -> LlmixResult<bool> {
313        match tokio::fs::metadata(&self.path).await {
314            Ok(_) => Ok(true),
315            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(false),
316            Err(err) => Err(err.into()),
317        }
318    }
319}
320
321pub type SharedCallResult<T, E> = Result<Arc<T>, Arc<E>>;
322
323#[derive(Debug)]
324struct FlightEntry<T, E> {
325    notify: Notify,
326    result: Mutex<Option<SharedCallResult<T, E>>>,
327}
328
329impl<T, E> FlightEntry<T, E> {
330    fn new() -> Self {
331        Self {
332            notify: Notify::new(),
333            result: Mutex::new(None),
334        }
335    }
336}
337
338#[derive(Debug, Default)]
339pub struct Singleflight<T, E> {
340    in_flight: Mutex<HashMap<String, Arc<FlightEntry<T, E>>>>,
341}
342
343impl<T, E> Singleflight<T, E>
344where
345    T: Send + Sync + 'static,
346    E: Send + Sync + 'static,
347{
348    pub fn new() -> Self {
349        Self {
350            in_flight: Mutex::new(HashMap::new()),
351        }
352    }
353
354    pub fn make_key(data: &str) -> String {
355        format!("{:x}", Sha256::digest(data.as_bytes()))
356    }
357
358    pub async fn do_call<F, Fut>(&self, key: impl Into<String>, func: F) -> SharedCallResult<T, E>
359    where
360        F: FnOnce() -> Fut,
361        Fut: Future<Output = Result<T, E>> + Send,
362    {
363        let key = key.into();
364        let (entry, is_leader) = {
365            let mut in_flight = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
366            if let Some(existing) = in_flight.get(&key) {
367                (existing.clone(), false)
368            } else {
369                let entry = Arc::new(FlightEntry::new());
370                in_flight.insert(key.clone(), entry.clone());
371                (entry, true)
372            }
373        };
374
375        if is_leader {
376            let result = func().await.map(Arc::new).map_err(Arc::new);
377            {
378                let mut slot = entry
379                    .result
380                    .lock()
381                    .expect("singleflight result mutex poisoned");
382                *slot = Some(result.clone());
383            }
384            self.in_flight
385                .lock()
386                .unwrap_or_else(|e| e.into_inner())
387                .remove(&key);
388            entry.notify.notify_waiters();
389            return result;
390        }
391
392        loop {
393            let notified = entry.notify.notified();
394            if let Some(result) = entry
395                .result
396                .lock()
397                .expect("singleflight result mutex poisoned")
398                .clone()
399            {
400                return result;
401            }
402            notified.await;
403        }
404    }
405
406    pub fn in_flight_count(&self) -> usize {
407        self.in_flight
408            .lock()
409            .unwrap_or_else(|e| e.into_inner())
410            .len()
411    }
412}
413
414pub fn calculate_delay(attempt: u32, base_ms: u64, max_delay_ms: u64, jitter_ms: u64) -> u64 {
415    let factor = 1_u64.checked_shl(attempt.min(63)).unwrap_or(u64::MAX);
416    let exponential = base_ms.saturating_mul(factor).min(max_delay_ms);
417    let jitter = if jitter_ms == 0 {
418        0
419    } else {
420        let nanos = SystemTime::now()
421            .duration_since(SystemTime::UNIX_EPOCH)
422            .map(|duration| duration.subsec_nanos() as u64)
423            .unwrap_or(0);
424        nanos % (jitter_ms + 1)
425    };
426    exponential.saturating_add(jitter)
427}
428
429pub fn parse_retry_after(header_value: Option<&str>, max_ms: u64) -> Option<u64> {
430    let value = header_value?.trim();
431
432    if let Ok(seconds) = value.parse::<u64>() {
433        return Some(seconds.saturating_mul(1_000).min(max_ms));
434    }
435
436    let parsed = httpdate::parse_http_date(value).ok()?;
437    let delta = parsed.duration_since(SystemTime::now()).ok()?;
438    Some(delta.as_millis().min(max_ms as u128) as u64)
439}
440
441#[derive(Debug, Clone, Copy)]
442pub struct RetryPolicyOptions {
443    pub max_retries: u32,
444    pub base_ms: u64,
445    pub max_delay_ms: u64,
446    pub jitter_ms: u64,
447    pub max_retry_after_ms: u64,
448}
449
450impl Default for RetryPolicyOptions {
451    fn default() -> Self {
452        Self {
453            max_retries: 3,
454            base_ms: DEFAULT_BASE_DELAY_MS,
455            max_delay_ms: DEFAULT_MAX_DELAY_MS,
456            jitter_ms: DEFAULT_JITTER_MS,
457            max_retry_after_ms: DEFAULT_MAX_RETRY_AFTER_MS,
458        }
459    }
460}
461
462#[derive(Debug, Clone, Copy)]
463pub struct RetryPolicy {
464    options: RetryPolicyOptions,
465}
466
467impl RetryPolicy {
468    pub fn new(options: RetryPolicyOptions) -> LlmixResult<Self> {
469        if options.max_delay_ms < options.base_ms {
470            return Err(LlmixError::InvalidRetryPolicyConfig(
471                "max_delay_ms must be >= base_ms".to_owned(),
472            ));
473        }
474        Ok(Self { options })
475    }
476
477    pub fn with_defaults() -> Self {
478        Self::new(RetryPolicyOptions::default())
479            .expect("default retry policy configuration must be valid")
480    }
481
482    pub fn get_delay_ms(&self, attempt: u32, retry_after_header: Option<&str>) -> u64 {
483        parse_retry_after(retry_after_header, self.options.max_retry_after_ms).unwrap_or_else(
484            || {
485                calculate_delay(
486                    attempt,
487                    self.options.base_ms,
488                    self.options.max_delay_ms,
489                    self.options.jitter_ms,
490                )
491            },
492        )
493    }
494
495    pub async fn execute<T, E, F, Fut>(&self, mut func: F) -> Result<T, E>
496    where
497        F: FnMut() -> Fut,
498        Fut: Future<Output = Result<T, E>>,
499    {
500        self.execute_with_hooks(
501            &mut func,
502            None::<fn(&E) -> bool>,
503            None::<fn(&E) -> Option<String>>,
504        )
505        .await
506    }
507
508    pub async fn execute_with_hooks<T, E, F, Fut, P, H>(
509        &self,
510        mut func: F,
511        is_retryable_fn: Option<P>,
512        retry_after_header: Option<H>,
513    ) -> Result<T, E>
514    where
515        F: FnMut() -> Fut,
516        Fut: Future<Output = Result<T, E>>,
517        P: Fn(&E) -> bool,
518        H: Fn(&E) -> Option<String>,
519    {
520        for attempt in 0..=self.options.max_retries {
521            match func().await {
522                Ok(value) => return Ok(value),
523                Err(err) => {
524                    if attempt >= self.options.max_retries {
525                        return Err(err);
526                    }
527                    if let Some(predicate) = &is_retryable_fn {
528                        if !predicate(&err) {
529                            return Err(err);
530                        }
531                    }
532                    let retry_after = retry_after_header
533                        .as_ref()
534                        .and_then(|extractor| extractor(&err));
535                    let delay = self.get_delay_ms(attempt, retry_after.as_deref());
536                    tokio::time::sleep(Duration::from_millis(delay)).await;
537                }
538            }
539        }
540
541        unreachable!("retry loop always returns or errors")
542    }
543}
544
545#[derive(Debug)]
546pub struct FileLock {
547    enabled: bool,
548    lock_path: Option<PathBuf>,
549    state: Mutex<FileLockState>,
550    available: Condvar,
551}
552
553#[derive(Debug)]
554struct FileLockState {
555    held: bool,
556    file: Option<File>,
557}
558
559#[derive(Debug)]
560pub struct FileLockGuard<'a> {
561    file_lock: &'a FileLock,
562    released: bool,
563}
564
565impl FileLock {
566    pub fn new() -> LlmixResult<Self> {
567        Self::with_path(resolve_state_dir().join("llmix.lock"))
568    }
569
570    pub fn with_path(path: impl Into<PathBuf>) -> LlmixResult<Self> {
571        let concurrency = std::env::var("LLM_GLOBAL_CONCURRENCY").ok();
572        let enabled = concurrency
573            .as_ref()
574            .is_some_and(|value| !value.trim().is_empty());
575
576        if let Some(value) = concurrency
577            .as_deref()
578            .filter(|value| !value.trim().is_empty())
579        {
580            if value
581                .trim()
582                .parse::<u32>()
583                .ok()
584                .filter(|parsed| *parsed > 0)
585                .is_none()
586            {
587                return Err(LlmixError::InvalidFileLockConfig(format!(
588                    "LLM_GLOBAL_CONCURRENCY must be a positive integer, got \"{value}\""
589                )));
590            }
591        }
592
593        Ok(Self {
594            enabled,
595            lock_path: enabled.then_some(path.into()),
596            state: Mutex::new(FileLockState {
597                held: false,
598                file: None,
599            }),
600            available: Condvar::new(),
601        })
602    }
603
604    pub fn enabled(&self) -> bool {
605        self.enabled
606    }
607
608    pub fn lock_path(&self) -> Option<&Path> {
609        self.lock_path.as_deref()
610    }
611
612    pub fn acquire(&self) -> LlmixResult<()> {
613        if !self.enabled {
614            return Ok(());
615        }
616
617        let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
618        while state.held {
619            state = self
620                .available
621                .wait(state)
622                .unwrap_or_else(|e| e.into_inner());
623        }
624        state.held = true;
625        drop(state);
626
627        let file_result = self.open_locked_file();
628        let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
629        match file_result {
630            Ok(file) => {
631                state.file = Some(file);
632                Ok(())
633            }
634            Err(error) => {
635                state.held = false;
636                self.available.notify_one();
637                Err(error)
638            }
639        }
640    }
641
642    fn open_locked_file(&self) -> LlmixResult<File> {
643        let path = self
644            .lock_path
645            .as_ref()
646            .expect("enabled file lock must have a path");
647        if let Some(parent) = path.parent() {
648            fs::create_dir_all(parent)?;
649        }
650        let file = File::options()
651            .create(true)
652            .truncate(false)
653            .read(true)
654            .write(true)
655            .open(path)?;
656        file.lock_exclusive()?;
657        Ok(file)
658    }
659
660    pub fn acquire_guard(&self) -> LlmixResult<FileLockGuard<'_>> {
661        self.acquire()?;
662        Ok(FileLockGuard {
663            file_lock: self,
664            released: false,
665        })
666    }
667
668    pub fn release(&self) -> LlmixResult<()> {
669        if !self.enabled {
670            return Ok(());
671        }
672
673        let maybe_file = {
674            self.state
675                .lock()
676                .unwrap_or_else(|e| e.into_inner())
677                .file
678                .take()
679        };
680        let result = if let Some(file) = maybe_file {
681            file.unlock().map_err(LlmixError::from)
682        } else {
683            Ok(())
684        };
685
686        let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
687        state.held = false;
688        self.available.notify_one();
689        result
690    }
691}
692
693impl FileLockGuard<'_> {
694    pub fn release(mut self) -> LlmixResult<()> {
695        if self.released {
696            return Ok(());
697        }
698        self.released = true;
699        self.file_lock.release()
700    }
701}
702
703impl Drop for FileLockGuard<'_> {
704    fn drop(&mut self) {
705        if !self.released {
706            let _ = self.file_lock.release();
707            self.released = true;
708        }
709    }
710}
711
712impl Drop for FileLock {
713    fn drop(&mut self) {
714        let state = self.state.get_mut().unwrap_or_else(|e| e.into_inner());
715        if let Some(file) = state.file.take() {
716            let _ = file.unlock();
717        }
718        state.held = false;
719    }
720}