Skip to main content

inferd_client/
wait.rs

1//! Connect-and-retry helpers per `docs/protocol-v1.md` §"Client
2//! connection lifecycle".
3
4use crate::client::{Client, ClientError};
5use std::future::Future;
6use std::path::PathBuf;
7use std::time::Duration;
8use tokio::time::Instant;
9
10/// Errors produced by `dial_and_wait_ready`.
11#[derive(Debug, thiserror::Error)]
12pub enum WaitError {
13    /// Deadline elapsed before any successful connect.
14    #[error("timed out after {0:?} waiting for inferd to become ready")]
15    Timeout(Duration),
16    /// A non-transient error surfaced — daemon is broken or
17    /// permissions are wrong; not worth retrying.
18    #[error("permanent connect error: {0}")]
19    Permanent(ClientError),
20}
21
22/// Pattern A passive readiness: retry connect against the inference
23/// transport until success or `timeout` elapses. Successful connect
24/// is the ready signal — the daemon's inference socket only exists
25/// when the backend is `ready` per THREAT_MODEL F-13.
26///
27/// Backoff schedule: 100ms initial, doubling each attempt, capped
28/// at 5s. Permanent errors (permission denied, malformed addr,
29/// decode failure) bubble up immediately as
30/// `WaitError::Permanent`.
31///
32/// `dial_fn` is a closure producing a fresh dial future on each
33/// attempt. This indirection lets callers swap transports without
34/// us duplicating the retry loop:
35///
36/// ```no_run
37/// use inferd_client::{dial_and_wait_ready, Client};
38/// use std::time::Duration;
39///
40/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
41/// let client = dial_and_wait_ready(
42///     Duration::from_secs(30),
43///     || Client::dial_tcp("127.0.0.1:47321"),
44/// )
45/// .await?;
46/// # Ok(()) }
47/// ```
48pub async fn dial_and_wait_ready<F, Fut>(
49    timeout: Duration,
50    mut dial_fn: F,
51) -> Result<Client, WaitError>
52where
53    F: FnMut() -> Fut,
54    Fut: Future<Output = Result<Client, ClientError>>,
55{
56    let deadline = Instant::now() + timeout;
57    let mut delay = Duration::from_millis(100);
58    let max_delay = Duration::from_secs(5);
59
60    loop {
61        match dial_fn().await {
62            Ok(c) => return Ok(c),
63            Err(e) if !is_transient_dial_error(&e) => {
64                return Err(WaitError::Permanent(e));
65            }
66            Err(_) => {
67                if Instant::now() >= deadline {
68                    return Err(WaitError::Timeout(timeout));
69                }
70                tokio::time::sleep(delay).await;
71                delay = (delay * 2).min(max_delay);
72            }
73        }
74    }
75}
76
77/// Returns `true` if `err` is the kind of transient connect failure
78/// that the daemon's F-13 ready-gating produces during bring-up
79/// (the inference socket doesn't exist yet). Permanent errors
80/// (permission denied, malformed addr) return `false` and bubble
81/// up immediately rather than spamming retries.
82pub fn is_transient_dial_error(err: &ClientError) -> bool {
83    let ClientError::Io(io_err) = err else {
84        return false;
85    };
86    use std::io::ErrorKind;
87    matches!(
88        io_err.kind(),
89        ErrorKind::ConnectionRefused
90            | ErrorKind::NotFound
91            | ErrorKind::TimedOut
92            | ErrorKind::AddrNotAvailable
93    ) || {
94        // Windows pipe-busy comes through as raw os error; check
95        // the message as a fallback. Rare but real.
96        let msg = io_err.to_string().to_ascii_lowercase();
97        msg.contains("all pipe instances are busy")
98            || msg.contains("the system cannot find")
99            || msg.contains("target machine actively refused")
100    }
101}
102
103/// Default admin endpoint path per platform. Mirrors the daemon's
104/// `endpoint::default_admin_addr` so clients can reach the spec'd
105/// default without hard-coding it.
106///
107/// Linux resolution chain (per `docs/protocol-v1.md` §"Admin endpoint"):
108/// 1. `$XDG_RUNTIME_DIR/inferd/admin.sock`
109/// 2. `$HOME/.inferd/run/admin.sock`
110/// 3. `/tmp/inferd/admin.sock`
111pub fn default_admin_addr() -> PathBuf {
112    #[cfg(target_os = "linux")]
113    {
114        if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
115            let mut p = PathBuf::from(xdg);
116            if !p.as_os_str().is_empty() {
117                p.push("inferd");
118                p.push("admin.sock");
119                return p;
120            }
121        }
122        if let Some(home) = std::env::var_os("HOME") {
123            let mut p = PathBuf::from(home);
124            if !p.as_os_str().is_empty() {
125                p.push(".inferd");
126                p.push("run");
127                p.push("admin.sock");
128                return p;
129            }
130        }
131        PathBuf::from("/tmp/inferd/admin.sock")
132    }
133    #[cfg(target_os = "macos")]
134    {
135        let mut p = std::env::temp_dir();
136        p.push("inferd");
137        p.push("admin.sock");
138        p
139    }
140    #[cfg(windows)]
141    {
142        PathBuf::from(r"\\.\pipe\inferd-admin")
143    }
144    #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
145    {
146        PathBuf::from("/tmp/inferd/admin.sock")
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use std::io;
154    use std::sync::Arc;
155    use std::sync::atomic::{AtomicUsize, Ordering};
156
157    fn io_err(kind: io::ErrorKind, msg: &str) -> ClientError {
158        ClientError::Io(io::Error::new(kind, msg))
159    }
160
161    #[test]
162    fn refused_is_transient() {
163        assert!(is_transient_dial_error(&io_err(
164            io::ErrorKind::ConnectionRefused,
165            "refused"
166        )));
167    }
168
169    #[test]
170    fn notfound_is_transient() {
171        assert!(is_transient_dial_error(&io_err(
172            io::ErrorKind::NotFound,
173            "no such file"
174        )));
175    }
176
177    #[test]
178    fn permission_denied_is_permanent() {
179        assert!(!is_transient_dial_error(&io_err(
180            io::ErrorKind::PermissionDenied,
181            "denied"
182        )));
183    }
184
185    #[test]
186    fn pipe_busy_message_recognised_as_transient() {
187        assert!(is_transient_dial_error(&io_err(
188            io::ErrorKind::Other,
189            "All pipe instances are busy."
190        )));
191    }
192
193    #[test]
194    fn decode_error_is_permanent() {
195        // Synthesize a serde error by parsing garbage.
196        let err: serde_json::Error = serde_json::from_str::<u32>("not a number").unwrap_err();
197        let cerr = ClientError::Decode(err);
198        assert!(!is_transient_dial_error(&cerr));
199    }
200
201    #[tokio::test]
202    async fn dial_and_wait_ready_succeeds_first_try() {
203        let calls = Arc::new(AtomicUsize::new(0));
204        let calls_clone = Arc::clone(&calls);
205        let dial = move || {
206            calls_clone.fetch_add(1, Ordering::SeqCst);
207            // Build a minimal Client wrapping an in-memory pipe pair.
208            let (a, _b) = tokio::io::duplex(64);
209            let (read, write) = tokio::io::split(a);
210            async move { Ok(Client::wrap_for_test(Box::new(read), Box::new(write))) }
211        };
212        let _ = dial_and_wait_ready(Duration::from_secs(1), dial)
213            .await
214            .unwrap();
215        assert_eq!(calls.load(Ordering::SeqCst), 1);
216    }
217
218    #[tokio::test]
219    async fn dial_and_wait_ready_retries_transient() {
220        let calls = Arc::new(AtomicUsize::new(0));
221        let calls_clone = Arc::clone(&calls);
222        let dial = move || {
223            let n = calls_clone.fetch_add(1, Ordering::SeqCst);
224            async move {
225                if n < 2 {
226                    Err(io_err(io::ErrorKind::ConnectionRefused, "refused"))
227                } else {
228                    let (a, _b) = tokio::io::duplex(64);
229                    let (read, write) = tokio::io::split(a);
230                    Ok(Client::wrap_for_test(Box::new(read), Box::new(write)))
231                }
232            }
233        };
234        let _ = dial_and_wait_ready(Duration::from_secs(5), dial)
235            .await
236            .unwrap();
237        assert_eq!(calls.load(Ordering::SeqCst), 3);
238    }
239
240    #[tokio::test]
241    async fn dial_and_wait_ready_returns_permanent_immediately() {
242        let calls = Arc::new(AtomicUsize::new(0));
243        let calls_clone = Arc::clone(&calls);
244        let dial = move || {
245            calls_clone.fetch_add(1, Ordering::SeqCst);
246            async move { Err::<Client, _>(io_err(io::ErrorKind::PermissionDenied, "denied")) }
247        };
248        let err = dial_and_wait_ready(Duration::from_secs(5), dial)
249            .await
250            .unwrap_err();
251        match err {
252            WaitError::Permanent(_) => {}
253            other => panic!("expected Permanent, got {other:?}"),
254        }
255        assert_eq!(calls.load(Ordering::SeqCst), 1);
256    }
257
258    #[tokio::test]
259    async fn dial_and_wait_ready_times_out() {
260        let dial = move || async move {
261            Err::<Client, _>(io_err(io::ErrorKind::ConnectionRefused, "refused"))
262        };
263        let err = dial_and_wait_ready(Duration::from_millis(250), dial)
264            .await
265            .unwrap_err();
266        assert!(matches!(err, WaitError::Timeout(_)));
267    }
268}