Skip to main content

frankensearch_core/
daemon.rs

1//! Daemon client abstraction for warm embedding and reranking.
2//!
3//! This module defines the protocol-agnostic daemon interfaces shared by
4//! host applications and fusion-layer fallback wrappers.
5
6use std::fmt;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::Duration;
9
10/// Retry/backoff configuration for daemon requests.
11#[derive(Debug, Clone)]
12pub struct DaemonRetryConfig {
13    /// Max attempts per request (including the first try).
14    pub max_attempts: u32,
15    /// Base backoff delay for the first failure.
16    pub base_delay: Duration,
17    /// Maximum backoff delay.
18    pub max_delay: Duration,
19    /// Jitter percentage applied to backoff (0.0..=1.0).
20    pub jitter_pct: f64,
21}
22
23impl Default for DaemonRetryConfig {
24    fn default() -> Self {
25        Self {
26            max_attempts: 2,
27            base_delay: Duration::from_millis(200),
28            max_delay: Duration::from_secs(5),
29            jitter_pct: 0.2,
30        }
31    }
32}
33
34impl DaemonRetryConfig {
35    /// Load retry config from environment variables; fall back to defaults.
36    #[must_use]
37    pub fn from_env() -> Self {
38        let mut cfg = Self::default();
39
40        if let Ok(val) = std::env::var("CASS_DAEMON_RETRY_MAX")
41            && let Ok(parsed) = val.parse::<u32>()
42        {
43            cfg.max_attempts = parsed.max(1);
44        }
45
46        if let Ok(val) = std::env::var("CASS_DAEMON_BACKOFF_BASE_MS")
47            && let Ok(parsed) = val.parse::<u64>()
48        {
49            cfg.base_delay = Duration::from_millis(parsed.max(1));
50        }
51
52        if let Ok(val) = std::env::var("CASS_DAEMON_BACKOFF_MAX_MS")
53            && let Ok(parsed) = val.parse::<u64>()
54        {
55            cfg.max_delay = Duration::from_millis(parsed.max(1));
56        }
57
58        if let Ok(val) = std::env::var("CASS_DAEMON_JITTER_PCT")
59            && let Ok(parsed) = val.parse::<f64>()
60        {
61            cfg.jitter_pct = parsed.clamp(0.0, 1.0);
62        }
63
64        cfg
65    }
66
67    /// Compute backoff for the given failure attempt.
68    #[must_use]
69    pub fn backoff_for_attempt(&self, attempt: u32, retry_after: Option<Duration>) -> Duration {
70        if let Some(explicit) = retry_after {
71            return explicit.min(self.max_delay);
72        }
73
74        let exp = 2u32.saturating_pow(attempt.saturating_sub(1));
75        let base = self.base_delay.checked_mul(exp).unwrap_or(self.max_delay);
76        apply_jitter(base.min(self.max_delay), self.jitter_pct)
77    }
78}
79
80/// Daemon request failure details.
81#[derive(Debug, Clone)]
82pub enum DaemonError {
83    Unavailable(String),
84    Timeout(String),
85    Overloaded {
86        retry_after: Option<Duration>,
87        message: String,
88    },
89    Failed(String),
90    InvalidInput(String),
91}
92
93impl fmt::Display for DaemonError {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        match self {
96            Self::Unavailable(msg) => write!(f, "daemon unavailable: {msg}"),
97            Self::Timeout(msg) => write!(f, "daemon timeout: {msg}"),
98            Self::Overloaded { message, .. } => write!(f, "daemon overloaded: {message}"),
99            Self::Failed(msg) => write!(f, "daemon failed: {msg}"),
100            Self::InvalidInput(msg) => write!(f, "daemon invalid input: {msg}"),
101        }
102    }
103}
104
105impl std::error::Error for DaemonError {}
106
107/// Abstract daemon client.
108///
109/// Concrete transports (e.g. UDS/HTTP) are implemented by host applications.
110#[allow(clippy::missing_errors_doc)]
111pub trait DaemonClient: Send + Sync {
112    fn id(&self) -> &str;
113    fn is_available(&self) -> bool;
114
115    fn embed(&self, text: &str, request_id: &str) -> Result<Vec<f32>, DaemonError>;
116    fn embed_batch(&self, texts: &[&str], request_id: &str) -> Result<Vec<Vec<f32>>, DaemonError>;
117    fn rerank(
118        &self,
119        query: &str,
120        documents: &[&str],
121        request_id: &str,
122    ) -> Result<Vec<f32>, DaemonError>;
123}
124
125/// Apply bounded symmetric jitter to a duration.
126#[must_use]
127pub fn apply_jitter(duration: Duration, jitter_pct: f64) -> Duration {
128    if jitter_pct <= 0.0 {
129        return duration;
130    }
131    let unit = next_jitter_unit();
132    let delta = unit.mul_add(2.0, -1.0) * jitter_pct;
133    #[allow(clippy::cast_precision_loss)]
134    let base_ms = duration.as_millis() as f64;
135    let jittered = (base_ms * (1.0 + delta)).max(1.0);
136    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
137    Duration::from_millis(jittered.round() as u64)
138}
139
140/// Generate a stable daemon request id for tracing and retries.
141#[must_use]
142pub fn next_request_id() -> String {
143    static COUNTER: AtomicU64 = AtomicU64::new(1);
144    let id = COUNTER.fetch_add(1, Ordering::Relaxed);
145    format!("daemon-{id}")
146}
147
148fn next_jitter_unit() -> f64 {
149    static SEED: AtomicU64 = AtomicU64::new(0x9e37_79b9_7f4a_7c15);
150    let mut current = SEED.load(Ordering::Relaxed);
151    loop {
152        let next = current
153            .wrapping_mul(6_364_136_223_846_793_005_u64)
154            .wrapping_add(1);
155        match SEED.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
156            Ok(_) => {
157                // Use top 53 bits for a uniform f64 in [0, 1).
158                let value = next >> 11;
159                #[allow(clippy::cast_precision_loss)]
160                return (value as f64) / ((1_u64 << 53) as f64);
161            }
162            Err(actual) => current = actual,
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn backoff_respects_retry_after() {
173        let cfg = DaemonRetryConfig::default();
174        let retry_after = Duration::from_secs(1);
175        assert_eq!(cfg.backoff_for_attempt(4, Some(retry_after)), retry_after);
176    }
177
178    #[test]
179    fn jitter_stays_positive() {
180        let base = Duration::from_millis(50);
181        for _ in 0..100 {
182            let jittered = apply_jitter(base, 0.2);
183            assert!(jittered.as_millis() >= 1);
184        }
185    }
186}