tripley-rpc-runtime-client 0.1.0

Client runtime for Tripley RPC generated Rust clients.
Documentation
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;

use rmpv::Value;
use rpc_runtime_activation::{
    CREATE_INSTANCE_METHOD_ID, CreateInstanceRequest, LIST_INSTANCES_METHOD_ID,
    ListInstancesRequest, RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID,
    ReleaseInstanceRequest, ResolveInstanceIdsRequest, activation_instance_id,
    decode_create_instance_response, decode_list_instances_response,
    decode_release_instance_response, decode_resolve_instance_ids_response,
    encode_create_instance_request, encode_list_instances_request, encode_release_instance_request,
    encode_resolve_instance_ids_request,
};
use rpc_runtime_core::{
    CapabilityFlags, Envelope, Hello, InstanceId, MethodId, Notification, RUNTIME_PROTOCOL_VERSION,
    Request, RequestId, Role, ServiceGuid,
};
use rpc_runtime_errors::{ErrorKind, RuntimeError, RuntimeErrorCode};
use rpc_runtime_transport::{RpcConnection, RpcReceiver, RpcSender};
use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection, IpcEndpoint};
use tokio::sync::{Mutex, broadcast, mpsc, oneshot};

#[derive(Clone)]
pub struct RpcClient {
    inner: Arc<ClientInner>,
}

struct ClientInner {
    sender: RpcSender,
    next_request_id: AtomicU64,
    pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, RuntimeError>>>>,
    notifications: broadcast::Sender<Notification>,
}

impl RpcClient {
    pub async fn connect(endpoint: IpcEndpoint, config: FrameConfig) -> Result<Self, RuntimeError> {
        let connection = IpcConnection::connect(endpoint, config)
            .await
            .map_err(|err| {
                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
            })?;
        Self::from_connection(connection).await
    }

