use std::io;
use std::path::{Path, PathBuf};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use crate::commands::daemon::error::{DaemonError, DaemonResult};
use crate::commands::daemon::pid::compute_hash;
use crate::commands::daemon::types::{DaemonCommand, DaemonResponse};
pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
pub const CONNECTION_TIMEOUT_SECS: u64 = 5;
pub const READ_TIMEOUT_SECS: u64 = 30;
#[cfg(unix)]
pub fn compute_socket_path(project: &Path) -> PathBuf {
let hash = compute_hash(project);
let tmp_dir = std::env::temp_dir();
tmp_dir.join(format!("tldr-{}.sock", hash))
}
#[cfg(windows)]
pub fn compute_tcp_port(project: &Path) -> u16 {
let hash = compute_hash(project);
let hash_int = u64::from_str_radix(&hash, 16).unwrap_or(0);
49152 + (hash_int % 10000) as u16
}
#[cfg(not(unix))]
pub fn compute_socket_path(project: &Path) -> PathBuf {
let hash = compute_hash(project);
let tmp_dir = std::env::temp_dir();
tmp_dir.join(format!("tldr-{}.sock", hash))
}
#[cfg(not(windows))]
pub fn compute_tcp_port(project: &Path) -> u16 {
let hash = compute_hash(project);
let hash_int = u64::from_str_radix(&hash, 16).unwrap_or(0);
49152 + (hash_int % 10000) as u16
}
pub fn validate_socket_path(socket_path: &Path) -> DaemonResult<()> {
let tmp_dir = std::env::temp_dir();
let canonical_tmp = tmp_dir.canonicalize().unwrap_or(tmp_dir);
let socket_parent = socket_path.parent().unwrap_or(socket_path);
let canonical_parent = socket_parent
.canonicalize()
.unwrap_or_else(|_| socket_parent.to_path_buf());
if !canonical_parent.starts_with(&canonical_tmp) {
return Err(DaemonError::PermissionDenied {
path: socket_path.to_path_buf(),
});
}
if let Some(filename) = socket_path.file_name() {
let filename_str = filename.to_string_lossy();
if filename_str.contains("..") || filename_str.contains('/') || filename_str.contains('\\')
{
return Err(DaemonError::PermissionDenied {
path: socket_path.to_path_buf(),
});
}
}
Ok(())
}
#[cfg(unix)]
pub fn check_not_symlink(path: &Path) -> DaemonResult<()> {
if let Ok(metadata) = std::fs::symlink_metadata(path) {
if metadata.file_type().is_symlink() {
return Err(DaemonError::PermissionDenied {
path: path.to_path_buf(),
});
}
}
Ok(())
}
#[cfg(not(unix))]
pub fn check_not_symlink(path: &Path) -> DaemonResult<()> {
if let Ok(metadata) = std::fs::symlink_metadata(path) {
if metadata.file_type().is_symlink() {
return Err(DaemonError::PermissionDenied {
path: path.to_path_buf(),
});
}
}
Ok(())
}
pub struct IpcListener {
#[cfg(unix)]
inner: tokio::net::UnixListener,
#[cfg(windows)]
inner: tokio::net::TcpListener,
#[allow(dead_code)]
socket_path: PathBuf,
}
impl IpcListener {
pub async fn bind(project: &Path) -> DaemonResult<Self> {
#[cfg(unix)]
{
Self::bind_unix(project).await
}
#[cfg(windows)]
{
Self::bind_tcp(project).await
}
}
#[cfg(unix)]
async fn bind_unix(project: &Path) -> DaemonResult<Self> {
use std::os::unix::fs::PermissionsExt;
let socket_path = compute_socket_path(project);
validate_socket_path(&socket_path)?;
check_not_symlink(&socket_path)?;
if socket_path.exists() {
check_not_symlink(&socket_path)?;
std::fs::remove_file(&socket_path).map_err(DaemonError::SocketBindFailed)?;
}
let listener = tokio::net::UnixListener::bind(&socket_path)
.map_err(DaemonError::SocketBindFailed)?;
let permissions = std::fs::Permissions::from_mode(0o600);
std::fs::set_permissions(&socket_path, permissions)
.map_err(DaemonError::SocketBindFailed)?;
Ok(Self {
inner: listener,
socket_path,
})
}
#[cfg(windows)]
async fn bind_tcp(project: &Path) -> DaemonResult<Self> {
let socket_path = compute_socket_path(project); let port = compute_tcp_port(project);
let addr = format!("127.0.0.1:{}", port);
let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
if e.kind() == io::ErrorKind::AddrInUse {
DaemonError::AddressInUse { addr }
} else {
DaemonError::SocketBindFailed(e)
}
})?;
Ok(Self {
inner: listener,
socket_path,
})
}
pub async fn accept(&self) -> DaemonResult<IpcStream> {
#[cfg(unix)]
{
let (stream, _addr) = self.inner.accept().await.map_err(DaemonError::Io)?;
Ok(IpcStream {
inner: IpcStreamInner::Unix(stream),
})
}
#[cfg(windows)]
{
let (stream, _addr) = self.inner.accept().await.map_err(DaemonError::Io)?;
Ok(IpcStream {
inner: IpcStreamInner::Tcp(stream),
})
}
}
}
enum IpcStreamInner {
#[cfg(unix)]
Unix(tokio::net::UnixStream),
#[cfg(windows)]
Tcp(tokio::net::TcpStream),
#[cfg(all(not(unix), not(windows)))]
Dummy,
}
pub struct IpcStream {
inner: IpcStreamInner,
}
impl IpcStream {
pub async fn connect(project: &Path) -> DaemonResult<Self> {
#[cfg(unix)]
{
Self::connect_unix(project).await
}
#[cfg(windows)]
{
Self::connect_tcp(project).await
}
}
#[cfg(unix)]
async fn connect_unix(project: &Path) -> DaemonResult<Self> {
let socket_path = compute_socket_path(project);
validate_socket_path(&socket_path)?;
if !socket_path.exists() {
return Err(DaemonError::NotRunning);
}
check_not_symlink(&socket_path)?;
let connect_future = tokio::net::UnixStream::connect(&socket_path);
let timeout = tokio::time::Duration::from_secs(CONNECTION_TIMEOUT_SECS);
match tokio::time::timeout(timeout, connect_future).await {
Ok(Ok(stream)) => Ok(Self {
inner: IpcStreamInner::Unix(stream),
}),
Ok(Err(e)) if e.kind() == io::ErrorKind::ConnectionRefused => {
Err(DaemonError::ConnectionRefused)
}
Ok(Err(e)) if e.kind() == io::ErrorKind::NotFound => Err(DaemonError::NotRunning),
Ok(Err(e)) => Err(DaemonError::Io(e)),
Err(_) => Err(DaemonError::ConnectionTimeout {
timeout_secs: CONNECTION_TIMEOUT_SECS,
}),
}
}
#[cfg(windows)]
async fn connect_tcp(project: &Path) -> DaemonResult<Self> {
let port = compute_tcp_port(project);
let addr = format!("127.0.0.1:{}", port);
let connect_future = tokio::net::TcpStream::connect(&addr);
let timeout = tokio::time::Duration::from_secs(CONNECTION_TIMEOUT_SECS);
match tokio::time::timeout(timeout, connect_future).await {
Ok(Ok(stream)) => Ok(Self {
inner: IpcStreamInner::Tcp(stream),
}),
Ok(Err(e)) if e.kind() == io::ErrorKind::ConnectionRefused => {
Err(DaemonError::ConnectionRefused)
}
Ok(Err(e)) => Err(DaemonError::Io(e)),
Err(_) => Err(DaemonError::ConnectionTimeout {
timeout_secs: CONNECTION_TIMEOUT_SECS,
}),
}
}
pub async fn send_command(&mut self, cmd: &DaemonCommand) -> DaemonResult<()> {
let json = serde_json::to_string(cmd)?;
self.send_raw(&json).await
}
pub async fn send_raw(&mut self, json: &str) -> DaemonResult<()> {
if json.len() > MAX_MESSAGE_SIZE {
return Err(DaemonError::InvalidMessage(format!(
"message too large: {} bytes (max {})",
json.len(),
MAX_MESSAGE_SIZE
)));
}
let mut message = json.to_string();
message.push('\n');
match &mut self.inner {
#[cfg(unix)]
IpcStreamInner::Unix(stream) => {
stream.write_all(message.as_bytes()).await?;
stream.flush().await?;
}
#[cfg(windows)]
IpcStreamInner::Tcp(stream) => {
stream.write_all(message.as_bytes()).await?;
stream.flush().await?;
}
#[cfg(all(not(unix), not(windows)))]
IpcStreamInner::Dummy => {}
}
Ok(())
}
pub async fn recv_response(&mut self) -> DaemonResult<DaemonResponse> {
let json = self.recv_raw().await?;
let response: DaemonResponse = serde_json::from_str(&json)?;
Ok(response)
}
pub async fn recv_raw(&mut self) -> DaemonResult<String> {
let timeout = tokio::time::Duration::from_secs(READ_TIMEOUT_SECS);
match &mut self.inner {
#[cfg(unix)]
IpcStreamInner::Unix(stream) => {
let mut reader = BufReader::new(stream);
let mut line = String::new();
let read_future = reader.read_line(&mut line);
match tokio::time::timeout(timeout, read_future).await {
Ok(Ok(0)) => Err(DaemonError::ConnectionRefused), Ok(Ok(n)) if n > MAX_MESSAGE_SIZE => Err(DaemonError::InvalidMessage(format!(
"response too large: {} bytes (max {})",
n, MAX_MESSAGE_SIZE
))),
Ok(Ok(_)) => Ok(line.trim_end().to_string()),
Ok(Err(e)) => Err(DaemonError::Io(e)),
Err(_) => Err(DaemonError::ConnectionTimeout {
timeout_secs: READ_TIMEOUT_SECS,
}),
}
}
#[cfg(windows)]
IpcStreamInner::Tcp(stream) => {
let mut reader = BufReader::new(stream);
let mut line = String::new();
let read_future = reader.read_line(&mut line);
match tokio::time::timeout(timeout, read_future).await {
Ok(Ok(0)) => Err(DaemonError::ConnectionRefused), Ok(Ok(n)) if n > MAX_MESSAGE_SIZE => Err(DaemonError::InvalidMessage(format!(
"response too large: {} bytes (max {})",
n, MAX_MESSAGE_SIZE
))),
Ok(Ok(_)) => Ok(line.trim_end().to_string()),
Ok(Err(e)) => Err(DaemonError::Io(e)),
Err(_) => Err(DaemonError::ConnectionTimeout {
timeout_secs: READ_TIMEOUT_SECS,
}),
}
}
#[cfg(all(not(unix), not(windows)))]
IpcStreamInner::Dummy => Err(DaemonError::NotRunning),
}
}
}
pub async fn read_command(stream: &mut IpcStream) -> DaemonResult<DaemonCommand> {
let json = stream.recv_raw().await?;
if json.len() > MAX_MESSAGE_SIZE {
return Err(DaemonError::InvalidMessage(format!(
"command too large: {} bytes (max {})",
json.len(),
MAX_MESSAGE_SIZE
)));
}
let cmd: DaemonCommand = serde_json::from_str(&json)?;
Ok(cmd)
}
pub async fn send_response(stream: &mut IpcStream, response: &DaemonResponse) -> DaemonResult<()> {
let json = serde_json::to_string(response)?;
stream.send_raw(&json).await
}
pub fn cleanup_socket(project: &Path) -> DaemonResult<()> {
let socket_path = compute_socket_path(project);
if socket_path.exists() {
check_not_symlink(&socket_path)?;
std::fs::remove_file(&socket_path)?;
}
Ok(())
}
pub async fn check_socket_alive(project: &Path) -> bool {
(IpcStream::connect(project).await).is_ok()
}
pub async fn send_command(project: &Path, cmd: &DaemonCommand) -> DaemonResult<DaemonResponse> {
let mut stream = IpcStream::connect(project).await?;
stream.send_command(cmd).await?;
stream.recv_response().await
}
pub async fn send_raw_command(project: &Path, json: &str) -> DaemonResult<String> {
let mut stream = IpcStream::connect(project).await?;
stream.send_raw(json).await?;
stream.recv_raw().await
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use tempfile::TempDir;
#[test]
fn test_compute_socket_path_format() {
let project = PathBuf::from("/test/project");
let socket_path = compute_socket_path(&project);
let filename = socket_path.file_name().unwrap().to_str().unwrap();
assert!(filename.starts_with("tldr-"));
assert!(filename.ends_with(".sock"));
}
#[test]
fn test_compute_socket_path_deterministic() {
let project = PathBuf::from("/test/project");
let path1 = compute_socket_path(&project);
let path2 = compute_socket_path(&project);
assert_eq!(path1, path2);
}
#[test]
fn test_compute_socket_path_different_projects() {
let project1 = PathBuf::from("/test/project1");
let project2 = PathBuf::from("/test/project2");
let path1 = compute_socket_path(&project1);
let path2 = compute_socket_path(&project2);
assert_ne!(path1, path2);
}
#[test]
fn test_compute_tcp_port_range() {
let project = PathBuf::from("/test/project");
let port = compute_tcp_port(&project);
assert!(port >= 49152);
assert!(port < 59152);
}
#[test]
fn test_compute_tcp_port_deterministic() {
let project = PathBuf::from("/test/project");
let port1 = compute_tcp_port(&project);
let port2 = compute_tcp_port(&project);
assert_eq!(port1, port2);
}
#[test]
fn test_validate_socket_path_valid() {
let tmp_dir = std::env::temp_dir();
let socket_path = tmp_dir.join("tldr-test.sock");
assert!(validate_socket_path(&socket_path).is_ok());
}
#[test]
fn test_validate_socket_path_traversal() {
let tmp_dir = std::env::temp_dir();
let socket_path = tmp_dir.join("../etc/passwd");
let result = validate_socket_path(&socket_path);
assert!(result.is_err() || !socket_path.starts_with(&tmp_dir));
}
#[test]
fn test_validate_socket_path_bad_filename() {
let tmp_dir = std::env::temp_dir();
let socket_path = tmp_dir.join("test..sock");
assert!(validate_socket_path(&socket_path).is_err());
}
#[test]
fn test_max_message_size_constant() {
assert_eq!(MAX_MESSAGE_SIZE, 10 * 1024 * 1024);
}
#[test]
fn test_cleanup_socket_nonexistent() {
let temp = TempDir::new().unwrap();
let project = temp.path().join("nonexistent");
let result = cleanup_socket(&project);
assert!(result.is_ok());
}
#[cfg(unix)]
#[test]
fn test_check_not_symlink_regular_file() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("regular.txt");
std::fs::write(&file_path, "test").unwrap();
assert!(check_not_symlink(&file_path).is_ok());
}
#[cfg(unix)]
#[test]
fn test_check_not_symlink_symlink() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("regular.txt");
let link_path = temp.path().join("symlink.txt");
std::fs::write(&file_path, "test").unwrap();
std::os::unix::fs::symlink(&file_path, &link_path).unwrap();
assert!(check_not_symlink(&link_path).is_err());
}
#[cfg(unix)]
#[test]
fn test_check_not_symlink_nonexistent() {
let temp = TempDir::new().unwrap();
let path = temp.path().join("nonexistent");
assert!(check_not_symlink(&path).is_ok());
}
#[tokio::test]
async fn test_connect_nonexistent_daemon() {
let temp = TempDir::new().unwrap();
let project = temp.path();
let result = IpcStream::connect(project).await;
assert!(matches!(result, Err(DaemonError::NotRunning)));
}
}