use std::{
io::{self, Read, Write},
path::{Path, PathBuf},
};
use tracing::{info, instrument, trace};
#[cfg(windows)]
use std::ffi::OsString;
#[cfg(windows)]
use named_pipe::PipeClient;
#[cfg(unix)]
use std::os::unix::net::UnixStream;
pub trait Transport {
fn send_line(&mut self, line: &str) -> io::Result<()>;
fn read_line(&mut self) -> io::Result<String>;
}
trait ReadWrite: Read + Write + Send {}
impl<T> ReadWrite for T where T: Read + Write + Send + 'static {}
pub struct NativeTransport {
stream: Box<dyn ReadWrite>,
}
impl NativeTransport {
pub fn from_stream<T>(stream: T) -> Self
where
T: Read + Write + Send + 'static,
{
Self {
stream: Box::new(stream),
}
}
#[cfg(unix)]
#[instrument(level = "debug", skip_all, err)]
pub fn connect_path(path: impl AsRef<Path>) -> io::Result<Self> {
let path_buf = path.as_ref().to_path_buf();
let stream = UnixStream::connect(&path_buf)?;
stream.set_nonblocking(false)?;
info!(path = %path_buf.display(), "connected to KeePassXC socket");
Ok(Self::from_stream(stream))
}
#[cfg(unix)]
#[instrument(level = "debug", skip_all, err)]
pub fn connect_default() -> io::Result<Self> {
let mut last_err = None;
for candidate in socket_candidates() {
trace!(path = %candidate.display(), "trying KeePassXC socket candidate");
match UnixStream::connect(&candidate) {
Ok(stream) => {
stream.set_nonblocking(false)?;
info!(path = %candidate.display(), "connected to KeePassXC socket");
return Ok(Self::from_stream(stream));
}
Err(err) => last_err = Some(err),
}
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"Unable to locate KeePassXC browser socket",
)
}))
}
#[cfg(windows)]
#[instrument(level = "debug", skip_all, err)]
pub fn connect_pipe(name: impl Into<OsString>) -> io::Result<Self> {
let os_name: OsString = name.into();
let client = PipeClient::connect(&os_name)?;
info!(pipe = %os_name.to_string_lossy(), "connected to KeePassXC pipe");
Ok(Self::from_stream(client))
}
#[cfg(windows)]
#[instrument(level = "debug", skip_all, err)]
pub fn connect_default() -> io::Result<Self> {
let name = std::env::var("KEEPASSXC_PIPE")
.unwrap_or_else(|_| String::from(r"\\.\pipe\keepassxc-browser"));
Self::connect_pipe(name)
}
#[cfg(not(any(unix, windows)))]
pub fn connect_default() -> io::Result<Self> {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"KeePassXC native transport not supported on this platform",
))
}
}
impl Transport for NativeTransport {
fn send_line(&mut self, line: &str) -> io::Result<()> {
self.stream.write_all(line.as_bytes())?;
self.stream.flush()
}
fn read_line(&mut self) -> io::Result<String> {
const BUFFER_SIZE: usize = 1024 * 1024; let mut buf = vec![0u8; BUFFER_SIZE];
let bytes_read = self.stream.read(&mut buf)?;
if bytes_read == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"KeePassXC connection closed",
));
}
buf.truncate(bytes_read);
String::from_utf8(buf).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
}
}
#[cfg(unix)]
fn socket_candidates() -> Vec<PathBuf> {
let mut result = Vec::new();
if let Ok(path) = std::env::var("KEEPASSXC_SOCKET") {
result.push(normalize_path(path.into()));
}
if let Some(home) = home_dir() {
result.push(home.join(".cache/keepassxc/keepassxc-browser.socket"));
result.push(home.join(".config/keepassxc/keepassxc-browser.socket"));
}
if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
let runtime_path = PathBuf::from(&runtime_dir);
result.push(runtime_path.join("keepassxc/keepassxc-browser.socket"));
result.push(
runtime_path.join("app/org.keepassxc.KeePassXC/org.keepassxc.KeePassXC.BrowserServer"),
);
}
#[cfg(target_os = "macos")]
if let Some(home) = home_dir() {
result.push(home.join("Library/Application Support/KeepassXC/keepassxc-browser.socket"));
result.push(home.join("Library/Caches/keepassxc/keepassxc-browser.socket"));
}
result
}
#[cfg(unix)]
fn home_dir() -> Option<PathBuf> {
std::env::var_os("HOME").map(PathBuf::from)
}
#[cfg(unix)]
fn normalize_path(path: PathBuf) -> PathBuf {
if let Some(str_path) = path.to_str() {
if str_path.starts_with('~') {
if let Some(home) = home_dir() {
let without_tilde = &str_path[1..];
return if without_tilde.starts_with('/') {
home.join(&without_tilde[1..])
} else {
home.join(without_tilde)
};
}
}
}
path
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
#[derive(Clone, Default)]
pub struct MockTransport {
pub sent: Arc<Mutex<Vec<String>>>,
pub incoming: Arc<Mutex<VecDeque<String>>>,
}
impl MockTransport {
pub fn with_responses(responses: Vec<String>) -> Self {
Self {
sent: Arc::new(Mutex::new(Vec::new())),
incoming: Arc::new(Mutex::new(responses.into())),
}
}
pub fn push_response(&self, response: String) {
self.incoming.lock().unwrap().push_back(response);
}
}
impl Transport for MockTransport {
fn send_line(&mut self, line: &str) -> io::Result<()> {
self.sent.lock().unwrap().push(line.to_string());
Ok(())
}
fn read_line(&mut self) -> io::Result<String> {
self.incoming
.lock()
.unwrap()
.pop_front()
.ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "no response queued"))
}
}
}