    pub async fn from_connection<C>(connection: C) -> Result<Self, RuntimeError>
    where
        C: Into<RpcConnection>,
    {
        let (sender, mut receiver) = connection.into().split();
        sender
            .send_envelope(&Envelope::Hello(Hello {
                protocol_version: RUNTIME_PROTOCOL_VERSION,
                role: Role::Client,
                capability_bits: client_capabilities(),
                max_message_size: rpc_runtime_codec_msgpack::DEFAULT_MAX_MESSAGE_SIZE as u64,
                options: Vec::new(),
            }))
            .await
            .map_err(|err| {
                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
            })?;

        let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
            RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
        })?
        else {
            return Err(RuntimeError::transport(
                RuntimeErrorCode::InternalRuntimeError,
                "server disconnected during handshake",
            ));
        };
        let Envelope::HelloAck(ack) = envelope else {
            return Err(RuntimeError::protocol(
                RuntimeErrorCode::InvalidEnvelope,
                "expected HELLO_ACK during handshake",
            ));
        };
        if ack.protocol_version != RUNTIME_PROTOCOL_VERSION {
            return Err(RuntimeError::protocol(
                RuntimeErrorCode::UnsupportedProtocolVersion,
                "server returned unsupported protocol version",
            ));
        }

        let (notifications, _) = broadcast::channel(128);
        let inner = Arc::new(ClientInner {
            sender,
            next_request_id: AtomicU64::new(1),
            pending: Mutex::new(HashMap::new()),
            notifications,
        });
        spawn_receive_loop(Arc::clone(&inner), receiver);
        Ok(Self { inner })
    }

    pub async fn call(
        &self,
        instance_id: InstanceId,
        method_id: MethodId,
        payload: Value,
    ) -> Result<Value, RuntimeError> {
        let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
        let (tx, rx) = oneshot::channel();
        self.inner.pending.lock().await.insert(request_id, tx);

        let send_result = self
            .inner
            .sender
            .send_envelope(&Envelope::Request(Request {
                request_id: RequestId::new(request_id),
                instance_id,
                method_id,
                payload,
            }))
            .await;
        if let Err(err) = send_result {
            self.inner.pending.lock().await.remove(&request_id);
            return Err(RuntimeError::transport(
                RuntimeErrorCode::InternalRuntimeError,
                err.to_string(),
            ));
        }

        rx.await.map_err(|_| {
            RuntimeError::transport(
                RuntimeErrorCode::InternalRuntimeError,
                "response channel closed before request completed",
            )
        })?
    }

    pub async fn call_timeout(
        &self,
        instance_id: InstanceId,
        method_id: MethodId,
        payload: Value,
        timeout: Duration,
    ) -> Result<Value, RuntimeError> {
        tokio::time::timeout(timeout, self.call(instance_id, method_id, payload))
            .await
            .map_err(|_| {
                RuntimeError::runtime(RuntimeErrorCode::RequestTimeout, "request timed out")
            })?
    }

    pub async fn resolve_instance_ids(&self, names: Vec<String>) -> Result<Vec<u64>, RuntimeError> {
        let response = self
            .call(
                activation_instance_id(),
                MethodId::new(RESOLVE_INSTANCE_IDS_METHOD_ID),
                encode_resolve_instance_ids_request(&ResolveInstanceIdsRequest {
                    instance_names: names,
                }),
            )
            .await?;
        Ok(decode_resolve_instance_ids_response(&response)?.instance_ids)
    }

    pub async fn create_instance(
        &self,
        service_guid: ServiceGuid,
        create_payload: Option<Vec<u8>>,
        options: BTreeMap<String, String>,
    ) -> Result<InstanceId, RuntimeError> {
        let response = self
            .call(
                activation_instance_id(),
                MethodId::new(CREATE_INSTANCE_METHOD_ID),
                encode_create_instance_request(&CreateInstanceRequest {
                    service_guid,
                    create_payload,
                    options,
                }),
            )
            .await?;
        Ok(decode_create_instance_response(&response)?.instance_id)
    }

    pub async fn release_instance(&self, instance_id: InstanceId) -> Result<(), RuntimeError> {
        let response = self
            .call(
                activation_instance_id(),
                MethodId::new(RELEASE_INSTANCE_METHOD_ID),
                encode_release_instance_request(&ReleaseInstanceRequest { instance_id }),
            )
            .await?;
        decode_release_instance_response(&response)?;
        Ok(())
    }

    pub async fn list_instances(
        &self,
        service_guid: Option<ServiceGuid>,
    ) -> Result<Vec<rpc_runtime_activation::InstanceDescriptor>, RuntimeError> {
        let response = self
            .call(
                activation_instance_id(),
                MethodId::new(LIST_INSTANCES_METHOD_ID),
                encode_list_instances_request(&ListInstancesRequest { service_guid }),
            )
            .await?;
        Ok(decode_list_instances_response(&response)?.instances)
    }

    pub fn subscribe_notifications(
        &self,
        instance_id_filter: Option<InstanceId>,
        notification_id_filter: Option<u32>,
    ) -> mpsc::UnboundedReceiver<Notification> {
        let mut source = self.inner.notifications.subscribe();
        let (tx, rx) = mpsc::unbounded_channel();
        tokio::spawn(async move {
            loop {
                let Ok(notification) = source.recv().await else {
                    break;
                };
                let instance_matches = instance_id_filter
                    .is_none_or(|expected| notification.instance_id == Some(expected));
                let notification_matches = notification_id_filter
                    .is_none_or(|expected| notification.notification_id.get() == expected);
                if instance_matches && notification_matches && tx.send(notification).is_err() {
                    break;
                }
            }
        });
        rx
    }

    pub async fn goodbye(&self, message: impl Into<String>) -> Result<(), RuntimeError> {
        self.inner
            .sender
            .send_envelope(&Envelope::Goodbye(rpc_runtime_core::Goodbye {
                reason_code: 0,
                message: Some(message.into()),
            }))
            .await
            .map_err(|err| {
                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
            })
    }
}

