use std::io;
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
use interprocess::local_socket::prelude::*;
use prost::Message;
use crate::broker::capabilities::{handoff_transport_available, CAP_HANDLE_PASSING};
use crate::broker::protocol::{
hello_reply::Result as HelloReplyResult, read_frame, validate_frame_envelope, write_frame,
AdminReply, AdminRequest, ErrorCode, Frame, FrameKind, FrameValidationError, FramingError,
HandoffAck, Hello, HelloReply, Negotiated, PayloadEncoding, ADMIN_PAYLOAD_PROTOCOL,
CONTROL_PAYLOAD_PROTOCOL, PROTOCOL_VERSION,
};
use crate::broker::server::handoff::validate_handoff_frame;
use crate::broker::server::local_socket_name;
pub const DEFAULT_HANDOFF_READY_TIMEOUT: Duration = Duration::from_secs(2);
pub const RUNNING_PROCESS_DISABLE_ENV: &str = "RUNNING_PROCESS_DISABLE";
pub const RUNNING_PROCESS_DISABLE_VALUE: &str = "1";
pub const RUNNING_PROCESS_FAKE_BACKEND_ENV: &str = "RUNNING_PROCESS_FAKE_BACKEND";
pub fn broker_disabled_by_env() -> Result<bool, BrokerDisableEnvError> {
let Some(value) = std::env::var_os(RUNNING_PROCESS_DISABLE_ENV) else {
return Ok(false);
};
let value = value.to_string_lossy();
if value == RUNNING_PROCESS_DISABLE_VALUE {
Ok(true)
} else {
Err(BrokerDisableEnvError {
value: value.into_owned(),
})
}
}
#[derive(Clone, Debug)]
pub struct ConnectBackendRequest<'a> {
pub broker_endpoint: &'a str,
pub service_name: &'a str,
pub wanted_version: &'a str,
pub self_version: &'a str,
pub cached_backend_endpoint: Option<&'a str>,
pub client_version: &'a str,
pub client_lib_name: &'a str,
pub client_lib_version: &'a str,
pub client_keepalive_secs: u64,
pub adopt_handed_off_connection: bool,
pub handoff_ready_timeout: Duration,
}
impl<'a> ConnectBackendRequest<'a> {
pub fn new(
broker_endpoint: &'a str,
service_name: &'a str,
wanted_version: &'a str,
self_version: &'a str,
) -> Self {
Self {
broker_endpoint,
service_name,
wanted_version,
self_version,
cached_backend_endpoint: None,
client_version: "",
client_lib_name: "running-process",
client_lib_version: env!("CARGO_PKG_VERSION"),
client_keepalive_secs: 0,
adopt_handed_off_connection: false,
handoff_ready_timeout: DEFAULT_HANDOFF_READY_TIMEOUT,
}
}
fn can_hello_skip(&self) -> bool {
self.cached_backend_endpoint.is_some() && self.wanted_version == self.self_version
}
fn hello(&self) -> Hello {
Hello {
client_min_protocol: PROTOCOL_VERSION,
client_max_protocol: PROTOCOL_VERSION,
service_name: self.service_name.into(),
wanted_version: self.wanted_version.into(),
client_version: self.client_version.into(),
client_capabilities: client_capabilities(),
auth_token: Vec::new(),
request_id: "hello".into(),
connection_id: 0,
peer_pid: std::process::id(),
client_lib_name: self.client_lib_name.into(),
client_lib_version: self.client_lib_version.into(),
peer_attestation_nonce: Vec::new(),
capability_token: Vec::new(),
client_keepalive_secs: self.client_keepalive_secs,
}
}
}
fn client_capabilities() -> u64 {
if handoff_transport_available() {
CAP_HANDLE_PASSING
} else {
0
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BackendConnectionRoute {
HelloSkip,
BrokerNegotiated,
HandlePassed,
}
#[derive(Debug)]
pub struct BackendConnection {
pub stream: interprocess::local_socket::Stream,
pub endpoint: String,
pub route: BackendConnectionRoute,
pub negotiated: Option<Negotiated>,
}
impl BackendConnection {
pub fn handoff_token(&self) -> Option<&[u8]> {
self.negotiated
.as_ref()
.map(|negotiated| negotiated.handle_passed_token.as_slice())
.filter(|token| !token.is_empty())
}
}
pub fn connect_to_backend(
request: ConnectBackendRequest<'_>,
) -> Result<BackendConnection, BrokerClientError> {
#[cfg(feature = "test-seams")]
if let Some(endpoint) = fake_backend_endpoint_from_env() {
let stream = connect_local_socket(&endpoint).map_err(BrokerClientError::BackendConnect)?;
return Ok(BackendConnection {
stream,
endpoint,
route: BackendConnectionRoute::HelloSkip,
negotiated: None,
});
}
if request.can_hello_skip() {
if let Some(endpoint) = request.cached_backend_endpoint {
if let Ok(stream) = connect_local_socket(endpoint) {
return Ok(BackendConnection {
stream,
endpoint: endpoint.into(),
route: BackendConnectionRoute::HelloSkip,
negotiated: None,
});
}
}
}
let (broker_stream, negotiated) = broker_hello(&request)?;
if request.adopt_handed_off_connection && handoff_negotiated(&negotiated) {
if let Some(adopted) = await_handoff_ready(
broker_stream,
negotiated.handle_passed_token.clone(),
request.handoff_ready_timeout,
) {
return Ok(BackendConnection {
endpoint: negotiated.backend_pipe.clone(),
stream: adopted,
route: BackendConnectionRoute::HandlePassed,
negotiated: Some(negotiated),
});
}
}
if negotiated.backend_pipe.is_empty() {
return Err(BrokerClientError::EmptyBackendPipe);
}
let stream = connect_local_socket(&negotiated.backend_pipe)
.map_err(BrokerClientError::BackendConnect)?;
Ok(BackendConnection {
endpoint: negotiated.backend_pipe.clone(),
stream,
route: BackendConnectionRoute::BrokerNegotiated,
negotiated: Some(negotiated),
})
}
#[cfg(feature = "test-seams")]
fn fake_backend_endpoint_from_env() -> Option<String> {
let value = std::env::var_os(RUNNING_PROCESS_FAKE_BACKEND_ENV)?;
let value = value.to_string_lossy();
if value.is_empty() {
return None;
}
if matches!(broker_disabled_by_env(), Ok(true)) {
return None;
}
Some(value.into_owned())
}
fn handoff_negotiated(negotiated: &Negotiated) -> bool {
negotiated.server_capabilities & CAP_HANDLE_PASSING == CAP_HANDLE_PASSING
&& !negotiated.handle_passed_token.is_empty()
}
fn await_handoff_ready(
stream: interprocess::local_socket::Stream,
expected_token: Vec<u8>,
timeout: Duration,
) -> Option<interprocess::local_socket::Stream> {
let (result_tx, result_rx) = mpsc::channel();
thread::spawn(move || {
let mut stream = stream;
let outcome = read_handoff_ready(&mut stream, &expected_token).map(|()| stream);
let _ = result_tx.send(outcome);
});
match result_rx.recv_timeout(timeout) {
Ok(Ok(stream)) => Some(stream),
Ok(Err(_)) | Err(_) => None,
}
}
fn read_handoff_ready(
stream: &mut interprocess::local_socket::Stream,
expected_token: &[u8],
) -> Result<(), &'static str> {
let bytes = read_frame(stream).map_err(|_| "failed to read handoff-ready frame")?;
let frame =
Frame::decode(bytes.as_slice()).map_err(|_| "failed to decode handoff-ready Frame")?;
validate_handoff_frame(&frame, FrameKind::Event)?;
let ack = HandoffAck::decode(frame.payload.as_slice())
.map_err(|_| "failed to decode handoff-ready HandoffAck payload")?;
if ack.token != expected_token {
return Err("handoff-ready token echo does not match the negotiated token");
}
if !ack.accepted {
return Err("broker relayed a refused handoff");
}
Ok(())
}
pub fn send_admin_request(
broker_endpoint: &str,
request: AdminRequest,
) -> Result<AdminReply, BrokerClientError> {
let mut stream =
connect_local_socket(broker_endpoint).map_err(BrokerClientError::BrokerConnect)?;
let request_frame = Frame {
envelope_version: PROTOCOL_VERSION,
kind: FrameKind::Request as i32,
payload_protocol: ADMIN_PAYLOAD_PROTOCOL,
payload: request.encode_to_vec(),
request_id: 1,
payload_encoding: PayloadEncoding::None as i32,
deadline_unix_ms: 0,
traceparent: String::new(),
tracestate: String::new(),
};
write_frame(&mut stream, &request_frame.encode_to_vec())?;
let response_bytes = read_frame(&mut stream)?;
let response_frame =
Frame::decode(response_bytes.as_slice()).map_err(BrokerClientError::DecodeFrame)?;
validate_response_frame(
&response_frame,
ADMIN_PAYLOAD_PROTOCOL,
"payload_protocol is not admin",
)?;
AdminReply::decode(response_frame.payload.as_slice())
.map_err(BrokerClientError::DecodeAdminReply)
}
pub fn connect_local_socket(endpoint: &str) -> io::Result<interprocess::local_socket::Stream> {
let name = local_socket_name(endpoint)?;
LocalSocketStream::connect(name)
}
fn broker_hello(
request: &ConnectBackendRequest<'_>,
) -> Result<(interprocess::local_socket::Stream, Negotiated), BrokerClientError> {
let mut stream =
connect_local_socket(request.broker_endpoint).map_err(BrokerClientError::BrokerConnect)?;
let hello = request.hello();
let request_frame = Frame {
envelope_version: PROTOCOL_VERSION,
kind: FrameKind::Request as i32,
payload_protocol: CONTROL_PAYLOAD_PROTOCOL,
payload: hello.encode_to_vec(),
request_id: 1,
payload_encoding: PayloadEncoding::None as i32,
deadline_unix_ms: 0,
traceparent: String::new(),
tracestate: String::new(),
};
write_frame(&mut stream, &request_frame.encode_to_vec())?;
let response_bytes = read_frame(&mut stream)?;
let response_frame =
Frame::decode(response_bytes.as_slice()).map_err(BrokerClientError::DecodeFrame)?;
validate_response_frame(
&response_frame,
CONTROL_PAYLOAD_PROTOCOL,
"payload_protocol is not control-plane",
)?;
let reply = HelloReply::decode(response_frame.payload.as_slice())
.map_err(BrokerClientError::DecodeHelloReply)?;
match reply
.result
.ok_or(BrokerClientError::MissingHelloReplyResult)?
{
HelloReplyResult::Negotiated(negotiated) => Ok((stream, negotiated)),
HelloReplyResult::Refused(refused) => Err(BrokerClientError::Refused {
code: ErrorCode::try_from(refused.code).unwrap_or(ErrorCode::Unspecified),
reason: refused.reason,
retry_after_ms: refused.retry_after_ms,
}),
}
}
fn validate_response_frame(
frame: &Frame,
expected_payload_protocol: u32,
payload_protocol_error: &'static str,
) -> Result<(), BrokerClientError> {
validate_frame_envelope(frame, FrameKind::Response, expected_payload_protocol).map_err(
|error| {
BrokerClientError::UnexpectedResponseFrame(match error {
FrameValidationError::EnvelopeVersion { .. } => "envelope_version is not v1",
FrameValidationError::Kind { .. } => "kind is not RESPONSE",
FrameValidationError::PayloadProtocol { .. } => payload_protocol_error,
FrameValidationError::PayloadEncoding { .. } => "payload is compressed",
})
},
)
}
#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
#[error("RUNNING_PROCESS_DISABLE must be unset or 1, got {value:?}")]
pub struct BrokerDisableEnvError {
pub value: String,
}
#[derive(Debug, thiserror::Error)]
pub enum BrokerClientError {
#[error("failed to connect to broker: {0}")]
BrokerConnect(io::Error),
#[error("failed to connect to negotiated backend: {0}")]
BackendConnect(io::Error),
#[error(transparent)]
Framing(#[from] FramingError),
#[error("failed to decode broker response Frame: {0}")]
DecodeFrame(prost::DecodeError),
#[error("failed to decode broker HelloReply: {0}")]
DecodeHelloReply(prost::DecodeError),
#[error("failed to decode broker AdminReply: {0}")]
DecodeAdminReply(prost::DecodeError),
#[error("unexpected broker response frame: {0}")]
UnexpectedResponseFrame(&'static str),
#[error("broker HelloReply did not contain a result")]
MissingHelloReplyResult,
#[error("broker refused Hello: {reason} ({code:?}, retry_after_ms={retry_after_ms})")]
Refused {
code: ErrorCode,
reason: String,
retry_after_ms: u64,
},
#[error("broker negotiated an empty backend endpoint")]
EmptyBackendPipe,
}
impl BrokerClientError {
pub fn refusal_kind(&self) -> Option<RefusalKind> {
match self {
BrokerClientError::Refused { code, .. } => Some(RefusalKind::from_code(*code)),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RefusalKind {
VersionUnsupported,
VersionBlocked,
ServiceUnknown,
RateLimited,
ShuttingDown,
Other(ErrorCode),
}
impl RefusalKind {
pub fn from_code(code: ErrorCode) -> Self {
match code {
ErrorCode::ErrorVersionUnsupported => RefusalKind::VersionUnsupported,
ErrorCode::ErrorVersionBlocked => RefusalKind::VersionBlocked,
ErrorCode::ErrorServiceUnknown => RefusalKind::ServiceUnknown,
ErrorCode::ErrorRateLimited => RefusalKind::RateLimited,
ErrorCode::ErrorShuttingDown => RefusalKind::ShuttingDown,
other => RefusalKind::Other(other),
}
}
}