inferd_daemon/
endpoint.rs1use std::io;
19use std::net::SocketAddr;
20use std::path::Path;
21use tokio::io::{AsyncRead, AsyncWrite};
22use tokio::net::{TcpListener, TcpStream};
23
24pub const DEFAULT_TCP_ADDR: &str = "127.0.0.1:47321";
26
27pub fn default_admin_addr() -> std::path::PathBuf {
42 #[cfg(target_os = "linux")]
43 {
44 linux_runtime_path("admin.sock")
45 }
46 #[cfg(target_os = "macos")]
47 {
48 let mut p = std::env::temp_dir();
49 p.push("inferd");
50 p.push("admin.sock");
51 p
52 }
53 #[cfg(windows)]
54 {
55 std::path::PathBuf::from(DEFAULT_ADMIN_PIPE_PATH)
56 }
57 #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
58 {
59 std::path::PathBuf::from("/tmp/inferd/admin.sock")
60 }
61}
62
63pub fn default_v2_addr() -> std::path::PathBuf {
70 #[cfg(target_os = "linux")]
71 {
72 linux_runtime_path("infer.v2.sock")
73 }
74 #[cfg(target_os = "macos")]
75 {
76 let mut p = std::env::temp_dir();
77 p.push("inferd");
78 p.push("infer.v2.sock");
79 p
80 }
81 #[cfg(windows)]
82 {
83 std::path::PathBuf::from(DEFAULT_PIPE_V2_PATH)
84 }
85 #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
86 {
87 std::path::PathBuf::from("/tmp/inferd/infer.v2.sock")
88 }
89}
90
91pub fn default_embed_addr() -> std::path::PathBuf {
98 #[cfg(target_os = "linux")]
99 {
100 linux_runtime_path("infer.embed.sock")
101 }
102 #[cfg(target_os = "macos")]
103 {
104 let mut p = std::env::temp_dir();
105 p.push("inferd");
106 p.push("infer.embed.sock");
107 p
108 }
109 #[cfg(windows)]
110 {
111 std::path::PathBuf::from(DEFAULT_PIPE_EMBED_PATH)
112 }
113 #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
114 {
115 std::path::PathBuf::from("/tmp/inferd/infer.embed.sock")
116 }
117}
118
119#[cfg(target_os = "linux")]
123pub fn linux_runtime_path(leaf: &str) -> std::path::PathBuf {
124 if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
125 let mut p = std::path::PathBuf::from(xdg);
126 if !p.as_os_str().is_empty() {
127 p.push("inferd");
128 p.push(leaf);
129 return p;
130 }
131 }
132 if let Some(home) = std::env::var_os("HOME") {
133 let mut p = std::path::PathBuf::from(home);
134 if !p.as_os_str().is_empty() {
135 p.push(".inferd");
136 p.push("run");
137 p.push(leaf);
138 return p;
139 }
140 }
141 let uid = nix::unistd::Uid::current().as_raw();
144 std::path::PathBuf::from(format!("/tmp/inferd-{uid}/{leaf}"))
145}
146
147pub trait Connection: AsyncRead + AsyncWrite + Unpin + Send {
150 fn transport(&self) -> &'static str;
153}
154
155impl Connection for TcpStream {
156 fn transport(&self) -> &'static str {
157 "tcp"
158 }
159}
160
161#[cfg(unix)]
162impl Connection for tokio::net::UnixStream {
163 fn transport(&self) -> &'static str {
164 "unix"
165 }
166}
167
168pub async fn bind_tcp(addr: &str) -> io::Result<TcpListener> {
176 let parsed: SocketAddr = addr
177 .parse()
178 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("bad tcp addr: {e}")))?;
179 TcpListener::bind(parsed).await
180}
181
182#[cfg(unix)]
188pub async fn bind_uds(path: &Path, group: Option<&str>) -> io::Result<tokio::net::UnixListener> {
189 use std::os::unix::fs::PermissionsExt;
190 if let Ok(meta) = std::fs::symlink_metadata(path) {
194 if meta.file_type().is_symlink() {
195 return Err(io::Error::new(
196 io::ErrorKind::InvalidInput,
197 format!("uds path is a symlink (refused): {}", path.display()),
198 ));
199 }
200 std::fs::remove_file(path)?;
201 }
202 let listener = tokio::net::UnixListener::bind(path)?;
203 let mut perms = std::fs::metadata(path)?.permissions();
204 perms.set_mode(0o660);
205 std::fs::set_permissions(path, perms)?;
206
207 if let Some(group_name) = group {
208 chown_to_group(path, group_name)?;
209 }
210 Ok(listener)
211}
212
213#[cfg(unix)]
218pub async fn bind_admin_uds(path: &Path) -> io::Result<tokio::net::UnixListener> {
219 use std::os::unix::fs::PermissionsExt;
220 if let Ok(meta) = std::fs::symlink_metadata(path) {
221 if meta.file_type().is_symlink() {
222 return Err(io::Error::new(
223 io::ErrorKind::InvalidInput,
224 format!("admin uds path is a symlink (refused): {}", path.display()),
225 ));
226 }
227 std::fs::remove_file(path)?;
228 }
229 if let Some(parent) = path.parent()
230 && !parent.as_os_str().is_empty()
231 {
232 std::fs::create_dir_all(parent)?;
233 }
234 let listener = tokio::net::UnixListener::bind(path)?;
235 let mut perms = std::fs::metadata(path)?.permissions();
236 perms.set_mode(0o600);
237 std::fs::set_permissions(path, perms)?;
238 Ok(listener)
239}
240
241#[cfg(not(unix))]
244pub async fn bind_admin_uds(_path: &Path) -> io::Result<()> {
245 Err(io::Error::new(
246 io::ErrorKind::Unsupported,
247 "Unix domain sockets are not supported on this platform; use bind_admin_pipe",
248 ))
249}
250
251#[cfg(windows)]
256#[allow(unsafe_code)] pub fn bind_admin_pipe(
258 path: &str,
259 first: bool,
260) -> io::Result<tokio::net::windows::named_pipe::NamedPipeServer> {
261 use crate::windows_security::PipeSecurityDescriptor;
262 use tokio::net::windows::named_pipe::ServerOptions;
263
264 let mut sd = PipeSecurityDescriptor::current_user_only()?;
265 let mut opts = ServerOptions::new();
266 opts.first_pipe_instance(first);
267 let server = unsafe { opts.create_with_security_attributes_raw(path, sd.as_attrs_ptr()) }?;
272 drop(sd);
273 Ok(server)
274}
275
276#[cfg(not(unix))]
279pub async fn bind_uds(_path: &Path, _group: Option<&str>) -> io::Result<()> {
280 Err(io::Error::new(
281 io::ErrorKind::Unsupported,
282 "Unix domain sockets are not supported on this platform; use bind_named_pipe or TCP",
283 ))
284}
285
286#[cfg(windows)]
288pub const DEFAULT_PIPE_PATH: &str = r"\\.\pipe\inferd-infer";
289
290#[cfg(windows)]
294pub const DEFAULT_PIPE_V2_PATH: &str = r"\\.\pipe\inferd-infer-v2";
295
296#[cfg(windows)]
300pub const DEFAULT_ADMIN_PIPE_PATH: &str = r"\\.\pipe\inferd-admin";
301
302#[cfg(windows)]
307pub const DEFAULT_PIPE_EMBED_PATH: &str = r"\\.\pipe\inferd-infer-embed";
308
309#[cfg(windows)]
329#[allow(unsafe_code)] pub fn bind_named_pipe(
331 path: &str,
332 first: bool,
333) -> io::Result<tokio::net::windows::named_pipe::NamedPipeServer> {
334 use crate::windows_security::PipeSecurityDescriptor;
335 use tokio::net::windows::named_pipe::ServerOptions;
336
337 let mut sd = PipeSecurityDescriptor::current_user_only()?;
338 let mut opts = ServerOptions::new();
339 opts.first_pipe_instance(first);
340 let server = unsafe { opts.create_with_security_attributes_raw(path, sd.as_attrs_ptr()) }?;
345 drop(sd);
346 Ok(server)
347}
348
349#[cfg(windows)]
350impl Connection for tokio::net::windows::named_pipe::NamedPipeServer {
351 fn transport(&self) -> &'static str {
352 "pipe"
353 }
354}
355
356#[cfg(unix)]
357fn chown_to_group(path: &Path, group_name: &str) -> io::Result<()> {
358 let group = nix::unistd::Group::from_name(group_name)
359 .map_err(|e| io::Error::other(format!("getgrnam: {e}")))?
360 .ok_or_else(|| {
361 io::Error::new(
362 io::ErrorKind::NotFound,
363 format!("group not found: {group_name}"),
364 )
365 })?;
366 nix::unistd::chown(path, None, Some(group.gid))
367 .map_err(|e| io::Error::other(format!("chown: {e}")))
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use tokio::io::{AsyncReadExt, AsyncWriteExt};
374
375 #[tokio::test]
376 async fn bind_tcp_accepts_a_connection() {
377 let listener = bind_tcp("127.0.0.1:0").await.unwrap();
378 let addr = listener.local_addr().unwrap();
379
380 let server = tokio::spawn(async move {
381 let (mut sock, _) = listener.accept().await.unwrap();
382 let mut buf = [0u8; 4];
383 sock.read_exact(&mut buf).await.unwrap();
384 assert_eq!(&buf, b"ping");
385 sock.write_all(b"pong").await.unwrap();
386 });
387
388 let mut client = TcpStream::connect(addr).await.unwrap();
389 client.write_all(b"ping").await.unwrap();
390 let mut buf = [0u8; 4];
391 client.read_exact(&mut buf).await.unwrap();
392 assert_eq!(&buf, b"pong");
393 server.await.unwrap();
394 }
395
396 #[tokio::test]
397 async fn bind_tcp_rejects_garbage_addr() {
398 let err = bind_tcp("not-an-addr").await.unwrap_err();
399 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
400 }
401
402 #[cfg(unix)]
403 #[tokio::test]
404 async fn bind_uds_creates_socket_and_accepts() {
405 use tempfile::tempdir;
406 let dir = tempdir().unwrap();
407 let path = dir.path().join("test.sock");
408 let listener = bind_uds(&path, None).await.unwrap();
409
410 let server = tokio::spawn(async move {
411 let (mut sock, _) = listener.accept().await.unwrap();
412 let mut buf = [0u8; 4];
413 sock.read_exact(&mut buf).await.unwrap();
414 assert_eq!(&buf, b"ping");
415 });
416
417 let mut client = tokio::net::UnixStream::connect(&path).await.unwrap();
418 client.write_all(b"ping").await.unwrap();
419 server.await.unwrap();
420 }
421
422 #[cfg(windows)]
423 #[tokio::test]
424 async fn bind_named_pipe_accepts_a_connection() {
425 use tokio::io::{AsyncReadExt, AsyncWriteExt};
426 use tokio::net::windows::named_pipe::ClientOptions;
427
428 use std::sync::atomic::{AtomicU64, Ordering};
432 static COUNTER: AtomicU64 = AtomicU64::new(0);
433 let pid = std::process::id();
434 let n = COUNTER.fetch_add(1, Ordering::Relaxed);
435 let ts = std::time::SystemTime::now()
436 .duration_since(std::time::UNIX_EPOCH)
437 .unwrap()
438 .as_nanos();
439 let path = format!(r"\\.\pipe\inferd-endpoint-test-{pid}-{ts}-{n}");
440
441 let server = bind_named_pipe(&path, true).expect("bind named pipe");
442
443 let path_for_server = path.clone();
444 let server_task = tokio::spawn(async move {
445 server.connect().await.expect("server connect");
446 let mut s = server;
447 let mut buf = [0u8; 4];
448 s.read_exact(&mut buf).await.unwrap();
449 assert_eq!(&buf, b"ping");
450 s.write_all(b"pong").await.unwrap();
451 drop(path_for_server);
452 });
453
454 let mut client = ClientOptions::new()
455 .open(&path)
456 .expect("client open named pipe");
457 client.write_all(b"ping").await.unwrap();
458 let mut buf = [0u8; 4];
459 client.read_exact(&mut buf).await.unwrap();
460 assert_eq!(&buf, b"pong");
461 server_task.await.unwrap();
462 }
463
464 #[cfg(unix)]
465 #[tokio::test]
466 async fn bind_uds_refuses_symlink_path() {
467 use tempfile::tempdir;
468 let dir = tempdir().unwrap();
469 let target = dir.path().join("real.sock");
470 std::fs::write(&target, b"").unwrap();
471 let symlink = dir.path().join("link.sock");
472 std::os::unix::fs::symlink(&target, &symlink).unwrap();
473
474 let err = bind_uds(&symlink, None).await.unwrap_err();
475 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
476 }
477}