1use crate::client::{Client, ClientError};
5use std::future::Future;
6use std::path::PathBuf;
7use std::time::Duration;
8use tokio::time::Instant;
9
10#[derive(Debug, thiserror::Error)]
12pub enum WaitError {
13 #[error("timed out after {0:?} waiting for inferd to become ready")]
15 Timeout(Duration),
16 #[error("permanent connect error: {0}")]
19 Permanent(ClientError),
20}
21
22pub async fn dial_and_wait_ready<F, Fut>(
49 timeout: Duration,
50 mut dial_fn: F,
51) -> Result<Client, WaitError>
52where
53 F: FnMut() -> Fut,
54 Fut: Future<Output = Result<Client, ClientError>>,
55{
56 let deadline = Instant::now() + timeout;
57 let mut delay = Duration::from_millis(100);
58 let max_delay = Duration::from_secs(5);
59
60 loop {
61 match dial_fn().await {
62 Ok(c) => return Ok(c),
63 Err(e) if !is_transient_dial_error(&e) => {
64 return Err(WaitError::Permanent(e));
65 }
66 Err(_) => {
67 if Instant::now() >= deadline {
68 return Err(WaitError::Timeout(timeout));
69 }
70 tokio::time::sleep(delay).await;
71 delay = (delay * 2).min(max_delay);
72 }
73 }
74 }
75}
76
77pub fn is_transient_dial_error(err: &ClientError) -> bool {
83 let ClientError::Io(io_err) = err else {
84 return false;
85 };
86 use std::io::ErrorKind;
87 matches!(
88 io_err.kind(),
89 ErrorKind::ConnectionRefused
90 | ErrorKind::NotFound
91 | ErrorKind::TimedOut
92 | ErrorKind::AddrNotAvailable
93 ) || {
94 let msg = io_err.to_string().to_ascii_lowercase();
97 msg.contains("all pipe instances are busy")
98 || msg.contains("the system cannot find")
99 || msg.contains("target machine actively refused")
100 }
101}
102
103pub fn default_admin_addr() -> PathBuf {
112 #[cfg(target_os = "linux")]
113 {
114 if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
115 let mut p = PathBuf::from(xdg);
116 if !p.as_os_str().is_empty() {
117 p.push("inferd");
118 p.push("admin.sock");
119 return p;
120 }
121 }
122 if let Some(home) = std::env::var_os("HOME") {
123 let mut p = PathBuf::from(home);
124 if !p.as_os_str().is_empty() {
125 p.push(".inferd");
126 p.push("run");
127 p.push("admin.sock");
128 return p;
129 }
130 }
131 PathBuf::from("/tmp/inferd/admin.sock")
132 }
133 #[cfg(target_os = "macos")]
134 {
135 let mut p = std::env::temp_dir();
136 p.push("inferd");
137 p.push("admin.sock");
138 p
139 }
140 #[cfg(windows)]
141 {
142 PathBuf::from(r"\\.\pipe\inferd-admin")
143 }
144 #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
145 {
146 PathBuf::from("/tmp/inferd/admin.sock")
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use std::io;
154 use std::sync::Arc;
155 use std::sync::atomic::{AtomicUsize, Ordering};
156
157 fn io_err(kind: io::ErrorKind, msg: &str) -> ClientError {
158 ClientError::Io(io::Error::new(kind, msg))
159 }
160
161 #[test]
162 fn refused_is_transient() {
163 assert!(is_transient_dial_error(&io_err(
164 io::ErrorKind::ConnectionRefused,
165 "refused"
166 )));
167 }
168
169 #[test]
170 fn notfound_is_transient() {
171 assert!(is_transient_dial_error(&io_err(
172 io::ErrorKind::NotFound,
173 "no such file"
174 )));
175 }
176
177 #[test]
178 fn permission_denied_is_permanent() {
179 assert!(!is_transient_dial_error(&io_err(
180 io::ErrorKind::PermissionDenied,
181 "denied"
182 )));
183 }
184
185 #[test]
186 fn pipe_busy_message_recognised_as_transient() {
187 assert!(is_transient_dial_error(&io_err(
188 io::ErrorKind::Other,
189 "All pipe instances are busy."
190 )));
191 }
192
193 #[test]
194 fn decode_error_is_permanent() {
195 let err: serde_json::Error = serde_json::from_str::<u32>("not a number").unwrap_err();
197 let cerr = ClientError::Decode(err);
198 assert!(!is_transient_dial_error(&cerr));
199 }
200
201 #[tokio::test]
202 async fn dial_and_wait_ready_succeeds_first_try() {
203 let calls = Arc::new(AtomicUsize::new(0));
204 let calls_clone = Arc::clone(&calls);
205 let dial = move || {
206 calls_clone.fetch_add(1, Ordering::SeqCst);
207 let (a, _b) = tokio::io::duplex(64);
209 let (read, write) = tokio::io::split(a);
210 async move { Ok(Client::wrap_for_test(Box::new(read), Box::new(write))) }
211 };
212 let _ = dial_and_wait_ready(Duration::from_secs(1), dial)
213 .await
214 .unwrap();
215 assert_eq!(calls.load(Ordering::SeqCst), 1);
216 }
217
218 #[tokio::test]
219 async fn dial_and_wait_ready_retries_transient() {
220 let calls = Arc::new(AtomicUsize::new(0));
221 let calls_clone = Arc::clone(&calls);
222 let dial = move || {
223 let n = calls_clone.fetch_add(1, Ordering::SeqCst);
224 async move {
225 if n < 2 {
226 Err(io_err(io::ErrorKind::ConnectionRefused, "refused"))
227 } else {
228 let (a, _b) = tokio::io::duplex(64);
229 let (read, write) = tokio::io::split(a);
230 Ok(Client::wrap_for_test(Box::new(read), Box::new(write)))
231 }
232 }
233 };
234 let _ = dial_and_wait_ready(Duration::from_secs(5), dial)
235 .await
236 .unwrap();
237 assert_eq!(calls.load(Ordering::SeqCst), 3);
238 }
239
240 #[tokio::test]
241 async fn dial_and_wait_ready_returns_permanent_immediately() {
242 let calls = Arc::new(AtomicUsize::new(0));
243 let calls_clone = Arc::clone(&calls);
244 let dial = move || {
245 calls_clone.fetch_add(1, Ordering::SeqCst);
246 async move { Err::<Client, _>(io_err(io::ErrorKind::PermissionDenied, "denied")) }
247 };
248 let err = dial_and_wait_ready(Duration::from_secs(5), dial)
249 .await
250 .unwrap_err();
251 match err {
252 WaitError::Permanent(_) => {}
253 other => panic!("expected Permanent, got {other:?}"),
254 }
255 assert_eq!(calls.load(Ordering::SeqCst), 1);
256 }
257
258 #[tokio::test]
259 async fn dial_and_wait_ready_times_out() {
260 let dial = move || async move {
261 Err::<Client, _>(io_err(io::ErrorKind::ConnectionRefused, "refused"))
262 };
263 let err = dial_and_wait_ready(Duration::from_millis(250), dial)
264 .await
265 .unwrap_err();
266 assert!(matches!(err, WaitError::Timeout(_)));
267 }
268}