use std::collections::{BTreeMap, HashMap};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use std::{panic, panic::AssertUnwindSafe};
use rmpv::Value;
use rpc_runtime_activation::{
ACTIVATION_INSTANCE_ID_VALUE, ActivationMode, CREATE_INSTANCE_METHOD_ID,
CreateInstanceResponse, InstanceDescriptor, LIST_INSTANCES_METHOD_ID, ListInstancesResponse,
RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID, ReleaseInstanceResponse,
ResolveInstanceIdsResponse, activation_instance_id, activation_service_guid,
decode_create_instance_request, decode_list_instances_request, decode_release_instance_request,
decode_resolve_instance_ids_request, encode_create_instance_response,
encode_list_instances_response, encode_release_instance_response,
encode_resolve_instance_ids_response,
};
use rpc_runtime_core::{
CapabilityFlags, Envelope, HelloAck, InstanceId, MethodId, Notification, Options,
RUNTIME_PROTOCOL_VERSION, Request, RequestId, ResponseError, ResponseOk, Role, ServiceGuid,
};
use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
pub use rpc_runtime_transport::ConnectionScope;
use rpc_runtime_transport::{RpcConnection, RpcListener, RpcReceiver, RpcSender, TransportError};
use tokio::sync::RwLock;
use tracing::{debug, error, info, trace, warn};
pub type HandlerFuture = Pin<Box<dyn Future<Output = Result<Value, RuntimeError>> + Send>>;
pub trait RpcServiceHandler: Send + Sync {
fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture;
}
impl<F> RpcServiceHandler for F
where
F: Send + Sync + 'static,
F: Fn(RpcCallContext, MethodId, Value) -> HandlerFuture,
{
fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
self(ctx, method_id, payload)
}
}
pub type FactoryFuture =
Pin<Box<dyn Future<Output = Result<Arc<dyn RpcServiceHandler>, RuntimeError>> + Send>>;
pub trait RpcServiceFactory: Send + Sync {
fn create(
&self,
ctx: RpcCallContext,
create_payload: Option<Vec<u8>>,
options: BTreeMap<String, String>,
) -> FactoryFuture;
}
impl<F> RpcServiceFactory for F
where
F: Send + Sync + 'static,
F: Fn(RpcCallContext, Option<Vec<u8>>, BTreeMap<String, String>) -> FactoryFuture,
{
fn create<'a>(
&self,
ctx: RpcCallContext,
create_payload: Option<Vec<u8>>,
options: BTreeMap<String, String>,
) -> FactoryFuture {
self(ctx, create_payload, options)
}
}
#[derive(Clone)]
pub struct RpcCallContext {
connection_id: u64,
instance_id: InstanceId,
sender: RpcSender,
}
impl RpcCallContext {
pub fn connection_id(&self) -> u64 {
self.connection_id
}
pub fn instance_id(&self) -> InstanceId {
self.instance_id
}
pub async fn notify(
&self,
instance_id: Option<InstanceId>,
notification_id: u32,
payload: Value,
) -> Result<(), RuntimeError> {
self.sender
.send_envelope(&Envelope::Notification(Notification {
instance_id,
notification_id: rpc_runtime_core::NotificationId::new(notification_id),
payload,
}))
.await
.map_err(|err| {
RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
})
}
pub async fn notify_bound(
&self,
notification_id: u32,
payload: Value,
) -> Result<(), RuntimeError> {
self.notify(Some(self.instance_id), notification_id, payload)
.await
}
}
#[derive(Clone)]
pub struct RpcServer {
state: Arc<ServerState>,
}
pub struct RpcServerBuilder {
state: ServerState,
}
pub const DEFAULT_AUTH_TOKEN_OPTION_KEY: &str = "tripley.auth.token";
pub trait RpcServerMetricsSink: Send + Sync {
fn record(&self, event: RpcServerMetricEvent);
}
impl<F> RpcServerMetricsSink for F
where
F: Send + Sync + 'static + Fn(RpcServerMetricEvent),
{
fn record(&self, event: RpcServerMetricEvent) {
self(event);
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum RpcServerMetricEvent {
ConnectionStarted {
connection_id: u64,
},
ConnectionEnded {
connection_id: u64,
success: bool,
},
HandshakeCompleted {
connection_id: u64,
},
HandshakeFailed {
connection_id: u64,
error_code: RuntimeErrorCode,
},
ListenerConnectionRejected {
error_code: RuntimeErrorCode,
},
RequestStarted {
connection_id: u64,
request_id: RequestId,
instance_id: InstanceId,
method_id: MethodId,
is_activation: bool,
},
RequestCompleted {
connection_id: u64,
request_id: RequestId,
instance_id: InstanceId,
method_id: MethodId,
is_activation: bool,
elapsed: Duration,
},
RequestFailed {
connection_id: u64,
request_id: RequestId,
instance_id: InstanceId,
method_id: MethodId,
is_activation: bool,
elapsed: Duration,
error_code: RuntimeErrorCode,
},
RequestSlow {
connection_id: u64,
request_id: RequestId,
instance_id: InstanceId,
method_id: MethodId,
is_activation: bool,
elapsed: Duration,
threshold: Duration,
},
ResponseSendFailed {
connection_id: u64,
request_id: RequestId,
},
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct RpcServerMetricsSnapshot {
pub connections_started: u64,
pub connections_ended: u64,
pub connections_ended_successfully: u64,
pub handshakes_completed: u64,
pub handshakes_failed: u64,
pub listener_connections_rejected: u64,
pub requests_started: u64,
pub requests_completed: u64,
pub requests_failed: u64,
pub requests_slow: u64,
pub response_send_failures: u64,
pub request_elapsed_total: Duration,
pub request_elapsed_max: Duration,
}
#[derive(Debug, Default)]
pub struct RpcServerMetricsRecorder {
connections_started: AtomicU64,
connections_ended: AtomicU64,
connections_ended_successfully: AtomicU64,
handshakes_completed: AtomicU64,
handshakes_failed: AtomicU64,
listener_connections_rejected: AtomicU64,
requests_started: AtomicU64,
requests_completed: AtomicU64,
requests_failed: AtomicU64,
requests_slow: AtomicU64,
response_send_failures: AtomicU64,
request_elapsed_total_nanos: AtomicU64,
request_elapsed_max_nanos: AtomicU64,
}
impl RpcServerMetricsRecorder {
pub fn new() -> Self {
Self::default()
}
pub fn snapshot(&self) -> RpcServerMetricsSnapshot {
RpcServerMetricsSnapshot {
connections_started: self.connections_started.load(Ordering::Relaxed),
connections_ended: self.connections_ended.load(Ordering::Relaxed),
connections_ended_successfully: self
.connections_ended_successfully
.load(Ordering::Relaxed),
handshakes_completed: self.handshakes_completed.load(Ordering::Relaxed),
handshakes_failed: self.handshakes_failed.load(Ordering::Relaxed),
listener_connections_rejected: self
.listener_connections_rejected
.load(Ordering::Relaxed),
requests_started: self.requests_started.load(Ordering::Relaxed),
requests_completed: self.requests_completed.load(Ordering::Relaxed),
requests_failed: self.requests_failed.load(Ordering::Relaxed),
requests_slow: self.requests_slow.load(Ordering::Relaxed),
response_send_failures: self.response_send_failures.load(Ordering::Relaxed),
request_elapsed_total: Duration::from_nanos(
self.request_elapsed_total_nanos.load(Ordering::Relaxed),
),
request_elapsed_max: Duration::from_nanos(
self.request_elapsed_max_nanos.load(Ordering::Relaxed),
),
}
}
fn record_elapsed(&self, elapsed: Duration) {
let nanos = duration_nanos_u64(elapsed);
saturating_atomic_add(&self.request_elapsed_total_nanos, nanos);
update_atomic_max(&self.request_elapsed_max_nanos, nanos);
}
}
impl RpcServerMetricsSink for RpcServerMetricsRecorder {
fn record(&self, event: RpcServerMetricEvent) {
match event {
RpcServerMetricEvent::ConnectionStarted { .. } => {
self.connections_started.fetch_add(1, Ordering::Relaxed);
}
RpcServerMetricEvent::ConnectionEnded { success, .. } => {
self.connections_ended.fetch_add(1, Ordering::Relaxed);
if success {
self.connections_ended_successfully
.fetch_add(1, Ordering::Relaxed);
}
}
RpcServerMetricEvent::HandshakeCompleted { .. } => {
self.handshakes_completed.fetch_add(1, Ordering::Relaxed);
}
RpcServerMetricEvent::HandshakeFailed { .. } => {
self.handshakes_failed.fetch_add(1, Ordering::Relaxed);
}
RpcServerMetricEvent::ListenerConnectionRejected { .. } => {
self.listener_connections_rejected
.fetch_add(1, Ordering::Relaxed);
}
RpcServerMetricEvent::RequestStarted { .. } => {
self.requests_started.fetch_add(1, Ordering::Relaxed);
}
RpcServerMetricEvent::RequestCompleted { elapsed, .. } => {
self.requests_completed.fetch_add(1, Ordering::Relaxed);
self.record_elapsed(elapsed);
}
RpcServerMetricEvent::RequestFailed { elapsed, .. } => {
self.requests_failed.fetch_add(1, Ordering::Relaxed);
self.record_elapsed(elapsed);
}
RpcServerMetricEvent::RequestSlow { .. } => {
self.requests_slow.fetch_add(1, Ordering::Relaxed);
}
RpcServerMetricEvent::ResponseSendFailed { .. } => {
self.response_send_failures.fetch_add(1, Ordering::Relaxed);
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RpcServerObservabilityConfig {
pub slow_call_threshold: Duration,
pub payload_preview_bytes: usize,
pub log_payload_preview: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RpcServerSecurityConfig {
pub connection_scope: ConnectionScope,
pub auth: RpcServerAuthConfig,
}
impl RpcServerSecurityConfig {
pub fn remote_allowed(mut self) -> Self {
self.connection_scope = ConnectionScope::RemoteAllowed;
self
}
pub fn local_only(mut self) -> Self {
self.connection_scope = ConnectionScope::LocalOnly;
self
}
pub fn with_token(mut self, token: impl Into<String>) -> Self {
self.auth = RpcServerAuthConfig::token(token);
self
}
pub fn with_auth(mut self, auth: RpcServerAuthConfig) -> Self {
self.auth = auth;
self
}
}
impl Default for RpcServerSecurityConfig {
fn default() -> Self {
Self {
connection_scope: ConnectionScope::LocalOnly,
auth: RpcServerAuthConfig::Disabled,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RpcServerAuthConfig {
Disabled,
Token { token: String, option_key: String },
}
impl RpcServerAuthConfig {
pub fn token(token: impl Into<String>) -> Self {
Self::Token {
token: token.into(),
option_key: DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
}
}
pub fn token_with_option_key(token: impl Into<String>, option_key: impl Into<String>) -> Self {
Self::Token {
token: token.into(),
option_key: option_key.into(),
}
}
}
impl RpcServerObservabilityConfig {
pub fn with_slow_call_threshold(mut self, threshold: Duration) -> Self {
self.slow_call_threshold = threshold;
self
}
pub fn with_payload_preview(mut self, bytes: usize) -> Self {
self.payload_preview_bytes = bytes;
self.log_payload_preview = bytes > 0;
self
}
}
impl Default for RpcServerObservabilityConfig {
fn default() -> Self {
Self {
slow_call_threshold: Duration::from_millis(500),
payload_preview_bytes: 0,
log_payload_preview: false,
}
}
}
impl RpcServerBuilder {
pub fn new() -> Self {
let mut state = ServerState::new();
state.insert_activation_instance();
Self { state }
}
pub fn observability(mut self, config: RpcServerObservabilityConfig) -> Self {
self.state.observability = config;
self
}
pub fn set_observability(&mut self, config: RpcServerObservabilityConfig) -> &mut Self {
self.state.observability = config;
self
}
pub fn metrics_sink(mut self, sink: Arc<dyn RpcServerMetricsSink>) -> Self {
self.state.metrics_sink = Some(sink);
self
}
pub fn set_metrics_sink(&mut self, sink: Arc<dyn RpcServerMetricsSink>) -> &mut Self {
self.state.metrics_sink = Some(sink);
self
}
pub fn security(mut self, config: RpcServerSecurityConfig) -> Self {
self.state.security = config;
self
}
pub fn set_security(&mut self, config: RpcServerSecurityConfig) -> &mut Self {
self.state.security = config;
self
}
pub fn register_named_instance(
&mut self,
name: impl Into<String>,
service_guid: ServiceGuid,
methods: impl IntoIterator<Item = u32>,
handler: Arc<dyn RpcServiceHandler>,
) -> InstanceId {
self.state.insert_instance(NewInstance {
service_guid,
name: Some(name.into()),
activation_mode: ActivationMode::NamedPrecreated,
releasable: false,
owner_connection_id: None,
methods: methods.into_iter().collect(),
handler,
})
}
pub fn register_singleton(
&mut self,
service_guid: ServiceGuid,
methods: impl IntoIterator<Item = u32>,
handler: Arc<dyn RpcServiceHandler>,
) -> InstanceId {
self.state.insert_instance(NewInstance {
service_guid,
name: None,
activation_mode: ActivationMode::Singleton,
releasable: false,
owner_connection_id: None,
methods: methods.into_iter().collect(),
handler,
})
}
pub fn register_factory(
&mut self,
service_guid: ServiceGuid,
methods: impl IntoIterator<Item = u32>,
factory: Arc<dyn RpcServiceFactory>,
) {
self.state.factories.insert(
service_guid.get(),
FactoryEntry {
methods: methods.into_iter().collect(),
factory,
},
);
}
pub fn build(self) -> RpcServer {
if self.state.security.connection_scope == ConnectionScope::RemoteAllowed
&& self.state.security.auth == RpcServerAuthConfig::Disabled
{
warn!("rpc server allows remote connections without token authentication");
}
RpcServer {
state: Arc::new(self.state),
}
}
}
impl Default for RpcServerBuilder {
fn default() -> Self {
Self::new()
}
}
impl RpcServer {
pub async fn serve_connection<C>(&self, connection: C) -> Result<(), RuntimeError>
where
C: Into<RpcConnection>,
{
let connection_id = self
.state
.next_connection_id
.fetch_add(1, Ordering::Relaxed);
self.state
.record_metric(RpcServerMetricEvent::ConnectionStarted { connection_id });
info!(connection_id, "rpc server connection started");
let (sender, mut receiver) = connection.into().split();
let result = async {
if let Err(error) = self
.perform_handshake(connection_id, &sender, &mut receiver)
.await
{
self.state
.record_metric(RpcServerMetricEvent::HandshakeFailed {
connection_id,
error_code: error.code,
});
return Err(error);
}
self.state
.record_metric(RpcServerMetricEvent::HandshakeCompleted { connection_id });
loop {
let envelope = match receiver.recv_envelope().await {
Ok(Some(envelope)) => envelope,
Ok(None) => {
debug!(connection_id, "rpc server connection closed by peer");
break;
}
Err(err) => {
let error = RuntimeError::transport(
RuntimeErrorCode::InternalRuntimeError,
err.to_string(),
);
warn!(
connection_id,
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server failed to receive envelope"
);
return Err(error);
}
};
match envelope {
Envelope::Request(request) => {
let state = Arc::clone(&self.state);
let sender = sender.clone();
let observability = self.state.observability;
tokio::spawn(async move {
handle_request(state, sender, connection_id, request, observability)
.await;
});
}
Envelope::Goodbye(goodbye) => {
info!(
connection_id,
reason_code = goodbye.reason_code,
message = goodbye.message.as_deref().unwrap_or(""),
"rpc server received goodbye"
);
break;
}
envelope => {
let error = RuntimeError::protocol(
RuntimeErrorCode::InvalidEnvelope,
"server expected request envelope",
);
warn!(
connection_id,
envelope_kind = envelope_name(&envelope),
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server received invalid envelope"
);
return Err(error);
}
}
}
Ok(())
}
.await;
self.state.cleanup_connection(connection_id).await;
debug!(connection_id, "rpc server connection cleanup completed");
self.state
.record_metric(RpcServerMetricEvent::ConnectionEnded {
connection_id,
success: result.is_ok(),
});
if let Err(error) = &result {
warn!(
connection_id,
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server connection ended with error"
);
} else {
info!(connection_id, "rpc server connection ended");
}
result
}
pub async fn serve_listener<L>(&self, mut listener: L) -> Result<(), RuntimeError>
where
L: RpcListener + Send,
{
listener.set_connection_scope(self.state.security.connection_scope);
loop {
let connection = match listener.accept().await {
Ok(connection) => connection,
Err(err) => {
let access_denied = is_transport_access_denied(&err);
let error = RuntimeError::transport(
if access_denied {
RuntimeErrorCode::AccessDenied
} else {
RuntimeErrorCode::InternalRuntimeError
},
err.to_string(),
);
if access_denied {
self.state.record_metric(
RpcServerMetricEvent::ListenerConnectionRejected {
error_code: RuntimeErrorCode::AccessDenied,
},
);
warn!(
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server listener rejected connection"
);
continue;
}
error!(
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server listener accept failed"
);
return Err(error);
}
};
let server = self.clone();
tokio::spawn(async move {
if let Err(error) = server.serve_connection(connection).await {
warn!(
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server listener connection task failed"
);
}
});
}
}
pub fn spawn_listener<L>(
&self,
listener: L,
) -> tokio::task::JoinHandle<Result<(), RuntimeError>>
where
L: RpcListener + Send + 'static,
{
let server = self.clone();
tokio::spawn(async move { server.serve_listener(listener).await })
}
async fn perform_handshake(
&self,
connection_id: u64,
sender: &RpcSender,
receiver: &mut RpcReceiver,
) -> Result<(), RuntimeError> {
let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
let error =
RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string());
warn!(
connection_id,
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server handshake receive failed"
);
error
})?
else {
let error = RuntimeError::transport(
RuntimeErrorCode::InternalRuntimeError,
"client disconnected during handshake",
);
warn!(
connection_id,
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server handshake disconnected"
);
return Err(error);
};
let Envelope::Hello(hello) = envelope else {
let error = RuntimeError::protocol(
RuntimeErrorCode::InvalidEnvelope,
"expected HELLO during handshake",
);
warn!(
connection_id,
envelope_kind = envelope_name(&envelope),
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server handshake received invalid envelope"
);
return Err(error);
};
if hello.protocol_version != RUNTIME_PROTOCOL_VERSION || hello.role != Role::Client {
let error = RuntimeError::protocol(
RuntimeErrorCode::UnsupportedProtocolVersion,
"unsupported client handshake",
);
warn!(
connection_id,
protocol_version = hello.protocol_version,
role = ?hello.role,
capability_bits = hello.capability_bits.bits(),
max_message_size = hello.max_message_size,
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server handshake rejected"
);
return Err(error);
}
self.validate_handshake_auth(connection_id, &hello.options)?;
sender
.send_envelope(&Envelope::HelloAck(HelloAck {
protocol_version: RUNTIME_PROTOCOL_VERSION,
accepted_capability_bits: server_capabilities() & hello.capability_bits,
max_message_size: hello.max_message_size,
options: Vec::new(),
}))
.await
.map_err(|err| {
let error = RuntimeError::transport(
RuntimeErrorCode::InternalRuntimeError,
err.to_string(),
);
warn!(
connection_id,
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server handshake ack send failed"
);
error
})?;
info!(
connection_id,
protocol_version = hello.protocol_version,
accepted_capability_bits = (server_capabilities() & hello.capability_bits).bits(),
max_message_size = hello.max_message_size,
"rpc server handshake completed"
);
Ok(())
}
fn validate_handshake_auth(
&self,
connection_id: u64,
options: &Options,
) -> Result<(), RuntimeError> {
let RpcServerAuthConfig::Token { token, option_key } = &self.state.security.auth else {
return Ok(());
};
let value = options
.iter()
.rev()
.find_map(|(key, value)| (key == option_key).then_some(value));
let Some(value) = value else {
let error = RuntimeError::protocol(
RuntimeErrorCode::AccessDenied,
"missing handshake authentication token",
);
warn!(
connection_id,
auth_option_key = %option_key,
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server handshake rejected authentication"
);
return Err(error);
};
let Some(received) = value.as_str() else {
let error = RuntimeError::protocol(
RuntimeErrorCode::AccessDenied,
"handshake authentication token must be a string",
);
warn!(
connection_id,
auth_option_key = %option_key,
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server handshake rejected authentication"
);
return Err(error);
};
if received != token {
let error = RuntimeError::protocol(
RuntimeErrorCode::AccessDenied,
"invalid handshake authentication token",
);
warn!(
connection_id,
auth_option_key = %option_key,
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server handshake rejected authentication"
);
return Err(error);
}
debug!(
connection_id,
auth_option_key = %option_key,
"rpc server handshake authentication accepted"
);
Ok(())
}
pub async fn list_instances(&self) -> Vec<InstanceDescriptor> {
self.state.list_instances(None).await
}
}
async fn handle_request(
state: Arc<ServerState>,
sender: RpcSender,
connection_id: u64,
request: Request,
observability: RpcServerObservabilityConfig,
) {
let request_id = request.request_id;
let instance_id = request.instance_id;
let method_id = request.method_id;
let is_activation = instance_id.get() == ACTIVATION_INSTANCE_ID_VALUE;
let payload_preview = payload_preview(&request.payload, observability);
debug!(
connection_id,
request_id = request_id.get(),
instance_id = instance_id.get(),
method_id = method_id.get(),
is_activation,
"rpc server request received"
);
state.record_metric(RpcServerMetricEvent::RequestStarted {
connection_id,
request_id,
instance_id,
method_id,
is_activation,
});
if let Some(payload_preview) = payload_preview {
trace!(
connection_id,
request_id = request_id.get(),
payload_preview,
"rpc server request payload preview"
);
}
let started = Instant::now();
let response = dispatch_request(state.clone(), sender.clone(), connection_id, request).await;
let elapsed = started.elapsed();
let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
let envelope = match response {
Ok(payload) => {
if elapsed >= observability.slow_call_threshold {
state.record_metric(RpcServerMetricEvent::RequestSlow {
connection_id,
request_id,
instance_id,
method_id,
is_activation,
elapsed,
threshold: observability.slow_call_threshold,
});
warn!(
connection_id,
request_id = request_id.get(),
instance_id = instance_id.get(),
method_id = method_id.get(),
is_activation,
elapsed_ms,
slow_call_threshold_ms =
observability.slow_call_threshold.as_secs_f64() * 1000.0,
"rpc server request completed slowly"
);
} else {
info!(
connection_id,
request_id = request_id.get(),
instance_id = instance_id.get(),
method_id = method_id.get(),
is_activation,
elapsed_ms,
"rpc server request completed"
);
}
state.record_metric(RpcServerMetricEvent::RequestCompleted {
connection_id,
request_id,
instance_id,
method_id,
is_activation,
elapsed,
});
Envelope::ResponseOk(ResponseOk {
request_id,
payload,
})
}
Err(error) => {
state.record_metric(RpcServerMetricEvent::RequestFailed {
connection_id,
request_id,
instance_id,
method_id,
is_activation,
elapsed,
error_code: error.code,
});
warn!(
connection_id,
request_id = request_id.get(),
instance_id = instance_id.get(),
method_id = method_id.get(),
is_activation,
elapsed_ms,
error_code = error.code.as_i32(),
error_kind = error.kind.as_u8(),
error_message = %error.message,
"rpc server request failed"
);
runtime_error_response(request_id, error)
}
};
if let Err(err) = sender.send_envelope(&envelope).await {
state.record_metric(RpcServerMetricEvent::ResponseSendFailed {
connection_id,
request_id,
});
error!(
connection_id,
request_id = request_id.get(),
error = %err,
"rpc server failed to send response"
);
} else {
trace!(
connection_id,
request_id = request_id.get(),
response_kind = envelope_name(&envelope),
"rpc server response sent"
);
}
}
async fn dispatch_request(
state: Arc<ServerState>,
sender: RpcSender,
connection_id: u64,
request: Request,
) -> Result<Value, RuntimeError> {
if request.instance_id.get() == ACTIVATION_INSTANCE_ID_VALUE {
return dispatch_activation(state, sender, connection_id, request).await;
}
let instance = state.get_instance(request.instance_id).await?;
if !instance.methods.contains(&request.method_id.get()) {
return Err(RuntimeError::runtime(
RuntimeErrorCode::MethodNotFound,
format!("method id `{}` was not found", request.method_id.get()),
));
}
let ctx = RpcCallContext {
connection_id,
instance_id: request.instance_id,
sender,
};
instance
.handler
.call(ctx, request.method_id, request.payload)
.await
}
async fn dispatch_activation(
state: Arc<ServerState>,
sender: RpcSender,
connection_id: u64,
request: Request,
) -> Result<Value, RuntimeError> {
let ctx = RpcCallContext {
connection_id,
instance_id: request.instance_id,
sender,
};
match request.method_id.get() {
RESOLVE_INSTANCE_IDS_METHOD_ID => {
let request = decode_resolve_instance_ids_request(&request.payload)?;
let ids = state.resolve_instance_ids(&request.instance_names).await;
Ok(encode_resolve_instance_ids_response(
&ResolveInstanceIdsResponse { instance_ids: ids },
))
}
CREATE_INSTANCE_METHOD_ID => {
let request = decode_create_instance_request(&request.payload)?;
let factory = state.get_factory(request.service_guid).ok_or_else(|| {
RuntimeError::runtime(
RuntimeErrorCode::ServiceGuidNotFound,
"service factory was not found",
)
})?;
let handler = factory
.factory
.create(ctx, request.create_payload, request.options)
.await?;
let instance_id = state
.insert_client_instance(
request.service_guid,
connection_id,
factory.methods.clone(),
handler,
)
.await;
Ok(encode_create_instance_response(&CreateInstanceResponse {
instance_id,
}))
}
RELEASE_INSTANCE_METHOD_ID => {
let request = decode_release_instance_request(&request.payload)?;
state
.release_instance(connection_id, request.instance_id)
.await?;
Ok(encode_release_instance_response(&ReleaseInstanceResponse))
}
LIST_INSTANCES_METHOD_ID => {
let request = decode_list_instances_request(&request.payload)?;
let instances = state.list_instances(request.service_guid).await;
Ok(encode_list_instances_response(&ListInstancesResponse {
instances,
}))
}
_ => Err(RuntimeError::runtime(
RuntimeErrorCode::MethodNotFound,
"activation method was not found",
)),
}
}
fn runtime_error_response(request_id: RequestId, error: RuntimeError) -> Envelope {
Envelope::ResponseError(ResponseError {
request_id,
error_code: error.code.as_i32(),
error_kind: error.kind.as_u8(),
error_message: Some(error.message),
error_details: Value::Nil,
})
}
fn server_capabilities() -> CapabilityFlags {
CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
| CapabilityFlags::NAMED_INSTANCE_RESOLUTION
| CapabilityFlags::SERVICE_ACTIVATION
| CapabilityFlags::GOODBYE
}
fn envelope_name(envelope: &Envelope) -> &'static str {
match envelope {
Envelope::Hello(_) => "hello",
Envelope::HelloAck(_) => "hello_ack",
Envelope::Request(_) => "request",
Envelope::ResponseOk(_) => "response_ok",
Envelope::ResponseError(_) => "response_error",
Envelope::Notification(_) => "notification",
Envelope::Goodbye(_) => "goodbye",
}
}
fn payload_preview(payload: &Value, config: RpcServerObservabilityConfig) -> Option<String> {
if !config.log_payload_preview || config.payload_preview_bytes == 0 {
return None;
}
let mut preview = format!("{payload:?}");
if preview.len() > config.payload_preview_bytes {
preview.truncate(config.payload_preview_bytes);
preview.push_str("...");
}
Some(preview)
}
fn is_transport_access_denied(error: &TransportError) -> bool {
matches!(
error,
TransportError::Runtime(error) if error.code == RuntimeErrorCode::AccessDenied
)
}
fn duration_nanos_u64(duration: Duration) -> u64 {
duration.as_nanos().min(u128::from(u64::MAX)) as u64
}
fn update_atomic_max(value: &AtomicU64, candidate: u64) {
let mut current = value.load(Ordering::Relaxed);
while candidate > current {
match value.compare_exchange_weak(current, candidate, Ordering::Relaxed, Ordering::Relaxed)
{
Ok(_) => break,
Err(actual) => current = actual,
}
}
}
fn saturating_atomic_add(value: &AtomicU64, increment: u64) {
let mut current = value.load(Ordering::Relaxed);
loop {
let next = current.saturating_add(increment);
match value.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => break,
Err(actual) => current = actual,
}
}
}
struct ServerState {
next_connection_id: AtomicU64,
next_instance_id: AtomicU64,
observability: RpcServerObservabilityConfig,
security: RpcServerSecurityConfig,
metrics_sink: Option<Arc<dyn RpcServerMetricsSink>>,
instances: RwLock<HashMap<u64, InstanceEntry>>,
names: RwLock<HashMap<String, u64>>,
factories: HashMap<uuid::Uuid, FactoryEntry>,
}
impl ServerState {
fn new() -> Self {
Self {
next_connection_id: AtomicU64::new(1),
next_instance_id: AtomicU64::new(2),
observability: RpcServerObservabilityConfig::default(),
security: RpcServerSecurityConfig::default(),
metrics_sink: None,
instances: RwLock::new(HashMap::new()),
names: RwLock::new(HashMap::new()),
factories: HashMap::new(),
}
}
fn record_metric(&self, event: RpcServerMetricEvent) {
let Some(sink) = &self.metrics_sink else {
return;
};
let result = panic::catch_unwind(AssertUnwindSafe(|| sink.record(event)));
if result.is_err() {
error!("rpc server metrics sink panicked while recording event");
}
}
fn insert_activation_instance(&mut self) {
self.instances.get_mut().insert(
ACTIVATION_INSTANCE_ID_VALUE,
InstanceEntry {
instance_id: activation_instance_id(),
service_guid: activation_service_guid(),
instance_name: Some("rpc.runtime.Activation".to_string()),
activation_mode: ActivationMode::Singleton,
releasable: false,
owner_connection_id: None,
methods: vec![
RESOLVE_INSTANCE_IDS_METHOD_ID,
CREATE_INSTANCE_METHOD_ID,
RELEASE_INSTANCE_METHOD_ID,
LIST_INSTANCES_METHOD_ID,
],
handler: Arc::new(ActivationMarker),
},
);
}
fn insert_instance(&mut self, instance: NewInstance) -> InstanceId {
let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
if let Some(name) = &instance.name {
self.names.get_mut().insert(name.clone(), id);
}
self.instances.get_mut().insert(
id,
InstanceEntry {
instance_id,
service_guid: instance.service_guid,
instance_name: instance.name,
activation_mode: instance.activation_mode,
releasable: instance.releasable,
owner_connection_id: instance.owner_connection_id,
methods: instance.methods,
handler: instance.handler,
},
);
instance_id
}
async fn insert_client_instance(
&self,
service_guid: ServiceGuid,
connection_id: u64,
methods: Vec<u32>,
handler: Arc<dyn RpcServiceHandler>,
) -> InstanceId {
let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
self.instances.write().await.insert(
id,
InstanceEntry {
instance_id,
service_guid,
instance_name: None,
activation_mode: ActivationMode::Instantiable,
releasable: true,
owner_connection_id: Some(connection_id),
methods,
handler,
},
);
instance_id
}
async fn get_instance(&self, instance_id: InstanceId) -> Result<InstanceEntry, RuntimeError> {
self.instances
.read()
.await
.get(&instance_id.get())
.cloned()
.ok_or_else(|| {
RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
})
}
fn get_factory(&self, service_guid: ServiceGuid) -> Option<FactoryEntry> {
self.factories.get(&service_guid.get()).cloned()
}
async fn resolve_instance_ids(&self, names: &[String]) -> Vec<u64> {
let index = self.names.read().await;
names
.iter()
.map(|name| index.get(name).copied().unwrap_or(0))
.collect()
}
async fn release_instance(
&self,
connection_id: u64,
instance_id: InstanceId,
) -> Result<(), RuntimeError> {
let mut instances = self.instances.write().await;
let entry = instances.get(&instance_id.get()).ok_or_else(|| {
RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
})?;
if !entry.releasable {
return Err(RuntimeError::runtime(
RuntimeErrorCode::InstanceReleaseNotAllowed,
"instance is not releasable",
));
}
if entry.owner_connection_id != Some(connection_id) {
return Err(RuntimeError::runtime(
RuntimeErrorCode::AccessDenied,
"instance is owned by another connection",
));
}
instances.remove(&instance_id.get());
Ok(())
}
async fn cleanup_connection(&self, connection_id: u64) {
self.instances
.write()
.await
.retain(|_, entry| entry.owner_connection_id != Some(connection_id));
}
async fn list_instances(&self, service_guid: Option<ServiceGuid>) -> Vec<InstanceDescriptor> {
let mut values = self
.instances
.read()
.await
.values()
.filter(|entry| service_guid.is_none_or(|guid| guid == entry.service_guid))
.map(InstanceEntry::descriptor)
.collect::<Vec<_>>();
values.sort_by_key(|entry| entry.instance_id.get());
values
}
}
struct NewInstance {
service_guid: ServiceGuid,
name: Option<String>,
activation_mode: ActivationMode,
releasable: bool,
owner_connection_id: Option<u64>,
methods: Vec<u32>,
handler: Arc<dyn RpcServiceHandler>,
}
#[derive(Clone)]
struct InstanceEntry {
instance_id: InstanceId,
service_guid: ServiceGuid,
instance_name: Option<String>,
activation_mode: ActivationMode,
releasable: bool,
owner_connection_id: Option<u64>,
methods: Vec<u32>,
handler: Arc<dyn RpcServiceHandler>,
}
impl InstanceEntry {
fn descriptor(&self) -> InstanceDescriptor {
InstanceDescriptor {
instance_id: self.instance_id,
instance_name: self.instance_name.clone(),
service_guid: self.service_guid,
activation_mode: self.activation_mode,
releasable: self.releasable,
}
}
}
#[derive(Clone)]
struct FactoryEntry {
methods: Vec<u32>,
factory: Arc<dyn RpcServiceFactory>,
}
struct ActivationMarker;
impl RpcServiceHandler for ActivationMarker {
fn call(&self, _: RpcCallContext, _: MethodId, _: Value) -> HandlerFuture {
Box::pin(async {
Err(RuntimeError::runtime(
RuntimeErrorCode::InternalRuntimeError,
"activation marker should not be dispatched directly",
))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use rpc_runtime_core::{Goodbye, Hello, Request, Role};
use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection};
use tokio::io::duplex;
#[test]
fn observability_defaults_are_safe() {
let config = RpcServerObservabilityConfig::default();
assert_eq!(config.slow_call_threshold, Duration::from_millis(500));
assert_eq!(config.payload_preview_bytes, 0);
assert!(!config.log_payload_preview);
}
#[test]
fn payload_preview_is_opt_in_and_bounded() {
let payload = Value::from("1234567890");
assert_eq!(
payload_preview(&payload, RpcServerObservabilityConfig::default()),
None
);
let preview = payload_preview(
&payload,
RpcServerObservabilityConfig::default().with_payload_preview(5),
)
.expect("preview");
assert!(preview.len() <= 8);
assert!(preview.ends_with("..."));
}
#[test]
fn metrics_recorder_counts_events_and_latency() {
let recorder = RpcServerMetricsRecorder::new();
recorder.record(RpcServerMetricEvent::ConnectionStarted { connection_id: 1 });
recorder.record(RpcServerMetricEvent::ConnectionEnded {
connection_id: 1,
success: true,
});
recorder.record(RpcServerMetricEvent::RequestCompleted {
connection_id: 1,
request_id: RequestId::new(7),
instance_id: activation_instance_id(),
method_id: MethodId::new(1),
is_activation: true,
elapsed: Duration::from_millis(3),
});
recorder.record(RpcServerMetricEvent::RequestFailed {
connection_id: 1,
request_id: RequestId::new(8),
instance_id: activation_instance_id(),
method_id: MethodId::new(2),
is_activation: true,
elapsed: Duration::from_millis(5),
error_code: RuntimeErrorCode::InternalRuntimeError,
});
let snapshot = recorder.snapshot();
assert_eq!(snapshot.connections_started, 1);
assert_eq!(snapshot.connections_ended, 1);
assert_eq!(snapshot.connections_ended_successfully, 1);
assert_eq!(snapshot.requests_completed, 1);
assert_eq!(snapshot.requests_failed, 1);
assert_eq!(snapshot.request_elapsed_total, Duration::from_millis(8));
assert_eq!(snapshot.request_elapsed_max, Duration::from_millis(5));
}
#[test]
fn security_defaults_are_local_auth_disabled() {
let config = RpcServerSecurityConfig::default();
assert_eq!(config.connection_scope, ConnectionScope::LocalOnly);
assert_eq!(config.auth, RpcServerAuthConfig::Disabled);
}
#[tokio::test]
async fn token_auth_accepts_matching_token() {
let server = RpcServerBuilder::new()
.security(RpcServerSecurityConfig::default().with_token("secret"))
.build();
let ack = run_handshake(server, vec![auth_option("secret")])
.await
.expect("handshake");
assert!(matches!(ack, Envelope::HelloAck(_)));
}
#[tokio::test]
async fn token_auth_rejects_missing_token() {
let server = RpcServerBuilder::new()
.security(RpcServerSecurityConfig::default().with_token("secret"))
.build();
let err = run_handshake(server, Vec::new())
.await
.expect_err("must reject");
assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
}
#[tokio::test]
async fn token_auth_rejects_wrong_token() {
let server = RpcServerBuilder::new()
.security(RpcServerSecurityConfig::default().with_token("secret"))
.build();
let err = run_handshake(server, vec![auth_option("wrong")])
.await
.expect_err("must reject");
assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
}
#[tokio::test]
async fn token_auth_rejects_non_string_token() {
let server = RpcServerBuilder::new()
.security(RpcServerSecurityConfig::default().with_token("secret"))
.build();
let err = run_handshake(
server,
vec![(
DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
Value::from(123_u64),
)],
)
.await
.expect_err("must reject");
assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
}
#[tokio::test]
async fn metrics_recorder_observes_handshake_failure() {
let recorder = Arc::new(RpcServerMetricsRecorder::new());
let server = RpcServerBuilder::new()
.metrics_sink(recorder.clone())
.security(RpcServerSecurityConfig::default().with_token("secret"))
.build();
let err = run_handshake(server, Vec::new())
.await
.expect_err("must reject");
assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
let snapshot = recorder.snapshot();
assert_eq!(snapshot.connections_started, 1);
assert_eq!(snapshot.connections_ended, 1);
assert_eq!(snapshot.connections_ended_successfully, 0);
assert_eq!(snapshot.handshakes_completed, 0);
assert_eq!(snapshot.handshakes_failed, 1);
}
#[tokio::test]
async fn metrics_recorder_observes_success_failure_and_slow_requests() {
let recorder = Arc::new(RpcServerMetricsRecorder::new());
let mut builder = RpcServerBuilder::new()
.metrics_sink(recorder.clone())
.observability(
RpcServerObservabilityConfig::default()
.with_slow_call_threshold(Duration::from_nanos(0)),
);
let instance_id = builder.register_named_instance(
"metrics",
activation_service_guid(),
[1, 2],
Arc::new(MetricsTestHandler),
);
let server = builder.build();
let (client_stream, server_stream) = duplex(4096);
let server_connection = IpcConnection::from_stream(server_stream, FrameConfig::default());
let server_task =
tokio::spawn(async move { server.serve_connection(server_connection).await });
let client_connection = IpcConnection::from_stream(client_stream, FrameConfig::default());
let (sender, mut receiver) = client_connection.split();
send_hello(&sender).await;
assert!(matches!(
receiver.recv_envelope().await.expect("recv ack"),
Some(Envelope::HelloAck(_))
));
sender
.send_envelope(&Envelope::Request(Request {
request_id: RequestId::new(11),
instance_id,
method_id: MethodId::new(1),
payload: Value::from("ok"),
}))
.await
.expect("send success request");
assert!(matches!(
receiver.recv_envelope().await.expect("recv response"),
Some(Envelope::ResponseOk(_))
));
sender
.send_envelope(&Envelope::Request(Request {
request_id: RequestId::new(12),
instance_id,
method_id: MethodId::new(2),
payload: Value::Nil,
}))
.await
.expect("send failing request");
assert!(matches!(
receiver.recv_envelope().await.expect("recv error"),
Some(Envelope::ResponseError(_))
));
sender
.send_envelope(&Envelope::Goodbye(Goodbye {
reason_code: 0,
message: Some("done".to_string()),
}))
.await
.expect("send goodbye");
drop(sender);
drop(receiver);
server_task.await.expect("server task").expect("serve");
let snapshot = recorder.snapshot();
assert_eq!(snapshot.connections_started, 1);
assert_eq!(snapshot.connections_ended_successfully, 1);
assert_eq!(snapshot.handshakes_completed, 1);
assert_eq!(snapshot.requests_started, 2);
assert_eq!(snapshot.requests_completed, 1);
assert_eq!(snapshot.requests_failed, 1);
assert_eq!(snapshot.requests_slow, 1);
assert!(snapshot.request_elapsed_total > Duration::ZERO);
}
async fn run_handshake(server: RpcServer, options: Options) -> Result<Envelope, RuntimeError> {
let (client_stream, server_stream) = duplex(4096);
let server_connection = IpcConnection::from_stream(server_stream, FrameConfig::default());
let server_task =
tokio::spawn(async move { server.serve_connection(server_connection).await });
let client_connection = IpcConnection::from_stream(client_stream, FrameConfig::default());
let (sender, mut receiver) = client_connection.split();
sender
.send_envelope(&hello_envelope(options))
.await
.expect("send hello");
let envelope = receiver.recv_envelope().await;
drop(sender);
drop(receiver);
let server_result = server_task.await.expect("server task");
match envelope.expect("recv hello ack") {
Some(envelope) => Ok(envelope),
None => Err(server_result.expect_err("server should return handshake error")),
}
}
fn auth_option(token: &str) -> (String, Value) {
(
DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
Value::from(token),
)
}
async fn send_hello(sender: &RpcSender) {
sender
.send_envelope(&hello_envelope(Vec::new()))
.await
.expect("send hello");
}
fn hello_envelope(options: Options) -> Envelope {
Envelope::Hello(Hello {
protocol_version: RUNTIME_PROTOCOL_VERSION,
role: Role::Client,
capability_bits: CapabilityFlags::empty(),
max_message_size: 16 * 1024 * 1024,
options,
})
}
struct MetricsTestHandler;
impl RpcServiceHandler for MetricsTestHandler {
fn call(&self, _: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
Box::pin(async move {
match method_id.get() {
1 => Ok(payload),
_ => Err(RuntimeError::runtime(
RuntimeErrorCode::InternalRuntimeError,
"test failure",
)),
}
})
}
}
}