fn spawn_receive_loop(inner: Arc<ClientInner>, mut receiver: RpcReceiver) {
    tokio::spawn(async move {
        loop {
            let envelope = match receiver.recv_envelope().await {
                Ok(Some(envelope)) => envelope,
                Ok(None) => {
                    fail_pending(
                        &inner,
                        RuntimeError::transport(
                            RuntimeErrorCode::InternalRuntimeError,
                            "server disconnected",
                        ),
                    )
                    .await;
                    break;
                }
                Err(err) => {
                    fail_pending(
                        &inner,
                        RuntimeError::transport(
                            RuntimeErrorCode::InternalRuntimeError,
                            err.to_string(),
                        ),
                    )
                    .await;
                    break;
                }
            };
            match envelope {
                Envelope::ResponseOk(response) => {
                    complete_pending(&inner, response.request_id.get(), Ok(response.payload)).await;
                }
                Envelope::ResponseError(response) => {
                    complete_pending(
                        &inner,
                        response.request_id.get(),
                        Err(RuntimeError::new(
                            runtime_error_code(response.error_code),
                            error_kind(response.error_kind),
                            response.error_message.unwrap_or_default(),
                        )),
                    )
                    .await;
                }
                Envelope::Notification(notification) => {
                    let _ = inner.notifications.send(notification);
                }
                _ => {
                    fail_pending(
                        &inner,
                        RuntimeError::protocol(
                            RuntimeErrorCode::InvalidEnvelope,
                            "client received invalid envelope kind",
                        ),
                    )
                    .await;
                    break;
                }
            }
        }
    });
}

async fn complete_pending(
    inner: &ClientInner,
    request_id: u64,
    result: Result<Value, RuntimeError>,
) {
    if let Some(sender) = inner.pending.lock().await.remove(&request_id) {
        let _ = sender.send(result);
    }
}

async fn fail_pending(inner: &ClientInner, error: RuntimeError) {
    let pending = std::mem::take(&mut *inner.pending.lock().await);
    for (_, sender) in pending {
        let _ = sender.send(Err(error.clone()));
    }
}

fn client_capabilities() -> CapabilityFlags {
    CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
        | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
        | CapabilityFlags::SERVICE_ACTIVATION
        | CapabilityFlags::GOODBYE
}

fn runtime_error_code(value: i32) -> RuntimeErrorCode {
    match value {
        1001 => RuntimeErrorCode::UnknownMessageKind,
        1002 => RuntimeErrorCode::UnsupportedProtocolVersion,
        1003 => RuntimeErrorCode::InvalidEnvelope,
        1004 => RuntimeErrorCode::InvalidRequestId,
        1005 => RuntimeErrorCode::InvalidInstanceId,
        1006 => RuntimeErrorCode::InstanceNotFound,
        1007 => RuntimeErrorCode::MethodNotFound,
        1008 => RuntimeErrorCode::NotificationNotFound,
        1009 => RuntimeErrorCode::PayloadDecodeFailed,
        1010 => RuntimeErrorCode::PayloadEncodeFailed,
        1011 => RuntimeErrorCode::ServiceActivationNotSupported,
        1012 => RuntimeErrorCode::ServiceGuidNotFound,
        1013 => RuntimeErrorCode::InstanceReleaseNotAllowed,
        1014 => RuntimeErrorCode::RequestTimeout,
        1015 => RuntimeErrorCode::UnsupportedCapability,
        1016 => RuntimeErrorCode::BusinessErrorDeclared,
        1017 => RuntimeErrorCode::DuplicateRequestId,
        1018 => RuntimeErrorCode::RequestCancelUnsupported,
        1019 => RuntimeErrorCode::AccessDenied,
        _ => RuntimeErrorCode::InternalRuntimeError,
    }
}

fn error_kind(value: u8) -> ErrorKind {
    match value {
        1 => ErrorKind::Transport,
        2 => ErrorKind::Protocol,
        3 => ErrorKind::Runtime,
        4 => ErrorKind::Business,
        5 => ErrorKind::Timeout,
        6 => ErrorKind::Cancelled,
        _ => ErrorKind::Runtime,
    }
}