1use crate::error::{CircuitOpenError, KillSwitchActiveError};
2use crate::{LlmixError, LlmixResult};
3use fs2::FileExt;
4use sha2::{Digest, Sha256};
5use std::collections::HashMap;
6use std::fs::{self, File};
7use std::future::Future;
8use std::io;
9use std::path::{Path, PathBuf};
10use std::sync::{Arc, Condvar, Mutex};
11use std::time::{Duration, Instant, SystemTime};
12use tokio::sync::Notify;
13
14const DEFAULT_FAILURE_THRESHOLD: u32 = 3;
15const DEFAULT_COOLDOWN: Duration = Duration::from_secs(30);
16const DEFAULT_PERMITTED_HALF_OPEN_CALLS: u32 = 10;
17const DEFAULT_BASE_DELAY_MS: u64 = 1_000;
18const DEFAULT_MAX_DELAY_MS: u64 = 30_000;
19const DEFAULT_JITTER_MS: u64 = 1_000;
20const DEFAULT_MAX_RETRY_AFTER_MS: u64 = 60_000;
21const MAX_COOLDOWN: Duration = Duration::from_secs(300);
22const KILLSWITCH_FILENAME: &str = "killswitch";
23const STATE_SUBDIR: &str = "llmix";
24
25pub fn is_retryable(status_code: u16) -> bool {
26 status_code == 408 || status_code == 429 || (500..=599).contains(&status_code)
27}
28
29pub fn resolve_state_dir() -> PathBuf {
30 if let Ok(value) = std::env::var("LLMIX_STATE_DIR") {
31 return PathBuf::from(value);
32 }
33 if let Ok(xdg) = std::env::var("XDG_STATE_HOME") {
34 return PathBuf::from(xdg).join(STATE_SUBDIR);
35 }
36 let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_owned());
37 PathBuf::from(home).join(".local/state").join(STATE_SUBDIR)
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum CircuitState {
42 Closed,
43 Open,
44 HalfOpen,
45}
46
47#[derive(Debug)]
48struct CircuitInner {
49 state: CircuitState,
50 consecutive_failures: u32,
51 opened_at: Option<Instant>,
52 cooldown: Duration,
53 half_open_active: u32,
54 half_open_successes: u32,
55 half_open_failures: u32,
56}
57
58#[derive(Debug)]
59pub struct CircuitBreaker {
60 provider: String,
61 base_url: String,
62 failure_threshold: u32,
63 permitted_half_open_calls: u32,
64 base_cooldown: Duration,
65 inner: Mutex<CircuitInner>,
66}
67
68impl CircuitBreaker {
69 pub fn new(provider: impl Into<String>, base_url: impl Into<String>) -> Self {
70 Self::with_options(
71 provider,
72 base_url,
73 DEFAULT_FAILURE_THRESHOLD,
74 DEFAULT_COOLDOWN,
75 DEFAULT_PERMITTED_HALF_OPEN_CALLS,
76 )
77 }
78
79 pub fn with_options(
80 provider: impl Into<String>,
81 base_url: impl Into<String>,
82 failure_threshold: u32,
83 cooldown: Duration,
84 permitted_half_open_calls: u32,
85 ) -> Self {
86 Self {
87 provider: provider.into(),
88 base_url: base_url.into(),
89 failure_threshold,
90 permitted_half_open_calls: permitted_half_open_calls.max(1),
91 base_cooldown: cooldown,
92 inner: Mutex::new(CircuitInner {
93 state: CircuitState::Closed,
94 consecutive_failures: 0,
95 opened_at: None,
96 cooldown,
97 half_open_active: 0,
98 half_open_successes: 0,
99 half_open_failures: 0,
100 }),
101 }
102 }
103
104 pub fn state(&self) -> CircuitState {
105 let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
106 transition_open_to_half_open(&mut inner);
107 inner.state
108 }
109
110 pub fn cooldown(&self) -> Duration {
111 self.inner
112 .lock()
113 .unwrap_or_else(|e| e.into_inner())
114 .cooldown
115 }
116
117 pub fn check(&self) -> Result<(), CircuitOpenError> {
118 let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
119 transition_open_to_half_open(&mut inner);
120
121 match inner.state {
122 CircuitState::Closed => Ok(()),
123 CircuitState::HalfOpen => {
124 if inner.half_open_active >= self.permitted_half_open_calls {
125 Err(CircuitOpenError {
126 provider: self.provider.clone(),
127 base_url: self.base_url.clone(),
128 })
129 } else {
130 inner.half_open_active += 1;
131 Ok(())
132 }
133 }
134 CircuitState::Open => Err(CircuitOpenError {
135 provider: self.provider.clone(),
136 base_url: self.base_url.clone(),
137 }),
138 }
139 }
140
141 pub fn on_success(&self) {
142 let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
143 match inner.state {
144 CircuitState::HalfOpen => {
145 inner.half_open_successes += 1;
146 evaluate_half_open(
147 &mut inner,
148 self.base_cooldown,
149 self.permitted_half_open_calls,
150 );
151 }
152 CircuitState::Open => {}
153 CircuitState::Closed => {
154 inner.consecutive_failures = 0;
155 inner.opened_at = None;
156 }
157 }
158 }
159
160 pub fn on_failure(&self, status_code: Option<u16>, network_error: bool) {
161 let retryable = network_error || status_code.is_some_and(is_retryable);
162 let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
163
164 if inner.state == CircuitState::HalfOpen {
165 if retryable {
166 inner.half_open_failures += 1;
167 } else {
168 inner.half_open_successes += 1;
169 }
170 evaluate_half_open(
171 &mut inner,
172 self.base_cooldown,
173 self.permitted_half_open_calls,
174 );
175 return;
176 }
177
178 if matches!(status_code, Some(401 | 403)) {
179 inner.consecutive_failures = 0;
180 return;
181 }
182
183 if !retryable {
184 inner.consecutive_failures = 0;
185 return;
186 }
187
188 inner.consecutive_failures += 1;
189 if inner.consecutive_failures >= self.failure_threshold {
190 inner.state = CircuitState::Open;
191 inner.opened_at = Some(Instant::now());
192 }
193 }
194
195 pub fn cancel_probe(&self) {
196 let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
197 if inner.state != CircuitState::HalfOpen {
198 return;
199 }
200 let total_finalized = inner.half_open_successes + inner.half_open_failures;
201 if total_finalized >= inner.half_open_active {
202 return;
203 }
204 inner.half_open_failures += 1;
205 evaluate_half_open(
206 &mut inner,
207 self.base_cooldown,
208 self.permitted_half_open_calls,
209 );
210 }
211
212 pub fn reset(&self) {
213 let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
214 inner.state = CircuitState::Closed;
215 inner.consecutive_failures = 0;
216 inner.opened_at = None;
217 inner.cooldown = self.base_cooldown;
218 inner.half_open_active = 0;
219 inner.half_open_successes = 0;
220 inner.half_open_failures = 0;
221 }
222}
223
224fn transition_open_to_half_open(inner: &mut CircuitInner) {
225 if inner.state != CircuitState::Open {
226 return;
227 }
228 let Some(opened_at) = inner.opened_at else {
229 return;
230 };
231 if opened_at.elapsed() >= inner.cooldown {
232 inner.state = CircuitState::HalfOpen;
233 inner.opened_at = None;
234 inner.half_open_active = 0;
235 inner.half_open_successes = 0;
236 inner.half_open_failures = 0;
237 }
238}
239
240fn evaluate_half_open(
241 inner: &mut CircuitInner,
242 base_cooldown: Duration,
243 permitted_half_open_calls: u32,
244) {
245 let total_completed = inner.half_open_successes + inner.half_open_failures;
246 if total_completed < permitted_half_open_calls {
247 return;
248 }
249
250 if inner.half_open_successes > inner.half_open_failures {
251 inner.state = CircuitState::Closed;
252 inner.consecutive_failures = 0;
253 inner.opened_at = None;
254 inner.cooldown = base_cooldown;
255 } else {
256 inner.state = CircuitState::Open;
257 inner.opened_at = Some(Instant::now());
258 inner.cooldown = (inner.cooldown * 2).min(MAX_COOLDOWN);
259 }
260}
261
262#[derive(Debug)]
263pub struct KillSwitch {
264 path: PathBuf,
265}
266
267impl KillSwitch {
268 pub fn new() -> io::Result<Self> {
269 Self::with_state_dir(resolve_state_dir())
270 }
271
272 pub fn with_state_dir(path: impl AsRef<Path>) -> io::Result<Self> {
273 Ok(Self {
274 path: path.as_ref().join(KILLSWITCH_FILENAME),
275 })
276 }
277
278 pub fn path(&self) -> &Path {
279 &self.path
280 }
281
282 pub fn check(&self) -> LlmixResult<()> {
283 match fs::metadata(&self.path) {
284 Ok(_) => Err(KillSwitchActiveError {
285 path: self.path.display().to_string(),
286 }
287 .into()),
288 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(()),
289 Err(err) => Err(err.into()),
290 }
291 }
292
293 pub fn is_active(&self) -> LlmixResult<bool> {
294 match fs::metadata(&self.path) {
295 Ok(_) => Ok(true),
296 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(false),
297 Err(err) => Err(err.into()),
298 }
299 }
300
301 pub async fn check_async(&self) -> LlmixResult<()> {
302 match tokio::fs::metadata(&self.path).await {
303 Ok(_) => Err(KillSwitchActiveError {
304 path: self.path.display().to_string(),
305 }
306 .into()),
307 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(()),
308 Err(err) => Err(err.into()),
309 }
310 }
311
312 pub async fn is_active_async(&self) -> LlmixResult<bool> {
313 match tokio::fs::metadata(&self.path).await {
314 Ok(_) => Ok(true),
315 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(false),
316 Err(err) => Err(err.into()),
317 }
318 }
319}
320
321pub type SharedCallResult<T, E> = Result<Arc<T>, Arc<E>>;
322
323#[derive(Debug)]
324struct FlightEntry<T, E> {
325 notify: Notify,
326 result: Mutex<Option<SharedCallResult<T, E>>>,
327}
328
329impl<T, E> FlightEntry<T, E> {
330 fn new() -> Self {
331 Self {
332 notify: Notify::new(),
333 result: Mutex::new(None),
334 }
335 }
336}
337
338#[derive(Debug, Default)]
339pub struct Singleflight<T, E> {
340 in_flight: Mutex<HashMap<String, Arc<FlightEntry<T, E>>>>,
341}
342
343impl<T, E> Singleflight<T, E>
344where
345 T: Send + Sync + 'static,
346 E: Send + Sync + 'static,
347{
348 pub fn new() -> Self {
349 Self {
350 in_flight: Mutex::new(HashMap::new()),
351 }
352 }
353
354 pub fn make_key(data: &str) -> String {
355 format!("{:x}", Sha256::digest(data.as_bytes()))
356 }
357
358 pub async fn do_call<F, Fut>(&self, key: impl Into<String>, func: F) -> SharedCallResult<T, E>
359 where
360 F: FnOnce() -> Fut,
361 Fut: Future<Output = Result<T, E>> + Send,
362 {
363 let key = key.into();
364 let (entry, is_leader) = {
365 let mut in_flight = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
366 if let Some(existing) = in_flight.get(&key) {
367 (existing.clone(), false)
368 } else {
369 let entry = Arc::new(FlightEntry::new());
370 in_flight.insert(key.clone(), entry.clone());
371 (entry, true)
372 }
373 };
374
375 if is_leader {
376 let result = func().await.map(Arc::new).map_err(Arc::new);
377 {
378 let mut slot = entry
379 .result
380 .lock()
381 .expect("singleflight result mutex poisoned");
382 *slot = Some(result.clone());
383 }
384 self.in_flight
385 .lock()
386 .unwrap_or_else(|e| e.into_inner())
387 .remove(&key);
388 entry.notify.notify_waiters();
389 return result;
390 }
391
392 loop {
393 let notified = entry.notify.notified();
394 if let Some(result) = entry
395 .result
396 .lock()
397 .expect("singleflight result mutex poisoned")
398 .clone()
399 {
400 return result;
401 }
402 notified.await;
403 }
404 }
405
406 pub fn in_flight_count(&self) -> usize {
407 self.in_flight
408 .lock()
409 .unwrap_or_else(|e| e.into_inner())
410 .len()
411 }
412}
413
414pub fn calculate_delay(attempt: u32, base_ms: u64, max_delay_ms: u64, jitter_ms: u64) -> u64 {
415 let factor = 1_u64.checked_shl(attempt.min(63)).unwrap_or(u64::MAX);
416 let exponential = base_ms.saturating_mul(factor).min(max_delay_ms);
417 let jitter = if jitter_ms == 0 {
418 0
419 } else {
420 let nanos = SystemTime::now()
421 .duration_since(SystemTime::UNIX_EPOCH)
422 .map(|duration| duration.subsec_nanos() as u64)
423 .unwrap_or(0);
424 nanos % (jitter_ms + 1)
425 };
426 exponential.saturating_add(jitter)
427}
428
429pub fn parse_retry_after(header_value: Option<&str>, max_ms: u64) -> Option<u64> {
430 let value = header_value?.trim();
431
432 if let Ok(seconds) = value.parse::<u64>() {
433 return Some(seconds.saturating_mul(1_000).min(max_ms));
434 }
435
436 let parsed = httpdate::parse_http_date(value).ok()?;
437 let delta = parsed.duration_since(SystemTime::now()).ok()?;
438 Some(delta.as_millis().min(max_ms as u128) as u64)
439}
440
441#[derive(Debug, Clone, Copy)]
442pub struct RetryPolicyOptions {
443 pub max_retries: u32,
444 pub base_ms: u64,
445 pub max_delay_ms: u64,
446 pub jitter_ms: u64,
447 pub max_retry_after_ms: u64,
448}
449
450impl Default for RetryPolicyOptions {
451 fn default() -> Self {
452 Self {
453 max_retries: 3,
454 base_ms: DEFAULT_BASE_DELAY_MS,
455 max_delay_ms: DEFAULT_MAX_DELAY_MS,
456 jitter_ms: DEFAULT_JITTER_MS,
457 max_retry_after_ms: DEFAULT_MAX_RETRY_AFTER_MS,
458 }
459 }
460}
461
462#[derive(Debug, Clone, Copy)]
463pub struct RetryPolicy {
464 options: RetryPolicyOptions,
465}
466
467impl RetryPolicy {
468 pub fn new(options: RetryPolicyOptions) -> LlmixResult<Self> {
469 if options.max_delay_ms < options.base_ms {
470 return Err(LlmixError::InvalidRetryPolicyConfig(
471 "max_delay_ms must be >= base_ms".to_owned(),
472 ));
473 }
474 Ok(Self { options })
475 }
476
477 pub fn with_defaults() -> Self {
478 Self::new(RetryPolicyOptions::default())
479 .expect("default retry policy configuration must be valid")
480 }
481
482 pub fn get_delay_ms(&self, attempt: u32, retry_after_header: Option<&str>) -> u64 {
483 parse_retry_after(retry_after_header, self.options.max_retry_after_ms).unwrap_or_else(
484 || {
485 calculate_delay(
486 attempt,
487 self.options.base_ms,
488 self.options.max_delay_ms,
489 self.options.jitter_ms,
490 )
491 },
492 )
493 }
494
495 pub async fn execute<T, E, F, Fut>(&self, mut func: F) -> Result<T, E>
496 where
497 F: FnMut() -> Fut,
498 Fut: Future<Output = Result<T, E>>,
499 {
500 self.execute_with_hooks(
501 &mut func,
502 None::<fn(&E) -> bool>,
503 None::<fn(&E) -> Option<String>>,
504 )
505 .await
506 }
507
508 pub async fn execute_with_hooks<T, E, F, Fut, P, H>(
509 &self,
510 mut func: F,
511 is_retryable_fn: Option<P>,
512 retry_after_header: Option<H>,
513 ) -> Result<T, E>
514 where
515 F: FnMut() -> Fut,
516 Fut: Future<Output = Result<T, E>>,
517 P: Fn(&E) -> bool,
518 H: Fn(&E) -> Option<String>,
519 {
520 for attempt in 0..=self.options.max_retries {
521 match func().await {
522 Ok(value) => return Ok(value),
523 Err(err) => {
524 if attempt >= self.options.max_retries {
525 return Err(err);
526 }
527 if let Some(predicate) = &is_retryable_fn {
528 if !predicate(&err) {
529 return Err(err);
530 }
531 }
532 let retry_after = retry_after_header
533 .as_ref()
534 .and_then(|extractor| extractor(&err));
535 let delay = self.get_delay_ms(attempt, retry_after.as_deref());
536 tokio::time::sleep(Duration::from_millis(delay)).await;
537 }
538 }
539 }
540
541 unreachable!("retry loop always returns or errors")
542 }
543}
544
545#[derive(Debug)]
546pub struct FileLock {
547 enabled: bool,
548 lock_path: Option<PathBuf>,
549 state: Mutex<FileLockState>,
550 available: Condvar,
551}
552
553#[derive(Debug)]
554struct FileLockState {
555 held: bool,
556 file: Option<File>,
557}
558
559#[derive(Debug)]
560pub struct FileLockGuard<'a> {
561 file_lock: &'a FileLock,
562 released: bool,
563}
564
565impl FileLock {
566 pub fn new() -> LlmixResult<Self> {
567 Self::with_path(resolve_state_dir().join("llmix.lock"))
568 }
569
570 pub fn with_path(path: impl Into<PathBuf>) -> LlmixResult<Self> {
571 let concurrency = std::env::var("LLM_GLOBAL_CONCURRENCY").ok();
572 let enabled = concurrency
573 .as_ref()
574 .is_some_and(|value| !value.trim().is_empty());
575
576 if let Some(value) = concurrency
577 .as_deref()
578 .filter(|value| !value.trim().is_empty())
579 {
580 if value
581 .trim()
582 .parse::<u32>()
583 .ok()
584 .filter(|parsed| *parsed > 0)
585 .is_none()
586 {
587 return Err(LlmixError::InvalidFileLockConfig(format!(
588 "LLM_GLOBAL_CONCURRENCY must be a positive integer, got \"{value}\""
589 )));
590 }
591 }
592
593 Ok(Self {
594 enabled,
595 lock_path: enabled.then_some(path.into()),
596 state: Mutex::new(FileLockState {
597 held: false,
598 file: None,
599 }),
600 available: Condvar::new(),
601 })
602 }
603
604 pub fn enabled(&self) -> bool {
605 self.enabled
606 }
607
608 pub fn lock_path(&self) -> Option<&Path> {
609 self.lock_path.as_deref()
610 }
611
612 pub fn acquire(&self) -> LlmixResult<()> {
613 if !self.enabled {
614 return Ok(());
615 }
616
617 let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
618 while state.held {
619 state = self
620 .available
621 .wait(state)
622 .unwrap_or_else(|e| e.into_inner());
623 }
624 state.held = true;
625 drop(state);
626
627 let file_result = self.open_locked_file();
628 let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
629 match file_result {
630 Ok(file) => {
631 state.file = Some(file);
632 Ok(())
633 }
634 Err(error) => {
635 state.held = false;
636 self.available.notify_one();
637 Err(error)
638 }
639 }
640 }
641
642 fn open_locked_file(&self) -> LlmixResult<File> {
643 let path = self
644 .lock_path
645 .as_ref()
646 .expect("enabled file lock must have a path");
647 if let Some(parent) = path.parent() {
648 fs::create_dir_all(parent)?;
649 }
650 let file = File::options()
651 .create(true)
652 .truncate(false)
653 .read(true)
654 .write(true)
655 .open(path)?;
656 file.lock_exclusive()?;
657 Ok(file)
658 }
659
660 pub fn acquire_guard(&self) -> LlmixResult<FileLockGuard<'_>> {
661 self.acquire()?;
662 Ok(FileLockGuard {
663 file_lock: self,
664 released: false,
665 })
666 }
667
668 pub fn release(&self) -> LlmixResult<()> {
669 if !self.enabled {
670 return Ok(());
671 }
672
673 let maybe_file = {
674 self.state
675 .lock()
676 .unwrap_or_else(|e| e.into_inner())
677 .file
678 .take()
679 };
680 let result = if let Some(file) = maybe_file {
681 file.unlock().map_err(LlmixError::from)
682 } else {
683 Ok(())
684 };
685
686 let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
687 state.held = false;
688 self.available.notify_one();
689 result
690 }
691}
692
693impl FileLockGuard<'_> {
694 pub fn release(mut self) -> LlmixResult<()> {
695 if self.released {
696 return Ok(());
697 }
698 self.released = true;
699 self.file_lock.release()
700 }
701}
702
703impl Drop for FileLockGuard<'_> {
704 fn drop(&mut self) {
705 if !self.released {
706 let _ = self.file_lock.release();
707 self.released = true;
708 }
709 }
710}
711
712impl Drop for FileLock {
713 fn drop(&mut self) {
714 let state = self.state.get_mut().unwrap_or_else(|e| e.into_inner());
715 if let Some(file) = state.file.take() {
716 let _ = file.unlock();
717 }
718 state.held = false;
719 }
720}