Skip to main content

rpc_runtime_client/
lib.rs

1use std::collections::{BTreeMap, HashMap, VecDeque};
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::{Arc, Mutex as StdMutex};
4use std::time::Duration;
5
6use rmpv::Value;
7use rpc_runtime_activation::{
8    CREATE_INSTANCE_METHOD_ID, CreateInstanceRequest, LIST_INSTANCES_METHOD_ID,
9    ListInstancesRequest, RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID,
10    ReleaseInstanceRequest, ResolveInstanceIdsRequest, activation_instance_id,
11    decode_create_instance_response, decode_list_instances_response,
12    decode_release_instance_response, decode_resolve_instance_ids_response,
13    encode_create_instance_request, encode_list_instances_request, encode_release_instance_request,
14    encode_resolve_instance_ids_request,
15};
16use rpc_runtime_core::{
17    CapabilityFlags, Envelope, Hello, InstanceId, MethodId, Notification, Options,
18    RUNTIME_PROTOCOL_VERSION, Request, RequestId, Role, ServiceGuid,
19};
20use rpc_runtime_errors::{ErrorKind, RuntimeError, RuntimeErrorCode};
21use rpc_runtime_transport::{RpcConnection, RpcReceiver, RpcSender};
22use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection, IpcEndpoint};
23use tokio::sync::{Mutex, Notify, broadcast, oneshot};
24
25#[derive(Clone)]
26pub struct RpcClient {
27    inner: Arc<ClientInner>,
28}
29
30pub const DEFAULT_AUTH_TOKEN_OPTION_KEY: &str = "tripley.auth.token";
31pub const DEFAULT_NOTIFICATION_BUFFER_SIZE: usize = 128;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum NotificationOverflowPolicy {
35    DropOldest,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct RpcClientNotificationConfig {
40    pub buffer_size: usize,
41    pub overflow_policy: NotificationOverflowPolicy,
42}
43
44impl RpcClientNotificationConfig {
45    pub fn new(buffer_size: usize) -> Self {
46        Self {
47            buffer_size: buffer_size.max(1),
48            overflow_policy: NotificationOverflowPolicy::DropOldest,
49        }
50    }
51
52    pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
53        self.buffer_size = buffer_size.max(1);
54        self
55    }
56
57    pub fn with_overflow_policy(mut self, overflow_policy: NotificationOverflowPolicy) -> Self {
58        self.overflow_policy = overflow_policy;
59        self
60    }
61}
62
63impl Default for RpcClientNotificationConfig {
64    fn default() -> Self {
65        Self::new(DEFAULT_NOTIFICATION_BUFFER_SIZE)
66    }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
70pub struct RpcClientHandshakeConfig {
71    pub auth_token: Option<String>,
72    pub auth_option_key: String,
73}
74
75impl RpcClientHandshakeConfig {
76    pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
77        self.auth_token = Some(token.into());
78        self
79    }
80
81    pub fn with_auth_option_key(mut self, key: impl Into<String>) -> Self {
82        self.auth_option_key = key.into();
83        self
84    }
85
86    fn hello_options(&self) -> Options {
87        self.auth_token
88            .as_ref()
89            .map(|token| vec![(self.auth_option_key.clone(), Value::from(token.as_str()))])
90            .unwrap_or_default()
91    }
92}
93
94impl Default for RpcClientHandshakeConfig {
95    fn default() -> Self {
96        Self {
97            auth_token: None,
98            auth_option_key: DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
99        }
100    }
101}
102
103struct ClientInner {
104    sender: RpcSender,
105    next_request_id: AtomicU64,
106    pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, RuntimeError>>>>,
107    notifications: broadcast::Sender<Notification>,
108    notification_config: RpcClientNotificationConfig,
109}
110
111impl RpcClient {
112    pub async fn connect(endpoint: IpcEndpoint, config: FrameConfig) -> Result<Self, RuntimeError> {
113        Self::connect_with_handshake_config(endpoint, config, RpcClientHandshakeConfig::default())
114            .await
115    }
116
117    pub async fn connect_with_handshake_config(
118        endpoint: IpcEndpoint,
119        config: FrameConfig,
120        handshake: RpcClientHandshakeConfig,
121    ) -> Result<Self, RuntimeError> {
122        Self::connect_with_configs(
123            endpoint,
124            config,
125            handshake,
126            RpcClientNotificationConfig::default(),
127        )
128        .await
129    }
130
131    pub async fn connect_with_configs(
132        endpoint: IpcEndpoint,
133        config: FrameConfig,
134        handshake: RpcClientHandshakeConfig,
135        notifications: RpcClientNotificationConfig,
136    ) -> Result<Self, RuntimeError> {
137        let connection = IpcConnection::connect(endpoint, config)
138            .await
139            .map_err(|err| {
140                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
141            })?;
142        Self::from_connection_with_configs(connection, handshake, notifications).await
143    }
144
145    pub async fn from_connection<C>(connection: C) -> Result<Self, RuntimeError>
146    where
147        C: Into<RpcConnection>,
148    {
149        Self::from_connection_with_handshake_config(connection, RpcClientHandshakeConfig::default())
150            .await
151    }
152
153    pub async fn from_connection_with_handshake_config<C>(
154        connection: C,
155        handshake: RpcClientHandshakeConfig,
156    ) -> Result<Self, RuntimeError>
157    where
158        C: Into<RpcConnection>,
159    {
160        Self::from_connection_with_configs(
161            connection,
162            handshake,
163            RpcClientNotificationConfig::default(),
164        )
165        .await
166    }
167
168    pub async fn from_connection_with_configs<C>(
169        connection: C,
170        handshake: RpcClientHandshakeConfig,
171        notifications: RpcClientNotificationConfig,
172    ) -> Result<Self, RuntimeError>
173    where
174        C: Into<RpcConnection>,
175    {
176        let (sender, mut receiver) = connection.into().split();
177        sender
178            .send_envelope(&Envelope::Hello(Hello {
179                protocol_version: RUNTIME_PROTOCOL_VERSION,
180                role: Role::Client,
181                capability_bits: client_capabilities(),
182                max_message_size: rpc_runtime_codec_msgpack::DEFAULT_MAX_MESSAGE_SIZE as u64,
183                options: handshake.hello_options(),
184            }))
185            .await
186            .map_err(|err| {
187                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
188            })?;
189
190        let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
191            RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
192        })?
193        else {
194            return Err(RuntimeError::transport(
195                RuntimeErrorCode::InternalRuntimeError,
196                "server disconnected during handshake",
197            ));
198        };
199        let Envelope::HelloAck(ack) = envelope else {
200            return Err(RuntimeError::protocol(
201                RuntimeErrorCode::InvalidEnvelope,
202                "expected HELLO_ACK during handshake",
203            ));
204        };
205        if ack.protocol_version != RUNTIME_PROTOCOL_VERSION {
206            return Err(RuntimeError::protocol(
207                RuntimeErrorCode::UnsupportedProtocolVersion,
208                "server returned unsupported protocol version",
209            ));
210        }
211
212        let notifications = RpcClientNotificationConfig {
213            buffer_size: notifications.buffer_size.max(1),
214            ..notifications
215        };
216        let (notification_tx, _) = broadcast::channel(notifications.buffer_size);
217        let inner = Arc::new(ClientInner {
218            sender,
219            next_request_id: AtomicU64::new(1),
220            pending: Mutex::new(HashMap::new()),
221            notifications: notification_tx,
222            notification_config: notifications,
223        });
224        spawn_receive_loop(Arc::clone(&inner), receiver);
225        Ok(Self { inner })
226    }
227
228    pub async fn call(
229        &self,
230        instance_id: InstanceId,
231        method_id: MethodId,
232        payload: Value,
233    ) -> Result<Value, RuntimeError> {
234        self.call_with_optional_timeout(instance_id, method_id, payload, None)
235            .await
236    }
237
238    async fn call_with_optional_timeout(
239        &self,
240        instance_id: InstanceId,
241        method_id: MethodId,
242        payload: Value,
243        timeout: Option<Duration>,
244    ) -> Result<Value, RuntimeError> {
245        let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
246        let (tx, rx) = oneshot::channel();
247        self.inner.pending.lock().await.insert(request_id, tx);
248
249        let send_result = self
250            .inner
251            .sender
252            .send_envelope(&Envelope::Request(Request {
253                request_id: RequestId::new(request_id),
254                instance_id,
255                method_id,
256                payload,
257            }))
258            .await;
259        if let Err(err) = send_result {
260            self.inner.pending.lock().await.remove(&request_id);
261            return Err(RuntimeError::transport(
262                RuntimeErrorCode::InternalRuntimeError,
263                err.to_string(),
264            ));
265        }
266
267        let response = if let Some(timeout) = timeout {
268            match tokio::time::timeout(timeout, rx).await {
269                Ok(response) => response,
270                Err(_) => {
271                    self.inner.pending.lock().await.remove(&request_id);
272                    return Err(RuntimeError::runtime(
273                        RuntimeErrorCode::RequestTimeout,
274                        "request timed out",
275                    ));
276                }
277            }
278        } else {
279            rx.await
280        };
281
282        response.map_err(|_| {
283            RuntimeError::transport(
284                RuntimeErrorCode::InternalRuntimeError,
285                "response channel closed before request completed",
286            )
287        })?
288    }
289
290    pub async fn call_timeout(
291        &self,
292        instance_id: InstanceId,
293        method_id: MethodId,
294        payload: Value,
295        timeout: Duration,
296    ) -> Result<Value, RuntimeError> {
297        self.call_with_optional_timeout(instance_id, method_id, payload, Some(timeout))
298            .await
299    }
300
301    pub async fn resolve_instance_ids(&self, names: Vec<String>) -> Result<Vec<u64>, RuntimeError> {
302        let response = self
303            .call(
304                activation_instance_id(),
305                MethodId::new(RESOLVE_INSTANCE_IDS_METHOD_ID),
306                encode_resolve_instance_ids_request(&ResolveInstanceIdsRequest {
307                    instance_names: names,
308                }),
309            )
310            .await?;
311        Ok(decode_resolve_instance_ids_response(&response)?.instance_ids)
312    }
313
314    pub async fn create_instance(
315        &self,
316        service_guid: ServiceGuid,
317        create_payload: Option<Vec<u8>>,
318        options: BTreeMap<String, String>,
319    ) -> Result<InstanceId, RuntimeError> {
320        let response = self
321            .call(
322                activation_instance_id(),
323                MethodId::new(CREATE_INSTANCE_METHOD_ID),
324                encode_create_instance_request(&CreateInstanceRequest {
325                    service_guid,
326                    create_payload,
327                    options,
328                }),
329            )
330            .await?;
331        Ok(decode_create_instance_response(&response)?.instance_id)
332    }
333
334    pub async fn release_instance(&self, instance_id: InstanceId) -> Result<(), RuntimeError> {
335        let response = self
336            .call(
337                activation_instance_id(),
338                MethodId::new(RELEASE_INSTANCE_METHOD_ID),
339                encode_release_instance_request(&ReleaseInstanceRequest { instance_id }),
340            )
341            .await?;
342        decode_release_instance_response(&response)?;
343        Ok(())
344    }
345
346    pub async fn list_instances(
347        &self,
348        service_guid: Option<ServiceGuid>,
349    ) -> Result<Vec<rpc_runtime_activation::InstanceDescriptor>, RuntimeError> {
350        let response = self
351            .call(
352                activation_instance_id(),
353                MethodId::new(LIST_INSTANCES_METHOD_ID),
354                encode_list_instances_request(&ListInstancesRequest { service_guid }),
355            )
356            .await?;
357        Ok(decode_list_instances_response(&response)?.instances)
358    }
359
360    pub fn subscribe_notifications(
361        &self,
362        instance_id_filter: Option<InstanceId>,
363        notification_id_filter: Option<u32>,
364    ) -> RpcNotificationReceiver {
365        let mut source = self.inner.notifications.subscribe();
366        let queue = Arc::new(BoundedNotificationQueue::new(
367            self.inner.notification_config.buffer_size,
368            self.inner.notification_config.overflow_policy,
369        ));
370        let receiver = RpcNotificationReceiver {
371            queue: Arc::clone(&queue),
372        };
373        tokio::spawn(async move {
374            loop {
375                let Ok(notification) = source.recv().await else {
376                    break;
377                };
378                let instance_matches = instance_id_filter
379                    .is_none_or(|expected| notification.instance_id == Some(expected));
380                let notification_matches = notification_id_filter
381                    .is_none_or(|expected| notification.notification_id.get() == expected);
382                if instance_matches && notification_matches {
383                    queue.push(notification);
384                }
385            }
386            queue.close();
387        });
388        receiver
389    }
390
391    pub async fn goodbye(&self, message: impl Into<String>) -> Result<(), RuntimeError> {
392        self.inner
393            .sender
394            .send_envelope(&Envelope::Goodbye(rpc_runtime_core::Goodbye {
395                reason_code: 0,
396                message: Some(message.into()),
397            }))
398            .await
399            .map_err(|err| {
400                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
401            })
402    }
403}
404
405pub struct RpcNotificationReceiver {
406    queue: Arc<BoundedNotificationQueue>,
407}
408
409impl RpcNotificationReceiver {
410    pub async fn recv(&mut self) -> Option<Notification> {
411        self.queue.recv().await
412    }
413}
414
415struct BoundedNotificationQueue {
416    state: StdMutex<BoundedNotificationQueueState>,
417    notify: Notify,
418    capacity: usize,
419    overflow_policy: NotificationOverflowPolicy,
420}
421
422struct BoundedNotificationQueueState {
423    items: VecDeque<Notification>,
424    closed: bool,
425}
426
427impl BoundedNotificationQueue {
428    fn new(capacity: usize, overflow_policy: NotificationOverflowPolicy) -> Self {
429        Self {
430            state: StdMutex::new(BoundedNotificationQueueState {
431                items: VecDeque::new(),
432                closed: false,
433            }),
434            notify: Notify::new(),
435            capacity: capacity.max(1),
436            overflow_policy,
437        }
438    }
439
440    fn push(&self, notification: Notification) {
441        let mut state = self
442            .state
443            .lock()
444            .expect("notification queue mutex poisoned");
445        if state.closed {
446            return;
447        }
448        if state.items.len() == self.capacity {
449            match self.overflow_policy {
450                NotificationOverflowPolicy::DropOldest => {
451                    state.items.pop_front();
452                }
453            }
454        }
455        state.items.push_back(notification);
456        drop(state);
457        self.notify.notify_one();
458    }
459
460    fn close(&self) {
461        let mut state = self
462            .state
463            .lock()
464            .expect("notification queue mutex poisoned");
465        state.closed = true;
466        drop(state);
467        self.notify.notify_waiters();
468    }
469
470    async fn recv(&self) -> Option<Notification> {
471        loop {
472            let notified = self.notify.notified();
473            {
474                let mut state = self
475                    .state
476                    .lock()
477                    .expect("notification queue mutex poisoned");
478                if let Some(notification) = state.items.pop_front() {
479                    return Some(notification);
480                }
481                if state.closed {
482                    return None;
483                }
484            }
485            notified.await;
486        }
487    }
488}
489
490fn spawn_receive_loop(inner: Arc<ClientInner>, mut receiver: RpcReceiver) {
491    tokio::spawn(async move {
492        loop {
493            let envelope = match receiver.recv_envelope().await {
494                Ok(Some(envelope)) => envelope,
495                Ok(None) => {
496                    fail_pending(
497                        &inner,
498                        RuntimeError::transport(
499                            RuntimeErrorCode::InternalRuntimeError,
500                            "server disconnected",
501                        ),
502                    )
503                    .await;
504                    break;
505                }
506                Err(err) => {
507                    fail_pending(
508                        &inner,
509                        RuntimeError::transport(
510                            RuntimeErrorCode::InternalRuntimeError,
511                            err.to_string(),
512                        ),
513                    )
514                    .await;
515                    break;
516                }
517            };
518            match envelope {
519                Envelope::ResponseOk(response) => {
520                    complete_pending(&inner, response.request_id.get(), Ok(response.payload)).await;
521                }
522                Envelope::ResponseError(response) => {
523                    complete_pending(
524                        &inner,
525                        response.request_id.get(),
526                        Err(RuntimeError::new(
527                            runtime_error_code(response.error_code),
528                            error_kind(response.error_kind),
529                            response.error_message.unwrap_or_default(),
530                        )),
531                    )
532                    .await;
533                }
534                Envelope::Notification(notification) => {
535                    let _ = inner.notifications.send(notification);
536                }
537                _ => {
538                    fail_pending(
539                        &inner,
540                        RuntimeError::protocol(
541                            RuntimeErrorCode::InvalidEnvelope,
542                            "client received invalid envelope kind",
543                        ),
544                    )
545                    .await;
546                    break;
547                }
548            }
549        }
550    });
551}
552
553async fn complete_pending(
554    inner: &ClientInner,
555    request_id: u64,
556    result: Result<Value, RuntimeError>,
557) {
558    if let Some(sender) = inner.pending.lock().await.remove(&request_id) {
559        let _ = sender.send(result);
560    }
561}
562
563async fn fail_pending(inner: &ClientInner, error: RuntimeError) {
564    let pending = std::mem::take(&mut *inner.pending.lock().await);
565    for (_, sender) in pending {
566        let _ = sender.send(Err(error.clone()));
567    }
568}
569
570fn client_capabilities() -> CapabilityFlags {
571    CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
572        | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
573        | CapabilityFlags::SERVICE_ACTIVATION
574        | CapabilityFlags::GOODBYE
575}
576
577fn runtime_error_code(value: i32) -> RuntimeErrorCode {
578    match value {
579        1001 => RuntimeErrorCode::UnknownMessageKind,
580        1002 => RuntimeErrorCode::UnsupportedProtocolVersion,
581        1003 => RuntimeErrorCode::InvalidEnvelope,
582        1004 => RuntimeErrorCode::InvalidRequestId,
583        1005 => RuntimeErrorCode::InvalidInstanceId,
584        1006 => RuntimeErrorCode::InstanceNotFound,
585        1007 => RuntimeErrorCode::MethodNotFound,
586        1008 => RuntimeErrorCode::NotificationNotFound,
587        1009 => RuntimeErrorCode::PayloadDecodeFailed,
588        1010 => RuntimeErrorCode::PayloadEncodeFailed,
589        1011 => RuntimeErrorCode::ServiceActivationNotSupported,
590        1012 => RuntimeErrorCode::ServiceGuidNotFound,
591        1013 => RuntimeErrorCode::InstanceReleaseNotAllowed,
592        1014 => RuntimeErrorCode::RequestTimeout,
593        1015 => RuntimeErrorCode::UnsupportedCapability,
594        1016 => RuntimeErrorCode::BusinessErrorDeclared,
595        1017 => RuntimeErrorCode::DuplicateRequestId,
596        1018 => RuntimeErrorCode::RequestCancelUnsupported,
597        1019 => RuntimeErrorCode::AccessDenied,
598        _ => RuntimeErrorCode::InternalRuntimeError,
599    }
600}
601
602fn error_kind(value: u8) -> ErrorKind {
603    match value {
604        1 => ErrorKind::Transport,
605        2 => ErrorKind::Protocol,
606        3 => ErrorKind::Runtime,
607        4 => ErrorKind::Business,
608        5 => ErrorKind::Timeout,
609        6 => ErrorKind::Cancelled,
610        _ => ErrorKind::Runtime,
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use std::sync::Arc;
617
618    use rpc_runtime_core::{CapabilityFlags, HelloAck};
619    use rpc_runtime_transport::{
620        EnvelopeReader, EnvelopeWriter, RpcConnection, RpcReceiver, RpcSender, TransportError,
621        TransportFuture,
622    };
623    use tokio::sync::mpsc;
624
625    use super::*;
626
627    #[tokio::test]
628    async fn call_timeout_removes_pending_request() {
629        let (tx, rx) = mpsc::unbounded_channel();
630        tx.send(Some(Envelope::HelloAck(HelloAck {
631            protocol_version: RUNTIME_PROTOCOL_VERSION,
632            accepted_capability_bits: CapabilityFlags::GOODBYE,
633            max_message_size: rpc_runtime_codec_msgpack::DEFAULT_MAX_MESSAGE_SIZE as u64,
634            options: Vec::new(),
635        })))
636        .expect("preload handshake ack");
637
638        let connection = RpcConnection::new(
639            RpcSender::new(Arc::new(NoopWriter)),
640            RpcReceiver::new(Box::new(ChannelReader { rx })),
641        );
642        let client = RpcClient::from_connection(connection)
643            .await
644            .expect("client handshake");
645
646        let err = client
647            .call_timeout(
648                InstanceId::new(1).expect("instance id"),
649                MethodId::new(1),
650                Value::Nil,
651                Duration::from_millis(1),
652            )
653            .await
654            .expect_err("call must time out");
655
656        assert_eq!(err.code, RuntimeErrorCode::RequestTimeout);
657        assert_eq!(client.inner.pending.lock().await.len(), 0);
658
659        drop(tx);
660    }
661
662    #[tokio::test]
663    async fn notification_receiver_drops_oldest_when_full() {
664        let queue = Arc::new(BoundedNotificationQueue::new(
665            2,
666            NotificationOverflowPolicy::DropOldest,
667        ));
668        let mut receiver = RpcNotificationReceiver {
669            queue: Arc::clone(&queue),
670        };
671
672        for value in 1..=3 {
673            queue.push(Notification {
674                instance_id: None,
675                notification_id: rpc_runtime_core::NotificationId::new(7),
676                payload: Value::from(value),
677            });
678        }
679        queue.close();
680
681        let first = receiver.recv().await.expect("first notification");
682        let second = receiver.recv().await.expect("second notification");
683        assert_eq!(first.payload, Value::from(2));
684        assert_eq!(second.payload, Value::from(3));
685        assert!(receiver.recv().await.is_none());
686    }
687
688    struct NoopWriter;
689
690    impl EnvelopeWriter for NoopWriter {
691        fn send_envelope<'a>(&'a self, _: &'a Envelope) -> TransportFuture<'a, ()> {
692            Box::pin(async { Ok(()) })
693        }
694
695        fn shutdown<'a>(&'a self) -> TransportFuture<'a, ()> {
696            Box::pin(async { Ok(()) })
697        }
698    }
699
700    struct ChannelReader {
701        rx: mpsc::UnboundedReceiver<Option<Envelope>>,
702    }
703
704    impl EnvelopeReader for ChannelReader {
705        fn recv_envelope<'a>(&'a mut self) -> TransportFuture<'a, Option<Envelope>> {
706            Box::pin(async move {
707                Ok(self.rx.recv().await.ok_or_else(|| {
708                    TransportError::Io(std::io::Error::new(
709                        std::io::ErrorKind::UnexpectedEof,
710                        "test channel closed",
711                    ))
712                })?)
713            })
714        }
715    }
716}