Skip to main content

inferd_client/
wait.rs

1//! Connect-and-retry helpers per `docs/protocol-v1.md` §"Client
2//! connection lifecycle".
3
4#[cfg(test)]
5use crate::client::Client;
6use crate::client::ClientError;
7use std::future::Future;
8use std::path::PathBuf;
9use std::time::Duration;
10use tokio::time::Instant;
11
12/// Errors produced by `dial_and_wait_ready`.
13#[derive(Debug, thiserror::Error)]
14pub enum WaitError {
15    /// Deadline elapsed before any successful connect.
16    #[error("timed out after {0:?} waiting for inferd to become ready")]
17    Timeout(Duration),
18    /// A non-transient error surfaced — daemon is broken or
19    /// permissions are wrong; not worth retrying.
20    #[error("permanent connect error: {0}")]
21    Permanent(ClientError),
22}
23
24/// Pattern A passive readiness: retry connect against the inference
25/// transport until success or `timeout` elapses. Successful connect
26/// is the ready signal — the daemon's inference socket only exists
27/// when the backend is `ready` per THREAT_MODEL F-13.
28///
29/// Backoff schedule: 100ms initial, doubling each attempt, capped
30/// at 5s. Permanent errors (permission denied, malformed addr,
31/// decode failure) bubble up immediately as
32/// `WaitError::Permanent`.
33///
34/// `dial_fn` is a closure producing a fresh dial future on each
35/// attempt. This indirection lets callers swap transports without
36/// us duplicating the retry loop:
37///
38/// ```no_run
39/// use inferd_client::{dial_and_wait_ready, Client};
40/// use std::time::Duration;
41///
42/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
43/// let client = dial_and_wait_ready(
44///     Duration::from_secs(30),
45///     || Client::dial_tcp("127.0.0.1:47321"),
46/// )
47/// .await?;
48/// # Ok(()) }
49/// ```
50pub async fn dial_and_wait_ready<C, F, Fut>(
51    timeout: Duration,
52    mut dial_fn: F,
53) -> Result<C, WaitError>
54where
55    F: FnMut() -> Fut,
56    Fut: Future<Output = Result<C, ClientError>>,
57{
58    let deadline = Instant::now() + timeout;
59    let mut delay = Duration::from_millis(100);
60    let max_delay = Duration::from_secs(5);
61
62    loop {
63        match dial_fn().await {
64            Ok(c) => return Ok(c),
65            Err(e) if !is_transient_dial_error(&e) => {
66                return Err(WaitError::Permanent(e));
67            }
68            Err(_) => {
69                if Instant::now() >= deadline {
70                    return Err(WaitError::Timeout(timeout));
71                }
72                tokio::time::sleep(delay).await;
73                delay = (delay * 2).min(max_delay);
74            }
75        }
76    }
77}
78
79/// Returns `true` if `err` is the kind of transient connect failure
80/// that the daemon's F-13 ready-gating produces during bring-up
81/// (the inference socket doesn't exist yet). Permanent errors
82/// (permission denied, malformed addr) return `false` and bubble
83/// up immediately rather than spamming retries.
84pub fn is_transient_dial_error(err: &ClientError) -> bool {
85    let ClientError::Io(io_err) = err else {
86        return false;
87    };
88    use std::io::ErrorKind;
89    matches!(
90        io_err.kind(),
91        ErrorKind::ConnectionRefused
92            | ErrorKind::NotFound
93            | ErrorKind::TimedOut
94            | ErrorKind::AddrNotAvailable
95    ) || {
96        // Windows pipe-busy comes through as raw os error; check
97        // the message as a fallback. Rare but real.
98        let msg = io_err.to_string().to_ascii_lowercase();
99        msg.contains("all pipe instances are busy")
100            || msg.contains("the system cannot find")
101            || msg.contains("target machine actively refused")
102    }
103}
104
105/// Default admin endpoint path per platform. Mirrors the daemon's
106/// `endpoint::default_admin_addr` so clients can reach the spec'd
107/// default without hard-coding it.
108///
109/// Linux resolution chain (per `docs/protocol-v1.md` §"Admin endpoint"):
110/// 1. `$XDG_RUNTIME_DIR/inferd/admin.sock`
111/// 2. `$HOME/.inferd/run/admin.sock`
112/// 3. `/tmp/inferd/admin.sock`
113pub fn default_admin_addr() -> PathBuf {
114    #[cfg(target_os = "linux")]
115    {
116        if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
117            let mut p = PathBuf::from(xdg);
118            if !p.as_os_str().is_empty() {
119                p.push("inferd");
120                p.push("admin.sock");
121                return p;
122            }
123        }
124        if let Some(home) = std::env::var_os("HOME") {
125            let mut p = PathBuf::from(home);
126            if !p.as_os_str().is_empty() {
127                p.push(".inferd");
128                p.push("run");
129                p.push("admin.sock");
130                return p;
131            }
132        }
133        PathBuf::from("/tmp/inferd/admin.sock")
134    }
135    #[cfg(target_os = "macos")]
136    {
137        let mut p = std::env::temp_dir();
138        p.push("inferd");
139        p.push("admin.sock");
140        p
141    }
142    #[cfg(windows)]
143    {
144        PathBuf::from(r"\\.\pipe\inferd-admin")
145    }
146    #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
147    {
148        PathBuf::from("/tmp/inferd/admin.sock")
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use std::io;
156    use std::sync::Arc;
157    use std::sync::atomic::{AtomicUsize, Ordering};
158
159    fn io_err(kind: io::ErrorKind, msg: &str) -> ClientError {
160        ClientError::Io(io::Error::new(kind, msg))
161    }
162
163    #[test]
164    fn refused_is_transient() {
165        assert!(is_transient_dial_error(&io_err(
166            io::ErrorKind::ConnectionRefused,
167            "refused"
168        )));
169    }
170
171    #[test]
172    fn notfound_is_transient() {
173        assert!(is_transient_dial_error(&io_err(
174            io::ErrorKind::NotFound,
175            "no such file"
176        )));
177    }
178
179    #[test]
180    fn permission_denied_is_permanent() {
181        assert!(!is_transient_dial_error(&io_err(
182            io::ErrorKind::PermissionDenied,
183            "denied"
184        )));
185    }
186
187    #[test]
188    fn pipe_busy_message_recognised_as_transient() {
189        assert!(is_transient_dial_error(&io_err(
190            io::ErrorKind::Other,
191            "All pipe instances are busy."
192        )));
193    }
194
195    #[test]
196    fn decode_error_is_permanent() {
197        // Synthesize a serde error by parsing garbage.
198        let err: serde_json::Error = serde_json::from_str::<u32>("not a number").unwrap_err();
199        let cerr = ClientError::Decode(err);
200        assert!(!is_transient_dial_error(&cerr));
201    }
202
203    #[tokio::test]
204    async fn dial_and_wait_ready_succeeds_first_try() {
205        let calls = Arc::new(AtomicUsize::new(0));
206        let calls_clone = Arc::clone(&calls);
207        let dial = move || {
208            calls_clone.fetch_add(1, Ordering::SeqCst);
209            // Build a minimal Client wrapping an in-memory pipe pair.
210            let (a, _b) = tokio::io::duplex(64);
211            let (read, write) = tokio::io::split(a);
212            async move { Ok(Client::wrap_for_test(Box::new(read), Box::new(write))) }
213        };
214        let _ = dial_and_wait_ready(Duration::from_secs(1), dial)
215            .await
216            .unwrap();
217        assert_eq!(calls.load(Ordering::SeqCst), 1);
218    }
219
220    #[tokio::test]
221    async fn dial_and_wait_ready_retries_transient() {
222        let calls = Arc::new(AtomicUsize::new(0));
223        let calls_clone = Arc::clone(&calls);
224        let dial = move || {
225            let n = calls_clone.fetch_add(1, Ordering::SeqCst);
226            async move {
227                if n < 2 {
228                    Err(io_err(io::ErrorKind::ConnectionRefused, "refused"))
229                } else {
230                    let (a, _b) = tokio::io::duplex(64);
231                    let (read, write) = tokio::io::split(a);
232                    Ok(Client::wrap_for_test(Box::new(read), Box::new(write)))
233                }
234            }
235        };
236        let _ = dial_and_wait_ready(Duration::from_secs(5), dial)
237            .await
238            .unwrap();
239        assert_eq!(calls.load(Ordering::SeqCst), 3);
240    }
241
242    #[tokio::test]
243    async fn dial_and_wait_ready_returns_permanent_immediately() {
244        let calls = Arc::new(AtomicUsize::new(0));
245        let calls_clone = Arc::clone(&calls);
246        let dial = move || {
247            calls_clone.fetch_add(1, Ordering::SeqCst);
248            async move { Err::<Client, _>(io_err(io::ErrorKind::PermissionDenied, "denied")) }
249        };
250        let err = dial_and_wait_ready(Duration::from_secs(5), dial)
251            .await
252            .unwrap_err();
253        match err {
254            WaitError::Permanent(_) => {}
255            other => panic!("expected Permanent, got {other:?}"),
256        }
257        assert_eq!(calls.load(Ordering::SeqCst), 1);
258    }
259
260    #[tokio::test]
261    async fn dial_and_wait_ready_times_out() {
262        let dial = move || async move {
263            Err::<Client, _>(io_err(io::ErrorKind::ConnectionRefused, "refused"))
264        };
265        let err = dial_and_wait_ready(Duration::from_millis(250), dial)
266            .await
267            .unwrap_err();
268        assert!(matches!(err, WaitError::Timeout(_)));
269    }
270}