kpx/
transport.rs

1use std::{
2    io::{self, Read, Write},
3    path::{Path, PathBuf},
4};
5
6use tracing::{info, instrument, trace};
7
8#[cfg(windows)]
9use std::ffi::OsString;
10
11#[cfg(windows)]
12use named_pipe::PipeClient;
13
14#[cfg(unix)]
15use std::os::unix::net::UnixStream;
16
17/// Basic transport abstraction over the KeePassXC IPC channel.
18pub trait Transport {
19    /// Send a raw JSON line (without delimiters) to KeePassXC.
20    fn send_line(&mut self, line: &str) -> io::Result<()>;
21    /// Read a single response frame from KeePassXC.
22    fn read_line(&mut self) -> io::Result<String>;
23}
24
25trait ReadWrite: Read + Write + Send {}
26
27impl<T> ReadWrite for T where T: Read + Write + Send + 'static {}
28
29/// Native KeePassXC transport (Unix domain socket on Unix, named pipe on Windows).
30pub struct NativeTransport {
31    stream: Box<dyn ReadWrite>,
32}
33
34impl NativeTransport {
35    /// Wrap an arbitrary read/write stream as a native KeePassXC transport.
36    pub fn from_stream<T>(stream: T) -> Self
37    where
38        T: Read + Write + Send + 'static,
39    {
40        Self {
41            stream: Box::new(stream),
42        }
43    }
44
45    #[cfg(unix)]
46    #[instrument(level = "debug", skip_all, err)]
47    /// Connect to the KeePassXC Unix domain socket at a specific path.
48    pub fn connect_path(path: impl AsRef<Path>) -> io::Result<Self> {
49        let path_buf = path.as_ref().to_path_buf();
50        let stream = UnixStream::connect(&path_buf)?;
51        stream.set_nonblocking(false)?;
52        info!(path = %path_buf.display(), "connected to KeePassXC socket");
53        Ok(Self::from_stream(stream))
54    }
55
56    #[cfg(unix)]
57    #[instrument(level = "debug", skip_all, err)]
58    /// Attempt to connect to the first reachable KeePassXC Unix domain socket.
59    pub fn connect_default() -> io::Result<Self> {
60        let mut last_err = None;
61        for candidate in socket_candidates() {
62            trace!(path = %candidate.display(), "trying KeePassXC socket candidate");
63            match UnixStream::connect(&candidate) {
64                Ok(stream) => {
65                    stream.set_nonblocking(false)?;
66                    info!(path = %candidate.display(), "connected to KeePassXC socket");
67                    return Ok(Self::from_stream(stream));
68                }
69                Err(err) => last_err = Some(err),
70            }
71        }
72        Err(last_err.unwrap_or_else(|| {
73            io::Error::new(
74                io::ErrorKind::NotFound,
75                "Unable to locate KeePassXC browser socket",
76            )
77        }))
78    }
79
80    #[cfg(windows)]
81    #[instrument(level = "debug", skip_all, err)]
82    /// Connect to the KeePassXC named pipe using an explicit pipe name.
83    pub fn connect_pipe(name: impl Into<OsString>) -> io::Result<Self> {
84        let os_name: OsString = name.into();
85        let client = PipeClient::connect(&os_name)?;
86        info!(pipe = %os_name.to_string_lossy(), "connected to KeePassXC pipe");
87        Ok(Self::from_stream(client))
88    }
89
90    #[cfg(windows)]
91    #[instrument(level = "debug", skip_all, err)]
92    /// Connect to the KeePassXC named pipe using the default well-known name.
93    pub fn connect_default() -> io::Result<Self> {
94        let name = std::env::var("KEEPASSXC_PIPE")
95            .unwrap_or_else(|_| String::from(r"\\.\pipe\keepassxc-browser"));
96        Self::connect_pipe(name)
97    }
98
99    #[cfg(not(any(unix, windows)))]
100    /// Return an error on unsupported platforms.
101    pub fn connect_default() -> io::Result<Self> {
102        Err(io::Error::new(
103            io::ErrorKind::Unsupported,
104            "KeePassXC native transport not supported on this platform",
105        ))
106    }
107}
108
109impl Transport for NativeTransport {
110    fn send_line(&mut self, line: &str) -> io::Result<()> {
111        // KeePassXC expects raw JSON without newlines
112        self.stream.write_all(line.as_bytes())?;
113        self.stream.flush()
114    }
115
116    fn read_line(&mut self) -> io::Result<String> {
117        // KeePassXC sends complete JSON messages in chunks
118        // Read until we have a complete JSON object
119        const BUFFER_SIZE: usize = 1024 * 1024; // 1 MB as per native messaging spec
120        let mut buf = vec![0u8; BUFFER_SIZE];
121
122        let bytes_read = self.stream.read(&mut buf)?;
123        if bytes_read == 0 {
124            return Err(io::Error::new(
125                io::ErrorKind::UnexpectedEof,
126                "KeePassXC connection closed",
127            ));
128        }
129
130        // Trim the buffer to the actual bytes read
131        buf.truncate(bytes_read);
132
133        String::from_utf8(buf).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
134    }
135}
136
137#[cfg(unix)]
138fn socket_candidates() -> Vec<PathBuf> {
139    let mut result = Vec::new();
140
141    if let Ok(path) = std::env::var("KEEPASSXC_SOCKET") {
142        result.push(normalize_path(path.into()));
143    }
144
145    if let Some(home) = home_dir() {
146        result.push(home.join(".cache/keepassxc/keepassxc-browser.socket"));
147        result.push(home.join(".config/keepassxc/keepassxc-browser.socket"));
148    }
149
150    if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
151        let runtime_path = PathBuf::from(&runtime_dir);
152        result.push(runtime_path.join("keepassxc/keepassxc-browser.socket"));
153        result.push(
154            runtime_path.join("app/org.keepassxc.KeePassXC/org.keepassxc.KeePassXC.BrowserServer"),
155        );
156    }
157
158    #[cfg(target_os = "macos")]
159    if let Some(home) = home_dir() {
160        result.push(home.join("Library/Application Support/KeepassXC/keepassxc-browser.socket"));
161        result.push(home.join("Library/Caches/keepassxc/keepassxc-browser.socket"));
162    }
163
164    result
165}
166
167#[cfg(unix)]
168fn home_dir() -> Option<PathBuf> {
169    std::env::var_os("HOME").map(PathBuf::from)
170}
171
172#[cfg(unix)]
173fn normalize_path(path: PathBuf) -> PathBuf {
174    if let Some(str_path) = path.to_str() {
175        if str_path.starts_with('~') {
176            if let Some(home) = home_dir() {
177                let without_tilde = &str_path[1..];
178                return if without_tilde.starts_with('/') {
179                    home.join(&without_tilde[1..])
180                } else {
181                    home.join(without_tilde)
182                };
183            }
184        }
185    }
186    path
187}
188
189#[cfg(test)]
190pub(crate) mod tests {
191    use super::*;
192    use std::collections::VecDeque;
193    use std::sync::{Arc, Mutex};
194
195    #[derive(Clone, Default)]
196    pub struct MockTransport {
197        pub sent: Arc<Mutex<Vec<String>>>,
198        pub incoming: Arc<Mutex<VecDeque<String>>>,
199    }
200
201    impl MockTransport {
202        pub fn with_responses(responses: Vec<String>) -> Self {
203            Self {
204                sent: Arc::new(Mutex::new(Vec::new())),
205                incoming: Arc::new(Mutex::new(responses.into())),
206            }
207        }
208
209        pub fn push_response(&self, response: String) {
210            self.incoming.lock().unwrap().push_back(response);
211        }
212    }
213
214    impl Transport for MockTransport {
215        fn send_line(&mut self, line: &str) -> io::Result<()> {
216            self.sent.lock().unwrap().push(line.to_string());
217            Ok(())
218        }
219
220        fn read_line(&mut self) -> io::Result<String> {
221            self.incoming
222                .lock()
223                .unwrap()
224                .pop_front()
225                .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "no response queued"))
226        }
227    }
228}