use std::sync::Arc;
use tokio::net::{UnixListener, UnixStream};
use tokio::sync::mpsc;
use super::ipc::{self, ErrPayload, Frame, Hello, Welcome, PROTOCOL_VERSION};
use super::ops;
use super::paths::CachePaths;
use super::state::DaemonState;
use super::{DaemonError, Result};
pub async fn serve(paths: CachePaths) -> Result<()> {
if paths.socket.exists() {
let _ = std::fs::remove_file(&paths.socket);
}
let listener = UnixListener::bind(&paths.socket)?;
super::paths::ensure_file_600(&paths.socket)?;
tracing::info!(socket = %paths.socket.display(), "listening");
let state = DaemonState::new(paths.clone())?;
if let Err(e) = state.fs_watcher.start(Arc::clone(&state)) {
tracing::warn!(?e, "fsnotify watcher failed to start; running degraded");
}
let shutdown = tokio::sync::Notify::new();
let shutdown = Arc::new(shutdown);
let shutdown_signals = Arc::clone(&shutdown);
tokio::spawn(async move {
let mut sigterm =
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
Ok(s) => s,
Err(e) => {
tracing::warn!(?e, "failed to install SIGTERM handler");
return;
}
};
let mut sigint =
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) {
Ok(s) => s,
Err(e) => {
tracing::warn!(?e, "failed to install SIGINT handler");
return;
}
};
tokio::select! {
_ = sigterm.recv() => tracing::info!("received SIGTERM"),
_ = sigint.recv() => tracing::info!("received SIGINT"),
}
shutdown_signals.notify_waiters();
});
let accept_state = Arc::clone(&state);
let accept_shutdown = Arc::clone(&shutdown);
let accept_loop = async move {
loop {
let (stream, _addr) = match listener.accept().await {
Ok(p) => p,
Err(e) => {
tracing::warn!(?e, "accept failed");
continue;
}
};
let state = Arc::clone(&accept_state);
let shutdown = Arc::clone(&accept_shutdown);
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, state, shutdown).await {
match e {
DaemonError::Io(io) if io.kind() == std::io::ErrorKind::UnexpectedEof => {
}
other => {
tracing::warn!(?other, "connection ended with error");
}
}
}
});
}
};
tokio::select! {
_ = accept_loop => {},
_ = shutdown.notified() => {
tracing::info!("shutdown notified, draining");
}
}
let _ = std::fs::remove_file(&paths.socket);
Ok(())
}
async fn handle_connection(
stream: UnixStream,
state: Arc<DaemonState>,
_shutdown: Arc<tokio::sync::Notify>,
) -> Result<()> {
let (read_half, write_half) = stream.into_split();
let mut reader = tokio::io::BufReader::new(read_half);
let mut writer = write_half;
let first = ipc::read_frame(&mut reader).await?;
let hello = match first {
Frame::Hello { hello } => hello,
_ => {
let err = Frame::Response {
id: 0,
ok: false,
payload: serde_json::json!({
"err": ErrPayload::new("bad_handshake", "expected Hello as first frame")
}),
};
let _ = ipc::write_frame(&mut writer, &err).await;
return Err(DaemonError::BadHandshake);
}
};
if hello.version != PROTOCOL_VERSION {
let err = serde_json::json!({
"welcome": null,
"err": ErrPayload::new(
"version_mismatch",
format!("client v{}, daemon v{}", hello.version, PROTOCOL_VERSION),
),
});
let frame: Frame = serde_json::from_value(err).map_err(DaemonError::Json)?;
ipc::write_frame(&mut writer, &frame).await?;
return Err(DaemonError::ProtocolMismatch {
client: hello.version,
daemon: PROTOCOL_VERSION,
});
}
let (out_tx, mut out_rx) = mpsc::unbounded_channel::<Frame>();
let (client_id, session_id) = state.register_session(
hello.client_pid,
hello.tty.clone(),
hello.cwd.clone(),
hello.argv0.clone(),
out_tx.clone(),
);
let welcome = Welcome {
version: PROTOCOL_VERSION,
client_id,
session_id: session_id.clone(),
daemon_pid: state.pid,
daemon_uptime_ms: state.uptime_ms(),
};
if out_tx.send(Frame::welcome(welcome)).is_err() {
state.unregister_session(client_id);
return Ok(());
}
tracing::info!(
client_id, pid = hello.client_pid, tty = ?hello.tty, cwd = ?hello.cwd,
"client registered"
);
drop(out_tx);
let pump = async move {
while let Some(frame) = out_rx.recv().await {
if let Err(e) = ipc::write_frame(&mut writer, &frame).await {
tracing::debug!(?e, "outbound write failed; closing");
break;
}
}
};
let req_state = Arc::clone(&state);
let request_loop = async move {
loop {
match ipc::read_frame(&mut reader).await {
Ok(Frame::Request { id, op, args }) => {
let response = ops::dispatch(&req_state, client_id, &op, args).await;
let frame = match response {
Ok(payload) => Frame::ok_response(id, payload),
Err(err) => Frame::err_response(id, err),
};
if !req_state.send_to(client_id, frame) {
break;
}
}
Ok(other) => {
tracing::debug!(?other, "ignoring unexpected post-handshake frame kind");
}
Err(DaemonError::Io(e))
if matches!(
e.kind(),
std::io::ErrorKind::UnexpectedEof
| std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::ConnectionReset
) =>
{
break;
}
Err(e) => {
tracing::warn!(?e, "frame read error; closing");
break;
}
}
}
req_state.unregister_session(client_id);
};
tokio::join!(pump, request_loop);
tracing::info!(client_id, "client unregistered");
Ok(())
}