use std::collections::{BTreeMap, HashMap};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use rmpv::Value;
use rpc_runtime_activation::{
ACTIVATION_INSTANCE_ID_VALUE, ActivationMode, CREATE_INSTANCE_METHOD_ID,
CreateInstanceResponse, InstanceDescriptor, LIST_INSTANCES_METHOD_ID, ListInstancesResponse,
RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID, ReleaseInstanceResponse,
ResolveInstanceIdsResponse, activation_instance_id, activation_service_guid,
decode_create_instance_request, decode_list_instances_request, decode_release_instance_request,
decode_resolve_instance_ids_request, encode_create_instance_response,
encode_list_instances_response, encode_release_instance_response,
encode_resolve_instance_ids_response,
};
use rpc_runtime_core::{
CapabilityFlags, Envelope, HelloAck, InstanceId, MethodId, Notification,
RUNTIME_PROTOCOL_VERSION, Request, RequestId, ResponseError, ResponseOk, Role, ServiceGuid,
};
use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
use rpc_runtime_transport::{RpcConnection, RpcListener, RpcReceiver, RpcSender};
use tokio::sync::RwLock;
pub type HandlerFuture = Pin<Box<dyn Future<Output = Result<Value, RuntimeError>> + Send>>;
pub trait RpcServiceHandler: Send + Sync {
fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture;
}
impl<F> RpcServiceHandler for F
where
F: Send + Sync + 'static,
F: Fn(RpcCallContext, MethodId, Value) -> HandlerFuture,
{
fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
self(ctx, method_id, payload)
}
}
pub type FactoryFuture =
Pin<Box<dyn Future<Output = Result<Arc<dyn RpcServiceHandler>, RuntimeError>> + Send>>;
pub trait RpcServiceFactory: Send + Sync {
fn create(
&self,
ctx: RpcCallContext,
create_payload: Option<Vec<u8>>,
options: BTreeMap<String, String>,
) -> FactoryFuture;
}
impl<F> RpcServiceFactory for F
where
F: Send + Sync + 'static,
F: Fn(RpcCallContext, Option<Vec<u8>>, BTreeMap<String, String>) -> FactoryFuture,
{
fn create<'a>(
&self,
ctx: RpcCallContext,
create_payload: Option<Vec<u8>>,
options: BTreeMap<String, String>,
) -> FactoryFuture {
self(ctx, create_payload, options)
}
}
#[derive(Clone)]
pub struct RpcCallContext {
connection_id: u64,
instance_id: InstanceId,
sender: RpcSender,
}
impl RpcCallContext {
pub fn connection_id(&self) -> u64 {
self.connection_id
}
pub fn instance_id(&self) -> InstanceId {
self.instance_id
}
pub async fn notify(
&self,
instance_id: Option<InstanceId>,
notification_id: u32,
payload: Value,
) -> Result<(), RuntimeError> {
self.sender
.send_envelope(&Envelope::Notification(Notification {
instance_id,
notification_id: rpc_runtime_core::NotificationId::new(notification_id),
payload,
}))
.await
.map_err(|err| {
RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
})
}
pub async fn notify_bound(
&self,
notification_id: u32,
payload: Value,
) -> Result<(), RuntimeError> {
self.notify(Some(self.instance_id), notification_id, payload)
.await
}
}
#[derive(Clone)]
pub struct RpcServer {
state: Arc<ServerState>,
}
pub struct RpcServerBuilder {
state: ServerState,
}
impl RpcServerBuilder {
pub fn new() -> Self {
let mut state = ServerState::new();
state.insert_activation_instance();
Self { state }
}
pub fn register_named_instance(
&mut self,
name: impl Into<String>,
service_guid: ServiceGuid,
methods: impl IntoIterator<Item = u32>,
handler: Arc<dyn RpcServiceHandler>,
) -> InstanceId {
self.state.insert_instance(NewInstance {
service_guid,
name: Some(name.into()),
activation_mode: ActivationMode::NamedPrecreated,
releasable: false,
owner_connection_id: None,
methods: methods.into_iter().collect(),
handler,
})
}
pub fn register_singleton(
&mut self,
service_guid: ServiceGuid,
methods: impl IntoIterator<Item = u32>,
handler: Arc<dyn RpcServiceHandler>,
) -> InstanceId {
self.state.insert_instance(NewInstance {
service_guid,
name: None,
activation_mode: ActivationMode::Singleton,
releasable: false,
owner_connection_id: None,
methods: methods.into_iter().collect(),
handler,
})
}
pub fn register_factory(
&mut self,
service_guid: ServiceGuid,
methods: impl IntoIterator<Item = u32>,
factory: Arc<dyn RpcServiceFactory>,
) {
self.state.factories.insert(
service_guid.get(),
FactoryEntry {
methods: methods.into_iter().collect(),
factory,
},
);
}
pub fn build(self) -> RpcServer {
RpcServer {
state: Arc::new(self.state),
}
}
}
impl Default for RpcServerBuilder {
fn default() -> Self {
Self::new()
}
}
impl RpcServer {
pub async fn serve_connection<C>(&self, connection: C) -> Result<(), RuntimeError>
where
C: Into<RpcConnection>,
{
let connection_id = self
.state
.next_connection_id
.fetch_add(1, Ordering::Relaxed);
let (sender, mut receiver) = connection.into().split();
self.perform_handshake(&sender, &mut receiver).await?;
while let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
})? {
match envelope {
Envelope::Request(request) => {
let state = Arc::clone(&self.state);
let sender = sender.clone();
tokio::spawn(async move {
let request_id = request.request_id;
let response =
dispatch_request(state, sender.clone(), connection_id, request).await;
let envelope = match response {
Ok(payload) => Envelope::ResponseOk(ResponseOk {
request_id,
payload,
}),
Err(error) => runtime_error_response(request_id, error),
};
let _ = sender.send_envelope(&envelope).await;
});
}
Envelope::Goodbye(_) => break,
_ => {
return Err(RuntimeError::protocol(
RuntimeErrorCode::InvalidEnvelope,
"server expected request envelope",
));
}
}
}
self.state.cleanup_connection(connection_id).await;
Ok(())
}
pub async fn serve_listener<L>(&self, mut listener: L) -> Result<(), RuntimeError>
where
L: RpcListener + Send,
{
loop {
let connection = listener.accept().await.map_err(|err| {
RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
})?;
let server = self.clone();
tokio::spawn(async move {
let _ = server.serve_connection(connection).await;
});
}
}
pub fn spawn_listener<L>(
&self,
listener: L,
) -> tokio::task::JoinHandle<Result<(), RuntimeError>>
where
L: RpcListener + Send + 'static,
{
let server = self.clone();
tokio::spawn(async move { server.serve_listener(listener).await })
}
async fn perform_handshake(
&self,
sender: &RpcSender,
receiver: &mut RpcReceiver,
) -> Result<(), RuntimeError> {
let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
})?
else {
return Err(RuntimeError::transport(
RuntimeErrorCode::InternalRuntimeError,
"client disconnected during handshake",
));
};
let Envelope::Hello(hello) = envelope else {
return Err(RuntimeError::protocol(
RuntimeErrorCode::InvalidEnvelope,
"expected HELLO during handshake",
));
};
if hello.protocol_version != RUNTIME_PROTOCOL_VERSION || hello.role != Role::Client {
return Err(RuntimeError::protocol(
RuntimeErrorCode::UnsupportedProtocolVersion,
"unsupported client handshake",
));
}
sender
.send_envelope(&Envelope::HelloAck(HelloAck {
protocol_version: RUNTIME_PROTOCOL_VERSION,
accepted_capability_bits: server_capabilities() & hello.capability_bits,
max_message_size: hello.max_message_size,
options: Vec::new(),
}))
.await
.map_err(|err| {
RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
})
}
pub async fn list_instances(&self) -> Vec<InstanceDescriptor> {
self.state.list_instances(None).await
}
}
async fn dispatch_request(
state: Arc<ServerState>,
sender: RpcSender,
connection_id: u64,
request: Request,
) -> Result<Value, RuntimeError> {
if request.instance_id.get() == ACTIVATION_INSTANCE_ID_VALUE {
return dispatch_activation(state, sender, connection_id, request).await;
}
let instance = state.get_instance(request.instance_id).await?;
if !instance.methods.contains(&request.method_id.get()) {
return Err(RuntimeError::runtime(
RuntimeErrorCode::MethodNotFound,
format!("method id `{}` was not found", request.method_id.get()),
));
}
let ctx = RpcCallContext {
connection_id,
instance_id: request.instance_id,
sender,
};
instance
.handler
.call(ctx, request.method_id, request.payload)
.await
}
async fn dispatch_activation(
state: Arc<ServerState>,
sender: RpcSender,
connection_id: u64,
request: Request,
) -> Result<Value, RuntimeError> {
let ctx = RpcCallContext {
connection_id,
instance_id: request.instance_id,
sender,
};
match request.method_id.get() {
RESOLVE_INSTANCE_IDS_METHOD_ID => {
let request = decode_resolve_instance_ids_request(&request.payload)?;
let ids = state.resolve_instance_ids(&request.instance_names).await;
Ok(encode_resolve_instance_ids_response(
&ResolveInstanceIdsResponse { instance_ids: ids },
))
}
CREATE_INSTANCE_METHOD_ID => {
let request = decode_create_instance_request(&request.payload)?;
let factory = state.get_factory(request.service_guid).ok_or_else(|| {
RuntimeError::runtime(
RuntimeErrorCode::ServiceGuidNotFound,
"service factory was not found",
)
})?;
let handler = factory
.factory
.create(ctx, request.create_payload, request.options)
.await?;
let instance_id = state
.insert_client_instance(
request.service_guid,
connection_id,
factory.methods.clone(),
handler,
)
.await;
Ok(encode_create_instance_response(&CreateInstanceResponse {
instance_id,
}))
}
RELEASE_INSTANCE_METHOD_ID => {
let request = decode_release_instance_request(&request.payload)?;
state
.release_instance(connection_id, request.instance_id)
.await?;
Ok(encode_release_instance_response(&ReleaseInstanceResponse))
}
LIST_INSTANCES_METHOD_ID => {
let request = decode_list_instances_request(&request.payload)?;
let instances = state.list_instances(request.service_guid).await;
Ok(encode_list_instances_response(&ListInstancesResponse {
instances,
}))
}
_ => Err(RuntimeError::runtime(
RuntimeErrorCode::MethodNotFound,
"activation method was not found",
)),
}
}
fn runtime_error_response(request_id: RequestId, error: RuntimeError) -> Envelope {
Envelope::ResponseError(ResponseError {
request_id,
error_code: error.code.as_i32(),
error_kind: error.kind.as_u8(),
error_message: Some(error.message),
error_details: Value::Nil,
})
}
fn server_capabilities() -> CapabilityFlags {
CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
| CapabilityFlags::NAMED_INSTANCE_RESOLUTION
| CapabilityFlags::SERVICE_ACTIVATION
| CapabilityFlags::GOODBYE
}
struct ServerState {
next_connection_id: AtomicU64,
next_instance_id: AtomicU64,
instances: RwLock<HashMap<u64, InstanceEntry>>,
names: RwLock<HashMap<String, u64>>,
factories: HashMap<uuid::Uuid, FactoryEntry>,
}
impl ServerState {
fn new() -> Self {
Self {
next_connection_id: AtomicU64::new(1),
next_instance_id: AtomicU64::new(2),
instances: RwLock::new(HashMap::new()),
names: RwLock::new(HashMap::new()),
factories: HashMap::new(),
}
}
fn insert_activation_instance(&mut self) {
self.instances.get_mut().insert(
ACTIVATION_INSTANCE_ID_VALUE,
InstanceEntry {
instance_id: activation_instance_id(),
service_guid: activation_service_guid(),
instance_name: Some("rpc.runtime.Activation".to_string()),
activation_mode: ActivationMode::Singleton,
releasable: false,
owner_connection_id: None,
methods: vec![
RESOLVE_INSTANCE_IDS_METHOD_ID,
CREATE_INSTANCE_METHOD_ID,
RELEASE_INSTANCE_METHOD_ID,
LIST_INSTANCES_METHOD_ID,
],
handler: Arc::new(ActivationMarker),
},
);
}
fn insert_instance(&mut self, instance: NewInstance) -> InstanceId {
let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
if let Some(name) = &instance.name {
self.names.get_mut().insert(name.clone(), id);
}
self.instances.get_mut().insert(
id,
InstanceEntry {
instance_id,
service_guid: instance.service_guid,
instance_name: instance.name,
activation_mode: instance.activation_mode,
releasable: instance.releasable,
owner_connection_id: instance.owner_connection_id,
methods: instance.methods,
handler: instance.handler,
},
);
instance_id
}
async fn insert_client_instance(
&self,
service_guid: ServiceGuid,
connection_id: u64,
methods: Vec<u32>,
handler: Arc<dyn RpcServiceHandler>,
) -> InstanceId {
let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
self.instances.write().await.insert(
id,
InstanceEntry {
instance_id,
service_guid,
instance_name: None,
activation_mode: ActivationMode::Instantiable,
releasable: true,
owner_connection_id: Some(connection_id),
methods,
handler,
},
);
instance_id
}
async fn get_instance(&self, instance_id: InstanceId) -> Result<InstanceEntry, RuntimeError> {
self.instances
.read()
.await
.get(&instance_id.get())
.cloned()
.ok_or_else(|| {
RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
})
}
fn get_factory(&self, service_guid: ServiceGuid) -> Option<FactoryEntry> {
self.factories.get(&service_guid.get()).cloned()
}
async fn resolve_instance_ids(&self, names: &[String]) -> Vec<u64> {
let index = self.names.read().await;
names
.iter()
.map(|name| index.get(name).copied().unwrap_or(0))
.collect()
}
async fn release_instance(
&self,
connection_id: u64,
instance_id: InstanceId,
) -> Result<(), RuntimeError> {
let mut instances = self.instances.write().await;
let entry = instances.get(&instance_id.get()).ok_or_else(|| {
RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
})?;
if !entry.releasable {
return Err(RuntimeError::runtime(
RuntimeErrorCode::InstanceReleaseNotAllowed,
"instance is not releasable",
));
}
if entry.owner_connection_id != Some(connection_id) {
return Err(RuntimeError::runtime(
RuntimeErrorCode::AccessDenied,
"instance is owned by another connection",
));
}
instances.remove(&instance_id.get());
Ok(())
}
async fn cleanup_connection(&self, connection_id: u64) {
self.instances
.write()
.await
.retain(|_, entry| entry.owner_connection_id != Some(connection_id));
}
async fn list_instances(&self, service_guid: Option<ServiceGuid>) -> Vec<InstanceDescriptor> {
let mut values = self
.instances
.read()
.await
.values()
.filter(|entry| service_guid.is_none_or(|guid| guid == entry.service_guid))
.map(InstanceEntry::descriptor)
.collect::<Vec<_>>();
values.sort_by_key(|entry| entry.instance_id.get());
values
}
}
struct NewInstance {
service_guid: ServiceGuid,
name: Option<String>,
activation_mode: ActivationMode,
releasable: bool,
owner_connection_id: Option<u64>,
methods: Vec<u32>,
handler: Arc<dyn RpcServiceHandler>,
}
#[derive(Clone)]
struct InstanceEntry {
instance_id: InstanceId,
service_guid: ServiceGuid,
instance_name: Option<String>,
activation_mode: ActivationMode,
releasable: bool,
owner_connection_id: Option<u64>,
methods: Vec<u32>,
handler: Arc<dyn RpcServiceHandler>,
}
impl InstanceEntry {
fn descriptor(&self) -> InstanceDescriptor {
InstanceDescriptor {
instance_id: self.instance_id,
instance_name: self.instance_name.clone(),
service_guid: self.service_guid,
activation_mode: self.activation_mode,
releasable: self.releasable,
}
}
}
#[derive(Clone)]
struct FactoryEntry {
methods: Vec<u32>,
factory: Arc<dyn RpcServiceFactory>,
}
struct ActivationMarker;
impl RpcServiceHandler for ActivationMarker {
fn call(&self, _: RpcCallContext, _: MethodId, _: Value) -> HandlerFuture {
Box::pin(async {
Err(RuntimeError::runtime(
RuntimeErrorCode::InternalRuntimeError,
"activation marker should not be dispatched directly",
))
})
}
}