use std::io::{Read, Write};
use std::num::NonZeroUsize;
use interprocess::local_socket::traits::Listener;
use prost::Message;
use crate::broker::protocol::{
read_frame, write_frame, AdminReply, ErrorCode, Frame, FramingError, HelloReply,
MAX_HELLO_BYTES,
};
use super::admin::{handle_admin_frame, AdminFrameError, AdminSnapshot, ADMIN_PAYLOAD_PROTOCOL};
use super::connection::{
bind_local_socket, peer_identity_from_stream, refused_reply, reply_for_framing_error,
write_response_frame, BrokerConnectionError, HelloResponder, LocalSocketCleanup,
PeerCredentialPolicy,
};
use super::fd_pressure::{FdPressureDecision, FdPressureGuard};
use super::hello_handler::PeerIdentity;
#[derive(Clone, Debug, PartialEq)]
pub enum ControlSocketReply {
DroppedPeer,
Hello(HelloReply),
Admin(AdminReply),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ControlSocketConnectionLimit {
Bounded(NonZeroUsize),
Unbounded,
}
impl ControlSocketConnectionLimit {
fn should_continue(self, accepted: usize) -> bool {
match self {
Self::Bounded(limit) => accepted < limit.get(),
Self::Unbounded => true,
}
}
}
pub fn handle_control_connection_with_peer_policy<S, R, F>(
stream: &mut S,
hello_responder: &R,
snapshot_provider: &F,
peer: PeerIdentity,
peer_policy: &PeerCredentialPolicy,
) -> Result<ControlSocketReply, ControlSocketError>
where
S: Read + Write,
R: HelloResponder + ?Sized,
F: Fn() -> AdminSnapshot + ?Sized,
{
handle_control_connection_with_peer_policy_and_fd_guard(
stream,
hello_responder,
snapshot_provider,
peer,
peer_policy,
None,
)
}
pub fn handle_control_connection_with_peer_policy_and_fd_guard<S, R, F>(
stream: &mut S,
hello_responder: &R,
snapshot_provider: &F,
peer: PeerIdentity,
peer_policy: &PeerCredentialPolicy,
fd_guard: Option<&FdPressureGuard>,
) -> Result<ControlSocketReply, ControlSocketError>
where
S: Read + Write,
R: HelloResponder + ?Sized,
F: Fn() -> AdminSnapshot + ?Sized,
{
if !peer_policy.allows(&peer) {
return Ok(ControlSocketReply::DroppedPeer);
}
let request_bytes = match read_frame(stream) {
Ok(bytes) => bytes,
Err(err) => {
let reply = reply_for_framing_error(&err);
write_response_frame(stream, None, &reply)?;
return Ok(ControlSocketReply::Hello(reply));
}
};
let request_frame = match Frame::decode(request_bytes.as_slice()) {
Ok(frame) => frame,
Err(_) => {
let reply = refused_reply(ErrorCode::ErrorPeerRejected, "malformed broker Frame", 0);
write_response_frame(stream, None, &reply)?;
return Ok(ControlSocketReply::Hello(reply));
}
};
if request_frame.payload_protocol == ADMIN_PAYLOAD_PROTOCOL {
let snapshot = snapshot_provider();
let response_frame = handle_admin_frame(request_frame, &snapshot)?;
let reply = write_admin_response_frame(stream, &response_frame)?;
return Ok(ControlSocketReply::Admin(reply));
}
let reply = if request_bytes.len() > MAX_HELLO_BYTES {
refused_reply(
ErrorCode::ErrorPeerRejected,
"initial Hello frame exceeds 64 KiB",
0,
)
} else if let Some(guard) = fd_guard.filter(|guard| guard.is_demoted()) {
guard.refusal_reply()
} else {
hello_responder.handle_frame(request_frame.clone(), peer)
};
write_response_frame(stream, Some(&request_frame), &reply)?;
Ok(ControlSocketReply::Hello(reply))
}
pub fn serve_control_socket_connections_with_policy<R, F>(
socket_path: &str,
hello_responder: &R,
snapshot_provider: F,
connection_count: usize,
peer_policy: &PeerCredentialPolicy,
) -> Result<(), ControlSocketError>
where
R: HelloResponder + ?Sized,
F: Fn() -> AdminSnapshot,
{
let Some(connection_count) = NonZeroUsize::new(connection_count) else {
return Ok(());
};
serve_control_socket_connections_with_limit_and_policy(
socket_path,
hello_responder,
snapshot_provider,
ControlSocketConnectionLimit::Bounded(connection_count),
peer_policy,
)
}
pub fn serve_control_socket_connections_with_limit_and_policy<R, F>(
socket_path: &str,
hello_responder: &R,
snapshot_provider: F,
connection_limit: ControlSocketConnectionLimit,
peer_policy: &PeerCredentialPolicy,
) -> Result<(), ControlSocketError>
where
R: HelloResponder + ?Sized,
F: Fn() -> AdminSnapshot,
{
serve_control_socket_connections_with_limit_policy_and_post_hello(
socket_path,
hello_responder,
snapshot_provider,
connection_limit,
peer_policy,
|_stream, _reply| {},
)
}
pub fn serve_control_socket_connections_with_limit_policy_and_post_hello<R, F, H>(
socket_path: &str,
hello_responder: &R,
snapshot_provider: F,
connection_limit: ControlSocketConnectionLimit,
peer_policy: &PeerCredentialPolicy,
post_hello: H,
) -> Result<(), ControlSocketError>
where
R: HelloResponder + ?Sized,
F: Fn() -> AdminSnapshot,
H: FnMut(&mut interprocess::local_socket::Stream, &HelloReply),
{
let fd_guard = FdPressureGuard::default();
serve_control_socket_connections_with_limit_policy_post_hello_and_fd_guard(
socket_path,
hello_responder,
snapshot_provider,
connection_limit,
peer_policy,
post_hello,
&fd_guard,
)
}
#[allow(clippy::too_many_arguments)]
pub fn serve_control_socket_connections_with_limit_policy_post_hello_and_fd_guard<R, F, H>(
socket_path: &str,
hello_responder: &R,
snapshot_provider: F,
connection_limit: ControlSocketConnectionLimit,
peer_policy: &PeerCredentialPolicy,
mut post_hello: H,
fd_guard: &FdPressureGuard,
) -> Result<(), ControlSocketError>
where
R: HelloResponder + ?Sized,
F: Fn() -> AdminSnapshot,
H: FnMut(&mut interprocess::local_socket::Stream, &HelloReply),
{
const FD_PRESSURE_ACCEPT_BACKOFF: std::time::Duration = std::time::Duration::from_millis(50);
let listener = bind_local_socket(socket_path)?;
let cleanup = LocalSocketCleanup(socket_path);
let result = (|| {
let mut accepted = 0;
while connection_limit.should_continue(accepted) {
let mut stream = match listener.accept() {
Ok(stream) => {
fd_guard.on_accept_ok();
stream
}
Err(err) => {
let was_demoted = fd_guard.is_demoted();
if fd_guard.on_accept_error(&err) == FdPressureDecision::Demoted {
if !was_demoted {
eprintln!(
"running-process-broker: accept on {socket_path} demoted \
under fd pressure: {err}"
);
}
accepted += 1;
std::thread::sleep(FD_PRESSURE_ACCEPT_BACKOFF);
continue;
}
return Err(BrokerConnectionError::Io(err).into());
}
};
accepted += 1;
let peer = peer_identity_from_stream(&stream)?;
let reply = handle_control_connection_with_peer_policy_and_fd_guard(
&mut stream,
hello_responder,
&snapshot_provider,
peer.clone(),
peer_policy,
Some(fd_guard),
)?;
if reply == ControlSocketReply::DroppedPeer {
eprintln!(
"running-process-broker: dropped connection on {socket_path} from peer \
pid={} uid_or_sid={:?}: credential policy refused",
peer.pid, peer.uid_or_sid
);
}
if let ControlSocketReply::Hello(hello_reply) = &reply {
post_hello(&mut stream, hello_reply);
}
}
Ok(())
})();
drop(listener);
drop(cleanup);
result
}
fn write_admin_response_frame<W: Write>(
writer: &mut W,
response_frame: &Frame,
) -> Result<AdminReply, ControlSocketError> {
let mut response_bytes = Vec::new();
response_frame
.encode(&mut response_bytes)
.map_err(ControlSocketError::EncodeFrame)?;
write_frame(writer, &response_bytes)?;
AdminReply::decode(response_frame.payload.as_slice())
.map_err(ControlSocketError::DecodeAdminReply)
}
#[derive(Debug, thiserror::Error)]
pub enum ControlSocketError {
#[error(transparent)]
Connection(#[from] BrokerConnectionError),
#[error(transparent)]
Framing(#[from] FramingError),
#[error(transparent)]
AdminFrame(#[from] AdminFrameError),
#[error("failed to encode broker control response Frame: {0}")]
EncodeFrame(prost::EncodeError),
#[error("failed to decode admin reply payload: {0}")]
DecodeAdminReply(prost::DecodeError),
}