frankensearch_core/
daemon.rs1use std::fmt;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::Duration;
9
10#[derive(Debug, Clone)]
12pub struct DaemonRetryConfig {
13 pub max_attempts: u32,
15 pub base_delay: Duration,
17 pub max_delay: Duration,
19 pub jitter_pct: f64,
21}
22
23impl Default for DaemonRetryConfig {
24 fn default() -> Self {
25 Self {
26 max_attempts: 2,
27 base_delay: Duration::from_millis(200),
28 max_delay: Duration::from_secs(5),
29 jitter_pct: 0.2,
30 }
31 }
32}
33
34impl DaemonRetryConfig {
35 #[must_use]
37 pub fn from_env() -> Self {
38 let mut cfg = Self::default();
39
40 if let Ok(val) = std::env::var("CASS_DAEMON_RETRY_MAX")
41 && let Ok(parsed) = val.parse::<u32>()
42 {
43 cfg.max_attempts = parsed.max(1);
44 }
45
46 if let Ok(val) = std::env::var("CASS_DAEMON_BACKOFF_BASE_MS")
47 && let Ok(parsed) = val.parse::<u64>()
48 {
49 cfg.base_delay = Duration::from_millis(parsed.max(1));
50 }
51
52 if let Ok(val) = std::env::var("CASS_DAEMON_BACKOFF_MAX_MS")
53 && let Ok(parsed) = val.parse::<u64>()
54 {
55 cfg.max_delay = Duration::from_millis(parsed.max(1));
56 }
57
58 if let Ok(val) = std::env::var("CASS_DAEMON_JITTER_PCT")
59 && let Ok(parsed) = val.parse::<f64>()
60 {
61 cfg.jitter_pct = parsed.clamp(0.0, 1.0);
62 }
63
64 cfg
65 }
66
67 #[must_use]
69 pub fn backoff_for_attempt(&self, attempt: u32, retry_after: Option<Duration>) -> Duration {
70 if let Some(explicit) = retry_after {
71 return explicit.min(self.max_delay);
72 }
73
74 let exp = 2u32.saturating_pow(attempt.saturating_sub(1));
75 let base = self.base_delay.checked_mul(exp).unwrap_or(self.max_delay);
76 apply_jitter(base.min(self.max_delay), self.jitter_pct)
77 }
78}
79
80#[derive(Debug, Clone)]
82pub enum DaemonError {
83 Unavailable(String),
84 Timeout(String),
85 Overloaded {
86 retry_after: Option<Duration>,
87 message: String,
88 },
89 Failed(String),
90 InvalidInput(String),
91}
92
93impl fmt::Display for DaemonError {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 match self {
96 Self::Unavailable(msg) => write!(f, "daemon unavailable: {msg}"),
97 Self::Timeout(msg) => write!(f, "daemon timeout: {msg}"),
98 Self::Overloaded { message, .. } => write!(f, "daemon overloaded: {message}"),
99 Self::Failed(msg) => write!(f, "daemon failed: {msg}"),
100 Self::InvalidInput(msg) => write!(f, "daemon invalid input: {msg}"),
101 }
102 }
103}
104
105impl std::error::Error for DaemonError {}
106
107#[allow(clippy::missing_errors_doc)]
111pub trait DaemonClient: Send + Sync {
112 fn id(&self) -> &str;
113 fn is_available(&self) -> bool;
114
115 fn embed(&self, text: &str, request_id: &str) -> Result<Vec<f32>, DaemonError>;
116 fn embed_batch(&self, texts: &[&str], request_id: &str) -> Result<Vec<Vec<f32>>, DaemonError>;
117 fn rerank(
118 &self,
119 query: &str,
120 documents: &[&str],
121 request_id: &str,
122 ) -> Result<Vec<f32>, DaemonError>;
123}
124
125#[must_use]
127pub fn apply_jitter(duration: Duration, jitter_pct: f64) -> Duration {
128 if jitter_pct <= 0.0 {
129 return duration;
130 }
131 let unit = next_jitter_unit();
132 let delta = unit.mul_add(2.0, -1.0) * jitter_pct;
133 #[allow(clippy::cast_precision_loss)]
134 let base_ms = duration.as_millis() as f64;
135 let jittered = (base_ms * (1.0 + delta)).max(1.0);
136 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
137 Duration::from_millis(jittered.round() as u64)
138}
139
140#[must_use]
142pub fn next_request_id() -> String {
143 static COUNTER: AtomicU64 = AtomicU64::new(1);
144 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
145 format!("daemon-{id}")
146}
147
148fn next_jitter_unit() -> f64 {
149 static SEED: AtomicU64 = AtomicU64::new(0x9e37_79b9_7f4a_7c15);
150 let mut current = SEED.load(Ordering::Relaxed);
151 loop {
152 let next = current
153 .wrapping_mul(6_364_136_223_846_793_005_u64)
154 .wrapping_add(1);
155 match SEED.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
156 Ok(_) => {
157 let value = next >> 11;
159 #[allow(clippy::cast_precision_loss)]
160 return (value as f64) / ((1_u64 << 53) as f64);
161 }
162 Err(actual) => current = actual,
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn backoff_respects_retry_after() {
173 let cfg = DaemonRetryConfig::default();
174 let retry_after = Duration::from_secs(1);
175 assert_eq!(cfg.backoff_for_attempt(4, Some(retry_after)), retry_after);
176 }
177
178 #[test]
179 fn jitter_stays_positive() {
180 let base = Duration::from_millis(50);
181 for _ in 0..100 {
182 let jittered = apply_jitter(base, 0.2);
183 assert!(jittered.as_millis() >= 1);
184 }
185 }
186}