1use std::collections::{BTreeMap, HashMap, VecDeque};
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::{Arc, Mutex as StdMutex};
4use std::time::Duration;
5
6use rmpv::Value;
7use rpc_runtime_activation::{
8 CREATE_INSTANCE_METHOD_ID, CreateInstanceRequest, LIST_INSTANCES_METHOD_ID,
9 ListInstancesRequest, RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID,
10 ReleaseInstanceRequest, ResolveInstanceIdsRequest, activation_instance_id,
11 decode_create_instance_response, decode_list_instances_response,
12 decode_release_instance_response, decode_resolve_instance_ids_response,
13 encode_create_instance_request, encode_list_instances_request, encode_release_instance_request,
14 encode_resolve_instance_ids_request,
15};
16use rpc_runtime_core::{
17 CapabilityFlags, Envelope, Hello, InstanceId, MethodId, Notification, Options,
18 RUNTIME_PROTOCOL_VERSION, Request, RequestId, Role, ServiceGuid,
19};
20use rpc_runtime_errors::{ErrorKind, RuntimeError, RuntimeErrorCode};
21use rpc_runtime_transport::{RpcConnection, RpcReceiver, RpcSender};
22use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection, IpcEndpoint};
23use tokio::sync::{Mutex, Notify, broadcast, oneshot};
24
25#[derive(Clone)]
26pub struct RpcClient {
27 inner: Arc<ClientInner>,
28}
29
30pub const DEFAULT_AUTH_TOKEN_OPTION_KEY: &str = "tripley.auth.token";
31pub const DEFAULT_NOTIFICATION_BUFFER_SIZE: usize = 128;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum NotificationOverflowPolicy {
35 DropOldest,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct RpcClientNotificationConfig {
40 pub buffer_size: usize,
41 pub overflow_policy: NotificationOverflowPolicy,
42}
43
44impl RpcClientNotificationConfig {
45 pub fn new(buffer_size: usize) -> Self {
46 Self {
47 buffer_size: buffer_size.max(1),
48 overflow_policy: NotificationOverflowPolicy::DropOldest,
49 }
50 }
51
52 pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
53 self.buffer_size = buffer_size.max(1);
54 self
55 }
56
57 pub fn with_overflow_policy(mut self, overflow_policy: NotificationOverflowPolicy) -> Self {
58 self.overflow_policy = overflow_policy;
59 self
60 }
61}
62
63impl Default for RpcClientNotificationConfig {
64 fn default() -> Self {
65 Self::new(DEFAULT_NOTIFICATION_BUFFER_SIZE)
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
70pub struct RpcClientHandshakeConfig {
71 pub auth_token: Option<String>,
72 pub auth_option_key: String,
73}
74
75impl RpcClientHandshakeConfig {
76 pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
77 self.auth_token = Some(token.into());
78 self
79 }
80
81 pub fn with_auth_option_key(mut self, key: impl Into<String>) -> Self {
82 self.auth_option_key = key.into();
83 self
84 }
85
86 fn hello_options(&self) -> Options {
87 self.auth_token
88 .as_ref()
89 .map(|token| vec![(self.auth_option_key.clone(), Value::from(token.as_str()))])
90 .unwrap_or_default()
91 }
92}
93
94impl Default for RpcClientHandshakeConfig {
95 fn default() -> Self {
96 Self {
97 auth_token: None,
98 auth_option_key: DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
99 }
100 }
101}
102
103struct ClientInner {
104 sender: RpcSender,
105 next_request_id: AtomicU64,
106 pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, RuntimeError>>>>,
107 notifications: broadcast::Sender<Notification>,
108 notification_config: RpcClientNotificationConfig,
109}
110
111impl RpcClient {
112 pub async fn connect(endpoint: IpcEndpoint, config: FrameConfig) -> Result<Self, RuntimeError> {
113 Self::connect_with_handshake_config(endpoint, config, RpcClientHandshakeConfig::default())
114 .await
115 }
116
117 pub async fn connect_with_handshake_config(
118 endpoint: IpcEndpoint,
119 config: FrameConfig,
120 handshake: RpcClientHandshakeConfig,
121 ) -> Result<Self, RuntimeError> {
122 Self::connect_with_configs(
123 endpoint,
124 config,
125 handshake,
126 RpcClientNotificationConfig::default(),
127 )
128 .await
129 }
130
131 pub async fn connect_with_configs(
132 endpoint: IpcEndpoint,
133 config: FrameConfig,
134 handshake: RpcClientHandshakeConfig,
135 notifications: RpcClientNotificationConfig,
136 ) -> Result<Self, RuntimeError> {
137 let connection = IpcConnection::connect(endpoint, config)
138 .await
139 .map_err(|err| {
140 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
141 })?;
142 Self::from_connection_with_configs(connection, handshake, notifications).await
143 }
144
145 pub async fn from_connection<C>(connection: C) -> Result<Self, RuntimeError>
146 where
147 C: Into<RpcConnection>,
148 {
149 Self::from_connection_with_handshake_config(connection, RpcClientHandshakeConfig::default())
150 .await
151 }
152
153 pub async fn from_connection_with_handshake_config<C>(
154 connection: C,
155 handshake: RpcClientHandshakeConfig,
156 ) -> Result<Self, RuntimeError>
157 where
158 C: Into<RpcConnection>,
159 {
160 Self::from_connection_with_configs(
161 connection,
162 handshake,
163 RpcClientNotificationConfig::default(),
164 )
165 .await
166 }
167
168 pub async fn from_connection_with_configs<C>(
169 connection: C,
170 handshake: RpcClientHandshakeConfig,
171 notifications: RpcClientNotificationConfig,
172 ) -> Result<Self, RuntimeError>
173 where
174 C: Into<RpcConnection>,
175 {
176 let (sender, mut receiver) = connection.into().split();
177 sender
178 .send_envelope(&Envelope::Hello(Hello {
179 protocol_version: RUNTIME_PROTOCOL_VERSION,
180 role: Role::Client,
181 capability_bits: client_capabilities(),
182 max_message_size: rpc_runtime_codec_msgpack::DEFAULT_MAX_MESSAGE_SIZE as u64,
183 options: handshake.hello_options(),
184 }))
185 .await
186 .map_err(|err| {
187 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
188 })?;
189
190 let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
191 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
192 })?
193 else {
194 return Err(RuntimeError::transport(
195 RuntimeErrorCode::InternalRuntimeError,
196 "server disconnected during handshake",
197 ));
198 };
199 let Envelope::HelloAck(ack) = envelope else {
200 return Err(RuntimeError::protocol(
201 RuntimeErrorCode::InvalidEnvelope,
202 "expected HELLO_ACK during handshake",
203 ));
204 };
205 if ack.protocol_version != RUNTIME_PROTOCOL_VERSION {
206 return Err(RuntimeError::protocol(
207 RuntimeErrorCode::UnsupportedProtocolVersion,
208 "server returned unsupported protocol version",
209 ));
210 }
211
212 let notifications = RpcClientNotificationConfig {
213 buffer_size: notifications.buffer_size.max(1),
214 ..notifications
215 };
216 let (notification_tx, _) = broadcast::channel(notifications.buffer_size);
217 let inner = Arc::new(ClientInner {
218 sender,
219 next_request_id: AtomicU64::new(1),
220 pending: Mutex::new(HashMap::new()),
221 notifications: notification_tx,
222 notification_config: notifications,
223 });
224 spawn_receive_loop(Arc::clone(&inner), receiver);
225 Ok(Self { inner })
226 }
227
228 pub async fn call(
229 &self,
230 instance_id: InstanceId,
231 method_id: MethodId,
232 payload: Value,
233 ) -> Result<Value, RuntimeError> {
234 self.call_with_optional_timeout(instance_id, method_id, payload, None)
235 .await
236 }
237
238 async fn call_with_optional_timeout(
239 &self,
240 instance_id: InstanceId,
241 method_id: MethodId,
242 payload: Value,
243 timeout: Option<Duration>,
244 ) -> Result<Value, RuntimeError> {
245 let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
246 let (tx, rx) = oneshot::channel();
247 self.inner.pending.lock().await.insert(request_id, tx);
248
249 let send_result = self
250 .inner
251 .sender
252 .send_envelope(&Envelope::Request(Request {
253 request_id: RequestId::new(request_id),
254 instance_id,
255 method_id,
256 payload,
257 }))
258 .await;
259 if let Err(err) = send_result {
260 self.inner.pending.lock().await.remove(&request_id);
261 return Err(RuntimeError::transport(
262 RuntimeErrorCode::InternalRuntimeError,
263 err.to_string(),
264 ));
265 }
266
267 let response = if let Some(timeout) = timeout {
268 match tokio::time::timeout(timeout, rx).await {
269 Ok(response) => response,
270 Err(_) => {
271 self.inner.pending.lock().await.remove(&request_id);
272 return Err(RuntimeError::runtime(
273 RuntimeErrorCode::RequestTimeout,
274 "request timed out",
275 ));
276 }
277 }
278 } else {
279 rx.await
280 };
281
282 response.map_err(|_| {
283 RuntimeError::transport(
284 RuntimeErrorCode::InternalRuntimeError,
285 "response channel closed before request completed",
286 )
287 })?
288 }
289
290 pub async fn call_timeout(
291 &self,
292 instance_id: InstanceId,
293 method_id: MethodId,
294 payload: Value,
295 timeout: Duration,
296 ) -> Result<Value, RuntimeError> {
297 self.call_with_optional_timeout(instance_id, method_id, payload, Some(timeout))
298 .await
299 }
300
301 pub async fn resolve_instance_ids(&self, names: Vec<String>) -> Result<Vec<u64>, RuntimeError> {
302 let response = self
303 .call(
304 activation_instance_id(),
305 MethodId::new(RESOLVE_INSTANCE_IDS_METHOD_ID),
306 encode_resolve_instance_ids_request(&ResolveInstanceIdsRequest {
307 instance_names: names,
308 }),
309 )
310 .await?;
311 Ok(decode_resolve_instance_ids_response(&response)?.instance_ids)
312 }
313
314 pub async fn create_instance(
315 &self,
316 service_guid: ServiceGuid,
317 create_payload: Option<Vec<u8>>,
318 options: BTreeMap<String, String>,
319 ) -> Result<InstanceId, RuntimeError> {
320 let response = self
321 .call(
322 activation_instance_id(),
323 MethodId::new(CREATE_INSTANCE_METHOD_ID),
324 encode_create_instance_request(&CreateInstanceRequest {
325 service_guid,
326 create_payload,
327 options,
328 }),
329 )
330 .await?;
331 Ok(decode_create_instance_response(&response)?.instance_id)
332 }
333
334 pub async fn release_instance(&self, instance_id: InstanceId) -> Result<(), RuntimeError> {
335 let response = self
336 .call(
337 activation_instance_id(),
338 MethodId::new(RELEASE_INSTANCE_METHOD_ID),
339 encode_release_instance_request(&ReleaseInstanceRequest { instance_id }),
340 )
341 .await?;
342 decode_release_instance_response(&response)?;
343 Ok(())
344 }
345
346 pub async fn list_instances(
347 &self,
348 service_guid: Option<ServiceGuid>,
349 ) -> Result<Vec<rpc_runtime_activation::InstanceDescriptor>, RuntimeError> {
350 let response = self
351 .call(
352 activation_instance_id(),
353 MethodId::new(LIST_INSTANCES_METHOD_ID),
354 encode_list_instances_request(&ListInstancesRequest { service_guid }),
355 )
356 .await?;
357 Ok(decode_list_instances_response(&response)?.instances)
358 }
359
360 pub fn subscribe_notifications(
361 &self,
362 instance_id_filter: Option<InstanceId>,
363 notification_id_filter: Option<u32>,
364 ) -> RpcNotificationReceiver {
365 let mut source = self.inner.notifications.subscribe();
366 let queue = Arc::new(BoundedNotificationQueue::new(
367 self.inner.notification_config.buffer_size,
368 self.inner.notification_config.overflow_policy,
369 ));
370 let receiver = RpcNotificationReceiver {
371 queue: Arc::clone(&queue),
372 };
373 tokio::spawn(async move {
374 loop {
375 let Ok(notification) = source.recv().await else {
376 break;
377 };
378 let instance_matches = instance_id_filter
379 .is_none_or(|expected| notification.instance_id == Some(expected));
380 let notification_matches = notification_id_filter
381 .is_none_or(|expected| notification.notification_id.get() == expected);
382 if instance_matches && notification_matches {
383 queue.push(notification);
384 }
385 }
386 queue.close();
387 });
388 receiver
389 }
390
391 pub async fn goodbye(&self, message: impl Into<String>) -> Result<(), RuntimeError> {
392 self.inner
393 .sender
394 .send_envelope(&Envelope::Goodbye(rpc_runtime_core::Goodbye {
395 reason_code: 0,
396 message: Some(message.into()),
397 }))
398 .await
399 .map_err(|err| {
400 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
401 })
402 }
403}
404
405pub struct RpcNotificationReceiver {
406 queue: Arc<BoundedNotificationQueue>,
407}
408
409impl RpcNotificationReceiver {
410 pub async fn recv(&mut self) -> Option<Notification> {
411 self.queue.recv().await
412 }
413}
414
415struct BoundedNotificationQueue {
416 state: StdMutex<BoundedNotificationQueueState>,
417 notify: Notify,
418 capacity: usize,
419 overflow_policy: NotificationOverflowPolicy,
420}
421
422struct BoundedNotificationQueueState {
423 items: VecDeque<Notification>,
424 closed: bool,
425}
426
427impl BoundedNotificationQueue {
428 fn new(capacity: usize, overflow_policy: NotificationOverflowPolicy) -> Self {
429 Self {
430 state: StdMutex::new(BoundedNotificationQueueState {
431 items: VecDeque::new(),
432 closed: false,
433 }),
434 notify: Notify::new(),
435 capacity: capacity.max(1),
436 overflow_policy,
437 }
438 }
439
440 fn push(&self, notification: Notification) {
441 let mut state = self
442 .state
443 .lock()
444 .expect("notification queue mutex poisoned");
445 if state.closed {
446 return;
447 }
448 if state.items.len() == self.capacity {
449 match self.overflow_policy {
450 NotificationOverflowPolicy::DropOldest => {
451 state.items.pop_front();
452 }
453 }
454 }
455 state.items.push_back(notification);
456 drop(state);
457 self.notify.notify_one();
458 }
459
460 fn close(&self) {
461 let mut state = self
462 .state
463 .lock()
464 .expect("notification queue mutex poisoned");
465 state.closed = true;
466 drop(state);
467 self.notify.notify_waiters();
468 }
469
470 async fn recv(&self) -> Option<Notification> {
471 loop {
472 let notified = self.notify.notified();
473 {
474 let mut state = self
475 .state
476 .lock()
477 .expect("notification queue mutex poisoned");
478 if let Some(notification) = state.items.pop_front() {
479 return Some(notification);
480 }
481 if state.closed {
482 return None;
483 }
484 }
485 notified.await;
486 }
487 }
488}
489
490fn spawn_receive_loop(inner: Arc<ClientInner>, mut receiver: RpcReceiver) {
491 tokio::spawn(async move {
492 loop {
493 let envelope = match receiver.recv_envelope().await {
494 Ok(Some(envelope)) => envelope,
495 Ok(None) => {
496 fail_pending(
497 &inner,
498 RuntimeError::transport(
499 RuntimeErrorCode::InternalRuntimeError,
500 "server disconnected",
501 ),
502 )
503 .await;
504 break;
505 }
506 Err(err) => {
507 fail_pending(
508 &inner,
509 RuntimeError::transport(
510 RuntimeErrorCode::InternalRuntimeError,
511 err.to_string(),
512 ),
513 )
514 .await;
515 break;
516 }
517 };
518 match envelope {
519 Envelope::ResponseOk(response) => {
520 complete_pending(&inner, response.request_id.get(), Ok(response.payload)).await;
521 }
522 Envelope::ResponseError(response) => {
523 complete_pending(
524 &inner,
525 response.request_id.get(),
526 Err(RuntimeError::new(
527 runtime_error_code(response.error_code),
528 error_kind(response.error_kind),
529 response.error_message.unwrap_or_default(),
530 )),
531 )
532 .await;
533 }
534 Envelope::Notification(notification) => {
535 let _ = inner.notifications.send(notification);
536 }
537 _ => {
538 fail_pending(
539 &inner,
540 RuntimeError::protocol(
541 RuntimeErrorCode::InvalidEnvelope,
542 "client received invalid envelope kind",
543 ),
544 )
545 .await;
546 break;
547 }
548 }
549 }
550 });
551}
552
553async fn complete_pending(
554 inner: &ClientInner,
555 request_id: u64,
556 result: Result<Value, RuntimeError>,
557) {
558 if let Some(sender) = inner.pending.lock().await.remove(&request_id) {
559 let _ = sender.send(result);
560 }
561}
562
563async fn fail_pending(inner: &ClientInner, error: RuntimeError) {
564 let pending = std::mem::take(&mut *inner.pending.lock().await);
565 for (_, sender) in pending {
566 let _ = sender.send(Err(error.clone()));
567 }
568}
569
570fn client_capabilities() -> CapabilityFlags {
571 CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
572 | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
573 | CapabilityFlags::SERVICE_ACTIVATION
574 | CapabilityFlags::GOODBYE
575}
576
577fn runtime_error_code(value: i32) -> RuntimeErrorCode {
578 match value {
579 1001 => RuntimeErrorCode::UnknownMessageKind,
580 1002 => RuntimeErrorCode::UnsupportedProtocolVersion,
581 1003 => RuntimeErrorCode::InvalidEnvelope,
582 1004 => RuntimeErrorCode::InvalidRequestId,
583 1005 => RuntimeErrorCode::InvalidInstanceId,
584 1006 => RuntimeErrorCode::InstanceNotFound,
585 1007 => RuntimeErrorCode::MethodNotFound,
586 1008 => RuntimeErrorCode::NotificationNotFound,
587 1009 => RuntimeErrorCode::PayloadDecodeFailed,
588 1010 => RuntimeErrorCode::PayloadEncodeFailed,
589 1011 => RuntimeErrorCode::ServiceActivationNotSupported,
590 1012 => RuntimeErrorCode::ServiceGuidNotFound,
591 1013 => RuntimeErrorCode::InstanceReleaseNotAllowed,
592 1014 => RuntimeErrorCode::RequestTimeout,
593 1015 => RuntimeErrorCode::UnsupportedCapability,
594 1016 => RuntimeErrorCode::BusinessErrorDeclared,
595 1017 => RuntimeErrorCode::DuplicateRequestId,
596 1018 => RuntimeErrorCode::RequestCancelUnsupported,
597 1019 => RuntimeErrorCode::AccessDenied,
598 _ => RuntimeErrorCode::InternalRuntimeError,
599 }
600}
601
602fn error_kind(value: u8) -> ErrorKind {
603 match value {
604 1 => ErrorKind::Transport,
605 2 => ErrorKind::Protocol,
606 3 => ErrorKind::Runtime,
607 4 => ErrorKind::Business,
608 5 => ErrorKind::Timeout,
609 6 => ErrorKind::Cancelled,
610 _ => ErrorKind::Runtime,
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use std::sync::Arc;
617
618 use rpc_runtime_core::{CapabilityFlags, HelloAck};
619 use rpc_runtime_transport::{
620 EnvelopeReader, EnvelopeWriter, RpcConnection, RpcReceiver, RpcSender, TransportError,
621 TransportFuture,
622 };
623 use tokio::sync::mpsc;
624
625 use super::*;
626
627 #[tokio::test]
628 async fn call_timeout_removes_pending_request() {
629 let (tx, rx) = mpsc::unbounded_channel();
630 tx.send(Some(Envelope::HelloAck(HelloAck {
631 protocol_version: RUNTIME_PROTOCOL_VERSION,
632 accepted_capability_bits: CapabilityFlags::GOODBYE,
633 max_message_size: rpc_runtime_codec_msgpack::DEFAULT_MAX_MESSAGE_SIZE as u64,
634 options: Vec::new(),
635 })))
636 .expect("preload handshake ack");
637
638 let connection = RpcConnection::new(
639 RpcSender::new(Arc::new(NoopWriter)),
640 RpcReceiver::new(Box::new(ChannelReader { rx })),
641 );
642 let client = RpcClient::from_connection(connection)
643 .await
644 .expect("client handshake");
645
646 let err = client
647 .call_timeout(
648 InstanceId::new(1).expect("instance id"),
649 MethodId::new(1),
650 Value::Nil,
651 Duration::from_millis(1),
652 )
653 .await
654 .expect_err("call must time out");
655
656 assert_eq!(err.code, RuntimeErrorCode::RequestTimeout);
657 assert_eq!(client.inner.pending.lock().await.len(), 0);
658
659 drop(tx);
660 }
661
662 #[tokio::test]
663 async fn notification_receiver_drops_oldest_when_full() {
664 let queue = Arc::new(BoundedNotificationQueue::new(
665 2,
666 NotificationOverflowPolicy::DropOldest,
667 ));
668 let mut receiver = RpcNotificationReceiver {
669 queue: Arc::clone(&queue),
670 };
671
672 for value in 1..=3 {
673 queue.push(Notification {
674 instance_id: None,
675 notification_id: rpc_runtime_core::NotificationId::new(7),
676 payload: Value::from(value),
677 });
678 }
679 queue.close();
680
681 let first = receiver.recv().await.expect("first notification");
682 let second = receiver.recv().await.expect("second notification");
683 assert_eq!(first.payload, Value::from(2));
684 assert_eq!(second.payload, Value::from(3));
685 assert!(receiver.recv().await.is_none());
686 }
687
688 struct NoopWriter;
689
690 impl EnvelopeWriter for NoopWriter {
691 fn send_envelope<'a>(&'a self, _: &'a Envelope) -> TransportFuture<'a, ()> {
692 Box::pin(async { Ok(()) })
693 }
694
695 fn shutdown<'a>(&'a self) -> TransportFuture<'a, ()> {
696 Box::pin(async { Ok(()) })
697 }
698 }
699
700 struct ChannelReader {
701 rx: mpsc::UnboundedReceiver<Option<Envelope>>,
702 }
703
704 impl EnvelopeReader for ChannelReader {
705 fn recv_envelope<'a>(&'a mut self) -> TransportFuture<'a, Option<Envelope>> {
706 Box::pin(async move {
707 Ok(self.rx.recv().await.ok_or_else(|| {
708 TransportError::Io(std::io::Error::new(
709 std::io::ErrorKind::UnexpectedEof,
710 "test channel closed",
711 ))
712 })?)
713 })
714 }
715 }
716}