Skip to main content

agent_can/
ipc.rs

1use std::path::Path;
2
3use interprocess::local_socket::tokio::{Listener as TokioListener, Stream as TokioStream};
4use interprocess::local_socket::traits::tokio::{Listener as _, Stream as _};
5use interprocess::local_socket::{GenericFilePath, ListenerOptions, ToFsName};
6use tokio::io::{AsyncRead, AsyncWrite};
7
8#[cfg(not(any(unix, windows)))]
9compile_error!("ipc is only implemented for Unix and Windows targets");
10
11pub trait LocalStream: AsyncRead + AsyncWrite + Send + Unpin {}
12
13impl<T> LocalStream for T where T: AsyncRead + AsyncWrite + Send + Unpin {}
14
15pub type BoxedLocalStream = Box<dyn LocalStream>;
16
17pub struct LocalListener {
18    inner: TokioListener,
19}
20
21impl LocalListener {
22    pub fn bind(endpoint: &Path) -> std::io::Result<Self> {
23        let name = endpoint_name(endpoint)?;
24        let inner = ListenerOptions::new()
25            .name(name)
26            .try_overwrite(false)
27            .create_tokio()?;
28        Ok(Self { inner })
29    }
30
31    pub async fn accept(&mut self) -> std::io::Result<BoxedLocalStream> {
32        let stream = self.inner.accept().await?;
33        Ok(Box::new(stream))
34    }
35}
36
37pub async fn bind_listener(endpoint: &Path) -> std::io::Result<LocalListener> {
38    match LocalListener::bind(endpoint) {
39        Ok(listener) => Ok(listener),
40        Err(err) if is_bind_conflict(&err) && connect(endpoint).await.is_err() => {
41            cleanup_endpoint(endpoint);
42            LocalListener::bind(endpoint)
43        }
44        Err(err) => Err(err),
45    }
46}
47
48pub async fn connect(endpoint: &Path) -> std::io::Result<BoxedLocalStream> {
49    let name = endpoint_name(endpoint)?;
50    Ok(Box::new(TokioStream::connect(name).await?))
51}
52
53pub fn cleanup_endpoint(endpoint: &Path) {
54    if endpoint.exists() {
55        let _ = std::fs::remove_file(endpoint);
56    }
57}
58
59pub fn create_endpoint_marker(endpoint: &Path) -> std::io::Result<()> {
60    #[cfg(unix)]
61    {
62        let _ = endpoint;
63        Ok(())
64    }
65
66    #[cfg(windows)]
67    {
68        std::fs::write(endpoint, [])
69    }
70}
71
72fn is_bind_conflict(err: &std::io::Error) -> bool {
73    matches!(
74        err.kind(),
75        std::io::ErrorKind::AddrInUse | std::io::ErrorKind::AlreadyExists
76    )
77}
78
79fn endpoint_name(endpoint: &Path) -> std::io::Result<interprocess::local_socket::Name<'_>> {
80    #[cfg(unix)]
81    {
82        endpoint.to_fs_name::<GenericFilePath>()
83    }
84
85    #[cfg(windows)]
86    {
87        pipe_name(endpoint).to_fs_name::<GenericFilePath>()
88    }
89}
90
91#[cfg(windows)]
92fn pipe_name(endpoint: &Path) -> String {
93    let raw = endpoint.to_string_lossy();
94    let suffix = stable_hash64(&raw);
95    let stem = endpoint
96        .file_stem()
97        .and_then(|value| value.to_str())
98        .unwrap_or("endpoint");
99    let sanitized = stem
100        .chars()
101        .map(|ch| {
102            if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
103                ch
104            } else {
105                '_'
106            }
107        })
108        .take(WINDOWS_PIPE_STEM_MAX_LEN)
109        .collect::<String>();
110    format!("{WINDOWS_PIPE_NAME_PREFIX}{sanitized}-{suffix:016x}")
111}
112
113#[cfg(windows)]
114const WINDOWS_PIPE_NAME_PREFIX: &str = r"\\.\pipe\agent-can-";
115#[cfg(windows)]
116const WINDOWS_PIPE_NAME_MAX_LEN: usize = 256;
117#[cfg(windows)]
118const WINDOWS_PIPE_HASH_SUFFIX_LEN: usize = 1 + 16;
119#[cfg(windows)]
120const WINDOWS_PIPE_STEM_MAX_LEN: usize =
121    WINDOWS_PIPE_NAME_MAX_LEN - WINDOWS_PIPE_NAME_PREFIX.len() - WINDOWS_PIPE_HASH_SUFFIX_LEN;
122
123#[cfg(windows)]
124fn stable_hash64(raw: &str) -> u64 {
125    // Use a fixed FNV-1a hash so pipe names remain stable across Rust upgrades.
126    const FNV_OFFSET_BASIS: u64 = 0xcbf29ce484222325;
127    const FNV_PRIME: u64 = 0x100000001b3;
128
129    let mut hash = FNV_OFFSET_BASIS;
130    for byte in raw.bytes() {
131        hash ^= u64::from(byte);
132        hash = hash.wrapping_mul(FNV_PRIME);
133    }
134    hash
135}
136
137#[cfg(test)]
138mod tests {
139    #[cfg(windows)]
140    use super::{WINDOWS_PIPE_NAME_MAX_LEN, WINDOWS_PIPE_STEM_MAX_LEN, pipe_name, stable_hash64};
141    #[cfg(windows)]
142    use std::path::Path;
143
144    #[cfg(windows)]
145    #[test]
146    fn stable_hash64_matches_known_fnv1a_values() {
147        assert_eq!(stable_hash64(""), 0xcbf29ce484222325);
148        assert_eq!(stable_hash64("agent-can"), 0xdfa658c032d0c9a1);
149        assert_eq!(
150            stable_hash64(r"C:/Users/alice/.agent-can/demo.sock"),
151            0x9f1bb071fc9f9589
152        );
153    }
154
155    #[cfg(windows)]
156    #[test]
157    fn pipe_name_uses_stable_hash_suffix() {
158        let endpoint = Path::new(r"C:/Users/alice/.agent-can/demo.sock");
159        assert_eq!(
160            pipe_name(endpoint),
161            r"\\.\pipe\agent-can-demo-9f1bb071fc9f9589"
162        );
163    }
164
165    #[cfg(windows)]
166    #[test]
167    fn pipe_name_stem_length_is_bounded() {
168        let long_stem = "a".repeat(WINDOWS_PIPE_STEM_MAX_LEN * 2);
169        let endpoint = format!(r"C:/Users/alice/.agent-can/{long_stem}.sock");
170        let name = pipe_name(Path::new(&endpoint));
171        assert!(
172            name.len() <= WINDOWS_PIPE_NAME_MAX_LEN,
173            "pipe name too long: {}",
174            name.len()
175        );
176    }
177}