use std::path::{Path, PathBuf};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::{UnixListener, UnixStream};
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use crate::error::HooksError;
use crate::hooks::HookEvent;
const SOCKET_DIR: &str = ".tazuna/sockets";
const SOCKET_NAME: &str = "main.sock";
#[must_use]
pub fn default_socket_path() -> PathBuf {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(SOCKET_DIR)
.join(SOCKET_NAME)
}
#[must_use]
pub fn socket_path_for_pid(pid: u32) -> PathBuf {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(SOCKET_DIR)
.join(format!("{pid}.sock"))
}
#[must_use]
pub fn current_socket_path() -> PathBuf {
socket_path_for_pid(std::process::id())
}
fn extract_pid_from_socket_path(path: &Path) -> Option<u32> {
path.file_stem()
.and_then(|s| s.to_str())
.and_then(|s| s.parse().ok())
}
fn is_process_alive(pid: u32) -> bool {
#[cfg(unix)]
{
#[allow(clippy::cast_possible_wrap)]
let result = unsafe { libc::kill(pid as libc::pid_t, 0) };
if result == 0 {
return true;
}
let errno = std::io::Error::last_os_error().raw_os_error().unwrap_or(0);
errno == libc::EPERM
}
#[cfg(not(unix))]
{
true
}
}
pub struct HooksServer {
socket_path: PathBuf,
listener: UnixListener,
event_tx: mpsc::Sender<HookEvent>,
}
impl HooksServer {
pub fn new(socket_path: &Path, event_tx: mpsc::Sender<HookEvent>) -> Result<Self, HooksError> {
if let Some(parent) = socket_path.parent() {
std::fs::create_dir_all(parent).map_err(HooksError::IpcFailed)?;
}
if socket_path.exists() {
if let Some(pid) = extract_pid_from_socket_path(socket_path) {
if is_process_alive(pid) {
return Err(HooksError::SocketInUse(socket_path.to_path_buf()));
}
info!("Removing stale socket from dead process {pid}");
}
std::fs::remove_file(socket_path).map_err(HooksError::IpcFailed)?;
}
let listener = UnixListener::bind(socket_path).map_err(HooksError::IpcFailed)?;
info!("HooksServer listening on {}", socket_path.display());
Ok(Self {
socket_path: socket_path.to_path_buf(),
listener,
event_tx,
})
}
pub fn with_default_path(event_tx: mpsc::Sender<HookEvent>) -> Result<Self, HooksError> {
Self::new(&default_socket_path(), event_tx)
}
pub async fn run(self) {
loop {
match self.listener.accept().await {
Ok((stream, _addr)) => {
let event_tx = self.event_tx.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, event_tx).await {
warn!("Connection handler error: {e}");
}
});
}
Err(e) => {
error!("Failed to accept connection: {e}");
}
}
}
}
fn cleanup(&self) {
if self.socket_path.exists()
&& let Err(e) = std::fs::remove_file(&self.socket_path)
{
warn!("Failed to remove socket file: {e}");
}
}
}
impl Drop for HooksServer {
fn drop(&mut self) {
self.cleanup();
}
}
async fn handle_connection(
stream: UnixStream,
event_tx: mpsc::Sender<HookEvent>,
) -> Result<(), HooksError> {
let mut reader = BufReader::new(stream);
let mut line = String::new();
let bytes_read = reader
.read_line(&mut line)
.await
.map_err(HooksError::IpcFailed)?;
if bytes_read == 0 {
return Ok(()); }
let event: HookEvent =
serde_json::from_str(line.trim()).map_err(|e| HooksError::ParseFailed(e.to_string()))?;
debug!("Received hook event: {:?}", event.event_type);
event_tx
.send(event)
.await
.map_err(|e| HooksError::IpcFailed(std::io::Error::other(e.to_string())))?;
Ok(())
}
pub struct HooksClient {
socket_path: PathBuf,
}
impl HooksClient {
#[must_use]
pub fn new(socket_path: PathBuf) -> Self {
Self { socket_path }
}
pub fn from_env() -> Result<Self, HooksError> {
let path = std::env::var("TAZUNA_SOCKET_PATH")
.map_err(|_| HooksError::MissingEnv("TAZUNA_SOCKET_PATH".to_string()))?;
Ok(Self::new(PathBuf::from(path)))
}
pub async fn send(&self, event: &HookEvent) -> Result<(), HooksError> {
if !self.socket_path.exists() {
return Err(HooksError::SocketNotFound(self.socket_path.clone()));
}
let mut stream = UnixStream::connect(&self.socket_path)
.await
.map_err(HooksError::IpcFailed)?;
let json =
serde_json::to_string(event).map_err(|e| HooksError::ParseFailed(e.to_string()))?;
stream
.write_all(json.as_bytes())
.await
.map_err(HooksError::IpcFailed)?;
stream
.write_all(b"\n")
.await
.map_err(HooksError::IpcFailed)?;
stream.flush().await.map_err(HooksError::IpcFailed)?;
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
use crate::hooks::HookEventType;
use crate::session::SessionId;
use rstest::rstest;
use serde_json::json;
use std::time::Duration;
use uuid::Uuid;
fn test_session_id() -> SessionId {
SessionId::from(Uuid::new_v4())
}
fn test_event() -> HookEvent {
HookEvent::from_payload(
test_session_id(),
json!({
"hook_event_name": "Notification",
"message": "Test notification"
}),
)
.expect("create test event")
}
#[test]
fn default_socket_path_structure() {
let path = default_socket_path();
assert!(path.to_string_lossy().contains(".tazuna"));
assert!(path.ends_with("sockets/main.sock"));
}
#[tokio::test]
async fn server_with_default_path_creates_socket() {
let (tx, _rx) = mpsc::channel(16);
let expected_path = default_socket_path();
let result = HooksServer::with_default_path(tx);
if let Ok(server) = result {
assert!(expected_path.exists());
drop(server);
}
}
#[tokio::test]
async fn server_drop_removes_socket() {
let temp = tempfile::tempdir().expect("create tempdir");
let socket_path = temp.path().join("cleanup-test.sock");
let (tx, _rx) = mpsc::channel(16);
let server = HooksServer::new(&socket_path, tx).expect("create server");
assert!(socket_path.exists());
drop(server);
assert!(!socket_path.exists());
}
#[tokio::test]
async fn server_creates_socket_file() {
let temp = tempfile::tempdir().expect("create tempdir");
let socket_path = temp.path().join("test.sock");
let (tx, _rx) = mpsc::channel(16);
let server = HooksServer::new(&socket_path, tx).expect("create server");
assert!(socket_path.exists());
drop(server);
assert!(!socket_path.exists()); }
#[tokio::test]
async fn server_removes_existing_socket() {
let temp = tempfile::tempdir().expect("create tempdir");
let socket_path = temp.path().join("test.sock");
let (tx, _rx) = mpsc::channel(16);
let server1 = HooksServer::new(&socket_path, tx.clone()).expect("create server 1");
drop(server1);
let server2 = HooksServer::new(&socket_path, tx).expect("create server 2");
assert!(socket_path.exists());
drop(server2);
}
#[tokio::test]
async fn client_server_communication() {
let temp = tempfile::tempdir().expect("create tempdir");
let socket_path = temp.path().join("test.sock");
let (tx, mut rx) = mpsc::channel(16);
let server = HooksServer::new(&socket_path, tx).expect("create server");
let server_handle = tokio::spawn(server.run());
tokio::time::sleep(Duration::from_millis(50)).await;
let client = HooksClient::new(socket_path);
let event = test_event();
client.send(&event).await.expect("send event");
let received = tokio::time::timeout(Duration::from_secs(1), rx.recv())
.await
.expect("timeout")
.expect("receive event");
assert_eq!(received.event_type, HookEventType::Notification);
assert_eq!(received.message(), Some("Test notification".to_string()));
server_handle.abort();
}
#[tokio::test]
async fn client_socket_not_found() {
let client = HooksClient::new(PathBuf::from("/nonexistent/socket.sock"));
let event = test_event();
let result = client.send(&event).await;
assert!(matches!(result, Err(HooksError::SocketNotFound(_))));
}
#[tokio::test]
async fn server_handles_empty_connection() {
use tokio::net::UnixStream;
let temp = tempfile::tempdir().expect("create tempdir");
let socket_path = temp.path().join("eof-test.sock");
let (tx, _rx) = mpsc::channel(16);
let server = HooksServer::new(&socket_path, tx).expect("create server");
let server_handle = tokio::spawn(server.run());
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let stream = UnixStream::connect(&socket_path).await.expect("connect");
drop(stream);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
server_handle.abort();
}
#[test]
#[serial_test::serial]
fn hooks_client_from_env_success() {
unsafe {
std::env::set_var("TAZUNA_SOCKET_PATH", "/tmp/test-socket.sock");
}
let result = HooksClient::from_env();
unsafe {
std::env::remove_var("TAZUNA_SOCKET_PATH");
}
assert!(result.is_ok());
}
#[test]
#[serial_test::serial]
fn hooks_client_from_env_missing() {
unsafe {
std::env::remove_var("TAZUNA_SOCKET_PATH");
}
let result = HooksClient::from_env();
assert!(matches!(result, Err(HooksError::MissingEnv(_))));
}
#[rstest]
#[case(1234, "1234.sock")]
#[case(0, "0.sock")]
fn socket_path_for_pid_structure(#[case] pid: u32, #[case] suffix: &str) {
let path = socket_path_for_pid(pid);
assert!(path.to_string_lossy().contains(".tazuna/sockets"));
assert!(path.to_string_lossy().ends_with(suffix));
}
#[test]
fn current_socket_path_uses_process_id() {
let path = current_socket_path();
let expected_suffix = format!("{}.sock", std::process::id());
assert!(path.to_string_lossy().ends_with(&expected_suffix));
}
#[rstest]
#[case("/tmp/1234.sock", Some(1234))]
#[case("/tmp/main.sock", None)]
#[case("/tmp/abc.sock", None)]
#[case("/tmp/pid1234.sock", None)]
fn extract_pid_from_socket_path_cases(#[case] path: &str, #[case] expected: Option<u32>) {
assert_eq!(extract_pid_from_socket_path(&PathBuf::from(path)), expected);
}
#[test]
fn is_process_alive_self() {
assert!(is_process_alive(std::process::id()));
}
#[test]
#[cfg(unix)]
fn is_process_alive_init() {
assert!(is_process_alive(1));
}
#[test]
fn is_process_alive_dead_process() {
assert!(!is_process_alive(999_999_999));
}
#[tokio::test]
async fn server_cleans_up_stale_socket() {
let temp = tempfile::tempdir().expect("create tempdir");
let socket_path = temp.path().join("999999999.sock");
let (tx, _rx) = mpsc::channel(16);
std::fs::write(&socket_path, "").expect("create fake socket");
assert!(socket_path.exists());
let server = HooksServer::new(&socket_path, tx).expect("create server with stale cleanup");
assert!(socket_path.exists());
drop(server);
}
#[tokio::test]
async fn server_rejects_socket_in_use() {
let temp = tempfile::tempdir().expect("create tempdir");
let socket_path = temp.path().join(format!("{}.sock", std::process::id()));
let (tx, _rx) = mpsc::channel(16);
std::fs::write(&socket_path, "").expect("create fake socket");
assert!(socket_path.exists());
let result = HooksServer::new(&socket_path, tx);
assert!(matches!(result, Err(HooksError::SocketInUse(_))));
}
#[tokio::test]
async fn server_allows_non_pid_socket_cleanup() {
let temp = tempfile::tempdir().expect("create tempdir");
let socket_path = temp.path().join("legacy.sock");
let (tx, _rx) = mpsc::channel(16);
std::fs::write(&socket_path, "").expect("create fake socket");
assert!(socket_path.exists());
let server = HooksServer::new(&socket_path, tx).expect("create server with legacy cleanup");
assert!(socket_path.exists());
drop(server);
}
}