use std::os::unix::fs::PermissionsExt;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use async_trait::async_trait;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::net::{UnixListener, UnixStream};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crate::protocol::{
EndMarker, Request, Response, ResponseOutcome, WireError, WireErrorKind, encode_line,
};
pub const MAX_NDJSON_LINE_BYTES: usize = 1024 * 1024;
#[async_trait]
pub trait Handler: Send + Sync + 'static {
async fn dispatch(&self, req: Request) -> DispatchOutcome;
}
pub enum DispatchOutcome {
OneShot(Result<serde_json::Value, WireError>),
Stream(Box<dyn EventStream + Send>),
}
#[async_trait]
pub trait EventStream: Send {
async fn next_event(&mut self) -> Option<serde_json::Value>;
}
struct UmaskRestore {
prev: libc::mode_t,
}
impl UmaskRestore {
#[allow(unsafe_code)] fn tighten(mask: libc::mode_t) -> Self {
let prev = unsafe { libc::umask(mask) };
Self { prev }
}
}
impl Drop for UmaskRestore {
#[allow(unsafe_code)] fn drop(&mut self) {
unsafe {
libc::umask(self.prev);
}
}
}
pub async fn spawn_unix_server<H: Handler>(
socket_path: &Path,
handler: Arc<H>,
cancel: CancellationToken,
) -> std::io::Result<JoinHandle<()>> {
let _ = std::fs::remove_file(socket_path);
let _umask_restore = UmaskRestore::tighten(0o117);
let listener = UnixListener::bind(socket_path)?;
let perms = std::fs::Permissions::from_mode(0o600);
std::fs::set_permissions(socket_path, perms)?;
if let Some(parent) = socket_path.parent()
&& let Ok(meta) = std::fs::metadata(parent)
{
let mode = meta.permissions().mode() & 0o777;
if mode != 0o700 && mode != 0o770 {
tracing::warn!(
dir = %parent.display(),
mode = format!("{:#o}", mode),
"mgmt socket parent dir is broader than 0700/0770; restrict perms or move the socket",
);
}
}
let socket_path: PathBuf = socket_path.to_path_buf();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
biased;
() = cancel.cancelled() => {
let _ = std::fs::remove_file(&socket_path);
return;
}
accepted = listener.accept() => {
let stream: UnixStream = match accepted {
Ok((s, _)) => s,
Err(e) => {
tracing::warn!(?e, "mgmt accept failed");
continue;
}
};
let h = Arc::clone(&handler);
let conn_cancel = cancel.child_token();
tokio::spawn(async move {
let (read, write) = stream.into_split();
handle_conn(read, write, h, conn_cancel).await;
});
}
}
}
});
Ok(handle)
}
async fn read_line_bounded<R>(
reader: &mut BufReader<R>,
buf: &mut String,
cap: usize,
) -> std::io::Result<Option<()>>
where
R: AsyncRead + Unpin,
{
buf.clear();
let start_len = buf.len();
loop {
let prev_len = buf.len();
let n = reader.read_line(buf).await?;
if n == 0 {
return if buf.len() == start_len { Ok(None) } else { Ok(Some(())) };
}
if buf.ends_with('\n') {
buf.pop();
if buf.ends_with('\r') {
buf.pop();
}
if buf.len() > cap {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("ndjson line exceeded {cap}-byte cap"),
));
}
return Ok(Some(()));
}
if buf.len() > cap {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("ndjson line exceeded {cap}-byte cap"),
));
}
if buf.len() == prev_len + n && n == 0 {
return Ok(Some(()));
}
}
}
pub(crate) async fn handle_conn<R, W, H>(
read: R,
mut write: W,
handler: Arc<H>,
cancel: CancellationToken,
) where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
H: Handler,
{
let mut reader = BufReader::new(read);
let mut line = String::new();
loop {
let read_outcome = tokio::select! {
biased;
() = cancel.cancelled() => return,
res = read_line_bounded(&mut reader, &mut line, MAX_NDJSON_LINE_BYTES) => res,
};
match read_outcome {
Ok(None) => return,
Ok(Some(())) => {}
Err(e) if e.kind() == std::io::ErrorKind::InvalidData => {
let frame = Response {
id: 0,
outcome: ResponseOutcome::Error {
error: WireError::new(WireErrorKind::BadArgs, format!("line too long: {e}")),
},
};
let _ = write_frame(&mut write, &frame).await;
return;
}
Err(e) => {
tracing::debug!(?e, "mgmt read failed");
return;
}
}
if line.is_empty() {
continue;
}
match serde_json::from_str::<Request>(&line) {
Ok(req) => {
let id = req.id;
match handler.dispatch(req).await {
DispatchOutcome::OneShot(Ok(value)) => {
let frame = Response { id, outcome: ResponseOutcome::Result { result: value } };
if write_frame(&mut write, &frame).await.is_err() {
return;
}
}
DispatchOutcome::OneShot(Err(error)) => {
let frame = Response { id, outcome: ResponseOutcome::Error { error } };
if write_frame(&mut write, &frame).await.is_err() {
return;
}
}
DispatchOutcome::Stream(mut stream) => {
loop {
tokio::select! {
biased;
() = cancel.cancelled() => {
let end = Response {
id,
outcome: ResponseOutcome::End { end: EndMarker::default() },
};
let _ = write_frame(&mut write, &end).await;
return;
}
maybe = stream.next_event() => {
let Some(event) = maybe else {
let end = Response {
id,
outcome: ResponseOutcome::End { end: EndMarker::default() },
};
let _ = write_frame(&mut write, &end).await;
return;
};
let frame = Response { id, outcome: ResponseOutcome::Event { event } };
if write_frame(&mut write, &frame).await.is_err() {
return;
}
}
}
}
}
}
}
Err(e) => {
let frame = Response {
id: 0,
outcome: ResponseOutcome::Error {
error: WireError::new(WireErrorKind::BadArgs, format!("parse: {e}")),
},
};
if write_frame(&mut write, &frame).await.is_err() {
return;
}
}
}
}
}
async fn write_frame<W: AsyncWrite + Unpin>(
write: &mut W,
frame: &Response,
) -> Result<(), std::io::Error> {
let bytes = match encode_line(frame) {
Ok(b) => b,
Err(e) => {
tracing::error!(?e, "mgmt response encode failed");
return Err(std::io::Error::other(e));
}
};
write.write_all(&bytes).await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
struct StubHandler {
last_verb: Mutex<Option<String>>,
}
#[async_trait]
impl Handler for StubHandler {
async fn dispatch(&self, req: Request) -> DispatchOutcome {
*self.last_verb.lock().unwrap() = Some(req.verb.clone());
let result: Result<serde_json::Value, WireError> = match req.verb.as_str() {
"ping" => Ok(serde_json::json!({ "pong": true })),
"echo" => Ok(req.args),
"stream2" => {
return DispatchOutcome::Stream(Box::new(MockStream::with_two_events()));
}
_ => Err(WireError::new(WireErrorKind::UnknownVerb, format!("unknown {}", req.verb))),
};
DispatchOutcome::OneShot(result)
}
}
struct MockStream {
remaining: Vec<serde_json::Value>,
}
impl MockStream {
fn with_two_events() -> Self {
Self { remaining: vec![serde_json::json!({ "n": 1 }), serde_json::json!({ "n": 2 })] }
}
}
#[async_trait]
impl EventStream for MockStream {
async fn next_event(&mut self) -> Option<serde_json::Value> {
self.remaining.pop()
}
}
async fn drive(handler: Arc<StubHandler>, requests: &str) -> Vec<u8> {
let (c2s_r, mut c2s_w) = tokio::io::duplex(8192);
let (s2c_w, mut s2c_r) = tokio::io::duplex(8192);
let req = requests.to_string();
let server_task = tokio::spawn(handle_conn(c2s_r, s2c_w, handler, CancellationToken::new()));
c2s_w.write_all(req.as_bytes()).await.expect("write requests");
drop(c2s_w);
server_task.await.expect("server task");
let mut buf = Vec::new();
tokio::io::AsyncReadExt::read_to_end(&mut s2c_r, &mut buf).await.expect("read responses");
buf
}
fn parse_responses(bytes: &[u8]) -> Vec<Response> {
std::str::from_utf8(bytes)
.expect("utf8")
.lines()
.filter(|l| !l.is_empty())
.map(|l| serde_json::from_str(l).expect("parse response"))
.collect()
}
#[tokio::test]
async fn server_stub_dispatches_known_verb_and_writes_result_line() {
let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
let req = Request { id: 11, verb: "ping".to_string(), args: serde_json::Value::Null };
let raw = serde_json::to_string(&req).unwrap() + "\n";
let bytes = drive(Arc::clone(&handler), &raw).await;
let responses = parse_responses(&bytes);
assert_eq!(responses.len(), 1);
assert_eq!(responses[0].id, 11);
match &responses[0].outcome {
ResponseOutcome::Result { result } => assert_eq!(result["pong"], true),
other => panic!("unexpected outcome: {other:?}"),
}
assert_eq!(handler.last_verb.lock().unwrap().as_deref(), Some("ping"));
}
#[tokio::test]
async fn server_stub_writes_error_for_unknown_verb() {
let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
let req = Request { id: 5, verb: "wat".to_string(), args: serde_json::Value::Null };
let raw = serde_json::to_string(&req).unwrap() + "\n";
let bytes = drive(handler, &raw).await;
let responses = parse_responses(&bytes);
assert_eq!(responses.len(), 1);
assert_eq!(responses[0].id, 5);
match &responses[0].outcome {
ResponseOutcome::Error { error } => {
assert_eq!(error.kind, WireErrorKind::UnknownVerb);
assert!(error.message.contains("wat"));
}
other => panic!("expected error, got {other:?}"),
}
}
#[tokio::test]
async fn server_stub_writes_bad_args_error_for_unparseable_request() {
let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
let raw = "this is not json\n";
let bytes = drive(handler, raw).await;
let responses = parse_responses(&bytes);
assert_eq!(responses.len(), 1);
assert_eq!(responses[0].id, 0);
match &responses[0].outcome {
ResponseOutcome::Error { error } => assert_eq!(error.kind, WireErrorKind::BadArgs),
other => panic!("expected error, got {other:?}"),
}
}
#[tokio::test]
async fn server_dispatches_streaming_verb_writes_event_then_end() {
let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
let req = Request { id: 99, verb: "stream2".to_string(), args: serde_json::Value::Null };
let raw = serde_json::to_string(&req).unwrap() + "\n";
let bytes = drive(handler, &raw).await;
let responses = parse_responses(&bytes);
assert_eq!(responses.len(), 3, "two events plus a terminating End frame");
for r in &responses {
assert_eq!(r.id, 99, "streaming frames echo the request id");
}
assert!(matches!(responses[0].outcome, ResponseOutcome::Event { .. }));
assert!(matches!(responses[1].outcome, ResponseOutcome::Event { .. }));
assert!(matches!(responses[2].outcome, ResponseOutcome::End { .. }));
if let ResponseOutcome::Event { event } = &responses[0].outcome {
assert_eq!(event["n"], 2);
}
if let ResponseOutcome::Event { event } = &responses[1].outcome {
assert_eq!(event["n"], 1);
}
}
#[tokio::test]
async fn server_rejects_line_exceeding_cap_with_bad_args() {
let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
let huge_line = format!(
"{{\"id\":1,\"verb\":\"x\",\"args\":\"{}\"}}\n",
"A".repeat(MAX_NDJSON_LINE_BYTES + 1)
);
let bytes = drive(handler.clone(), &huge_line).await;
let responses = parse_responses(&bytes);
assert_eq!(responses.len(), 1);
match &responses[0].outcome {
ResponseOutcome::Error { error } => {
assert_eq!(error.kind, WireErrorKind::BadArgs);
assert!(error.message.contains("line too long"), "{}", error.message);
}
other => panic!("expected BadArgs error, got {other:?}"),
}
assert!(handler.last_verb.lock().unwrap().is_none());
}
#[tokio::test]
async fn server_stub_handles_multiple_requests_serial_per_connection() {
let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
let r1 =
serde_json::to_string(&Request { id: 1, verb: "ping".into(), args: serde_json::Value::Null })
.unwrap();
let r2 = serde_json::to_string(&Request {
id: 2,
verb: "echo".into(),
args: serde_json::json!({"x": 1}),
})
.unwrap();
let r3 =
serde_json::to_string(&Request { id: 3, verb: "nope".into(), args: serde_json::Value::Null })
.unwrap();
let raw = format!("{r1}\n{r2}\n\n{r3}\n");
let bytes = drive(handler, &raw).await;
let responses = parse_responses(&bytes);
assert_eq!(responses.len(), 3, "blank line is skipped, not echoed back");
assert_eq!(responses[0].id, 1);
assert_eq!(responses[1].id, 2);
assert_eq!(responses[2].id, 3);
assert!(matches!(responses[0].outcome, ResponseOutcome::Result { .. }));
assert!(matches!(responses[1].outcome, ResponseOutcome::Result { .. }));
assert!(matches!(responses[2].outcome, ResponseOutcome::Error { .. }));
}
}