1#[cfg(test)]
5use crate::client::Client;
6use crate::client::ClientError;
7use std::future::Future;
8use std::path::PathBuf;
9use std::time::Duration;
10use tokio::time::Instant;
11
12#[derive(Debug, thiserror::Error)]
14pub enum WaitError {
15 #[error("timed out after {0:?} waiting for inferd to become ready")]
17 Timeout(Duration),
18 #[error("permanent connect error: {0}")]
21 Permanent(ClientError),
22}
23
24pub async fn dial_and_wait_ready<C, F, Fut>(
51 timeout: Duration,
52 mut dial_fn: F,
53) -> Result<C, WaitError>
54where
55 F: FnMut() -> Fut,
56 Fut: Future<Output = Result<C, ClientError>>,
57{
58 let deadline = Instant::now() + timeout;
59 let mut delay = Duration::from_millis(100);
60 let max_delay = Duration::from_secs(5);
61
62 loop {
63 match dial_fn().await {
64 Ok(c) => return Ok(c),
65 Err(e) if !is_transient_dial_error(&e) => {
66 return Err(WaitError::Permanent(e));
67 }
68 Err(_) => {
69 if Instant::now() >= deadline {
70 return Err(WaitError::Timeout(timeout));
71 }
72 tokio::time::sleep(delay).await;
73 delay = (delay * 2).min(max_delay);
74 }
75 }
76 }
77}
78
79pub fn is_transient_dial_error(err: &ClientError) -> bool {
85 let ClientError::Io(io_err) = err else {
86 return false;
87 };
88 use std::io::ErrorKind;
89 matches!(
90 io_err.kind(),
91 ErrorKind::ConnectionRefused
92 | ErrorKind::NotFound
93 | ErrorKind::TimedOut
94 | ErrorKind::AddrNotAvailable
95 ) || {
96 let msg = io_err.to_string().to_ascii_lowercase();
99 msg.contains("all pipe instances are busy")
100 || msg.contains("the system cannot find")
101 || msg.contains("target machine actively refused")
102 }
103}
104
105pub fn default_admin_addr() -> PathBuf {
114 #[cfg(target_os = "linux")]
115 {
116 if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
117 let mut p = PathBuf::from(xdg);
118 if !p.as_os_str().is_empty() {
119 p.push("inferd");
120 p.push("admin.sock");
121 return p;
122 }
123 }
124 if let Some(home) = std::env::var_os("HOME") {
125 let mut p = PathBuf::from(home);
126 if !p.as_os_str().is_empty() {
127 p.push(".inferd");
128 p.push("run");
129 p.push("admin.sock");
130 return p;
131 }
132 }
133 PathBuf::from("/tmp/inferd/admin.sock")
134 }
135 #[cfg(target_os = "macos")]
136 {
137 let mut p = std::env::temp_dir();
138 p.push("inferd");
139 p.push("admin.sock");
140 p
141 }
142 #[cfg(windows)]
143 {
144 PathBuf::from(r"\\.\pipe\inferd-admin")
145 }
146 #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
147 {
148 PathBuf::from("/tmp/inferd/admin.sock")
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use std::io;
156 use std::sync::Arc;
157 use std::sync::atomic::{AtomicUsize, Ordering};
158
159 fn io_err(kind: io::ErrorKind, msg: &str) -> ClientError {
160 ClientError::Io(io::Error::new(kind, msg))
161 }
162
163 #[test]
164 fn refused_is_transient() {
165 assert!(is_transient_dial_error(&io_err(
166 io::ErrorKind::ConnectionRefused,
167 "refused"
168 )));
169 }
170
171 #[test]
172 fn notfound_is_transient() {
173 assert!(is_transient_dial_error(&io_err(
174 io::ErrorKind::NotFound,
175 "no such file"
176 )));
177 }
178
179 #[test]
180 fn permission_denied_is_permanent() {
181 assert!(!is_transient_dial_error(&io_err(
182 io::ErrorKind::PermissionDenied,
183 "denied"
184 )));
185 }
186
187 #[test]
188 fn pipe_busy_message_recognised_as_transient() {
189 assert!(is_transient_dial_error(&io_err(
190 io::ErrorKind::Other,
191 "All pipe instances are busy."
192 )));
193 }
194
195 #[test]
196 fn decode_error_is_permanent() {
197 let err: serde_json::Error = serde_json::from_str::<u32>("not a number").unwrap_err();
199 let cerr = ClientError::Decode(err);
200 assert!(!is_transient_dial_error(&cerr));
201 }
202
203 #[tokio::test]
204 async fn dial_and_wait_ready_succeeds_first_try() {
205 let calls = Arc::new(AtomicUsize::new(0));
206 let calls_clone = Arc::clone(&calls);
207 let dial = move || {
208 calls_clone.fetch_add(1, Ordering::SeqCst);
209 let (a, _b) = tokio::io::duplex(64);
211 let (read, write) = tokio::io::split(a);
212 async move { Ok(Client::wrap_for_test(Box::new(read), Box::new(write))) }
213 };
214 let _ = dial_and_wait_ready(Duration::from_secs(1), dial)
215 .await
216 .unwrap();
217 assert_eq!(calls.load(Ordering::SeqCst), 1);
218 }
219
220 #[tokio::test]
221 async fn dial_and_wait_ready_retries_transient() {
222 let calls = Arc::new(AtomicUsize::new(0));
223 let calls_clone = Arc::clone(&calls);
224 let dial = move || {
225 let n = calls_clone.fetch_add(1, Ordering::SeqCst);
226 async move {
227 if n < 2 {
228 Err(io_err(io::ErrorKind::ConnectionRefused, "refused"))
229 } else {
230 let (a, _b) = tokio::io::duplex(64);
231 let (read, write) = tokio::io::split(a);
232 Ok(Client::wrap_for_test(Box::new(read), Box::new(write)))
233 }
234 }
235 };
236 let _ = dial_and_wait_ready(Duration::from_secs(5), dial)
237 .await
238 .unwrap();
239 assert_eq!(calls.load(Ordering::SeqCst), 3);
240 }
241
242 #[tokio::test]
243 async fn dial_and_wait_ready_returns_permanent_immediately() {
244 let calls = Arc::new(AtomicUsize::new(0));
245 let calls_clone = Arc::clone(&calls);
246 let dial = move || {
247 calls_clone.fetch_add(1, Ordering::SeqCst);
248 async move { Err::<Client, _>(io_err(io::ErrorKind::PermissionDenied, "denied")) }
249 };
250 let err = dial_and_wait_ready(Duration::from_secs(5), dial)
251 .await
252 .unwrap_err();
253 match err {
254 WaitError::Permanent(_) => {}
255 other => panic!("expected Permanent, got {other:?}"),
256 }
257 assert_eq!(calls.load(Ordering::SeqCst), 1);
258 }
259
260 #[tokio::test]
261 async fn dial_and_wait_ready_times_out() {
262 let dial = move || async move {
263 Err::<Client, _>(io_err(io::ErrorKind::ConnectionRefused, "refused"))
264 };
265 let err = dial_and_wait_ready(Duration::from_millis(250), dial)
266 .await
267 .unwrap_err();
268 assert!(matches!(err, WaitError::Timeout(_)));
269 }
270}