#![cfg_attr(feature = "fail-on-warnings", deny(warnings))]
#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
#![allow(clippy::multiple_crate_versions)]
use bmux_attach_image_protocol::CompressionId;
use bmux_attach_layout_protocol::{
AttachPaneChunk, AttachPaneInputMode, AttachPaneMouseProtocol, AttachScene, PaneLayoutNode,
PaneSummary,
};
use bmux_config::{BmuxConfig, ConfigPaths};
pub use bmux_ipc::Event as ServerEvent;
use bmux_ipc::transport::{
ErasedIpcStream, ErasedIpcStreamReader, ErasedIpcStreamWriter, IpcStreamReader,
IpcStreamWriter, IpcTransportError, IpcWriteTiming, LocalIpcStream,
};
use bmux_ipc::{
Envelope, EnvelopeKind, ErrorCode, IncompatibilityReason, InvokeServiceKind, IpcEndpoint,
NegotiatedProtocol, ProtocolContract, Request, Response, ResponsePayload,
ServicePipelineRequest, ServicePipelineStepResult, decode, default_supported_capabilities,
encode,
};
use bmux_perf_telemetry::{PhaseChannel, PhasePayload, PhaseTimer, emit as emit_phase_timing};
use bmux_plugin_sdk::{TypedDispatchClient, TypedDispatchClientError, TypedDispatchClientResult};
use std::collections::BTreeMap;
use std::sync::{Arc, Mutex as StdMutex};
use std::time::Duration;
use thiserror::Error;
use tracing::{debug, trace, warn};
use uuid::Uuid;
pub type Result<T> = std::result::Result<T, ClientError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AttachOpenInfo {
pub context_id: Option<Uuid>,
pub session_id: Uuid,
pub can_write: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AttachLayoutState {
pub context_id: Option<Uuid>,
pub session_id: Uuid,
pub focused_pane_id: Uuid,
pub panes: Vec<PaneSummary>,
pub layout_root: PaneLayoutNode,
pub scene: AttachScene,
pub zoomed: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PaneOutputBatchResult {
pub chunks: Vec<AttachPaneChunk>,
pub output_still_pending: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AttachSnapshotState {
pub context_id: Option<Uuid>,
pub session_id: Uuid,
pub focused_pane_id: Uuid,
pub panes: Vec<PaneSummary>,
pub layout_root: PaneLayoutNode,
pub scene: AttachScene,
pub chunks: Vec<AttachPaneChunk>,
pub pane_mouse_protocols: Vec<AttachPaneMouseProtocol>,
pub pane_input_modes: Vec<AttachPaneInputMode>,
pub zoomed: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AttachPaneSnapshotState {
pub chunks: Vec<AttachPaneChunk>,
pub pane_mouse_protocols: Vec<AttachPaneMouseProtocol>,
pub pane_input_modes: Vec<AttachPaneInputMode>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ServerStatusInfo {
pub running: bool,
pub principal_id: Uuid,
pub server_control_principal_id: Uuid,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PrincipalIdentityInfo {
pub principal_id: Uuid,
pub server_control_principal_id: Uuid,
pub force_local_permitted: bool,
}
#[derive(Debug, Error)]
pub enum ClientError {
#[error("transport error: {0}")]
Transport(#[from] IpcTransportError),
#[error("serialization error: {0}")]
Serialization(#[from] bmux_codec::Error),
#[error("request timed out after {0:?}")]
Timeout(Duration),
#[error("request id mismatch (expected {expected}, got {actual})")]
RequestIdMismatch { expected: u64, actual: u64 },
#[error("unexpected envelope kind: expected {expected:?}, got {actual:?}")]
UnexpectedEnvelopeKind {
expected: EnvelopeKind,
actual: EnvelopeKind,
},
#[error("server returned error {code:?}: {message}")]
ServerError { code: ErrorCode, message: String },
#[error("unexpected response payload: {0}")]
UnexpectedResponse(&'static str),
#[error("protocol negotiation failed: {reason:?}")]
ProtocolIncompatible { reason: IncompatibilityReason },
#[error("failed loading config: {0}")]
ConfigLoad(#[from] bmux_config::ConfigError),
#[error("failed reading principal id file {path}: {source}")]
PrincipalIdRead {
path: String,
source: std::io::Error,
},
#[error("failed writing principal id file {path}: {source}")]
PrincipalIdWrite {
path: String,
source: std::io::Error,
},
#[error("invalid principal id in {path}: {value}")]
PrincipalIdParse { path: String, value: String },
}
#[derive(Debug)]
pub struct BmuxClient {
stream: ClientStream,
timeout: Duration,
next_request_id: u64,
principal_id: Uuid,
negotiated_protocol: Option<NegotiatedProtocol>,
}
#[derive(Debug)]
enum ClientStream {
Local(LocalIpcStream),
Bridge(ErasedIpcStream),
}
impl ClientStream {
async fn send_envelope_with_timing(
&mut self,
envelope: &Envelope,
) -> std::result::Result<IpcWriteTiming, IpcTransportError> {
match self {
Self::Local(stream) => stream.send_envelope_with_timing(envelope).await,
Self::Bridge(stream) => stream.send_envelope_with_timing(envelope).await,
}
}
async fn recv_envelope(&mut self) -> std::result::Result<Envelope, IpcTransportError> {
match self {
Self::Local(stream) => stream.recv_envelope().await,
Self::Bridge(stream) => stream.recv_envelope().await,
}
}
}
impl BmuxClient {
#[must_use]
pub const fn negotiated_protocol(&self) -> Option<&NegotiatedProtocol> {
self.negotiated_protocol.as_ref()
}
pub async fn connect(
endpoint: &IpcEndpoint,
timeout: Duration,
client_name: impl Into<String>,
) -> Result<Self> {
Self::connect_with_principal(endpoint, timeout, client_name, Uuid::new_v4()).await
}
pub async fn connect_with_principal(
endpoint: &IpcEndpoint,
timeout: Duration,
client_name: impl Into<String>,
principal_id: Uuid,
) -> Result<Self> {
let stream = LocalIpcStream::connect(endpoint).await?;
Self::connect_with_stream(
ClientStream::Local(stream),
timeout,
client_name,
principal_id,
)
.await
}
pub async fn connect_with_bridge_stream(
stream: ErasedIpcStream,
timeout: Duration,
client_name: impl Into<String>,
principal_id: Uuid,
) -> Result<Self> {
Self::connect_with_stream(
ClientStream::Bridge(stream),
timeout,
client_name,
principal_id,
)
.await
}
async fn connect_with_stream(
stream: ClientStream,
timeout: Duration,
client_name: impl Into<String>,
principal_id: Uuid,
) -> Result<Self> {
let client_name = client_name.into();
let mut client = Self {
stream,
timeout,
next_request_id: 1,
principal_id,
negotiated_protocol: None,
};
let handshake_attempt = client
.request(Request::Hello {
contract: ProtocolContract::current(default_supported_capabilities()),
client_name: client_name.clone(),
principal_id,
})
.await;
match handshake_attempt {
Ok(ResponsePayload::HelloNegotiated { negotiated }) => {
client.negotiated_protocol = Some(negotiated);
Ok(client)
}
Ok(ResponsePayload::HelloIncompatible { reason }) => {
Err(ClientError::ProtocolIncompatible { reason })
}
Ok(_) => Err(ClientError::UnexpectedResponse(
"handshake expected hello negotiation response",
)),
Err(error) => Err(error),
}
}
pub async fn connect_with_paths(
paths: &ConfigPaths,
client_name: impl Into<String>,
) -> Result<Self> {
let timeout = Duration::from_millis(BmuxConfig::load()?.general.server_timeout.max(1));
let endpoint = endpoint_from_paths(paths);
let principal_id = load_or_create_principal_id(paths)?;
Self::connect_with_principal(&endpoint, timeout, client_name, principal_id).await
}
pub async fn connect_default(client_name: impl Into<String>) -> Result<Self> {
Self::connect_with_paths(&ConfigPaths::default(), client_name).await
}
pub async fn ping(&mut self) -> Result<()> {
match self.request(Request::Ping).await? {
ResponsePayload::Pong => Ok(()),
_ => Err(ClientError::UnexpectedResponse("expected pong")),
}
}
#[must_use]
pub const fn principal_id(&self) -> Uuid {
self.principal_id
}
pub async fn whoami_principal(&mut self) -> Result<PrincipalIdentityInfo> {
match self.request(Request::WhoAmIPrincipal).await? {
ResponsePayload::PrincipalIdentity {
principal_id,
server_control_principal_id,
force_local_permitted,
} => Ok(PrincipalIdentityInfo {
principal_id,
server_control_principal_id,
force_local_permitted,
}),
_ => Err(ClientError::UnexpectedResponse(
"expected principal identity",
)),
}
}
pub async fn server_status(&mut self) -> Result<ServerStatusInfo> {
match self.request(Request::ServerStatus).await? {
ResponsePayload::ServerStatus {
running,
principal_id,
server_control_principal_id,
} => Ok(ServerStatusInfo {
running,
principal_id,
server_control_principal_id,
}),
_ => Err(ClientError::UnexpectedResponse("expected server status")),
}
}
pub async fn invoke_service_raw(
&mut self,
capability: impl Into<String>,
kind: InvokeServiceKind,
interface_id: impl Into<String>,
operation: impl Into<String>,
payload: Vec<u8>,
) -> Result<Vec<u8>> {
let capability = capability.into();
let interface_id = interface_id.into();
let operation = operation.into();
let payload_len = payload.len();
let total_timer = PhaseTimer::start();
let response = self
.request(Request::InvokeService {
capability: capability.clone(),
kind,
interface_id: interface_id.clone(),
operation: operation.clone(),
payload,
})
.await?;
match response {
ResponsePayload::ServiceInvoked { payload } => {
emit_phase_timing(
PhaseChannel::Service,
&service_client_invoke_phase_payload(
&capability,
kind,
&interface_id,
&operation,
payload_len,
payload.len(),
total_timer.elapsed_us(),
),
);
Ok(payload)
}
_ => Err(ClientError::UnexpectedResponse("expected service invoked")),
}
}
pub async fn invoke_service_pipeline_raw(
&mut self,
pipeline: ServicePipelineRequest,
) -> Result<Vec<ServicePipelineStepResult>> {
let step_count = pipeline.steps.len();
let total_timer = PhaseTimer::start();
match self
.request(Request::InvokeServicePipeline { pipeline })
.await?
{
ResponsePayload::ServicePipelineInvoked { results } => {
emit_phase_timing(
PhaseChannel::Service,
&service_client_pipeline_phase_payload(
step_count,
results.len(),
total_timer.elapsed_us(),
),
);
Ok(results)
}
_ => Err(ClientError::UnexpectedResponse(
"expected service pipeline invoked",
)),
}
}
pub async fn emit_on_plugin_bus(
&mut self,
kind: impl Into<String>,
payload: Vec<u8>,
) -> Result<bool> {
match self
.request(Request::EmitOnPluginBus {
kind: kind.into(),
payload,
})
.await?
{
ResponsePayload::PluginBusEmitted { emitted } => Ok(emitted),
_ => Err(ClientError::UnexpectedResponse(
"expected plugin bus emitted",
)),
}
}
pub async fn request_raw(&mut self, request: Request) -> Result<Response> {
let request_id = self.take_request_id();
let request_kind = request_kind_name(&request);
let timeout_ms = self.timeout.as_millis();
let started_at = std::time::Instant::now();
debug!(
request_id,
request = request_kind,
timeout_ms,
"ipc.request.start"
);
let encode_started = std::time::Instant::now();
let payload = encode(&request)?;
let encode_us = encode_started.elapsed().as_micros();
let envelope = Envelope::new(request_id, EnvelopeKind::Request, payload);
let send_started = std::time::Instant::now();
let write_timing = tokio::time::timeout(
self.timeout,
self.stream.send_envelope_with_timing(&envelope),
)
.await
.map_err(|_| {
warn!(
request_id,
request = request_kind,
timeout_ms,
phase = "send",
duration_ms = started_at.elapsed().as_millis(),
"ipc.request.timeout"
);
ClientError::Timeout(self.timeout)
})??;
let send_us = send_started.elapsed().as_micros();
trace!(
request_id,
request = request_kind,
duration_ms = started_at.elapsed().as_millis(),
"ipc.request.sent"
);
let recv_started = std::time::Instant::now();
let response_envelope = tokio::time::timeout(self.timeout, self.stream.recv_envelope())
.await
.map_err(|_| {
warn!(
request_id,
request = request_kind,
timeout_ms,
phase = "recv",
duration_ms = started_at.elapsed().as_millis(),
"ipc.request.timeout"
);
ClientError::Timeout(self.timeout)
})??;
let recv_us = recv_started.elapsed().as_micros();
validate_response_envelope(&response_envelope, request_id, request_kind, &started_at)?;
let decode_started = std::time::Instant::now();
let response: Response = decode(&response_envelope.payload).map_err(ClientError::from)?;
let decode_us = decode_started.elapsed().as_micros();
debug!(
request_id,
request = request_kind,
response = response_kind_name(&response),
duration_ms = started_at.elapsed().as_millis(),
"ipc.request.done"
);
emit_ipc_request_timing(
&request,
request_id,
response_kind_name(&response),
IpcClientTiming {
encode: encode_us,
send: send_us,
frame_encode: write_timing.frame_encode_us,
socket_write: write_timing.socket_write_us,
recv: recv_us,
decode: decode_us,
total: started_at.elapsed().as_micros(),
},
);
Ok(response)
}
pub async fn force_local_permitted(&mut self) -> Result<bool> {
match self.request(Request::WhoAmIPrincipal).await? {
ResponsePayload::PrincipalIdentity {
force_local_permitted,
..
} => Ok(force_local_permitted),
_ => Err(ClientError::UnexpectedResponse(
"expected principal identity",
)),
}
}
pub async fn stop_server(&mut self) -> Result<()> {
match self.request(Request::ServerStop).await? {
ResponsePayload::ServerStopping => Ok(()),
_ => Err(ClientError::UnexpectedResponse("expected server stopping")),
}
}
pub async fn subscribe_events(&mut self) -> Result<()> {
match self.request(Request::SubscribeEvents).await? {
ResponsePayload::EventsSubscribed => Ok(()),
_ => Err(ClientError::UnexpectedResponse(
"expected events subscribed response",
)),
}
}
pub async fn poll_events(&mut self, max_events: usize) -> Result<Vec<ServerEvent>> {
match self.request(Request::PollEvents { max_events }).await? {
ResponsePayload::EventBatch { events } => Ok(events),
_ => Err(ClientError::UnexpectedResponse(
"expected event batch response",
)),
}
}
async fn request(&mut self, request: Request) -> Result<ResponsePayload> {
let response = self.request_raw(request).await?;
match response {
Response::Ok(payload) => Ok(payload),
Response::Err(error) => {
debug!("server returned error {:?}: {}", error.code, error.message);
Err(ClientError::ServerError {
code: error.code,
message: error.message,
})
}
}
}
fn take_request_id(&mut self) -> u64 {
let request_id = self.next_request_id;
self.next_request_id = self.next_request_id.wrapping_add(1).max(1);
request_id
}
}
fn map_client_error(
interface: &str,
operation: &str,
err: ClientError,
) -> TypedDispatchClientError {
match err {
ClientError::ServerError { code, message } => {
TypedDispatchClientError::server(interface, operation, format!("{code:?}: {message}"))
}
ClientError::UnexpectedResponse(details) => {
TypedDispatchClientError::unexpected_response(interface, operation, details)
}
other => TypedDispatchClientError::transport(interface, operation, other.to_string()),
}
}
impl TypedDispatchClient for BmuxClient {
fn invoke_service_raw(
&mut self,
capability: &str,
kind: InvokeServiceKind,
interface_id: &str,
operation: &str,
payload: Vec<u8>,
) -> impl std::future::Future<Output = TypedDispatchClientResult<Vec<u8>>> + Send {
let interface_owned = interface_id.to_string();
let op_owned = operation.to_string();
let cap_owned = capability.to_string();
async move {
let iface_for_err = interface_owned.clone();
let op_for_err = op_owned.clone();
match self
.request(Request::InvokeService {
capability: cap_owned,
kind,
interface_id: interface_owned,
operation: op_owned,
payload,
})
.await
.map_err(|err| map_client_error(&iface_for_err, &op_for_err, err))?
{
ResponsePayload::ServiceInvoked { payload } => Ok(payload),
_ => Err(TypedDispatchClientError::unexpected_response(
iface_for_err,
op_for_err,
"expected service invoked",
)),
}
}
}
}
type PendingMap =
Arc<tokio::sync::Mutex<BTreeMap<u64, tokio::sync::oneshot::Sender<Result<Response>>>>>;
type TimedOutMap = Arc<tokio::sync::Mutex<BTreeMap<u64, TimedOutRequest>>>;
#[derive(Debug, Clone)]
struct TimedOutRequest {
request: &'static str,
elapsed_ms: u128,
}
fn store_stream_disconnect_reason(reason_slot: &Arc<StdMutex<Option<String>>>, reason: String) {
if let Ok(mut slot) = reason_slot.lock()
&& slot.is_none()
{
*slot = Some(reason);
}
}
fn format_stream_disconnect_reason(error: &IpcTransportError) -> String {
match error {
IpcTransportError::Io(io_error) if io_error.kind() == std::io::ErrorKind::UnexpectedEof => {
format!("stream closed with unexpected EOF: {io_error}")
}
IpcTransportError::Io(io_error)
if io_error.kind() == std::io::ErrorKind::ConnectionReset =>
{
format!("stream connection reset by peer: {io_error}")
}
IpcTransportError::Io(io_error) => format!("stream I/O failure: {io_error}"),
IpcTransportError::FrameDecode(decode_error) => {
format!("stream frame decode failure: {decode_error}")
}
IpcTransportError::FrameEncode(encode_error) => {
format!("stream frame encode failure: {encode_error}")
}
IpcTransportError::UnsupportedEndpoint => "stream failed: unsupported endpoint".to_string(),
}
}
#[derive(Debug)]
pub struct StreamingBmuxClient {
writer: StreamingClientWriter,
timeout: Duration,
next_request_id: u64,
principal_id: Uuid,
negotiated_protocol: Option<NegotiatedProtocol>,
pending: PendingMap,
timed_out: TimedOutMap,
event_rx: tokio::sync::mpsc::UnboundedReceiver<ServerEvent>,
disconnect_reason: Arc<StdMutex<Option<String>>>,
_reader_task: tokio::task::JoinHandle<()>,
}
#[derive(Debug)]
enum StreamingClientWriter {
Local(IpcStreamWriter),
Bridge(ErasedIpcStreamWriter),
}
impl StreamingClientWriter {
async fn send_envelope(
&mut self,
envelope: &Envelope,
) -> std::result::Result<(), IpcTransportError> {
match self {
Self::Local(writer) => writer.send_envelope(envelope).await,
Self::Bridge(writer) => writer.send_envelope(envelope).await,
}
}
}
#[derive(Debug)]
enum StreamingClientReader {
Local(IpcStreamReader),
Bridge(ErasedIpcStreamReader),
}
impl StreamingClientReader {
async fn recv_envelope(&mut self) -> std::result::Result<Envelope, IpcTransportError> {
match self {
Self::Local(reader) => reader.recv_envelope().await,
Self::Bridge(reader) => reader.recv_envelope().await,
}
}
const fn enable_frame_compression(&mut self) {
match self {
Self::Local(reader) => reader.enable_frame_compression(),
Self::Bridge(reader) => reader.enable_frame_compression(),
}
}
}
impl StreamingBmuxClient {
pub fn from_client(client: BmuxClient) -> Result<Self> {
let BmuxClient {
stream,
timeout,
next_request_id,
principal_id,
negotiated_protocol,
} = client;
let (mut reader, mut writer) = match stream {
ClientStream::Local(local_stream) => {
let (reader, writer) = local_stream.into_split();
(
StreamingClientReader::Local(reader),
StreamingClientWriter::Local(writer),
)
}
ClientStream::Bridge(bridge_stream) => {
let (reader, writer) = bridge_stream.into_split();
(
StreamingClientReader::Bridge(reader),
StreamingClientWriter::Bridge(writer),
)
}
};
if let Some(ref negotiated) = negotiated_protocol
&& let Some(codec) = resolve_frame_codec_from_capabilities(&negotiated.capabilities)
{
match &mut writer {
StreamingClientWriter::Local(writer) => {
writer.enable_frame_compression(codec.clone());
}
StreamingClientWriter::Bridge(writer) => {
writer.enable_frame_compression(codec.clone());
}
}
reader.enable_frame_compression();
}
let pending: PendingMap = Arc::new(tokio::sync::Mutex::new(BTreeMap::new()));
let timed_out: TimedOutMap = Arc::new(tokio::sync::Mutex::new(BTreeMap::new()));
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel();
let disconnect_reason = Arc::new(StdMutex::new(None));
let reader_pending = Arc::clone(&pending);
let reader_timed_out = Arc::clone(&timed_out);
let reader_disconnect_reason = Arc::clone(&disconnect_reason);
let reader_task = tokio::spawn(async move {
Self::reader_loop(
reader,
reader_pending,
reader_timed_out,
event_tx,
reader_disconnect_reason,
)
.await;
});
Ok(Self {
writer,
timeout,
next_request_id,
principal_id,
negotiated_protocol,
pending,
timed_out,
event_rx,
disconnect_reason,
_reader_task: reader_task,
})
}
#[must_use]
pub const fn negotiated_protocol(&self) -> Option<&NegotiatedProtocol> {
self.negotiated_protocol.as_ref()
}
async fn reader_loop(
mut reader: StreamingClientReader,
pending: PendingMap,
timed_out: TimedOutMap,
event_tx: tokio::sync::mpsc::UnboundedSender<ServerEvent>,
disconnect_reason: Arc<StdMutex<Option<String>>>,
) {
loop {
let envelope = match reader.recv_envelope().await {
Ok(envelope) => envelope,
Err(error) => {
let reason = format_stream_disconnect_reason(&error);
store_stream_disconnect_reason(&disconnect_reason, reason.clone());
let pending_requests = std::mem::take(&mut *pending.lock().await);
for (_, tx) in pending_requests {
let io_error_kind = match &error {
IpcTransportError::Io(io_error) => io_error.kind(),
_ => std::io::ErrorKind::BrokenPipe,
};
let io_error = std::io::Error::new(io_error_kind, reason.clone());
let _ =
tx.send(Err(ClientError::Transport(IpcTransportError::Io(io_error))));
}
return;
}
};
match envelope.kind {
EnvelopeKind::Response => {
let response_tx = pending.lock().await.remove(&envelope.request_id);
if let Some(tx) = response_tx {
match decode::<Response>(&envelope.payload) {
Ok(response) => {
let _ = tx.send(Ok(response));
}
Err(e) => {
let _ = tx.send(Err(ClientError::Serialization(e)));
}
}
} else {
let timed_out_response =
timed_out.lock().await.remove(&envelope.request_id);
if let Some(timed_out) = timed_out_response {
warn!(
request_id = envelope.request_id,
request = timed_out.request,
timed_out_elapsed_ms = timed_out.elapsed_ms,
"streaming client received late response after timeout"
);
} else {
trace!(
request_id = envelope.request_id,
"streaming client received response for unknown request id"
);
}
}
}
EnvelopeKind::Event => match decode::<ServerEvent>(&envelope.payload) {
Ok(event) => {
let _ = event_tx.send(event);
}
Err(e) => {
warn!("streaming client failed to decode event: {e:#}");
}
},
EnvelopeKind::Request => {
warn!("streaming client received unexpected request envelope");
}
}
}
}
pub const fn event_receiver(
&mut self,
) -> &mut tokio::sync::mpsc::UnboundedReceiver<ServerEvent> {
&mut self.event_rx
}
#[must_use]
pub fn disconnect_reason(&self) -> Option<String> {
self.disconnect_reason
.lock()
.ok()
.and_then(|reason| reason.clone())
}
#[must_use]
pub const fn principal_id(&self) -> Uuid {
self.principal_id
}
pub async fn request_raw(&mut self, request: Request) -> Result<Response> {
let request_id = self.take_request_id();
let request_kind = request_kind_name(&request);
let started_at = std::time::Instant::now();
debug!(
request_id,
request = request_kind,
"streaming_ipc.request.start"
);
let encode_started = std::time::Instant::now();
let payload = encode(&request)?;
let encode_us = encode_started.elapsed().as_micros();
let envelope = Envelope::new(request_id, EnvelopeKind::Request, payload);
let (tx, rx) = tokio::sync::oneshot::channel();
{
let mut map = self.pending.lock().await;
map.insert(request_id, tx);
}
let send_started = std::time::Instant::now();
if let Err(e) = tokio::time::timeout(self.timeout, self.writer.send_envelope(&envelope))
.await
.map_err(|_| ClientError::Timeout(self.timeout))?
{
self.pending.lock().await.remove(&request_id);
return Err(ClientError::Transport(e));
}
let send_us = send_started.elapsed().as_micros();
let recv_started = std::time::Instant::now();
let response = tokio::time::timeout(self.timeout, rx)
.await
.map_err(|_| {
let elapsed_ms = started_at.elapsed().as_millis();
warn!(
request_id,
request = request_kind,
timeout_ms = self.timeout.as_millis(),
elapsed_ms,
"streaming_ipc.request.timeout"
);
let pending = Arc::clone(&self.pending);
let timed_out = Arc::clone(&self.timed_out);
tokio::spawn(async move {
pending.lock().await.remove(&request_id);
timed_out.lock().await.insert(
request_id,
TimedOutRequest {
request: request_kind,
elapsed_ms,
},
);
});
ClientError::Timeout(self.timeout)
})?
.map_err(|_| {
ClientError::Transport(IpcTransportError::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"reader task dropped before response",
)))
})??;
let recv_us = recv_started.elapsed().as_micros();
debug!(
request_id,
request = request_kind,
response = response_kind_name(&response),
duration_ms = started_at.elapsed().as_millis(),
"streaming_ipc.request.done"
);
emit_ipc_request_timing(
&request,
request_id,
response_kind_name(&response),
IpcClientTiming {
encode: encode_us,
send: send_us,
frame_encode: 0,
socket_write: 0,
recv: recv_us,
decode: 0,
total: started_at.elapsed().as_micros(),
},
);
Ok(response)
}
async fn request(&mut self, request: Request) -> Result<ResponsePayload> {
let response = self.request_raw(request).await?;
match response {
Response::Ok(payload) => Ok(payload),
Response::Err(error) => Err(ClientError::ServerError {
code: error.code,
message: error.message,
}),
}
}
fn take_request_id(&mut self) -> u64 {
let request_id = self.next_request_id;
self.next_request_id = self.next_request_id.wrapping_add(1).max(1);
request_id
}
pub async fn send_one_way(&mut self, request: Request) -> Result<()> {
let request_id = self.take_request_id();
let request_kind = request_kind_name(&request);
trace!(
request_id,
request = request_kind,
"streaming_ipc.one_way.send"
);
let payload = encode(&request)?;
let envelope = Envelope::new(request_id, EnvelopeKind::Request, payload);
self.writer
.send_envelope(&envelope)
.await
.map_err(ClientError::Transport)
}
pub async fn enable_event_push(&mut self) -> Result<()> {
match self.request(Request::EnableEventPush).await? {
ResponsePayload::EventPushEnabled => Ok(()),
_ => Err(ClientError::UnexpectedResponse(
"expected event push enabled",
)),
}
}
pub async fn ping(&mut self) -> Result<()> {
match self.request(Request::Ping).await? {
ResponsePayload::Pong => Ok(()),
_ => Err(ClientError::UnexpectedResponse("expected pong")),
}
}
pub async fn whoami_principal(&mut self) -> Result<PrincipalIdentityInfo> {
match self.request(Request::WhoAmIPrincipal).await? {
ResponsePayload::PrincipalIdentity {
principal_id,
server_control_principal_id,
force_local_permitted,
} => Ok(PrincipalIdentityInfo {
principal_id,
server_control_principal_id,
force_local_permitted,
}),
_ => Err(ClientError::UnexpectedResponse(
"expected principal identity",
)),
}
}
pub async fn subscribe_events(&mut self) -> Result<()> {
match self.request(Request::SubscribeEvents).await? {
ResponsePayload::EventsSubscribed => Ok(()),
_ => Err(ClientError::UnexpectedResponse(
"expected events subscribed",
)),
}
}
pub async fn poll_events(&mut self, max_events: usize) -> Result<Vec<ServerEvent>> {
match self.request(Request::PollEvents { max_events }).await? {
ResponsePayload::EventBatch { events } => Ok(events),
_ => Err(ClientError::UnexpectedResponse(
"expected event batch response",
)),
}
}
pub async fn invoke_service_raw(
&mut self,
capability: impl Into<String>,
kind: InvokeServiceKind,
interface_id: impl Into<String>,
operation: impl Into<String>,
payload: Vec<u8>,
) -> Result<Vec<u8>> {
let capability = capability.into();
let interface_id = interface_id.into();
let operation = operation.into();
let payload_len = payload.len();
let total_timer = PhaseTimer::start();
let response = self
.request(Request::InvokeService {
capability: capability.clone(),
kind,
interface_id: interface_id.clone(),
operation: operation.clone(),
payload,
})
.await?;
match response {
ResponsePayload::ServiceInvoked { payload } => {
emit_phase_timing(
PhaseChannel::Service,
&service_client_invoke_phase_payload(
&capability,
kind,
&interface_id,
&operation,
payload_len,
payload.len(),
total_timer.elapsed_us(),
),
);
Ok(payload)
}
_ => Err(ClientError::UnexpectedResponse("expected service invoked")),
}
}
pub async fn invoke_service_pipeline_raw(
&mut self,
pipeline: ServicePipelineRequest,
) -> Result<Vec<ServicePipelineStepResult>> {
let step_count = pipeline.steps.len();
let total_timer = PhaseTimer::start();
match self
.request(Request::InvokeServicePipeline { pipeline })
.await?
{
ResponsePayload::ServicePipelineInvoked { results } => {
emit_phase_timing(
PhaseChannel::Service,
&service_client_pipeline_phase_payload(
step_count,
results.len(),
total_timer.elapsed_us(),
),
);
Ok(results)
}
_ => Err(ClientError::UnexpectedResponse(
"expected service pipeline invoked",
)),
}
}
pub async fn emit_on_plugin_bus(
&mut self,
kind: impl Into<String>,
payload: Vec<u8>,
) -> Result<bool> {
match self
.request(Request::EmitOnPluginBus {
kind: kind.into(),
payload,
})
.await?
{
ResponsePayload::PluginBusEmitted { emitted } => Ok(emitted),
_ => Err(ClientError::UnexpectedResponse(
"expected plugin bus emitted",
)),
}
}
}
impl TypedDispatchClient for StreamingBmuxClient {
fn invoke_service_raw(
&mut self,
capability: &str,
kind: InvokeServiceKind,
interface_id: &str,
operation: &str,
payload: Vec<u8>,
) -> impl std::future::Future<Output = TypedDispatchClientResult<Vec<u8>>> + Send {
let interface_owned = interface_id.to_string();
let op_owned = operation.to_string();
let cap_owned = capability.to_string();
async move {
let iface_for_err = interface_owned.clone();
let op_for_err = op_owned.clone();
match self
.request_raw(Request::InvokeService {
capability: cap_owned,
kind,
interface_id: interface_owned,
operation: op_owned,
payload,
})
.await
.map_err(|err| map_client_error(&iface_for_err, &op_for_err, err))?
{
Response::Ok(ResponsePayload::ServiceInvoked { payload }) => Ok(payload),
Response::Err(error) => Err(TypedDispatchClientError::server(
iface_for_err,
op_for_err,
format!("{:?}: {}", error.code, error.message),
)),
Response::Ok(_) => Err(TypedDispatchClientError::unexpected_response(
iface_for_err,
op_for_err,
"expected service invoked",
)),
}
}
}
}
const fn request_kind_name(request: &Request) -> &'static str {
match request {
Request::Hello { .. } => "hello",
Request::Ping => "ping",
Request::WhoAmIPrincipal => "whoami_principal",
Request::ServerStatus => "server_status",
Request::ServerStop => "server_stop",
Request::InvokeService { .. } => "invoke_service",
Request::InvokeServicePipeline { .. } => "invoke_service_pipeline",
Request::EmitOnPluginBus { .. } => "emit_on_plugin_bus",
Request::SubscribeEvents => "subscribe_events",
Request::PollEvents { .. } => "poll_events",
Request::EnableEventPush => "enable_event_push",
}
}
fn validate_response_envelope(
response_envelope: &Envelope,
request_id: u64,
request_kind: &str,
started_at: &std::time::Instant,
) -> Result<()> {
if response_envelope.request_id != request_id {
warn!(
request_id,
request = request_kind,
actual_request_id = response_envelope.request_id,
duration_ms = started_at.elapsed().as_millis(),
"ipc.request.id_mismatch"
);
return Err(ClientError::RequestIdMismatch {
expected: request_id,
actual: response_envelope.request_id,
});
}
if response_envelope.kind != EnvelopeKind::Response {
warn!(
request_id,
request = request_kind,
actual_kind = ?response_envelope.kind,
duration_ms = started_at.elapsed().as_millis(),
"ipc.request.unexpected_envelope_kind"
);
return Err(ClientError::UnexpectedEnvelopeKind {
expected: EnvelopeKind::Response,
actual: response_envelope.kind,
});
}
Ok(())
}
#[derive(Debug, Clone, Copy)]
struct IpcClientTiming {
encode: u128,
send: u128,
frame_encode: u128,
socket_write: u128,
recv: u128,
decode: u128,
total: u128,
}
fn emit_ipc_request_timing(
request: &Request,
request_id: u64,
response: &'static str,
timing: IpcClientTiming,
) {
emit_phase_timing(
PhaseChannel::Ipc,
&ipc_client_request_phase_payload(request, request_id, response, timing),
);
}
fn service_client_invoke_phase_payload(
capability: &str,
kind: InvokeServiceKind,
interface_id: &str,
operation: &str,
payload_len: usize,
response_len: usize,
total_us: u128,
) -> serde_json::Value {
PhasePayload::new("service.client_invoke")
.service_fields(capability, format!("{kind:?}"), interface_id, operation)
.field("payload_len", payload_len)
.field("response_len", response_len)
.field("total_us", total_us)
.finish()
}
fn service_client_pipeline_phase_payload(
step_count: usize,
result_count: usize,
total_us: u128,
) -> serde_json::Value {
PhasePayload::new("service_pipeline.client_request")
.field("step_count", step_count)
.field("result_count", result_count)
.field("total_us", total_us)
.finish()
}
fn ipc_client_request_phase_payload(
request: &Request,
request_id: u64,
response: &'static str,
timing: IpcClientTiming,
) -> serde_json::Value {
let mut payload = PhasePayload::new("ipc.client_request")
.field("request", request_kind_name(request))
.field("request_id", request_id)
.field("response", response)
.field("encode_us", timing.encode)
.field("request_encode_us", timing.encode)
.field("send_us", timing.send)
.field("frame_encode_us", timing.frame_encode)
.field("socket_write_us", timing.socket_write)
.field("recv_us", timing.recv)
.field("response_read_us", timing.recv)
.field("decode_us", timing.decode)
.field("response_decode_us", timing.decode)
.field("total_us", timing.total);
if let Request::InvokeService {
capability,
kind,
interface_id,
operation,
payload: service_payload,
} = request
{
payload = payload
.service_fields(capability, format!("{kind:?}"), interface_id, operation)
.field("service_payload_len", service_payload.len());
} else if let Request::InvokeServicePipeline { pipeline } = request {
payload = payload.field("pipeline_step_count", pipeline.steps.len());
}
payload.finish()
}
const fn response_kind_name(response: &Response) -> &'static str {
match response {
Response::Ok(payload) => match payload {
ResponsePayload::Pong => "pong",
ResponsePayload::PrincipalIdentity { .. } => "principal_identity",
ResponsePayload::HelloNegotiated { .. } => "hello_negotiated",
ResponsePayload::HelloIncompatible { .. } => "hello_incompatible",
ResponsePayload::ServerStatus { .. } => "server_status",
ResponsePayload::ServerStopping => "server_stopping",
ResponsePayload::ServiceInvoked { .. } => "service_invoked",
ResponsePayload::ServicePipelineInvoked { .. } => "service_pipeline_invoked",
ResponsePayload::EventsSubscribed => "events_subscribed",
ResponsePayload::EventBatch { .. } => "event_batch",
ResponsePayload::EventPushEnabled => "event_push_enabled",
ResponsePayload::PluginBusEmitted { .. } => "plugin_bus_emitted",
},
Response::Err(_) => "error",
}
}
fn resolve_frame_codec_from_capabilities(
capabilities: &[String],
) -> Option<Arc<dyn bmux_ipc::compression::CompressionCodec>> {
use bmux_ipc::compression;
if capabilities
.iter()
.any(|c| c == bmux_ipc::CAPABILITY_COMPRESSION_FRAME_LZ4)
{
compression::resolve_codec(CompressionId::Lz4).map(Arc::from)
} else if capabilities
.iter()
.any(|c| c == bmux_ipc::CAPABILITY_COMPRESSION_FRAME_ZSTD)
{
compression::resolve_codec(CompressionId::Zstd).map(Arc::from)
} else {
None
}
}
fn endpoint_from_paths(paths: &ConfigPaths) -> IpcEndpoint {
#[cfg(unix)]
{
IpcEndpoint::unix_socket(paths.server_socket())
}
#[cfg(windows)]
{
IpcEndpoint::windows_named_pipe(paths.server_named_pipe())
}
}
fn load_or_create_principal_id(paths: &ConfigPaths) -> Result<Uuid> {
let path = paths.principal_id_file();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|source| ClientError::PrincipalIdWrite {
path: path.display().to_string(),
source,
})?;
}
match std::fs::read_to_string(&path) {
Ok(content) => {
let raw = content.trim();
Uuid::parse_str(raw).map_err(|_| ClientError::PrincipalIdParse {
path: path.display().to_string(),
value: raw.to_string(),
})
}
Err(error) if error.kind() == std::io::ErrorKind::NotFound => {
let principal_id = Uuid::new_v4();
std::fs::write(&path, principal_id.to_string()).map_err(|source| {
ClientError::PrincipalIdWrite {
path: path.display().to_string(),
source,
}
})?;
Ok(principal_id)
}
Err(source) => Err(ClientError::PrincipalIdRead {
path: path.display().to_string(),
source,
}),
}
}
#[cfg(test)]
mod tests {
use super::{
BmuxClient, ClientStream, ConfigPaths, StreamingBmuxClient, load_or_create_principal_id,
};
use bmux_ipc::transport::ErasedIpcStream;
use std::fs;
use std::time::Duration;
use tempfile::TempDir;
use uuid::Uuid;
fn temp_dir() -> TempDir {
tempfile::Builder::new()
.prefix("bmux-client-test-")
.tempdir()
.expect("temp dir should be created")
}
#[test]
fn load_or_create_principal_id_creates_and_persists_value() {
let root = temp_dir();
let paths = ConfigPaths::new(
root.path().join("config"),
root.path().join("runtime"),
root.path().join("data"),
root.path().join("state"),
);
let first = load_or_create_principal_id(&paths).expect("principal id should be created");
let second = load_or_create_principal_id(&paths).expect("principal id should be reused");
assert_eq!(first, second);
}
#[test]
fn load_or_create_principal_id_rejects_invalid_file_contents() {
let root = temp_dir();
let paths = ConfigPaths::new(
root.path().join("config"),
root.path().join("runtime"),
root.path().join("data"),
root.path().join("state"),
);
let path = paths.principal_id_file();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).expect("principal parent should exist");
}
fs::write(&path, "not-a-uuid").expect("principal file should be written");
let error = load_or_create_principal_id(&paths).expect_err("invalid principal should fail");
assert!(error.to_string().contains("invalid principal id"));
}
#[tokio::test]
async fn streaming_client_upgrade_accepts_bridge_stream() {
let (bridge_stream, _peer_stream) = tokio::io::duplex(8 * 1024);
let principal_id = Uuid::new_v4();
let client = BmuxClient {
stream: ClientStream::Bridge(ErasedIpcStream::new(Box::new(bridge_stream))),
timeout: Duration::from_millis(500),
next_request_id: 1,
principal_id,
negotiated_protocol: None,
};
let streaming =
StreamingBmuxClient::from_client(client).expect("bridge stream upgrade should work");
assert_eq!(streaming.principal_id(), principal_id);
}
}