use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use rmpv::Value;
use rpc_runtime_activation::{
CREATE_INSTANCE_METHOD_ID, CreateInstanceRequest, LIST_INSTANCES_METHOD_ID,
ListInstancesRequest, RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID,
ReleaseInstanceRequest, ResolveInstanceIdsRequest, activation_instance_id,
decode_create_instance_response, decode_list_instances_response,
decode_release_instance_response, decode_resolve_instance_ids_response,
encode_create_instance_request, encode_list_instances_request, encode_release_instance_request,
encode_resolve_instance_ids_request,
};
use rpc_runtime_core::{
CapabilityFlags, Envelope, Hello, InstanceId, MethodId, Notification, Options,
RUNTIME_PROTOCOL_VERSION, Request, RequestId, Role, ServiceGuid,
};
use rpc_runtime_errors::{ErrorKind, RuntimeError, RuntimeErrorCode};
use rpc_runtime_transport::{RpcConnection, RpcReceiver, RpcSender};
use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection, IpcEndpoint};
use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
#[derive(Clone)]
pub struct RpcClient {
inner: Arc<ClientInner>,
}
pub const DEFAULT_AUTH_TOKEN_OPTION_KEY: &str = "tripley.auth.token";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RpcClientHandshakeConfig {
pub auth_token: Option<String>,
pub auth_option_key: String,
}
impl RpcClientHandshakeConfig {
pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
self.auth_token = Some(token.into());
self
}
pub fn with_auth_option_key(mut self, key: impl Into<String>) -> Self {
self.auth_option_key = key.into();
self
}
fn hello_options(&self) -> Options {
self.auth_token
.as_ref()
.map(|token| vec![(self.auth_option_key.clone(), Value::from(token.as_str()))])
.unwrap_or_default()
}
}
impl Default for RpcClientHandshakeConfig {
fn default() -> Self {
Self {
auth_token: None,
auth_option_key: DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
}
}
}
struct ClientInner {
sender: RpcSender,
next_request_id: AtomicU64,
pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, RuntimeError>>>>,
notifications: broadcast::Sender<Notification>,
}
impl RpcClient {
pub async fn connect(endpoint: IpcEndpoint, config: FrameConfig) -> Result<Self, RuntimeError> {
Self::connect_with_handshake_config(endpoint, config, RpcClientHandshakeConfig::default())
.await
}
pub async fn connect_with_handshake_config(
endpoint: IpcEndpoint,
config: FrameConfig,
handshake: RpcClientHandshakeConfig,
) -> Result<Self, RuntimeError> {
let connection = IpcConnection::connect(endpoint, config)
.await
.map_err(|err| {
RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
})?;
Self::from_connection_with_handshake_config(connection, handshake).await
}
pub async fn from_connection<C>(connection: C) -> Result<Self, RuntimeError>
where
C: Into<RpcConnection>,
{
Self::from_connection_with_handshake_config(connection, RpcClientHandshakeConfig::default())
.await
}
pub async fn from_connection_with_handshake_config<C>(
connection: C,
handshake: RpcClientHandshakeConfig,
) -> Result<Self, RuntimeError>
where
C: Into<RpcConnection>,
{
let (sender, mut receiver) = connection.into().split();
sender
.send_envelope(&Envelope::Hello(Hello {
protocol_version: RUNTIME_PROTOCOL_VERSION,
role: Role::Client,
capability_bits: client_capabilities(),
max_message_size: rpc_runtime_codec_msgpack::DEFAULT_MAX_MESSAGE_SIZE as u64,
options: handshake.hello_options(),
}))
.await
.map_err(|err| {
RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
})?;
let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
})?
else {
return Err(RuntimeError::transport(
RuntimeErrorCode::InternalRuntimeError,
"server disconnected during handshake",
));
};
let Envelope::HelloAck(ack) = envelope else {
return Err(RuntimeError::protocol(
RuntimeErrorCode::InvalidEnvelope,
"expected HELLO_ACK during handshake",
));
};
if ack.protocol_version != RUNTIME_PROTOCOL_VERSION {
return Err(RuntimeError::protocol(
RuntimeErrorCode::UnsupportedProtocolVersion,
"server returned unsupported protocol version",
));
}
let (notifications, _) = broadcast::channel(128);
let inner = Arc::new(ClientInner {
sender,
next_request_id: AtomicU64::new(1),
pending: Mutex::new(HashMap::new()),
notifications,
});
spawn_receive_loop(Arc::clone(&inner), receiver);
Ok(Self { inner })
}
pub async fn call(
&self,
instance_id: InstanceId,
method_id: MethodId,
payload: Value,
) -> Result<Value, RuntimeError> {
let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel();
self.inner.pending.lock().await.insert(request_id, tx);
let send_result = self
.inner
.sender
.send_envelope(&Envelope::Request(Request {
request_id: RequestId::new(request_id),
instance_id,
method_id,
payload,
}))
.await;
if let Err(err) = send_result {
self.inner.pending.lock().await.remove(&request_id);
return Err(RuntimeError::transport(
RuntimeErrorCode::InternalRuntimeError,
err.to_string(),
));
}
rx.await.map_err(|_| {
RuntimeError::transport(
RuntimeErrorCode::InternalRuntimeError,
"response channel closed before request completed",
)
})?
}
pub async fn call_timeout(
&self,
instance_id: InstanceId,
method_id: MethodId,
payload: Value,
timeout: Duration,
) -> Result<Value, RuntimeError> {
tokio::time::timeout(timeout, self.call(instance_id, method_id, payload))
.await
.map_err(|_| {
RuntimeError::runtime(RuntimeErrorCode::RequestTimeout, "request timed out")
})?
}
pub async fn resolve_instance_ids(&self, names: Vec<String>) -> Result<Vec<u64>, RuntimeError> {
let response = self
.call(
activation_instance_id(),
MethodId::new(RESOLVE_INSTANCE_IDS_METHOD_ID),
encode_resolve_instance_ids_request(&ResolveInstanceIdsRequest {
instance_names: names,
}),
)
.await?;
Ok(decode_resolve_instance_ids_response(&response)?.instance_ids)
}
pub async fn create_instance(
&self,
service_guid: ServiceGuid,
create_payload: Option<Vec<u8>>,
options: BTreeMap<String, String>,
) -> Result<InstanceId, RuntimeError> {
let response = self
.call(
activation_instance_id(),
MethodId::new(CREATE_INSTANCE_METHOD_ID),
encode_create_instance_request(&CreateInstanceRequest {
service_guid,
create_payload,
options,
}),
)
.await?;
Ok(decode_create_instance_response(&response)?.instance_id)
}
pub async fn release_instance(&self, instance_id: InstanceId) -> Result<(), RuntimeError> {
let response = self
.call(
activation_instance_id(),
MethodId::new(RELEASE_INSTANCE_METHOD_ID),
encode_release_instance_request(&ReleaseInstanceRequest { instance_id }),
)
.await?;
decode_release_instance_response(&response)?;
Ok(())
}
pub async fn list_instances(
&self,
service_guid: Option<ServiceGuid>,
) -> Result<Vec<rpc_runtime_activation::InstanceDescriptor>, RuntimeError> {
let response = self
.call(
activation_instance_id(),
MethodId::new(LIST_INSTANCES_METHOD_ID),
encode_list_instances_request(&ListInstancesRequest { service_guid }),
)
.await?;
Ok(decode_list_instances_response(&response)?.instances)
}
pub fn subscribe_notifications(
&self,
instance_id_filter: Option<InstanceId>,
notification_id_filter: Option<u32>,
) -> mpsc::UnboundedReceiver<Notification> {
let mut source = self.inner.notifications.subscribe();
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(async move {
loop {
let Ok(notification) = source.recv().await else {
break;
};
let instance_matches = instance_id_filter
.is_none_or(|expected| notification.instance_id == Some(expected));
let notification_matches = notification_id_filter
.is_none_or(|expected| notification.notification_id.get() == expected);
if instance_matches && notification_matches && tx.send(notification).is_err() {
break;
}
}
});
rx
}
pub async fn goodbye(&self, message: impl Into<String>) -> Result<(), RuntimeError> {
self.inner
.sender
.send_envelope(&Envelope::Goodbye(rpc_runtime_core::Goodbye {
reason_code: 0,
message: Some(message.into()),
}))
.await
.map_err(|err| {
RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
})
}
}
fn spawn_receive_loop(inner: Arc<ClientInner>, mut receiver: RpcReceiver) {
tokio::spawn(async move {
loop {
let envelope = match receiver.recv_envelope().await {
Ok(Some(envelope)) => envelope,
Ok(None) => {
fail_pending(
&inner,
RuntimeError::transport(
RuntimeErrorCode::InternalRuntimeError,
"server disconnected",
),
)
.await;
break;
}
Err(err) => {
fail_pending(
&inner,
RuntimeError::transport(
RuntimeErrorCode::InternalRuntimeError,
err.to_string(),
),
)
.await;
break;
}
};
match envelope {
Envelope::ResponseOk(response) => {
complete_pending(&inner, response.request_id.get(), Ok(response.payload)).await;
}
Envelope::ResponseError(response) => {
complete_pending(
&inner,
response.request_id.get(),
Err(RuntimeError::new(
runtime_error_code(response.error_code),
error_kind(response.error_kind),
response.error_message.unwrap_or_default(),
)),
)
.await;
}
Envelope::Notification(notification) => {
let _ = inner.notifications.send(notification);
}
_ => {
fail_pending(
&inner,
RuntimeError::protocol(
RuntimeErrorCode::InvalidEnvelope,
"client received invalid envelope kind",
),
)
.await;
break;
}
}
}
});
}
async fn complete_pending(
inner: &ClientInner,
request_id: u64,
result: Result<Value, RuntimeError>,
) {
if let Some(sender) = inner.pending.lock().await.remove(&request_id) {
let _ = sender.send(result);
}
}
async fn fail_pending(inner: &ClientInner, error: RuntimeError) {
let pending = std::mem::take(&mut *inner.pending.lock().await);
for (_, sender) in pending {
let _ = sender.send(Err(error.clone()));
}
}
fn client_capabilities() -> CapabilityFlags {
CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
| CapabilityFlags::NAMED_INSTANCE_RESOLUTION
| CapabilityFlags::SERVICE_ACTIVATION
| CapabilityFlags::GOODBYE
}
fn runtime_error_code(value: i32) -> RuntimeErrorCode {
match value {
1001 => RuntimeErrorCode::UnknownMessageKind,
1002 => RuntimeErrorCode::UnsupportedProtocolVersion,
1003 => RuntimeErrorCode::InvalidEnvelope,
1004 => RuntimeErrorCode::InvalidRequestId,
1005 => RuntimeErrorCode::InvalidInstanceId,
1006 => RuntimeErrorCode::InstanceNotFound,
1007 => RuntimeErrorCode::MethodNotFound,
1008 => RuntimeErrorCode::NotificationNotFound,
1009 => RuntimeErrorCode::PayloadDecodeFailed,
1010 => RuntimeErrorCode::PayloadEncodeFailed,
1011 => RuntimeErrorCode::ServiceActivationNotSupported,
1012 => RuntimeErrorCode::ServiceGuidNotFound,
1013 => RuntimeErrorCode::InstanceReleaseNotAllowed,
1014 => RuntimeErrorCode::RequestTimeout,
1015 => RuntimeErrorCode::UnsupportedCapability,
1016 => RuntimeErrorCode::BusinessErrorDeclared,
1017 => RuntimeErrorCode::DuplicateRequestId,
1018 => RuntimeErrorCode::RequestCancelUnsupported,
1019 => RuntimeErrorCode::AccessDenied,
_ => RuntimeErrorCode::InternalRuntimeError,
}
}
fn error_kind(value: u8) -> ErrorKind {
match value {
1 => ErrorKind::Transport,
2 => ErrorKind::Protocol,
3 => ErrorKind::Runtime,
4 => ErrorKind::Business,
5 => ErrorKind::Timeout,
6 => ErrorKind::Cancelled,
_ => ErrorKind::Runtime,
}
}