use crate::broker::backend_sdk::{FrameClient, FrameClientError};
use crate::broker::client::{
broker_disabled_by_env, connect_to_backend, BackendConnection, BackendConnectionRoute,
BrokerClientError, BrokerDisableEnvError, ConnectBackendRequest,
};
use crate::broker::protocol::{Frame, Negotiated};
pub struct BrokerSession {
client: FrameClient,
route: BackendConnectionRoute,
endpoint: String,
negotiated: Option<Negotiated>,
}
impl BrokerSession {
pub fn adopt(request: ConnectBackendRequest<'_>) -> Result<Self, AdoptError> {
if broker_disabled_by_env()? {
return Err(AdoptError::BrokerDisabled);
}
Ok(Self::from_connection(connect_to_backend(request)?))
}
fn from_connection(connection: BackendConnection) -> Self {
Self {
client: FrameClient::from_stream(connection.stream),
route: connection.route,
endpoint: connection.endpoint,
negotiated: connection.negotiated,
}
}
pub fn route(&self) -> BackendConnectionRoute {
self.route
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
pub fn negotiated(&self) -> Option<&Negotiated> {
self.negotiated.as_ref()
}
pub fn request(
&mut self,
payload_protocol: u32,
payload: Vec<u8>,
) -> Result<Frame, FrameClientError> {
self.client.request(payload_protocol, payload)
}
pub fn client_mut(&mut self) -> &mut FrameClient {
&mut self.client
}
pub fn into_client(self) -> FrameClient {
self.client
}
pub fn into_backend_io(self) -> Result<OwnedBackendIo, IntoBackendIoError> {
let buffered = self.client.buffered_len();
if buffered != 0 {
return Err(IntoBackendIoError::BufferedResidual { buffered });
}
OwnedBackendIo::from_local_socket_stream(self.client.into_stream())
}
}
#[derive(Debug)]
pub struct OwnedBackendIo {
#[cfg(unix)]
fd: std::os::fd::OwnedFd,
}
impl OwnedBackendIo {
#[cfg(unix)]
fn from_local_socket_stream(
stream: interprocess::local_socket::Stream,
) -> Result<Self, IntoBackendIoError> {
match stream {
interprocess::local_socket::Stream::UdSocket(uds) => Ok(Self {
fd: std::os::fd::OwnedFd::from(uds),
}),
}
}
#[cfg(windows)]
fn from_local_socket_stream(
_stream: interprocess::local_socket::Stream,
) -> Result<Self, IntoBackendIoError> {
Err(IntoBackendIoError::WindowsUnsupported)
}
#[cfg(unix)]
pub fn into_owned_fd(self) -> std::os::fd::OwnedFd {
self.fd
}
}
#[cfg(unix)]
impl std::os::fd::AsFd for OwnedBackendIo {
fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
self.fd.as_fd()
}
}
#[derive(Debug, thiserror::Error)]
pub enum IntoBackendIoError {
#[error(
"frame client has {buffered} buffered response byte(s); cannot hand off the raw socket without losing them"
)]
BufferedResidual {
buffered: usize,
},
#[cfg(feature = "client-async")]
#[error("async frame client was poisoned by a prior request panic")]
Poisoned,
#[cfg(windows)]
#[error("into_backend_io() is not yet supported on Windows; the OwnedHandle path is deferred (#720)")]
WindowsUnsupported,
}
#[derive(Debug, thiserror::Error)]
pub enum AdoptError {
#[error("broker disabled via RUNNING_PROCESS_DISABLE=1; use the direct path")]
BrokerDisabled,
#[error(transparent)]
DisableEnv(#[from] BrokerDisableEnvError),
#[error(transparent)]
Connect(#[from] BrokerClientError),
#[cfg(feature = "client-async")]
#[error("async adopt worker failed to join: {0}")]
AsyncJoin(String),
}
#[cfg(feature = "client-async")]
#[derive(Clone, Debug)]
pub struct OwnedConnectRequest {
pub broker_endpoint: String,
pub service_name: String,
pub wanted_version: String,
pub self_version: String,
pub cached_backend_endpoint: Option<String>,
pub client_version: String,
pub client_lib_name: String,
pub client_lib_version: String,
pub client_keepalive_secs: u64,
pub adopt_handed_off_connection: bool,
pub handoff_ready_timeout: std::time::Duration,
}
#[cfg(feature = "client-async")]
impl OwnedConnectRequest {
pub fn new(
broker_endpoint: impl Into<String>,
service_name: impl Into<String>,
wanted_version: impl Into<String>,
self_version: impl Into<String>,
) -> Self {
Self {
broker_endpoint: broker_endpoint.into(),
service_name: service_name.into(),
wanted_version: wanted_version.into(),
self_version: self_version.into(),
cached_backend_endpoint: None,
client_version: String::new(),
client_lib_name: "running-process".to_string(),
client_lib_version: env!("CARGO_PKG_VERSION").to_string(),
client_keepalive_secs: 0,
adopt_handed_off_connection: false,
handoff_ready_timeout: crate::broker::client::DEFAULT_HANDOFF_READY_TIMEOUT,
}
}
fn as_request(&self) -> ConnectBackendRequest<'_> {
ConnectBackendRequest {
broker_endpoint: &self.broker_endpoint,
service_name: &self.service_name,
wanted_version: &self.wanted_version,
self_version: &self.self_version,
cached_backend_endpoint: self.cached_backend_endpoint.as_deref(),
client_version: &self.client_version,
client_lib_name: &self.client_lib_name,
client_lib_version: &self.client_lib_version,
client_keepalive_secs: self.client_keepalive_secs,
adopt_handed_off_connection: self.adopt_handed_off_connection,
handoff_ready_timeout: self.handoff_ready_timeout,
}
}
}
#[cfg(feature = "client-async")]
pub struct AsyncBrokerSession {
client: crate::broker::backend_sdk::AsyncFrameClient,
route: BackendConnectionRoute,
endpoint: String,
negotiated: Option<Negotiated>,
}
#[cfg(feature = "client-async")]
impl AsyncBrokerSession {
pub async fn adopt(request: OwnedConnectRequest) -> Result<Self, AdoptError> {
let joined = tokio::task::spawn_blocking(move || {
BrokerSession::adopt(request.as_request()).map(|session| {
(
session.route,
session.endpoint,
session.negotiated,
session.client,
)
})
})
.await
.map_err(|err| AdoptError::AsyncJoin(err.to_string()))?;
let (route, endpoint, negotiated, client) = joined?;
Ok(Self {
client: crate::broker::backend_sdk::AsyncFrameClient::from_blocking(client),
route,
endpoint,
negotiated,
})
}
pub fn route(&self) -> BackendConnectionRoute {
self.route
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
pub fn negotiated(&self) -> Option<&Negotiated> {
self.negotiated.as_ref()
}
pub async fn request(
&mut self,
payload_protocol: u32,
payload: Vec<u8>,
) -> Result<Frame, FrameClientError> {
self.client.request(payload_protocol, payload).await
}
pub fn into_client(self) -> crate::broker::backend_sdk::AsyncFrameClient {
self.client
}
pub fn into_backend_io(self) -> Result<OwnedBackendIo, IntoBackendIoError> {
let client = self
.client
.into_blocking()
.ok_or(IntoBackendIoError::Poisoned)?;
let buffered = client.buffered_len();
if buffered != 0 {
return Err(IntoBackendIoError::BufferedResidual { buffered });
}
OwnedBackendIo::from_local_socket_stream(client.into_stream())
}
}