1use std::collections::{BTreeMap, HashMap};
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::{Duration, Instant};
7use std::{panic, panic::AssertUnwindSafe};
8
9use futures_util::FutureExt;
10use rmpv::Value;
11use rpc_runtime_activation::{
12 ACTIVATION_INSTANCE_ID_VALUE, ActivationMode, CREATE_INSTANCE_METHOD_ID,
13 CreateInstanceResponse, InstanceDescriptor, LIST_INSTANCES_METHOD_ID, ListInstancesResponse,
14 RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID, ReleaseInstanceResponse,
15 ResolveInstanceIdsResponse, activation_instance_id, activation_service_guid,
16 decode_create_instance_request, decode_list_instances_request, decode_release_instance_request,
17 decode_resolve_instance_ids_request, encode_create_instance_response,
18 encode_list_instances_response, encode_release_instance_response,
19 encode_resolve_instance_ids_response,
20};
21use rpc_runtime_core::{
22 CapabilityFlags, Envelope, HelloAck, InstanceId, MethodId, Notification, Options,
23 RUNTIME_PROTOCOL_VERSION, Request, RequestId, ResponseError, ResponseOk, Role, ServiceGuid,
24};
25use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
26pub use rpc_runtime_transport::ConnectionScope;
27use rpc_runtime_transport::{RpcConnection, RpcListener, RpcReceiver, RpcSender, TransportError};
28use tokio::sync::RwLock;
29use tracing::{debug, error, info, trace, warn};
30
31pub type HandlerFuture = Pin<Box<dyn Future<Output = Result<Value, RuntimeError>> + Send>>;
32
33pub trait RpcServiceHandler: Send + Sync {
34 fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture;
35}
36
37impl<F> RpcServiceHandler for F
38where
39 F: Send + Sync + 'static,
40 F: Fn(RpcCallContext, MethodId, Value) -> HandlerFuture,
41{
42 fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
43 self(ctx, method_id, payload)
44 }
45}
46
47pub type FactoryFuture =
48 Pin<Box<dyn Future<Output = Result<Arc<dyn RpcServiceHandler>, RuntimeError>> + Send>>;
49
50pub trait RpcServiceFactory: Send + Sync {
51 fn create(
52 &self,
53 ctx: RpcCallContext,
54 create_payload: Option<Vec<u8>>,
55 options: BTreeMap<String, String>,
56 ) -> FactoryFuture;
57}
58
59impl<F> RpcServiceFactory for F
60where
61 F: Send + Sync + 'static,
62 F: Fn(RpcCallContext, Option<Vec<u8>>, BTreeMap<String, String>) -> FactoryFuture,
63{
64 fn create<'a>(
65 &self,
66 ctx: RpcCallContext,
67 create_payload: Option<Vec<u8>>,
68 options: BTreeMap<String, String>,
69 ) -> FactoryFuture {
70 self(ctx, create_payload, options)
71 }
72}
73
74#[derive(Clone)]
75pub struct RpcCallContext {
76 connection_id: u64,
77 instance_id: InstanceId,
78 sender: RpcSender,
79}
80
81impl RpcCallContext {
82 pub fn connection_id(&self) -> u64 {
83 self.connection_id
84 }
85
86 pub fn instance_id(&self) -> InstanceId {
87 self.instance_id
88 }
89
90 pub async fn notify(
91 &self,
92 instance_id: Option<InstanceId>,
93 notification_id: u32,
94 payload: Value,
95 ) -> Result<(), RuntimeError> {
96 self.sender
97 .send_envelope(&Envelope::Notification(Notification {
98 instance_id,
99 notification_id: rpc_runtime_core::NotificationId::new(notification_id),
100 payload,
101 }))
102 .await
103 .map_err(|err| {
104 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
105 })
106 }
107
108 pub async fn notify_bound(
109 &self,
110 notification_id: u32,
111 payload: Value,
112 ) -> Result<(), RuntimeError> {
113 self.notify(Some(self.instance_id), notification_id, payload)
114 .await
115 }
116}
117
118#[derive(Clone)]
119pub struct RpcServer {
120 state: Arc<ServerState>,
121}
122
123pub struct RpcServerBuilder {
124 state: ServerState,
125}
126
127pub const DEFAULT_AUTH_TOKEN_OPTION_KEY: &str = "tripley.auth.token";
128
129pub trait RpcServerMetricsSink: Send + Sync {
130 fn record(&self, event: RpcServerMetricEvent);
131}
132
133impl<F> RpcServerMetricsSink for F
134where
135 F: Send + Sync + 'static + Fn(RpcServerMetricEvent),
136{
137 fn record(&self, event: RpcServerMetricEvent) {
138 self(event);
139 }
140}
141
142pub type ConnectionCleanupFuture<'a> = Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
143
144pub trait RpcConnectionCleanupSink: Send + Sync {
145 fn cleanup_connection<'a>(&'a self, connection_id: u64) -> ConnectionCleanupFuture<'a>;
146}
147
148#[derive(Debug, Clone, PartialEq)]
149pub enum RpcServerMetricEvent {
150 ConnectionStarted {
151 connection_id: u64,
152 },
153 ConnectionEnded {
154 connection_id: u64,
155 success: bool,
156 },
157 HandshakeCompleted {
158 connection_id: u64,
159 },
160 HandshakeFailed {
161 connection_id: u64,
162 error_code: RuntimeErrorCode,
163 },
164 ListenerConnectionRejected {
165 error_code: RuntimeErrorCode,
166 },
167 RequestStarted {
168 connection_id: u64,
169 request_id: RequestId,
170 instance_id: InstanceId,
171 method_id: MethodId,
172 is_activation: bool,
173 },
174 RequestCompleted {
175 connection_id: u64,
176 request_id: RequestId,
177 instance_id: InstanceId,
178 method_id: MethodId,
179 is_activation: bool,
180 elapsed: Duration,
181 },
182 RequestFailed {
183 connection_id: u64,
184 request_id: RequestId,
185 instance_id: InstanceId,
186 method_id: MethodId,
187 is_activation: bool,
188 elapsed: Duration,
189 error_code: RuntimeErrorCode,
190 },
191 RequestSlow {
192 connection_id: u64,
193 request_id: RequestId,
194 instance_id: InstanceId,
195 method_id: MethodId,
196 is_activation: bool,
197 elapsed: Duration,
198 threshold: Duration,
199 },
200 ResponseSendFailed {
201 connection_id: u64,
202 request_id: RequestId,
203 },
204}
205
206#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
207pub struct RpcServerMetricsSnapshot {
208 pub connections_started: u64,
209 pub connections_ended: u64,
210 pub connections_ended_successfully: u64,
211 pub handshakes_completed: u64,
212 pub handshakes_failed: u64,
213 pub listener_connections_rejected: u64,
214 pub requests_started: u64,
215 pub requests_completed: u64,
216 pub requests_failed: u64,
217 pub requests_slow: u64,
218 pub response_send_failures: u64,
219 pub request_elapsed_total: Duration,
220 pub request_elapsed_max: Duration,
221}
222
223#[derive(Debug, Default)]
224pub struct RpcServerMetricsRecorder {
225 connections_started: AtomicU64,
226 connections_ended: AtomicU64,
227 connections_ended_successfully: AtomicU64,
228 handshakes_completed: AtomicU64,
229 handshakes_failed: AtomicU64,
230 listener_connections_rejected: AtomicU64,
231 requests_started: AtomicU64,
232 requests_completed: AtomicU64,
233 requests_failed: AtomicU64,
234 requests_slow: AtomicU64,
235 response_send_failures: AtomicU64,
236 request_elapsed_total_nanos: AtomicU64,
237 request_elapsed_max_nanos: AtomicU64,
238}
239
240impl RpcServerMetricsRecorder {
241 pub fn new() -> Self {
242 Self::default()
243 }
244
245 pub fn snapshot(&self) -> RpcServerMetricsSnapshot {
246 RpcServerMetricsSnapshot {
247 connections_started: self.connections_started.load(Ordering::Relaxed),
248 connections_ended: self.connections_ended.load(Ordering::Relaxed),
249 connections_ended_successfully: self
250 .connections_ended_successfully
251 .load(Ordering::Relaxed),
252 handshakes_completed: self.handshakes_completed.load(Ordering::Relaxed),
253 handshakes_failed: self.handshakes_failed.load(Ordering::Relaxed),
254 listener_connections_rejected: self
255 .listener_connections_rejected
256 .load(Ordering::Relaxed),
257 requests_started: self.requests_started.load(Ordering::Relaxed),
258 requests_completed: self.requests_completed.load(Ordering::Relaxed),
259 requests_failed: self.requests_failed.load(Ordering::Relaxed),
260 requests_slow: self.requests_slow.load(Ordering::Relaxed),
261 response_send_failures: self.response_send_failures.load(Ordering::Relaxed),
262 request_elapsed_total: Duration::from_nanos(
263 self.request_elapsed_total_nanos.load(Ordering::Relaxed),
264 ),
265 request_elapsed_max: Duration::from_nanos(
266 self.request_elapsed_max_nanos.load(Ordering::Relaxed),
267 ),
268 }
269 }
270
271 fn record_elapsed(&self, elapsed: Duration) {
272 let nanos = duration_nanos_u64(elapsed);
273 saturating_atomic_add(&self.request_elapsed_total_nanos, nanos);
274 update_atomic_max(&self.request_elapsed_max_nanos, nanos);
275 }
276}
277
278impl RpcServerMetricsSink for RpcServerMetricsRecorder {
279 fn record(&self, event: RpcServerMetricEvent) {
280 match event {
281 RpcServerMetricEvent::ConnectionStarted { .. } => {
282 self.connections_started.fetch_add(1, Ordering::Relaxed);
283 }
284 RpcServerMetricEvent::ConnectionEnded { success, .. } => {
285 self.connections_ended.fetch_add(1, Ordering::Relaxed);
286 if success {
287 self.connections_ended_successfully
288 .fetch_add(1, Ordering::Relaxed);
289 }
290 }
291 RpcServerMetricEvent::HandshakeCompleted { .. } => {
292 self.handshakes_completed.fetch_add(1, Ordering::Relaxed);
293 }
294 RpcServerMetricEvent::HandshakeFailed { .. } => {
295 self.handshakes_failed.fetch_add(1, Ordering::Relaxed);
296 }
297 RpcServerMetricEvent::ListenerConnectionRejected { .. } => {
298 self.listener_connections_rejected
299 .fetch_add(1, Ordering::Relaxed);
300 }
301 RpcServerMetricEvent::RequestStarted { .. } => {
302 self.requests_started.fetch_add(1, Ordering::Relaxed);
303 }
304 RpcServerMetricEvent::RequestCompleted { elapsed, .. } => {
305 self.requests_completed.fetch_add(1, Ordering::Relaxed);
306 self.record_elapsed(elapsed);
307 }
308 RpcServerMetricEvent::RequestFailed { elapsed, .. } => {
309 self.requests_failed.fetch_add(1, Ordering::Relaxed);
310 self.record_elapsed(elapsed);
311 }
312 RpcServerMetricEvent::RequestSlow { .. } => {
313 self.requests_slow.fetch_add(1, Ordering::Relaxed);
314 }
315 RpcServerMetricEvent::ResponseSendFailed { .. } => {
316 self.response_send_failures.fetch_add(1, Ordering::Relaxed);
317 }
318 }
319 }
320}
321
322#[derive(Debug, Clone, Copy, PartialEq, Eq)]
323pub struct RpcServerObservabilityConfig {
324 pub slow_call_threshold: Duration,
325 pub payload_preview_bytes: usize,
326 pub log_payload_preview: bool,
327}
328
329#[derive(Debug, Clone, PartialEq, Eq)]
330pub struct RpcServerSecurityConfig {
331 pub connection_scope: ConnectionScope,
332 pub auth: RpcServerAuthConfig,
333}
334
335impl RpcServerSecurityConfig {
336 pub fn remote_allowed(mut self) -> Self {
337 self.connection_scope = ConnectionScope::RemoteAllowed;
338 self
339 }
340
341 pub fn local_only(mut self) -> Self {
342 self.connection_scope = ConnectionScope::LocalOnly;
343 self
344 }
345
346 pub fn with_token(mut self, token: impl Into<String>) -> Self {
347 self.auth = RpcServerAuthConfig::token(token);
348 self
349 }
350
351 pub fn with_auth(mut self, auth: RpcServerAuthConfig) -> Self {
352 self.auth = auth;
353 self
354 }
355}
356
357impl Default for RpcServerSecurityConfig {
358 fn default() -> Self {
359 Self {
360 connection_scope: ConnectionScope::LocalOnly,
361 auth: RpcServerAuthConfig::Disabled,
362 }
363 }
364}
365
366#[derive(Debug, Clone, PartialEq, Eq)]
367pub enum RpcServerAuthConfig {
368 Disabled,
369 Token { token: String, option_key: String },
370}
371
372impl RpcServerAuthConfig {
373 pub fn token(token: impl Into<String>) -> Self {
374 Self::Token {
375 token: token.into(),
376 option_key: DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
377 }
378 }
379
380 pub fn token_with_option_key(token: impl Into<String>, option_key: impl Into<String>) -> Self {
381 Self::Token {
382 token: token.into(),
383 option_key: option_key.into(),
384 }
385 }
386}
387
388impl RpcServerObservabilityConfig {
389 pub fn with_slow_call_threshold(mut self, threshold: Duration) -> Self {
390 self.slow_call_threshold = threshold;
391 self
392 }
393
394 pub fn with_payload_preview(mut self, bytes: usize) -> Self {
395 self.payload_preview_bytes = bytes;
396 self.log_payload_preview = bytes > 0;
397 self
398 }
399}
400
401impl Default for RpcServerObservabilityConfig {
402 fn default() -> Self {
403 Self {
404 slow_call_threshold: Duration::from_millis(500),
405 payload_preview_bytes: 0,
406 log_payload_preview: false,
407 }
408 }
409}
410
411impl RpcServerBuilder {
412 pub fn new() -> Self {
413 let mut state = ServerState::new();
414 state.insert_activation_instance();
415 Self { state }
416 }
417
418 pub fn observability(mut self, config: RpcServerObservabilityConfig) -> Self {
419 self.state.observability = config;
420 self
421 }
422
423 pub fn set_observability(&mut self, config: RpcServerObservabilityConfig) -> &mut Self {
424 self.state.observability = config;
425 self
426 }
427
428 pub fn metrics_sink(mut self, sink: Arc<dyn RpcServerMetricsSink>) -> Self {
429 self.state.metrics_sink = Some(sink);
430 self
431 }
432
433 pub fn set_metrics_sink(&mut self, sink: Arc<dyn RpcServerMetricsSink>) -> &mut Self {
434 self.state.metrics_sink = Some(sink);
435 self
436 }
437
438 pub fn connection_cleanup_sink(mut self, sink: Arc<dyn RpcConnectionCleanupSink>) -> Self {
439 self.state.connection_cleanup_sink = Some(sink);
440 self
441 }
442
443 pub fn set_connection_cleanup_sink(
444 &mut self,
445 sink: Arc<dyn RpcConnectionCleanupSink>,
446 ) -> &mut Self {
447 self.state.connection_cleanup_sink = Some(sink);
448 self
449 }
450
451 pub fn security(mut self, config: RpcServerSecurityConfig) -> Self {
452 self.state.security = config;
453 self
454 }
455
456 pub fn set_security(&mut self, config: RpcServerSecurityConfig) -> &mut Self {
457 self.state.security = config;
458 self
459 }
460
461 pub fn register_named_instance(
462 &mut self,
463 name: impl Into<String>,
464 service_guid: ServiceGuid,
465 methods: impl IntoIterator<Item = u32>,
466 handler: Arc<dyn RpcServiceHandler>,
467 ) -> InstanceId {
468 self.state.insert_instance(NewInstance {
469 service_guid,
470 name: Some(name.into()),
471 activation_mode: ActivationMode::NamedPrecreated,
472 releasable: false,
473 owner_connection_id: None,
474 methods: methods.into_iter().collect(),
475 handler,
476 })
477 }
478
479 pub fn register_singleton(
480 &mut self,
481 service_guid: ServiceGuid,
482 methods: impl IntoIterator<Item = u32>,
483 handler: Arc<dyn RpcServiceHandler>,
484 ) -> InstanceId {
485 self.state.insert_instance(NewInstance {
486 service_guid,
487 name: None,
488 activation_mode: ActivationMode::Singleton,
489 releasable: false,
490 owner_connection_id: None,
491 methods: methods.into_iter().collect(),
492 handler,
493 })
494 }
495
496 pub fn register_factory(
497 &mut self,
498 service_guid: ServiceGuid,
499 methods: impl IntoIterator<Item = u32>,
500 factory: Arc<dyn RpcServiceFactory>,
501 ) {
502 self.state.factories.insert(
503 service_guid.get(),
504 FactoryEntry {
505 methods: methods.into_iter().collect(),
506 factory,
507 },
508 );
509 }
510
511 pub fn build(self) -> RpcServer {
512 if self.state.security.connection_scope == ConnectionScope::RemoteAllowed
513 && self.state.security.auth == RpcServerAuthConfig::Disabled
514 {
515 warn!("rpc server allows remote connections without token authentication");
516 }
517 RpcServer {
518 state: Arc::new(self.state),
519 }
520 }
521}
522
523impl Default for RpcServerBuilder {
524 fn default() -> Self {
525 Self::new()
526 }
527}
528
529impl RpcServer {
530 pub async fn serve_connection<C>(&self, connection: C) -> Result<(), RuntimeError>
531 where
532 C: Into<RpcConnection>,
533 {
534 let connection_id = self
535 .state
536 .next_connection_id
537 .fetch_add(1, Ordering::Relaxed);
538 self.state
539 .record_metric(RpcServerMetricEvent::ConnectionStarted { connection_id });
540 info!(connection_id, "rpc server connection started");
541 let (sender, mut receiver) = connection.into().split();
542
543 let result = async {
544 if let Err(error) = self
545 .perform_handshake(connection_id, &sender, &mut receiver)
546 .await
547 {
548 self.state
549 .record_metric(RpcServerMetricEvent::HandshakeFailed {
550 connection_id,
551 error_code: error.code,
552 });
553 return Err(error);
554 }
555 self.state
556 .record_metric(RpcServerMetricEvent::HandshakeCompleted { connection_id });
557
558 loop {
559 let envelope = match receiver.recv_envelope().await {
560 Ok(Some(envelope)) => envelope,
561 Ok(None) => {
562 debug!(connection_id, "rpc server connection closed by peer");
563 break;
564 }
565 Err(err) => {
566 let error = RuntimeError::transport(
567 RuntimeErrorCode::InternalRuntimeError,
568 err.to_string(),
569 );
570 warn!(
571 connection_id,
572 error_code = error.code.as_i32(),
573 error_kind = error.kind.as_u8(),
574 error_message = %error.message,
575 "rpc server failed to receive envelope"
576 );
577 return Err(error);
578 }
579 };
580
581 match envelope {
582 Envelope::Request(request) => {
583 let state = Arc::clone(&self.state);
584 let sender = sender.clone();
585 let observability = self.state.observability;
586 tokio::spawn(async move {
587 handle_request(state, sender, connection_id, request, observability)
588 .await;
589 });
590 }
591 Envelope::Goodbye(goodbye) => {
592 info!(
593 connection_id,
594 reason_code = goodbye.reason_code,
595 message = goodbye.message.as_deref().unwrap_or(""),
596 "rpc server received goodbye"
597 );
598 break;
599 }
600 envelope => {
601 let error = RuntimeError::protocol(
602 RuntimeErrorCode::InvalidEnvelope,
603 "server expected request envelope",
604 );
605 warn!(
606 connection_id,
607 envelope_kind = envelope_name(&envelope),
608 error_code = error.code.as_i32(),
609 error_kind = error.kind.as_u8(),
610 error_message = %error.message,
611 "rpc server received invalid envelope"
612 );
613 return Err(error);
614 }
615 }
616 }
617
618 Ok(())
619 }
620 .await;
621
622 self.state.cleanup_connection(connection_id).await;
623 self.state.cleanup_external_connection(connection_id).await;
624 debug!(connection_id, "rpc server connection cleanup completed");
625 self.state
626 .record_metric(RpcServerMetricEvent::ConnectionEnded {
627 connection_id,
628 success: result.is_ok(),
629 });
630 if let Err(error) = &result {
631 warn!(
632 connection_id,
633 error_code = error.code.as_i32(),
634 error_kind = error.kind.as_u8(),
635 error_message = %error.message,
636 "rpc server connection ended with error"
637 );
638 } else {
639 info!(connection_id, "rpc server connection ended");
640 }
641 result
642 }
643
644 pub async fn serve_listener<L>(&self, mut listener: L) -> Result<(), RuntimeError>
645 where
646 L: RpcListener + Send,
647 {
648 listener.set_connection_scope(self.state.security.connection_scope);
649 loop {
650 let connection = match listener.accept().await {
651 Ok(connection) => connection,
652 Err(err) => {
653 let access_denied = is_transport_access_denied(&err);
654 let error = RuntimeError::transport(
655 if access_denied {
656 RuntimeErrorCode::AccessDenied
657 } else {
658 RuntimeErrorCode::InternalRuntimeError
659 },
660 err.to_string(),
661 );
662 if access_denied {
663 self.state.record_metric(
664 RpcServerMetricEvent::ListenerConnectionRejected {
665 error_code: RuntimeErrorCode::AccessDenied,
666 },
667 );
668 warn!(
669 error_code = error.code.as_i32(),
670 error_kind = error.kind.as_u8(),
671 error_message = %error.message,
672 "rpc server listener rejected connection"
673 );
674 continue;
675 }
676 error!(
677 error_code = error.code.as_i32(),
678 error_kind = error.kind.as_u8(),
679 error_message = %error.message,
680 "rpc server listener accept failed"
681 );
682 return Err(error);
683 }
684 };
685 let server = self.clone();
686 tokio::spawn(async move {
687 if let Err(error) = server.serve_connection(connection).await {
688 warn!(
689 error_code = error.code.as_i32(),
690 error_kind = error.kind.as_u8(),
691 error_message = %error.message,
692 "rpc server listener connection task failed"
693 );
694 }
695 });
696 }
697 }
698
699 pub fn spawn_listener<L>(
700 &self,
701 listener: L,
702 ) -> tokio::task::JoinHandle<Result<(), RuntimeError>>
703 where
704 L: RpcListener + Send + 'static,
705 {
706 let server = self.clone();
707 tokio::spawn(async move { server.serve_listener(listener).await })
708 }
709
710 async fn perform_handshake(
711 &self,
712 connection_id: u64,
713 sender: &RpcSender,
714 receiver: &mut RpcReceiver,
715 ) -> Result<(), RuntimeError> {
716 let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
717 let error =
718 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string());
719 warn!(
720 connection_id,
721 error_code = error.code.as_i32(),
722 error_kind = error.kind.as_u8(),
723 error_message = %error.message,
724 "rpc server handshake receive failed"
725 );
726 error
727 })?
728 else {
729 let error = RuntimeError::transport(
730 RuntimeErrorCode::InternalRuntimeError,
731 "client disconnected during handshake",
732 );
733 warn!(
734 connection_id,
735 error_code = error.code.as_i32(),
736 error_kind = error.kind.as_u8(),
737 error_message = %error.message,
738 "rpc server handshake disconnected"
739 );
740 return Err(error);
741 };
742 let Envelope::Hello(hello) = envelope else {
743 let error = RuntimeError::protocol(
744 RuntimeErrorCode::InvalidEnvelope,
745 "expected HELLO during handshake",
746 );
747 warn!(
748 connection_id,
749 envelope_kind = envelope_name(&envelope),
750 error_code = error.code.as_i32(),
751 error_kind = error.kind.as_u8(),
752 error_message = %error.message,
753 "rpc server handshake received invalid envelope"
754 );
755 return Err(error);
756 };
757 if hello.protocol_version != RUNTIME_PROTOCOL_VERSION || hello.role != Role::Client {
758 let error = RuntimeError::protocol(
759 RuntimeErrorCode::UnsupportedProtocolVersion,
760 "unsupported client handshake",
761 );
762 warn!(
763 connection_id,
764 protocol_version = hello.protocol_version,
765 role = ?hello.role,
766 capability_bits = hello.capability_bits.bits(),
767 max_message_size = hello.max_message_size,
768 error_code = error.code.as_i32(),
769 error_kind = error.kind.as_u8(),
770 error_message = %error.message,
771 "rpc server handshake rejected"
772 );
773 return Err(error);
774 }
775 self.validate_handshake_auth(connection_id, &hello.options)?;
776 sender
777 .send_envelope(&Envelope::HelloAck(HelloAck {
778 protocol_version: RUNTIME_PROTOCOL_VERSION,
779 accepted_capability_bits: server_capabilities() & hello.capability_bits,
780 max_message_size: hello.max_message_size,
781 options: Vec::new(),
782 }))
783 .await
784 .map_err(|err| {
785 let error = RuntimeError::transport(
786 RuntimeErrorCode::InternalRuntimeError,
787 err.to_string(),
788 );
789 warn!(
790 connection_id,
791 error_code = error.code.as_i32(),
792 error_kind = error.kind.as_u8(),
793 error_message = %error.message,
794 "rpc server handshake ack send failed"
795 );
796 error
797 })?;
798 info!(
799 connection_id,
800 protocol_version = hello.protocol_version,
801 accepted_capability_bits = (server_capabilities() & hello.capability_bits).bits(),
802 max_message_size = hello.max_message_size,
803 "rpc server handshake completed"
804 );
805 Ok(())
806 }
807
808 fn validate_handshake_auth(
809 &self,
810 connection_id: u64,
811 options: &Options,
812 ) -> Result<(), RuntimeError> {
813 let RpcServerAuthConfig::Token { token, option_key } = &self.state.security.auth else {
814 return Ok(());
815 };
816
817 let value = options
818 .iter()
819 .rev()
820 .find_map(|(key, value)| (key == option_key).then_some(value));
821 let Some(value) = value else {
822 let error = RuntimeError::protocol(
823 RuntimeErrorCode::AccessDenied,
824 "missing handshake authentication token",
825 );
826 warn!(
827 connection_id,
828 auth_option_key = %option_key,
829 error_code = error.code.as_i32(),
830 error_kind = error.kind.as_u8(),
831 error_message = %error.message,
832 "rpc server handshake rejected authentication"
833 );
834 return Err(error);
835 };
836 let Some(received) = value.as_str() else {
837 let error = RuntimeError::protocol(
838 RuntimeErrorCode::AccessDenied,
839 "handshake authentication token must be a string",
840 );
841 warn!(
842 connection_id,
843 auth_option_key = %option_key,
844 error_code = error.code.as_i32(),
845 error_kind = error.kind.as_u8(),
846 error_message = %error.message,
847 "rpc server handshake rejected authentication"
848 );
849 return Err(error);
850 };
851 if received != token {
852 let error = RuntimeError::protocol(
853 RuntimeErrorCode::AccessDenied,
854 "invalid handshake authentication token",
855 );
856 warn!(
857 connection_id,
858 auth_option_key = %option_key,
859 error_code = error.code.as_i32(),
860 error_kind = error.kind.as_u8(),
861 error_message = %error.message,
862 "rpc server handshake rejected authentication"
863 );
864 return Err(error);
865 }
866 debug!(
867 connection_id,
868 auth_option_key = %option_key,
869 "rpc server handshake authentication accepted"
870 );
871 Ok(())
872 }
873
874 pub async fn list_instances(&self) -> Vec<InstanceDescriptor> {
875 self.state.list_instances(None).await
876 }
877}
878
879async fn handle_request(
880 state: Arc<ServerState>,
881 sender: RpcSender,
882 connection_id: u64,
883 request: Request,
884 observability: RpcServerObservabilityConfig,
885) {
886 let request_id = request.request_id;
887 let instance_id = request.instance_id;
888 let method_id = request.method_id;
889 let is_activation = instance_id.get() == ACTIVATION_INSTANCE_ID_VALUE;
890 let payload_preview = payload_preview(&request.payload, observability);
891
892 debug!(
893 connection_id,
894 request_id = request_id.get(),
895 instance_id = instance_id.get(),
896 method_id = method_id.get(),
897 is_activation,
898 "rpc server request received"
899 );
900 state.record_metric(RpcServerMetricEvent::RequestStarted {
901 connection_id,
902 request_id,
903 instance_id,
904 method_id,
905 is_activation,
906 });
907 if let Some(payload_preview) = payload_preview {
908 trace!(
909 connection_id,
910 request_id = request_id.get(),
911 payload_preview,
912 "rpc server request payload preview"
913 );
914 }
915
916 let started = Instant::now();
917 let response = dispatch_request(state.clone(), sender.clone(), connection_id, request).await;
918 let elapsed = started.elapsed();
919 let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
920
921 let envelope = match response {
922 Ok(payload) => {
923 if elapsed >= observability.slow_call_threshold {
924 state.record_metric(RpcServerMetricEvent::RequestSlow {
925 connection_id,
926 request_id,
927 instance_id,
928 method_id,
929 is_activation,
930 elapsed,
931 threshold: observability.slow_call_threshold,
932 });
933 warn!(
934 connection_id,
935 request_id = request_id.get(),
936 instance_id = instance_id.get(),
937 method_id = method_id.get(),
938 is_activation,
939 elapsed_ms,
940 slow_call_threshold_ms =
941 observability.slow_call_threshold.as_secs_f64() * 1000.0,
942 "rpc server request completed slowly"
943 );
944 } else {
945 info!(
946 connection_id,
947 request_id = request_id.get(),
948 instance_id = instance_id.get(),
949 method_id = method_id.get(),
950 is_activation,
951 elapsed_ms,
952 "rpc server request completed"
953 );
954 }
955 state.record_metric(RpcServerMetricEvent::RequestCompleted {
956 connection_id,
957 request_id,
958 instance_id,
959 method_id,
960 is_activation,
961 elapsed,
962 });
963 Envelope::ResponseOk(ResponseOk {
964 request_id,
965 payload,
966 })
967 }
968 Err(error) => {
969 state.record_metric(RpcServerMetricEvent::RequestFailed {
970 connection_id,
971 request_id,
972 instance_id,
973 method_id,
974 is_activation,
975 elapsed,
976 error_code: error.code,
977 });
978 warn!(
979 connection_id,
980 request_id = request_id.get(),
981 instance_id = instance_id.get(),
982 method_id = method_id.get(),
983 is_activation,
984 elapsed_ms,
985 error_code = error.code.as_i32(),
986 error_kind = error.kind.as_u8(),
987 error_message = %error.message,
988 "rpc server request failed"
989 );
990 runtime_error_response(request_id, error)
991 }
992 };
993
994 if let Err(err) = sender.send_envelope(&envelope).await {
995 state.record_metric(RpcServerMetricEvent::ResponseSendFailed {
996 connection_id,
997 request_id,
998 });
999 error!(
1000 connection_id,
1001 request_id = request_id.get(),
1002 error = %err,
1003 "rpc server failed to send response"
1004 );
1005 } else {
1006 trace!(
1007 connection_id,
1008 request_id = request_id.get(),
1009 response_kind = envelope_name(&envelope),
1010 "rpc server response sent"
1011 );
1012 }
1013}
1014
1015async fn dispatch_request(
1016 state: Arc<ServerState>,
1017 sender: RpcSender,
1018 connection_id: u64,
1019 request: Request,
1020) -> Result<Value, RuntimeError> {
1021 if request.instance_id.get() == ACTIVATION_INSTANCE_ID_VALUE {
1022 return dispatch_activation(state, sender, connection_id, request).await;
1023 }
1024
1025 let instance = state.get_instance(request.instance_id).await?;
1026 if !instance.methods.contains(&request.method_id.get()) {
1027 return Err(RuntimeError::runtime(
1028 RuntimeErrorCode::MethodNotFound,
1029 format!("method id `{}` was not found", request.method_id.get()),
1030 ));
1031 }
1032 let ctx = RpcCallContext {
1033 connection_id,
1034 instance_id: request.instance_id,
1035 sender,
1036 };
1037 let call = panic::catch_unwind(AssertUnwindSafe(|| {
1038 instance
1039 .handler
1040 .call(ctx, request.method_id, request.payload)
1041 }))
1042 .map_err(|_| {
1043 RuntimeError::runtime(
1044 RuntimeErrorCode::InternalRuntimeError,
1045 "service handler panicked before returning a future",
1046 )
1047 })?;
1048
1049 AssertUnwindSafe(call).catch_unwind().await.map_err(|_| {
1050 RuntimeError::runtime(
1051 RuntimeErrorCode::InternalRuntimeError,
1052 "service handler panicked",
1053 )
1054 })?
1055}
1056
1057async fn dispatch_activation(
1058 state: Arc<ServerState>,
1059 sender: RpcSender,
1060 connection_id: u64,
1061 request: Request,
1062) -> Result<Value, RuntimeError> {
1063 let ctx = RpcCallContext {
1064 connection_id,
1065 instance_id: request.instance_id,
1066 sender,
1067 };
1068 match request.method_id.get() {
1069 RESOLVE_INSTANCE_IDS_METHOD_ID => {
1070 let request = decode_resolve_instance_ids_request(&request.payload)?;
1071 let ids = state.resolve_instance_ids(&request.instance_names).await;
1072 Ok(encode_resolve_instance_ids_response(
1073 &ResolveInstanceIdsResponse { instance_ids: ids },
1074 ))
1075 }
1076 CREATE_INSTANCE_METHOD_ID => {
1077 let request = decode_create_instance_request(&request.payload)?;
1078 let factory = state.get_factory(request.service_guid).ok_or_else(|| {
1079 RuntimeError::runtime(
1080 RuntimeErrorCode::ServiceGuidNotFound,
1081 "service factory was not found",
1082 )
1083 })?;
1084 let handler = factory
1085 .factory
1086 .create(ctx, request.create_payload, request.options)
1087 .await?;
1088 let instance_id = state
1089 .insert_client_instance(
1090 request.service_guid,
1091 connection_id,
1092 factory.methods.clone(),
1093 handler,
1094 )
1095 .await;
1096 Ok(encode_create_instance_response(&CreateInstanceResponse {
1097 instance_id,
1098 }))
1099 }
1100 RELEASE_INSTANCE_METHOD_ID => {
1101 let request = decode_release_instance_request(&request.payload)?;
1102 state
1103 .release_instance(connection_id, request.instance_id)
1104 .await?;
1105 Ok(encode_release_instance_response(&ReleaseInstanceResponse))
1106 }
1107 LIST_INSTANCES_METHOD_ID => {
1108 let request = decode_list_instances_request(&request.payload)?;
1109 let instances = state.list_instances(request.service_guid).await;
1110 Ok(encode_list_instances_response(&ListInstancesResponse {
1111 instances,
1112 }))
1113 }
1114 _ => Err(RuntimeError::runtime(
1115 RuntimeErrorCode::MethodNotFound,
1116 "activation method was not found",
1117 )),
1118 }
1119}
1120
1121fn runtime_error_response(request_id: RequestId, error: RuntimeError) -> Envelope {
1122 Envelope::ResponseError(ResponseError {
1123 request_id,
1124 error_code: error.code.as_i32(),
1125 error_kind: error.kind.as_u8(),
1126 error_message: Some(error.message),
1127 error_details: Value::Nil,
1128 })
1129}
1130
1131fn server_capabilities() -> CapabilityFlags {
1132 CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
1133 | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
1134 | CapabilityFlags::SERVICE_ACTIVATION
1135 | CapabilityFlags::GOODBYE
1136}
1137
1138fn envelope_name(envelope: &Envelope) -> &'static str {
1139 match envelope {
1140 Envelope::Hello(_) => "hello",
1141 Envelope::HelloAck(_) => "hello_ack",
1142 Envelope::Request(_) => "request",
1143 Envelope::ResponseOk(_) => "response_ok",
1144 Envelope::ResponseError(_) => "response_error",
1145 Envelope::Notification(_) => "notification",
1146 Envelope::Goodbye(_) => "goodbye",
1147 }
1148}
1149
1150fn payload_preview(payload: &Value, config: RpcServerObservabilityConfig) -> Option<String> {
1151 if !config.log_payload_preview || config.payload_preview_bytes == 0 {
1152 return None;
1153 }
1154 let mut preview = format!("{payload:?}");
1155 if preview.len() > config.payload_preview_bytes {
1156 preview.truncate(config.payload_preview_bytes);
1157 preview.push_str("...");
1158 }
1159 Some(preview)
1160}
1161
1162fn is_transport_access_denied(error: &TransportError) -> bool {
1163 matches!(
1164 error,
1165 TransportError::Runtime(error) if error.code == RuntimeErrorCode::AccessDenied
1166 )
1167}
1168
1169fn duration_nanos_u64(duration: Duration) -> u64 {
1170 duration.as_nanos().min(u128::from(u64::MAX)) as u64
1171}
1172
1173fn update_atomic_max(value: &AtomicU64, candidate: u64) {
1174 let mut current = value.load(Ordering::Relaxed);
1175 while candidate > current {
1176 match value.compare_exchange_weak(current, candidate, Ordering::Relaxed, Ordering::Relaxed)
1177 {
1178 Ok(_) => break,
1179 Err(actual) => current = actual,
1180 }
1181 }
1182}
1183
1184fn saturating_atomic_add(value: &AtomicU64, increment: u64) {
1185 let mut current = value.load(Ordering::Relaxed);
1186 loop {
1187 let next = current.saturating_add(increment);
1188 match value.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
1189 Ok(_) => break,
1190 Err(actual) => current = actual,
1191 }
1192 }
1193}
1194
1195struct ServerState {
1196 next_connection_id: AtomicU64,
1197 next_instance_id: AtomicU64,
1198 observability: RpcServerObservabilityConfig,
1199 security: RpcServerSecurityConfig,
1200 metrics_sink: Option<Arc<dyn RpcServerMetricsSink>>,
1201 connection_cleanup_sink: Option<Arc<dyn RpcConnectionCleanupSink>>,
1202 instances: RwLock<HashMap<u64, InstanceEntry>>,
1203 names: RwLock<HashMap<String, u64>>,
1204 factories: HashMap<uuid::Uuid, FactoryEntry>,
1205}
1206
1207impl ServerState {
1208 fn new() -> Self {
1209 Self {
1210 next_connection_id: AtomicU64::new(1),
1211 next_instance_id: AtomicU64::new(2),
1212 observability: RpcServerObservabilityConfig::default(),
1213 security: RpcServerSecurityConfig::default(),
1214 metrics_sink: None,
1215 connection_cleanup_sink: None,
1216 instances: RwLock::new(HashMap::new()),
1217 names: RwLock::new(HashMap::new()),
1218 factories: HashMap::new(),
1219 }
1220 }
1221
1222 fn record_metric(&self, event: RpcServerMetricEvent) {
1223 let Some(sink) = &self.metrics_sink else {
1224 return;
1225 };
1226 let result = panic::catch_unwind(AssertUnwindSafe(|| sink.record(event)));
1227 if result.is_err() {
1228 error!("rpc server metrics sink panicked while recording event");
1229 }
1230 }
1231
1232 fn insert_activation_instance(&mut self) {
1233 self.instances.get_mut().insert(
1234 ACTIVATION_INSTANCE_ID_VALUE,
1235 InstanceEntry {
1236 instance_id: activation_instance_id(),
1237 service_guid: activation_service_guid(),
1238 instance_name: Some("rpc.runtime.Activation".to_string()),
1239 activation_mode: ActivationMode::Singleton,
1240 releasable: false,
1241 owner_connection_id: None,
1242 methods: vec![
1243 RESOLVE_INSTANCE_IDS_METHOD_ID,
1244 CREATE_INSTANCE_METHOD_ID,
1245 RELEASE_INSTANCE_METHOD_ID,
1246 LIST_INSTANCES_METHOD_ID,
1247 ],
1248 handler: Arc::new(ActivationMarker),
1249 },
1250 );
1251 }
1252
1253 fn insert_instance(&mut self, instance: NewInstance) -> InstanceId {
1254 let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
1255 let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
1256 if let Some(name) = &instance.name {
1257 self.names.get_mut().insert(name.clone(), id);
1258 }
1259 self.instances.get_mut().insert(
1260 id,
1261 InstanceEntry {
1262 instance_id,
1263 service_guid: instance.service_guid,
1264 instance_name: instance.name,
1265 activation_mode: instance.activation_mode,
1266 releasable: instance.releasable,
1267 owner_connection_id: instance.owner_connection_id,
1268 methods: instance.methods,
1269 handler: instance.handler,
1270 },
1271 );
1272 instance_id
1273 }
1274
1275 async fn insert_client_instance(
1276 &self,
1277 service_guid: ServiceGuid,
1278 connection_id: u64,
1279 methods: Vec<u32>,
1280 handler: Arc<dyn RpcServiceHandler>,
1281 ) -> InstanceId {
1282 let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
1283 let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
1284 self.instances.write().await.insert(
1285 id,
1286 InstanceEntry {
1287 instance_id,
1288 service_guid,
1289 instance_name: None,
1290 activation_mode: ActivationMode::Instantiable,
1291 releasable: true,
1292 owner_connection_id: Some(connection_id),
1293 methods,
1294 handler,
1295 },
1296 );
1297 instance_id
1298 }
1299
1300 async fn get_instance(&self, instance_id: InstanceId) -> Result<InstanceEntry, RuntimeError> {
1301 self.instances
1302 .read()
1303 .await
1304 .get(&instance_id.get())
1305 .cloned()
1306 .ok_or_else(|| {
1307 RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
1308 })
1309 }
1310
1311 fn get_factory(&self, service_guid: ServiceGuid) -> Option<FactoryEntry> {
1312 self.factories.get(&service_guid.get()).cloned()
1313 }
1314
1315 async fn resolve_instance_ids(&self, names: &[String]) -> Vec<u64> {
1316 let index = self.names.read().await;
1317 names
1318 .iter()
1319 .map(|name| index.get(name).copied().unwrap_or(0))
1320 .collect()
1321 }
1322
1323 async fn release_instance(
1324 &self,
1325 connection_id: u64,
1326 instance_id: InstanceId,
1327 ) -> Result<(), RuntimeError> {
1328 let mut instances = self.instances.write().await;
1329 let entry = instances.get(&instance_id.get()).ok_or_else(|| {
1330 RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
1331 })?;
1332 if !entry.releasable {
1333 return Err(RuntimeError::runtime(
1334 RuntimeErrorCode::InstanceReleaseNotAllowed,
1335 "instance is not releasable",
1336 ));
1337 }
1338 if entry.owner_connection_id != Some(connection_id) {
1339 return Err(RuntimeError::runtime(
1340 RuntimeErrorCode::AccessDenied,
1341 "instance is owned by another connection",
1342 ));
1343 }
1344 instances.remove(&instance_id.get());
1345 Ok(())
1346 }
1347
1348 async fn cleanup_connection(&self, connection_id: u64) {
1349 self.instances
1350 .write()
1351 .await
1352 .retain(|_, entry| entry.owner_connection_id != Some(connection_id));
1353 }
1354
1355 async fn cleanup_external_connection(&self, connection_id: u64) {
1356 let Some(sink) = &self.connection_cleanup_sink else {
1357 return;
1358 };
1359 sink.cleanup_connection(connection_id).await;
1360 }
1361
1362 async fn list_instances(&self, service_guid: Option<ServiceGuid>) -> Vec<InstanceDescriptor> {
1363 let mut values = self
1364 .instances
1365 .read()
1366 .await
1367 .values()
1368 .filter(|entry| service_guid.is_none_or(|guid| guid == entry.service_guid))
1369 .map(InstanceEntry::descriptor)
1370 .collect::<Vec<_>>();
1371 values.sort_by_key(|entry| entry.instance_id.get());
1372 values
1373 }
1374}
1375
1376struct NewInstance {
1377 service_guid: ServiceGuid,
1378 name: Option<String>,
1379 activation_mode: ActivationMode,
1380 releasable: bool,
1381 owner_connection_id: Option<u64>,
1382 methods: Vec<u32>,
1383 handler: Arc<dyn RpcServiceHandler>,
1384}
1385
1386#[derive(Clone)]
1387struct InstanceEntry {
1388 instance_id: InstanceId,
1389 service_guid: ServiceGuid,
1390 instance_name: Option<String>,
1391 activation_mode: ActivationMode,
1392 releasable: bool,
1393 owner_connection_id: Option<u64>,
1394 methods: Vec<u32>,
1395 handler: Arc<dyn RpcServiceHandler>,
1396}
1397
1398impl InstanceEntry {
1399 fn descriptor(&self) -> InstanceDescriptor {
1400 InstanceDescriptor {
1401 instance_id: self.instance_id,
1402 instance_name: self.instance_name.clone(),
1403 service_guid: self.service_guid,
1404 activation_mode: self.activation_mode,
1405 releasable: self.releasable,
1406 }
1407 }
1408}
1409
1410#[derive(Clone)]
1411struct FactoryEntry {
1412 methods: Vec<u32>,
1413 factory: Arc<dyn RpcServiceFactory>,
1414}
1415
1416struct ActivationMarker;
1417
1418impl RpcServiceHandler for ActivationMarker {
1419 fn call(&self, _: RpcCallContext, _: MethodId, _: Value) -> HandlerFuture {
1420 Box::pin(async {
1421 Err(RuntimeError::runtime(
1422 RuntimeErrorCode::InternalRuntimeError,
1423 "activation marker should not be dispatched directly",
1424 ))
1425 })
1426 }
1427}
1428
1429#[cfg(test)]
1430mod tests {
1431 use super::*;
1432 use rpc_runtime_core::{Goodbye, Hello, Request, Role};
1433 use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection};
1434 use tokio::io::duplex;
1435
1436 #[test]
1437 fn observability_defaults_are_safe() {
1438 let config = RpcServerObservabilityConfig::default();
1439
1440 assert_eq!(config.slow_call_threshold, Duration::from_millis(500));
1441 assert_eq!(config.payload_preview_bytes, 0);
1442 assert!(!config.log_payload_preview);
1443 }
1444
1445 #[test]
1446 fn payload_preview_is_opt_in_and_bounded() {
1447 let payload = Value::from("1234567890");
1448
1449 assert_eq!(
1450 payload_preview(&payload, RpcServerObservabilityConfig::default()),
1451 None
1452 );
1453 let preview = payload_preview(
1454 &payload,
1455 RpcServerObservabilityConfig::default().with_payload_preview(5),
1456 )
1457 .expect("preview");
1458 assert!(preview.len() <= 8);
1459 assert!(preview.ends_with("..."));
1460 }
1461
1462 #[test]
1463 fn metrics_recorder_counts_events_and_latency() {
1464 let recorder = RpcServerMetricsRecorder::new();
1465 recorder.record(RpcServerMetricEvent::ConnectionStarted { connection_id: 1 });
1466 recorder.record(RpcServerMetricEvent::ConnectionEnded {
1467 connection_id: 1,
1468 success: true,
1469 });
1470 recorder.record(RpcServerMetricEvent::RequestCompleted {
1471 connection_id: 1,
1472 request_id: RequestId::new(7),
1473 instance_id: activation_instance_id(),
1474 method_id: MethodId::new(1),
1475 is_activation: true,
1476 elapsed: Duration::from_millis(3),
1477 });
1478 recorder.record(RpcServerMetricEvent::RequestFailed {
1479 connection_id: 1,
1480 request_id: RequestId::new(8),
1481 instance_id: activation_instance_id(),
1482 method_id: MethodId::new(2),
1483 is_activation: true,
1484 elapsed: Duration::from_millis(5),
1485 error_code: RuntimeErrorCode::InternalRuntimeError,
1486 });
1487
1488 let snapshot = recorder.snapshot();
1489 assert_eq!(snapshot.connections_started, 1);
1490 assert_eq!(snapshot.connections_ended, 1);
1491 assert_eq!(snapshot.connections_ended_successfully, 1);
1492 assert_eq!(snapshot.requests_completed, 1);
1493 assert_eq!(snapshot.requests_failed, 1);
1494 assert_eq!(snapshot.request_elapsed_total, Duration::from_millis(8));
1495 assert_eq!(snapshot.request_elapsed_max, Duration::from_millis(5));
1496 }
1497
1498 #[test]
1499 fn security_defaults_are_local_auth_disabled() {
1500 let config = RpcServerSecurityConfig::default();
1501
1502 assert_eq!(config.connection_scope, ConnectionScope::LocalOnly);
1503 assert_eq!(config.auth, RpcServerAuthConfig::Disabled);
1504 }
1505
1506 #[tokio::test]
1507 async fn token_auth_accepts_matching_token() {
1508 let server = RpcServerBuilder::new()
1509 .security(RpcServerSecurityConfig::default().with_token("secret"))
1510 .build();
1511
1512 let ack = run_handshake(server, vec![auth_option("secret")])
1513 .await
1514 .expect("handshake");
1515
1516 assert!(matches!(ack, Envelope::HelloAck(_)));
1517 }
1518
1519 #[tokio::test]
1520 async fn token_auth_rejects_missing_token() {
1521 let server = RpcServerBuilder::new()
1522 .security(RpcServerSecurityConfig::default().with_token("secret"))
1523 .build();
1524
1525 let err = run_handshake(server, Vec::new())
1526 .await
1527 .expect_err("must reject");
1528
1529 assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1530 }
1531
1532 #[tokio::test]
1533 async fn token_auth_rejects_wrong_token() {
1534 let server = RpcServerBuilder::new()
1535 .security(RpcServerSecurityConfig::default().with_token("secret"))
1536 .build();
1537
1538 let err = run_handshake(server, vec![auth_option("wrong")])
1539 .await
1540 .expect_err("must reject");
1541
1542 assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1543 }
1544
1545 #[tokio::test]
1546 async fn token_auth_rejects_non_string_token() {
1547 let server = RpcServerBuilder::new()
1548 .security(RpcServerSecurityConfig::default().with_token("secret"))
1549 .build();
1550
1551 let err = run_handshake(
1552 server,
1553 vec![(
1554 DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
1555 Value::from(123_u64),
1556 )],
1557 )
1558 .await
1559 .expect_err("must reject");
1560
1561 assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1562 }
1563
1564 #[tokio::test]
1565 async fn metrics_recorder_observes_handshake_failure() {
1566 let recorder = Arc::new(RpcServerMetricsRecorder::new());
1567 let server = RpcServerBuilder::new()
1568 .metrics_sink(recorder.clone())
1569 .security(RpcServerSecurityConfig::default().with_token("secret"))
1570 .build();
1571
1572 let err = run_handshake(server, Vec::new())
1573 .await
1574 .expect_err("must reject");
1575 assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1576
1577 let snapshot = recorder.snapshot();
1578 assert_eq!(snapshot.connections_started, 1);
1579 assert_eq!(snapshot.connections_ended, 1);
1580 assert_eq!(snapshot.connections_ended_successfully, 0);
1581 assert_eq!(snapshot.handshakes_completed, 0);
1582 assert_eq!(snapshot.handshakes_failed, 1);
1583 }
1584
1585 #[tokio::test]
1586 async fn metrics_recorder_observes_success_failure_and_slow_requests() {
1587 let recorder = Arc::new(RpcServerMetricsRecorder::new());
1588 let mut builder = RpcServerBuilder::new()
1589 .metrics_sink(recorder.clone())
1590 .observability(
1591 RpcServerObservabilityConfig::default()
1592 .with_slow_call_threshold(Duration::from_nanos(0)),
1593 );
1594 let instance_id = builder.register_named_instance(
1595 "metrics",
1596 activation_service_guid(),
1597 [1, 2],
1598 Arc::new(MetricsTestHandler),
1599 );
1600 let server = builder.build();
1601 let (client_stream, server_stream) = duplex(4096);
1602 let server_connection = IpcConnection::from_stream(server_stream, FrameConfig::default());
1603 let server_task =
1604 tokio::spawn(async move { server.serve_connection(server_connection).await });
1605 let client_connection = IpcConnection::from_stream(client_stream, FrameConfig::default());
1606 let (sender, mut receiver) = client_connection.split();
1607
1608 send_hello(&sender).await;
1609 assert!(matches!(
1610 receiver.recv_envelope().await.expect("recv ack"),
1611 Some(Envelope::HelloAck(_))
1612 ));
1613 sender
1614 .send_envelope(&Envelope::Request(Request {
1615 request_id: RequestId::new(11),
1616 instance_id,
1617 method_id: MethodId::new(1),
1618 payload: Value::from("ok"),
1619 }))
1620 .await
1621 .expect("send success request");
1622 assert!(matches!(
1623 receiver.recv_envelope().await.expect("recv response"),
1624 Some(Envelope::ResponseOk(_))
1625 ));
1626 sender
1627 .send_envelope(&Envelope::Request(Request {
1628 request_id: RequestId::new(12),
1629 instance_id,
1630 method_id: MethodId::new(2),
1631 payload: Value::Nil,
1632 }))
1633 .await
1634 .expect("send failing request");
1635 assert!(matches!(
1636 receiver.recv_envelope().await.expect("recv error"),
1637 Some(Envelope::ResponseError(_))
1638 ));
1639 sender
1640 .send_envelope(&Envelope::Goodbye(Goodbye {
1641 reason_code: 0,
1642 message: Some("done".to_string()),
1643 }))
1644 .await
1645 .expect("send goodbye");
1646 drop(sender);
1647 drop(receiver);
1648 server_task.await.expect("server task").expect("serve");
1649
1650 let snapshot = recorder.snapshot();
1651 assert_eq!(snapshot.connections_started, 1);
1652 assert_eq!(snapshot.connections_ended_successfully, 1);
1653 assert_eq!(snapshot.handshakes_completed, 1);
1654 assert_eq!(snapshot.requests_started, 2);
1655 assert_eq!(snapshot.requests_completed, 1);
1656 assert_eq!(snapshot.requests_failed, 1);
1657 assert_eq!(snapshot.requests_slow, 1);
1658 assert!(snapshot.request_elapsed_total > Duration::ZERO);
1659 }
1660
1661 async fn run_handshake(server: RpcServer, options: Options) -> Result<Envelope, RuntimeError> {
1662 let (client_stream, server_stream) = duplex(4096);
1663 let server_connection = IpcConnection::from_stream(server_stream, FrameConfig::default());
1664 let server_task =
1665 tokio::spawn(async move { server.serve_connection(server_connection).await });
1666
1667 let client_connection = IpcConnection::from_stream(client_stream, FrameConfig::default());
1668 let (sender, mut receiver) = client_connection.split();
1669 sender
1670 .send_envelope(&hello_envelope(options))
1671 .await
1672 .expect("send hello");
1673
1674 let envelope = receiver.recv_envelope().await;
1675 drop(sender);
1676 drop(receiver);
1677 let server_result = server_task.await.expect("server task");
1678 match envelope.expect("recv hello ack") {
1679 Some(envelope) => Ok(envelope),
1680 None => Err(server_result.expect_err("server should return handshake error")),
1681 }
1682 }
1683
1684 fn auth_option(token: &str) -> (String, Value) {
1685 (
1686 DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
1687 Value::from(token),
1688 )
1689 }
1690
1691 async fn send_hello(sender: &RpcSender) {
1692 sender
1693 .send_envelope(&hello_envelope(Vec::new()))
1694 .await
1695 .expect("send hello");
1696 }
1697
1698 fn hello_envelope(options: Options) -> Envelope {
1699 Envelope::Hello(Hello {
1700 protocol_version: RUNTIME_PROTOCOL_VERSION,
1701 role: Role::Client,
1702 capability_bits: CapabilityFlags::empty(),
1703 max_message_size: 16 * 1024 * 1024,
1704 options,
1705 })
1706 }
1707
1708 struct MetricsTestHandler;
1709
1710 impl RpcServiceHandler for MetricsTestHandler {
1711 fn call(&self, _: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
1712 Box::pin(async move {
1713 match method_id.get() {
1714 1 => Ok(payload),
1715 _ => Err(RuntimeError::runtime(
1716 RuntimeErrorCode::InternalRuntimeError,
1717 "test failure",
1718 )),
1719 }
1720 })
1721 }
1722 }
1723}