Skip to main content

rpc_runtime_server/
lib.rs

1use std::collections::{BTreeMap, HashMap};
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::{Duration, Instant};
7use std::{panic, panic::AssertUnwindSafe};
8
9use rmpv::Value;
10use rpc_runtime_activation::{
11    ACTIVATION_INSTANCE_ID_VALUE, ActivationMode, CREATE_INSTANCE_METHOD_ID,
12    CreateInstanceResponse, InstanceDescriptor, LIST_INSTANCES_METHOD_ID, ListInstancesResponse,
13    RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID, ReleaseInstanceResponse,
14    ResolveInstanceIdsResponse, activation_instance_id, activation_service_guid,
15    decode_create_instance_request, decode_list_instances_request, decode_release_instance_request,
16    decode_resolve_instance_ids_request, encode_create_instance_response,
17    encode_list_instances_response, encode_release_instance_response,
18    encode_resolve_instance_ids_response,
19};
20use rpc_runtime_core::{
21    CapabilityFlags, Envelope, HelloAck, InstanceId, MethodId, Notification, Options,
22    RUNTIME_PROTOCOL_VERSION, Request, RequestId, ResponseError, ResponseOk, Role, ServiceGuid,
23};
24use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
25pub use rpc_runtime_transport::ConnectionScope;
26use rpc_runtime_transport::{RpcConnection, RpcListener, RpcReceiver, RpcSender, TransportError};
27use tokio::sync::RwLock;
28use tracing::{debug, error, info, trace, warn};
29
30pub type HandlerFuture = Pin<Box<dyn Future<Output = Result<Value, RuntimeError>> + Send>>;
31
32pub trait RpcServiceHandler: Send + Sync {
33    fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture;
34}
35
36impl<F> RpcServiceHandler for F
37where
38    F: Send + Sync + 'static,
39    F: Fn(RpcCallContext, MethodId, Value) -> HandlerFuture,
40{
41    fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
42        self(ctx, method_id, payload)
43    }
44}
45
46pub type FactoryFuture =
47    Pin<Box<dyn Future<Output = Result<Arc<dyn RpcServiceHandler>, RuntimeError>> + Send>>;
48
49pub trait RpcServiceFactory: Send + Sync {
50    fn create(
51        &self,
52        ctx: RpcCallContext,
53        create_payload: Option<Vec<u8>>,
54        options: BTreeMap<String, String>,
55    ) -> FactoryFuture;
56}
57
58impl<F> RpcServiceFactory for F
59where
60    F: Send + Sync + 'static,
61    F: Fn(RpcCallContext, Option<Vec<u8>>, BTreeMap<String, String>) -> FactoryFuture,
62{
63    fn create<'a>(
64        &self,
65        ctx: RpcCallContext,
66        create_payload: Option<Vec<u8>>,
67        options: BTreeMap<String, String>,
68    ) -> FactoryFuture {
69        self(ctx, create_payload, options)
70    }
71}
72
73#[derive(Clone)]
74pub struct RpcCallContext {
75    connection_id: u64,
76    instance_id: InstanceId,
77    sender: RpcSender,
78}
79
80impl RpcCallContext {
81    pub fn connection_id(&self) -> u64 {
82        self.connection_id
83    }
84
85    pub fn instance_id(&self) -> InstanceId {
86        self.instance_id
87    }
88
89    pub async fn notify(
90        &self,
91        instance_id: Option<InstanceId>,
92        notification_id: u32,
93        payload: Value,
94    ) -> Result<(), RuntimeError> {
95        self.sender
96            .send_envelope(&Envelope::Notification(Notification {
97                instance_id,
98                notification_id: rpc_runtime_core::NotificationId::new(notification_id),
99                payload,
100            }))
101            .await
102            .map_err(|err| {
103                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
104            })
105    }
106
107    pub async fn notify_bound(
108        &self,
109        notification_id: u32,
110        payload: Value,
111    ) -> Result<(), RuntimeError> {
112        self.notify(Some(self.instance_id), notification_id, payload)
113            .await
114    }
115}
116
117#[derive(Clone)]
118pub struct RpcServer {
119    state: Arc<ServerState>,
120}
121
122pub struct RpcServerBuilder {
123    state: ServerState,
124}
125
126pub const DEFAULT_AUTH_TOKEN_OPTION_KEY: &str = "tripley.auth.token";
127
128pub trait RpcServerMetricsSink: Send + Sync {
129    fn record(&self, event: RpcServerMetricEvent);
130}
131
132impl<F> RpcServerMetricsSink for F
133where
134    F: Send + Sync + 'static + Fn(RpcServerMetricEvent),
135{
136    fn record(&self, event: RpcServerMetricEvent) {
137        self(event);
138    }
139}
140
141#[derive(Debug, Clone, PartialEq)]
142pub enum RpcServerMetricEvent {
143    ConnectionStarted {
144        connection_id: u64,
145    },
146    ConnectionEnded {
147        connection_id: u64,
148        success: bool,
149    },
150    HandshakeCompleted {
151        connection_id: u64,
152    },
153    HandshakeFailed {
154        connection_id: u64,
155        error_code: RuntimeErrorCode,
156    },
157    ListenerConnectionRejected {
158        error_code: RuntimeErrorCode,
159    },
160    RequestStarted {
161        connection_id: u64,
162        request_id: RequestId,
163        instance_id: InstanceId,
164        method_id: MethodId,
165        is_activation: bool,
166    },
167    RequestCompleted {
168        connection_id: u64,
169        request_id: RequestId,
170        instance_id: InstanceId,
171        method_id: MethodId,
172        is_activation: bool,
173        elapsed: Duration,
174    },
175    RequestFailed {
176        connection_id: u64,
177        request_id: RequestId,
178        instance_id: InstanceId,
179        method_id: MethodId,
180        is_activation: bool,
181        elapsed: Duration,
182        error_code: RuntimeErrorCode,
183    },
184    RequestSlow {
185        connection_id: u64,
186        request_id: RequestId,
187        instance_id: InstanceId,
188        method_id: MethodId,
189        is_activation: bool,
190        elapsed: Duration,
191        threshold: Duration,
192    },
193    ResponseSendFailed {
194        connection_id: u64,
195        request_id: RequestId,
196    },
197}
198
199#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
200pub struct RpcServerMetricsSnapshot {
201    pub connections_started: u64,
202    pub connections_ended: u64,
203    pub connections_ended_successfully: u64,
204    pub handshakes_completed: u64,
205    pub handshakes_failed: u64,
206    pub listener_connections_rejected: u64,
207    pub requests_started: u64,
208    pub requests_completed: u64,
209    pub requests_failed: u64,
210    pub requests_slow: u64,
211    pub response_send_failures: u64,
212    pub request_elapsed_total: Duration,
213    pub request_elapsed_max: Duration,
214}
215
216#[derive(Debug, Default)]
217pub struct RpcServerMetricsRecorder {
218    connections_started: AtomicU64,
219    connections_ended: AtomicU64,
220    connections_ended_successfully: AtomicU64,
221    handshakes_completed: AtomicU64,
222    handshakes_failed: AtomicU64,
223    listener_connections_rejected: AtomicU64,
224    requests_started: AtomicU64,
225    requests_completed: AtomicU64,
226    requests_failed: AtomicU64,
227    requests_slow: AtomicU64,
228    response_send_failures: AtomicU64,
229    request_elapsed_total_nanos: AtomicU64,
230    request_elapsed_max_nanos: AtomicU64,
231}
232
233impl RpcServerMetricsRecorder {
234    pub fn new() -> Self {
235        Self::default()
236    }
237
238    pub fn snapshot(&self) -> RpcServerMetricsSnapshot {
239        RpcServerMetricsSnapshot {
240            connections_started: self.connections_started.load(Ordering::Relaxed),
241            connections_ended: self.connections_ended.load(Ordering::Relaxed),
242            connections_ended_successfully: self
243                .connections_ended_successfully
244                .load(Ordering::Relaxed),
245            handshakes_completed: self.handshakes_completed.load(Ordering::Relaxed),
246            handshakes_failed: self.handshakes_failed.load(Ordering::Relaxed),
247            listener_connections_rejected: self
248                .listener_connections_rejected
249                .load(Ordering::Relaxed),
250            requests_started: self.requests_started.load(Ordering::Relaxed),
251            requests_completed: self.requests_completed.load(Ordering::Relaxed),
252            requests_failed: self.requests_failed.load(Ordering::Relaxed),
253            requests_slow: self.requests_slow.load(Ordering::Relaxed),
254            response_send_failures: self.response_send_failures.load(Ordering::Relaxed),
255            request_elapsed_total: Duration::from_nanos(
256                self.request_elapsed_total_nanos.load(Ordering::Relaxed),
257            ),
258            request_elapsed_max: Duration::from_nanos(
259                self.request_elapsed_max_nanos.load(Ordering::Relaxed),
260            ),
261        }
262    }
263
264    fn record_elapsed(&self, elapsed: Duration) {
265        let nanos = duration_nanos_u64(elapsed);
266        saturating_atomic_add(&self.request_elapsed_total_nanos, nanos);
267        update_atomic_max(&self.request_elapsed_max_nanos, nanos);
268    }
269}
270
271impl RpcServerMetricsSink for RpcServerMetricsRecorder {
272    fn record(&self, event: RpcServerMetricEvent) {
273        match event {
274            RpcServerMetricEvent::ConnectionStarted { .. } => {
275                self.connections_started.fetch_add(1, Ordering::Relaxed);
276            }
277            RpcServerMetricEvent::ConnectionEnded { success, .. } => {
278                self.connections_ended.fetch_add(1, Ordering::Relaxed);
279                if success {
280                    self.connections_ended_successfully
281                        .fetch_add(1, Ordering::Relaxed);
282                }
283            }
284            RpcServerMetricEvent::HandshakeCompleted { .. } => {
285                self.handshakes_completed.fetch_add(1, Ordering::Relaxed);
286            }
287            RpcServerMetricEvent::HandshakeFailed { .. } => {
288                self.handshakes_failed.fetch_add(1, Ordering::Relaxed);
289            }
290            RpcServerMetricEvent::ListenerConnectionRejected { .. } => {
291                self.listener_connections_rejected
292                    .fetch_add(1, Ordering::Relaxed);
293            }
294            RpcServerMetricEvent::RequestStarted { .. } => {
295                self.requests_started.fetch_add(1, Ordering::Relaxed);
296            }
297            RpcServerMetricEvent::RequestCompleted { elapsed, .. } => {
298                self.requests_completed.fetch_add(1, Ordering::Relaxed);
299                self.record_elapsed(elapsed);
300            }
301            RpcServerMetricEvent::RequestFailed { elapsed, .. } => {
302                self.requests_failed.fetch_add(1, Ordering::Relaxed);
303                self.record_elapsed(elapsed);
304            }
305            RpcServerMetricEvent::RequestSlow { .. } => {
306                self.requests_slow.fetch_add(1, Ordering::Relaxed);
307            }
308            RpcServerMetricEvent::ResponseSendFailed { .. } => {
309                self.response_send_failures.fetch_add(1, Ordering::Relaxed);
310            }
311        }
312    }
313}
314
315#[derive(Debug, Clone, Copy, PartialEq, Eq)]
316pub struct RpcServerObservabilityConfig {
317    pub slow_call_threshold: Duration,
318    pub payload_preview_bytes: usize,
319    pub log_payload_preview: bool,
320}
321
322#[derive(Debug, Clone, PartialEq, Eq)]
323pub struct RpcServerSecurityConfig {
324    pub connection_scope: ConnectionScope,
325    pub auth: RpcServerAuthConfig,
326}
327
328impl RpcServerSecurityConfig {
329    pub fn remote_allowed(mut self) -> Self {
330        self.connection_scope = ConnectionScope::RemoteAllowed;
331        self
332    }
333
334    pub fn local_only(mut self) -> Self {
335        self.connection_scope = ConnectionScope::LocalOnly;
336        self
337    }
338
339    pub fn with_token(mut self, token: impl Into<String>) -> Self {
340        self.auth = RpcServerAuthConfig::token(token);
341        self
342    }
343
344    pub fn with_auth(mut self, auth: RpcServerAuthConfig) -> Self {
345        self.auth = auth;
346        self
347    }
348}
349
350impl Default for RpcServerSecurityConfig {
351    fn default() -> Self {
352        Self {
353            connection_scope: ConnectionScope::LocalOnly,
354            auth: RpcServerAuthConfig::Disabled,
355        }
356    }
357}
358
359#[derive(Debug, Clone, PartialEq, Eq)]
360pub enum RpcServerAuthConfig {
361    Disabled,
362    Token { token: String, option_key: String },
363}
364
365impl RpcServerAuthConfig {
366    pub fn token(token: impl Into<String>) -> Self {
367        Self::Token {
368            token: token.into(),
369            option_key: DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
370        }
371    }
372
373    pub fn token_with_option_key(token: impl Into<String>, option_key: impl Into<String>) -> Self {
374        Self::Token {
375            token: token.into(),
376            option_key: option_key.into(),
377        }
378    }
379}
380
381impl RpcServerObservabilityConfig {
382    pub fn with_slow_call_threshold(mut self, threshold: Duration) -> Self {
383        self.slow_call_threshold = threshold;
384        self
385    }
386
387    pub fn with_payload_preview(mut self, bytes: usize) -> Self {
388        self.payload_preview_bytes = bytes;
389        self.log_payload_preview = bytes > 0;
390        self
391    }
392}
393
394impl Default for RpcServerObservabilityConfig {
395    fn default() -> Self {
396        Self {
397            slow_call_threshold: Duration::from_millis(500),
398            payload_preview_bytes: 0,
399            log_payload_preview: false,
400        }
401    }
402}
403
404impl RpcServerBuilder {
405    pub fn new() -> Self {
406        let mut state = ServerState::new();
407        state.insert_activation_instance();
408        Self { state }
409    }
410
411    pub fn observability(mut self, config: RpcServerObservabilityConfig) -> Self {
412        self.state.observability = config;
413        self
414    }
415
416    pub fn set_observability(&mut self, config: RpcServerObservabilityConfig) -> &mut Self {
417        self.state.observability = config;
418        self
419    }
420
421    pub fn metrics_sink(mut self, sink: Arc<dyn RpcServerMetricsSink>) -> Self {
422        self.state.metrics_sink = Some(sink);
423        self
424    }
425
426    pub fn set_metrics_sink(&mut self, sink: Arc<dyn RpcServerMetricsSink>) -> &mut Self {
427        self.state.metrics_sink = Some(sink);
428        self
429    }
430
431    pub fn security(mut self, config: RpcServerSecurityConfig) -> Self {
432        self.state.security = config;
433        self
434    }
435
436    pub fn set_security(&mut self, config: RpcServerSecurityConfig) -> &mut Self {
437        self.state.security = config;
438        self
439    }
440
441    pub fn register_named_instance(
442        &mut self,
443        name: impl Into<String>,
444        service_guid: ServiceGuid,
445        methods: impl IntoIterator<Item = u32>,
446        handler: Arc<dyn RpcServiceHandler>,
447    ) -> InstanceId {
448        self.state.insert_instance(NewInstance {
449            service_guid,
450            name: Some(name.into()),
451            activation_mode: ActivationMode::NamedPrecreated,
452            releasable: false,
453            owner_connection_id: None,
454            methods: methods.into_iter().collect(),
455            handler,
456        })
457    }
458
459    pub fn register_singleton(
460        &mut self,
461        service_guid: ServiceGuid,
462        methods: impl IntoIterator<Item = u32>,
463        handler: Arc<dyn RpcServiceHandler>,
464    ) -> InstanceId {
465        self.state.insert_instance(NewInstance {
466            service_guid,
467            name: None,
468            activation_mode: ActivationMode::Singleton,
469            releasable: false,
470            owner_connection_id: None,
471            methods: methods.into_iter().collect(),
472            handler,
473        })
474    }
475
476    pub fn register_factory(
477        &mut self,
478        service_guid: ServiceGuid,
479        methods: impl IntoIterator<Item = u32>,
480        factory: Arc<dyn RpcServiceFactory>,
481    ) {
482        self.state.factories.insert(
483            service_guid.get(),
484            FactoryEntry {
485                methods: methods.into_iter().collect(),
486                factory,
487            },
488        );
489    }
490
491    pub fn build(self) -> RpcServer {
492        if self.state.security.connection_scope == ConnectionScope::RemoteAllowed
493            && self.state.security.auth == RpcServerAuthConfig::Disabled
494        {
495            warn!("rpc server allows remote connections without token authentication");
496        }
497        RpcServer {
498            state: Arc::new(self.state),
499        }
500    }
501}
502
503impl Default for RpcServerBuilder {
504    fn default() -> Self {
505        Self::new()
506    }
507}
508
509impl RpcServer {
510    pub async fn serve_connection<C>(&self, connection: C) -> Result<(), RuntimeError>
511    where
512        C: Into<RpcConnection>,
513    {
514        let connection_id = self
515            .state
516            .next_connection_id
517            .fetch_add(1, Ordering::Relaxed);
518        self.state
519            .record_metric(RpcServerMetricEvent::ConnectionStarted { connection_id });
520        info!(connection_id, "rpc server connection started");
521        let (sender, mut receiver) = connection.into().split();
522
523        let result = async {
524            if let Err(error) = self
525                .perform_handshake(connection_id, &sender, &mut receiver)
526                .await
527            {
528                self.state
529                    .record_metric(RpcServerMetricEvent::HandshakeFailed {
530                        connection_id,
531                        error_code: error.code,
532                    });
533                return Err(error);
534            }
535            self.state
536                .record_metric(RpcServerMetricEvent::HandshakeCompleted { connection_id });
537
538            loop {
539                let envelope = match receiver.recv_envelope().await {
540                    Ok(Some(envelope)) => envelope,
541                    Ok(None) => {
542                        debug!(connection_id, "rpc server connection closed by peer");
543                        break;
544                    }
545                    Err(err) => {
546                        let error = RuntimeError::transport(
547                            RuntimeErrorCode::InternalRuntimeError,
548                            err.to_string(),
549                        );
550                        warn!(
551                            connection_id,
552                            error_code = error.code.as_i32(),
553                            error_kind = error.kind.as_u8(),
554                            error_message = %error.message,
555                            "rpc server failed to receive envelope"
556                        );
557                        return Err(error);
558                    }
559                };
560
561                match envelope {
562                    Envelope::Request(request) => {
563                        let state = Arc::clone(&self.state);
564                        let sender = sender.clone();
565                        let observability = self.state.observability;
566                        tokio::spawn(async move {
567                            handle_request(state, sender, connection_id, request, observability)
568                                .await;
569                        });
570                    }
571                    Envelope::Goodbye(goodbye) => {
572                        info!(
573                            connection_id,
574                            reason_code = goodbye.reason_code,
575                            message = goodbye.message.as_deref().unwrap_or(""),
576                            "rpc server received goodbye"
577                        );
578                        break;
579                    }
580                    envelope => {
581                        let error = RuntimeError::protocol(
582                            RuntimeErrorCode::InvalidEnvelope,
583                            "server expected request envelope",
584                        );
585                        warn!(
586                            connection_id,
587                            envelope_kind = envelope_name(&envelope),
588                            error_code = error.code.as_i32(),
589                            error_kind = error.kind.as_u8(),
590                            error_message = %error.message,
591                            "rpc server received invalid envelope"
592                        );
593                        return Err(error);
594                    }
595                }
596            }
597
598            Ok(())
599        }
600        .await;
601
602        self.state.cleanup_connection(connection_id).await;
603        debug!(connection_id, "rpc server connection cleanup completed");
604        self.state
605            .record_metric(RpcServerMetricEvent::ConnectionEnded {
606                connection_id,
607                success: result.is_ok(),
608            });
609        if let Err(error) = &result {
610            warn!(
611                connection_id,
612                error_code = error.code.as_i32(),
613                error_kind = error.kind.as_u8(),
614                error_message = %error.message,
615                "rpc server connection ended with error"
616            );
617        } else {
618            info!(connection_id, "rpc server connection ended");
619        }
620        result
621    }
622
623    pub async fn serve_listener<L>(&self, mut listener: L) -> Result<(), RuntimeError>
624    where
625        L: RpcListener + Send,
626    {
627        listener.set_connection_scope(self.state.security.connection_scope);
628        loop {
629            let connection = match listener.accept().await {
630                Ok(connection) => connection,
631                Err(err) => {
632                    let access_denied = is_transport_access_denied(&err);
633                    let error = RuntimeError::transport(
634                        if access_denied {
635                            RuntimeErrorCode::AccessDenied
636                        } else {
637                            RuntimeErrorCode::InternalRuntimeError
638                        },
639                        err.to_string(),
640                    );
641                    if access_denied {
642                        self.state.record_metric(
643                            RpcServerMetricEvent::ListenerConnectionRejected {
644                                error_code: RuntimeErrorCode::AccessDenied,
645                            },
646                        );
647                        warn!(
648                            error_code = error.code.as_i32(),
649                            error_kind = error.kind.as_u8(),
650                            error_message = %error.message,
651                            "rpc server listener rejected connection"
652                        );
653                        continue;
654                    }
655                    error!(
656                        error_code = error.code.as_i32(),
657                        error_kind = error.kind.as_u8(),
658                        error_message = %error.message,
659                        "rpc server listener accept failed"
660                    );
661                    return Err(error);
662                }
663            };
664            let server = self.clone();
665            tokio::spawn(async move {
666                if let Err(error) = server.serve_connection(connection).await {
667                    warn!(
668                        error_code = error.code.as_i32(),
669                        error_kind = error.kind.as_u8(),
670                        error_message = %error.message,
671                        "rpc server listener connection task failed"
672                    );
673                }
674            });
675        }
676    }
677
678    pub fn spawn_listener<L>(
679        &self,
680        listener: L,
681    ) -> tokio::task::JoinHandle<Result<(), RuntimeError>>
682    where
683        L: RpcListener + Send + 'static,
684    {
685        let server = self.clone();
686        tokio::spawn(async move { server.serve_listener(listener).await })
687    }
688
689    async fn perform_handshake(
690        &self,
691        connection_id: u64,
692        sender: &RpcSender,
693        receiver: &mut RpcReceiver,
694    ) -> Result<(), RuntimeError> {
695        let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
696            let error =
697                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string());
698            warn!(
699                connection_id,
700                error_code = error.code.as_i32(),
701                error_kind = error.kind.as_u8(),
702                error_message = %error.message,
703                "rpc server handshake receive failed"
704            );
705            error
706        })?
707        else {
708            let error = RuntimeError::transport(
709                RuntimeErrorCode::InternalRuntimeError,
710                "client disconnected during handshake",
711            );
712            warn!(
713                connection_id,
714                error_code = error.code.as_i32(),
715                error_kind = error.kind.as_u8(),
716                error_message = %error.message,
717                "rpc server handshake disconnected"
718            );
719            return Err(error);
720        };
721        let Envelope::Hello(hello) = envelope else {
722            let error = RuntimeError::protocol(
723                RuntimeErrorCode::InvalidEnvelope,
724                "expected HELLO during handshake",
725            );
726            warn!(
727                connection_id,
728                envelope_kind = envelope_name(&envelope),
729                error_code = error.code.as_i32(),
730                error_kind = error.kind.as_u8(),
731                error_message = %error.message,
732                "rpc server handshake received invalid envelope"
733            );
734            return Err(error);
735        };
736        if hello.protocol_version != RUNTIME_PROTOCOL_VERSION || hello.role != Role::Client {
737            let error = RuntimeError::protocol(
738                RuntimeErrorCode::UnsupportedProtocolVersion,
739                "unsupported client handshake",
740            );
741            warn!(
742                connection_id,
743                protocol_version = hello.protocol_version,
744                role = ?hello.role,
745                capability_bits = hello.capability_bits.bits(),
746                max_message_size = hello.max_message_size,
747                error_code = error.code.as_i32(),
748                error_kind = error.kind.as_u8(),
749                error_message = %error.message,
750                "rpc server handshake rejected"
751            );
752            return Err(error);
753        }
754        self.validate_handshake_auth(connection_id, &hello.options)?;
755        sender
756            .send_envelope(&Envelope::HelloAck(HelloAck {
757                protocol_version: RUNTIME_PROTOCOL_VERSION,
758                accepted_capability_bits: server_capabilities() & hello.capability_bits,
759                max_message_size: hello.max_message_size,
760                options: Vec::new(),
761            }))
762            .await
763            .map_err(|err| {
764                let error = RuntimeError::transport(
765                    RuntimeErrorCode::InternalRuntimeError,
766                    err.to_string(),
767                );
768                warn!(
769                    connection_id,
770                    error_code = error.code.as_i32(),
771                    error_kind = error.kind.as_u8(),
772                    error_message = %error.message,
773                    "rpc server handshake ack send failed"
774                );
775                error
776            })?;
777        info!(
778            connection_id,
779            protocol_version = hello.protocol_version,
780            accepted_capability_bits = (server_capabilities() & hello.capability_bits).bits(),
781            max_message_size = hello.max_message_size,
782            "rpc server handshake completed"
783        );
784        Ok(())
785    }
786
787    fn validate_handshake_auth(
788        &self,
789        connection_id: u64,
790        options: &Options,
791    ) -> Result<(), RuntimeError> {
792        let RpcServerAuthConfig::Token { token, option_key } = &self.state.security.auth else {
793            return Ok(());
794        };
795
796        let value = options
797            .iter()
798            .rev()
799            .find_map(|(key, value)| (key == option_key).then_some(value));
800        let Some(value) = value else {
801            let error = RuntimeError::protocol(
802                RuntimeErrorCode::AccessDenied,
803                "missing handshake authentication token",
804            );
805            warn!(
806                connection_id,
807                auth_option_key = %option_key,
808                error_code = error.code.as_i32(),
809                error_kind = error.kind.as_u8(),
810                error_message = %error.message,
811                "rpc server handshake rejected authentication"
812            );
813            return Err(error);
814        };
815        let Some(received) = value.as_str() else {
816            let error = RuntimeError::protocol(
817                RuntimeErrorCode::AccessDenied,
818                "handshake authentication token must be a string",
819            );
820            warn!(
821                connection_id,
822                auth_option_key = %option_key,
823                error_code = error.code.as_i32(),
824                error_kind = error.kind.as_u8(),
825                error_message = %error.message,
826                "rpc server handshake rejected authentication"
827            );
828            return Err(error);
829        };
830        if received != token {
831            let error = RuntimeError::protocol(
832                RuntimeErrorCode::AccessDenied,
833                "invalid handshake authentication token",
834            );
835            warn!(
836                connection_id,
837                auth_option_key = %option_key,
838                error_code = error.code.as_i32(),
839                error_kind = error.kind.as_u8(),
840                error_message = %error.message,
841                "rpc server handshake rejected authentication"
842            );
843            return Err(error);
844        }
845        debug!(
846            connection_id,
847            auth_option_key = %option_key,
848            "rpc server handshake authentication accepted"
849        );
850        Ok(())
851    }
852
853    pub async fn list_instances(&self) -> Vec<InstanceDescriptor> {
854        self.state.list_instances(None).await
855    }
856}
857
858async fn handle_request(
859    state: Arc<ServerState>,
860    sender: RpcSender,
861    connection_id: u64,
862    request: Request,
863    observability: RpcServerObservabilityConfig,
864) {
865    let request_id = request.request_id;
866    let instance_id = request.instance_id;
867    let method_id = request.method_id;
868    let is_activation = instance_id.get() == ACTIVATION_INSTANCE_ID_VALUE;
869    let payload_preview = payload_preview(&request.payload, observability);
870
871    debug!(
872        connection_id,
873        request_id = request_id.get(),
874        instance_id = instance_id.get(),
875        method_id = method_id.get(),
876        is_activation,
877        "rpc server request received"
878    );
879    state.record_metric(RpcServerMetricEvent::RequestStarted {
880        connection_id,
881        request_id,
882        instance_id,
883        method_id,
884        is_activation,
885    });
886    if let Some(payload_preview) = payload_preview {
887        trace!(
888            connection_id,
889            request_id = request_id.get(),
890            payload_preview,
891            "rpc server request payload preview"
892        );
893    }
894
895    let started = Instant::now();
896    let response = dispatch_request(state.clone(), sender.clone(), connection_id, request).await;
897    let elapsed = started.elapsed();
898    let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
899
900    let envelope = match response {
901        Ok(payload) => {
902            if elapsed >= observability.slow_call_threshold {
903                state.record_metric(RpcServerMetricEvent::RequestSlow {
904                    connection_id,
905                    request_id,
906                    instance_id,
907                    method_id,
908                    is_activation,
909                    elapsed,
910                    threshold: observability.slow_call_threshold,
911                });
912                warn!(
913                    connection_id,
914                    request_id = request_id.get(),
915                    instance_id = instance_id.get(),
916                    method_id = method_id.get(),
917                    is_activation,
918                    elapsed_ms,
919                    slow_call_threshold_ms =
920                        observability.slow_call_threshold.as_secs_f64() * 1000.0,
921                    "rpc server request completed slowly"
922                );
923            } else {
924                info!(
925                    connection_id,
926                    request_id = request_id.get(),
927                    instance_id = instance_id.get(),
928                    method_id = method_id.get(),
929                    is_activation,
930                    elapsed_ms,
931                    "rpc server request completed"
932                );
933            }
934            state.record_metric(RpcServerMetricEvent::RequestCompleted {
935                connection_id,
936                request_id,
937                instance_id,
938                method_id,
939                is_activation,
940                elapsed,
941            });
942            Envelope::ResponseOk(ResponseOk {
943                request_id,
944                payload,
945            })
946        }
947        Err(error) => {
948            state.record_metric(RpcServerMetricEvent::RequestFailed {
949                connection_id,
950                request_id,
951                instance_id,
952                method_id,
953                is_activation,
954                elapsed,
955                error_code: error.code,
956            });
957            warn!(
958                connection_id,
959                request_id = request_id.get(),
960                instance_id = instance_id.get(),
961                method_id = method_id.get(),
962                is_activation,
963                elapsed_ms,
964                error_code = error.code.as_i32(),
965                error_kind = error.kind.as_u8(),
966                error_message = %error.message,
967                "rpc server request failed"
968            );
969            runtime_error_response(request_id, error)
970        }
971    };
972
973    if let Err(err) = sender.send_envelope(&envelope).await {
974        state.record_metric(RpcServerMetricEvent::ResponseSendFailed {
975            connection_id,
976            request_id,
977        });
978        error!(
979            connection_id,
980            request_id = request_id.get(),
981            error = %err,
982            "rpc server failed to send response"
983        );
984    } else {
985        trace!(
986            connection_id,
987            request_id = request_id.get(),
988            response_kind = envelope_name(&envelope),
989            "rpc server response sent"
990        );
991    }
992}
993
994async fn dispatch_request(
995    state: Arc<ServerState>,
996    sender: RpcSender,
997    connection_id: u64,
998    request: Request,
999) -> Result<Value, RuntimeError> {
1000    if request.instance_id.get() == ACTIVATION_INSTANCE_ID_VALUE {
1001        return dispatch_activation(state, sender, connection_id, request).await;
1002    }
1003
1004    let instance = state.get_instance(request.instance_id).await?;
1005    if !instance.methods.contains(&request.method_id.get()) {
1006        return Err(RuntimeError::runtime(
1007            RuntimeErrorCode::MethodNotFound,
1008            format!("method id `{}` was not found", request.method_id.get()),
1009        ));
1010    }
1011    let ctx = RpcCallContext {
1012        connection_id,
1013        instance_id: request.instance_id,
1014        sender,
1015    };
1016    instance
1017        .handler
1018        .call(ctx, request.method_id, request.payload)
1019        .await
1020}
1021
1022async fn dispatch_activation(
1023    state: Arc<ServerState>,
1024    sender: RpcSender,
1025    connection_id: u64,
1026    request: Request,
1027) -> Result<Value, RuntimeError> {
1028    let ctx = RpcCallContext {
1029        connection_id,
1030        instance_id: request.instance_id,
1031        sender,
1032    };
1033    match request.method_id.get() {
1034        RESOLVE_INSTANCE_IDS_METHOD_ID => {
1035            let request = decode_resolve_instance_ids_request(&request.payload)?;
1036            let ids = state.resolve_instance_ids(&request.instance_names).await;
1037            Ok(encode_resolve_instance_ids_response(
1038                &ResolveInstanceIdsResponse { instance_ids: ids },
1039            ))
1040        }
1041        CREATE_INSTANCE_METHOD_ID => {
1042            let request = decode_create_instance_request(&request.payload)?;
1043            let factory = state.get_factory(request.service_guid).ok_or_else(|| {
1044                RuntimeError::runtime(
1045                    RuntimeErrorCode::ServiceGuidNotFound,
1046                    "service factory was not found",
1047                )
1048            })?;
1049            let handler = factory
1050                .factory
1051                .create(ctx, request.create_payload, request.options)
1052                .await?;
1053            let instance_id = state
1054                .insert_client_instance(
1055                    request.service_guid,
1056                    connection_id,
1057                    factory.methods.clone(),
1058                    handler,
1059                )
1060                .await;
1061            Ok(encode_create_instance_response(&CreateInstanceResponse {
1062                instance_id,
1063            }))
1064        }
1065        RELEASE_INSTANCE_METHOD_ID => {
1066            let request = decode_release_instance_request(&request.payload)?;
1067            state
1068                .release_instance(connection_id, request.instance_id)
1069                .await?;
1070            Ok(encode_release_instance_response(&ReleaseInstanceResponse))
1071        }
1072        LIST_INSTANCES_METHOD_ID => {
1073            let request = decode_list_instances_request(&request.payload)?;
1074            let instances = state.list_instances(request.service_guid).await;
1075            Ok(encode_list_instances_response(&ListInstancesResponse {
1076                instances,
1077            }))
1078        }
1079        _ => Err(RuntimeError::runtime(
1080            RuntimeErrorCode::MethodNotFound,
1081            "activation method was not found",
1082        )),
1083    }
1084}
1085
1086fn runtime_error_response(request_id: RequestId, error: RuntimeError) -> Envelope {
1087    Envelope::ResponseError(ResponseError {
1088        request_id,
1089        error_code: error.code.as_i32(),
1090        error_kind: error.kind.as_u8(),
1091        error_message: Some(error.message),
1092        error_details: Value::Nil,
1093    })
1094}
1095
1096fn server_capabilities() -> CapabilityFlags {
1097    CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
1098        | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
1099        | CapabilityFlags::SERVICE_ACTIVATION
1100        | CapabilityFlags::GOODBYE
1101}
1102
1103fn envelope_name(envelope: &Envelope) -> &'static str {
1104    match envelope {
1105        Envelope::Hello(_) => "hello",
1106        Envelope::HelloAck(_) => "hello_ack",
1107        Envelope::Request(_) => "request",
1108        Envelope::ResponseOk(_) => "response_ok",
1109        Envelope::ResponseError(_) => "response_error",
1110        Envelope::Notification(_) => "notification",
1111        Envelope::Goodbye(_) => "goodbye",
1112    }
1113}
1114
1115fn payload_preview(payload: &Value, config: RpcServerObservabilityConfig) -> Option<String> {
1116    if !config.log_payload_preview || config.payload_preview_bytes == 0 {
1117        return None;
1118    }
1119    let mut preview = format!("{payload:?}");
1120    if preview.len() > config.payload_preview_bytes {
1121        preview.truncate(config.payload_preview_bytes);
1122        preview.push_str("...");
1123    }
1124    Some(preview)
1125}
1126
1127fn is_transport_access_denied(error: &TransportError) -> bool {
1128    matches!(
1129        error,
1130        TransportError::Runtime(error) if error.code == RuntimeErrorCode::AccessDenied
1131    )
1132}
1133
1134fn duration_nanos_u64(duration: Duration) -> u64 {
1135    duration.as_nanos().min(u128::from(u64::MAX)) as u64
1136}
1137
1138fn update_atomic_max(value: &AtomicU64, candidate: u64) {
1139    let mut current = value.load(Ordering::Relaxed);
1140    while candidate > current {
1141        match value.compare_exchange_weak(current, candidate, Ordering::Relaxed, Ordering::Relaxed)
1142        {
1143            Ok(_) => break,
1144            Err(actual) => current = actual,
1145        }
1146    }
1147}
1148
1149fn saturating_atomic_add(value: &AtomicU64, increment: u64) {
1150    let mut current = value.load(Ordering::Relaxed);
1151    loop {
1152        let next = current.saturating_add(increment);
1153        match value.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
1154            Ok(_) => break,
1155            Err(actual) => current = actual,
1156        }
1157    }
1158}
1159
1160struct ServerState {
1161    next_connection_id: AtomicU64,
1162    next_instance_id: AtomicU64,
1163    observability: RpcServerObservabilityConfig,
1164    security: RpcServerSecurityConfig,
1165    metrics_sink: Option<Arc<dyn RpcServerMetricsSink>>,
1166    instances: RwLock<HashMap<u64, InstanceEntry>>,
1167    names: RwLock<HashMap<String, u64>>,
1168    factories: HashMap<uuid::Uuid, FactoryEntry>,
1169}
1170
1171impl ServerState {
1172    fn new() -> Self {
1173        Self {
1174            next_connection_id: AtomicU64::new(1),
1175            next_instance_id: AtomicU64::new(2),
1176            observability: RpcServerObservabilityConfig::default(),
1177            security: RpcServerSecurityConfig::default(),
1178            metrics_sink: None,
1179            instances: RwLock::new(HashMap::new()),
1180            names: RwLock::new(HashMap::new()),
1181            factories: HashMap::new(),
1182        }
1183    }
1184
1185    fn record_metric(&self, event: RpcServerMetricEvent) {
1186        let Some(sink) = &self.metrics_sink else {
1187            return;
1188        };
1189        let result = panic::catch_unwind(AssertUnwindSafe(|| sink.record(event)));
1190        if result.is_err() {
1191            error!("rpc server metrics sink panicked while recording event");
1192        }
1193    }
1194
1195    fn insert_activation_instance(&mut self) {
1196        self.instances.get_mut().insert(
1197            ACTIVATION_INSTANCE_ID_VALUE,
1198            InstanceEntry {
1199                instance_id: activation_instance_id(),
1200                service_guid: activation_service_guid(),
1201                instance_name: Some("rpc.runtime.Activation".to_string()),
1202                activation_mode: ActivationMode::Singleton,
1203                releasable: false,
1204                owner_connection_id: None,
1205                methods: vec![
1206                    RESOLVE_INSTANCE_IDS_METHOD_ID,
1207                    CREATE_INSTANCE_METHOD_ID,
1208                    RELEASE_INSTANCE_METHOD_ID,
1209                    LIST_INSTANCES_METHOD_ID,
1210                ],
1211                handler: Arc::new(ActivationMarker),
1212            },
1213        );
1214    }
1215
1216    fn insert_instance(&mut self, instance: NewInstance) -> InstanceId {
1217        let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
1218        let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
1219        if let Some(name) = &instance.name {
1220            self.names.get_mut().insert(name.clone(), id);
1221        }
1222        self.instances.get_mut().insert(
1223            id,
1224            InstanceEntry {
1225                instance_id,
1226                service_guid: instance.service_guid,
1227                instance_name: instance.name,
1228                activation_mode: instance.activation_mode,
1229                releasable: instance.releasable,
1230                owner_connection_id: instance.owner_connection_id,
1231                methods: instance.methods,
1232                handler: instance.handler,
1233            },
1234        );
1235        instance_id
1236    }
1237
1238    async fn insert_client_instance(
1239        &self,
1240        service_guid: ServiceGuid,
1241        connection_id: u64,
1242        methods: Vec<u32>,
1243        handler: Arc<dyn RpcServiceHandler>,
1244    ) -> InstanceId {
1245        let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
1246        let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
1247        self.instances.write().await.insert(
1248            id,
1249            InstanceEntry {
1250                instance_id,
1251                service_guid,
1252                instance_name: None,
1253                activation_mode: ActivationMode::Instantiable,
1254                releasable: true,
1255                owner_connection_id: Some(connection_id),
1256                methods,
1257                handler,
1258            },
1259        );
1260        instance_id
1261    }
1262
1263    async fn get_instance(&self, instance_id: InstanceId) -> Result<InstanceEntry, RuntimeError> {
1264        self.instances
1265            .read()
1266            .await
1267            .get(&instance_id.get())
1268            .cloned()
1269            .ok_or_else(|| {
1270                RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
1271            })
1272    }
1273
1274    fn get_factory(&self, service_guid: ServiceGuid) -> Option<FactoryEntry> {
1275        self.factories.get(&service_guid.get()).cloned()
1276    }
1277
1278    async fn resolve_instance_ids(&self, names: &[String]) -> Vec<u64> {
1279        let index = self.names.read().await;
1280        names
1281            .iter()
1282            .map(|name| index.get(name).copied().unwrap_or(0))
1283            .collect()
1284    }
1285
1286    async fn release_instance(
1287        &self,
1288        connection_id: u64,
1289        instance_id: InstanceId,
1290    ) -> Result<(), RuntimeError> {
1291        let mut instances = self.instances.write().await;
1292        let entry = instances.get(&instance_id.get()).ok_or_else(|| {
1293            RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
1294        })?;
1295        if !entry.releasable {
1296            return Err(RuntimeError::runtime(
1297                RuntimeErrorCode::InstanceReleaseNotAllowed,
1298                "instance is not releasable",
1299            ));
1300        }
1301        if entry.owner_connection_id != Some(connection_id) {
1302            return Err(RuntimeError::runtime(
1303                RuntimeErrorCode::AccessDenied,
1304                "instance is owned by another connection",
1305            ));
1306        }
1307        instances.remove(&instance_id.get());
1308        Ok(())
1309    }
1310
1311    async fn cleanup_connection(&self, connection_id: u64) {
1312        self.instances
1313            .write()
1314            .await
1315            .retain(|_, entry| entry.owner_connection_id != Some(connection_id));
1316    }
1317
1318    async fn list_instances(&self, service_guid: Option<ServiceGuid>) -> Vec<InstanceDescriptor> {
1319        let mut values = self
1320            .instances
1321            .read()
1322            .await
1323            .values()
1324            .filter(|entry| service_guid.is_none_or(|guid| guid == entry.service_guid))
1325            .map(InstanceEntry::descriptor)
1326            .collect::<Vec<_>>();
1327        values.sort_by_key(|entry| entry.instance_id.get());
1328        values
1329    }
1330}
1331
1332struct NewInstance {
1333    service_guid: ServiceGuid,
1334    name: Option<String>,
1335    activation_mode: ActivationMode,
1336    releasable: bool,
1337    owner_connection_id: Option<u64>,
1338    methods: Vec<u32>,
1339    handler: Arc<dyn RpcServiceHandler>,
1340}
1341
1342#[derive(Clone)]
1343struct InstanceEntry {
1344    instance_id: InstanceId,
1345    service_guid: ServiceGuid,
1346    instance_name: Option<String>,
1347    activation_mode: ActivationMode,
1348    releasable: bool,
1349    owner_connection_id: Option<u64>,
1350    methods: Vec<u32>,
1351    handler: Arc<dyn RpcServiceHandler>,
1352}
1353
1354impl InstanceEntry {
1355    fn descriptor(&self) -> InstanceDescriptor {
1356        InstanceDescriptor {
1357            instance_id: self.instance_id,
1358            instance_name: self.instance_name.clone(),
1359            service_guid: self.service_guid,
1360            activation_mode: self.activation_mode,
1361            releasable: self.releasable,
1362        }
1363    }
1364}
1365
1366#[derive(Clone)]
1367struct FactoryEntry {
1368    methods: Vec<u32>,
1369    factory: Arc<dyn RpcServiceFactory>,
1370}
1371
1372struct ActivationMarker;
1373
1374impl RpcServiceHandler for ActivationMarker {
1375    fn call(&self, _: RpcCallContext, _: MethodId, _: Value) -> HandlerFuture {
1376        Box::pin(async {
1377            Err(RuntimeError::runtime(
1378                RuntimeErrorCode::InternalRuntimeError,
1379                "activation marker should not be dispatched directly",
1380            ))
1381        })
1382    }
1383}
1384
1385#[cfg(test)]
1386mod tests {
1387    use super::*;
1388    use rpc_runtime_core::{Goodbye, Hello, Request, Role};
1389    use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection};
1390    use tokio::io::duplex;
1391
1392    #[test]
1393    fn observability_defaults_are_safe() {
1394        let config = RpcServerObservabilityConfig::default();
1395
1396        assert_eq!(config.slow_call_threshold, Duration::from_millis(500));
1397        assert_eq!(config.payload_preview_bytes, 0);
1398        assert!(!config.log_payload_preview);
1399    }
1400
1401    #[test]
1402    fn payload_preview_is_opt_in_and_bounded() {
1403        let payload = Value::from("1234567890");
1404
1405        assert_eq!(
1406            payload_preview(&payload, RpcServerObservabilityConfig::default()),
1407            None
1408        );
1409        let preview = payload_preview(
1410            &payload,
1411            RpcServerObservabilityConfig::default().with_payload_preview(5),
1412        )
1413        .expect("preview");
1414        assert!(preview.len() <= 8);
1415        assert!(preview.ends_with("..."));
1416    }
1417
1418    #[test]
1419    fn metrics_recorder_counts_events_and_latency() {
1420        let recorder = RpcServerMetricsRecorder::new();
1421        recorder.record(RpcServerMetricEvent::ConnectionStarted { connection_id: 1 });
1422        recorder.record(RpcServerMetricEvent::ConnectionEnded {
1423            connection_id: 1,
1424            success: true,
1425        });
1426        recorder.record(RpcServerMetricEvent::RequestCompleted {
1427            connection_id: 1,
1428            request_id: RequestId::new(7),
1429            instance_id: activation_instance_id(),
1430            method_id: MethodId::new(1),
1431            is_activation: true,
1432            elapsed: Duration::from_millis(3),
1433        });
1434        recorder.record(RpcServerMetricEvent::RequestFailed {
1435            connection_id: 1,
1436            request_id: RequestId::new(8),
1437            instance_id: activation_instance_id(),
1438            method_id: MethodId::new(2),
1439            is_activation: true,
1440            elapsed: Duration::from_millis(5),
1441            error_code: RuntimeErrorCode::InternalRuntimeError,
1442        });
1443
1444        let snapshot = recorder.snapshot();
1445        assert_eq!(snapshot.connections_started, 1);
1446        assert_eq!(snapshot.connections_ended, 1);
1447        assert_eq!(snapshot.connections_ended_successfully, 1);
1448        assert_eq!(snapshot.requests_completed, 1);
1449        assert_eq!(snapshot.requests_failed, 1);
1450        assert_eq!(snapshot.request_elapsed_total, Duration::from_millis(8));
1451        assert_eq!(snapshot.request_elapsed_max, Duration::from_millis(5));
1452    }
1453
1454    #[test]
1455    fn security_defaults_are_local_auth_disabled() {
1456        let config = RpcServerSecurityConfig::default();
1457
1458        assert_eq!(config.connection_scope, ConnectionScope::LocalOnly);
1459        assert_eq!(config.auth, RpcServerAuthConfig::Disabled);
1460    }
1461
1462    #[tokio::test]
1463    async fn token_auth_accepts_matching_token() {
1464        let server = RpcServerBuilder::new()
1465            .security(RpcServerSecurityConfig::default().with_token("secret"))
1466            .build();
1467
1468        let ack = run_handshake(server, vec![auth_option("secret")])
1469            .await
1470            .expect("handshake");
1471
1472        assert!(matches!(ack, Envelope::HelloAck(_)));
1473    }
1474
1475    #[tokio::test]
1476    async fn token_auth_rejects_missing_token() {
1477        let server = RpcServerBuilder::new()
1478            .security(RpcServerSecurityConfig::default().with_token("secret"))
1479            .build();
1480
1481        let err = run_handshake(server, Vec::new())
1482            .await
1483            .expect_err("must reject");
1484
1485        assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1486    }
1487
1488    #[tokio::test]
1489    async fn token_auth_rejects_wrong_token() {
1490        let server = RpcServerBuilder::new()
1491            .security(RpcServerSecurityConfig::default().with_token("secret"))
1492            .build();
1493
1494        let err = run_handshake(server, vec![auth_option("wrong")])
1495            .await
1496            .expect_err("must reject");
1497
1498        assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1499    }
1500
1501    #[tokio::test]
1502    async fn token_auth_rejects_non_string_token() {
1503        let server = RpcServerBuilder::new()
1504            .security(RpcServerSecurityConfig::default().with_token("secret"))
1505            .build();
1506
1507        let err = run_handshake(
1508            server,
1509            vec![(
1510                DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
1511                Value::from(123_u64),
1512            )],
1513        )
1514        .await
1515        .expect_err("must reject");
1516
1517        assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1518    }
1519
1520    #[tokio::test]
1521    async fn metrics_recorder_observes_handshake_failure() {
1522        let recorder = Arc::new(RpcServerMetricsRecorder::new());
1523        let server = RpcServerBuilder::new()
1524            .metrics_sink(recorder.clone())
1525            .security(RpcServerSecurityConfig::default().with_token("secret"))
1526            .build();
1527
1528        let err = run_handshake(server, Vec::new())
1529            .await
1530            .expect_err("must reject");
1531        assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1532
1533        let snapshot = recorder.snapshot();
1534        assert_eq!(snapshot.connections_started, 1);
1535        assert_eq!(snapshot.connections_ended, 1);
1536        assert_eq!(snapshot.connections_ended_successfully, 0);
1537        assert_eq!(snapshot.handshakes_completed, 0);
1538        assert_eq!(snapshot.handshakes_failed, 1);
1539    }
1540
1541    #[tokio::test]
1542    async fn metrics_recorder_observes_success_failure_and_slow_requests() {
1543        let recorder = Arc::new(RpcServerMetricsRecorder::new());
1544        let mut builder = RpcServerBuilder::new()
1545            .metrics_sink(recorder.clone())
1546            .observability(
1547                RpcServerObservabilityConfig::default()
1548                    .with_slow_call_threshold(Duration::from_nanos(0)),
1549            );
1550        let instance_id = builder.register_named_instance(
1551            "metrics",
1552            activation_service_guid(),
1553            [1, 2],
1554            Arc::new(MetricsTestHandler),
1555        );
1556        let server = builder.build();
1557        let (client_stream, server_stream) = duplex(4096);
1558        let server_connection = IpcConnection::from_stream(server_stream, FrameConfig::default());
1559        let server_task =
1560            tokio::spawn(async move { server.serve_connection(server_connection).await });
1561        let client_connection = IpcConnection::from_stream(client_stream, FrameConfig::default());
1562        let (sender, mut receiver) = client_connection.split();
1563
1564        send_hello(&sender).await;
1565        assert!(matches!(
1566            receiver.recv_envelope().await.expect("recv ack"),
1567            Some(Envelope::HelloAck(_))
1568        ));
1569        sender
1570            .send_envelope(&Envelope::Request(Request {
1571                request_id: RequestId::new(11),
1572                instance_id,
1573                method_id: MethodId::new(1),
1574                payload: Value::from("ok"),
1575            }))
1576            .await
1577            .expect("send success request");
1578        assert!(matches!(
1579            receiver.recv_envelope().await.expect("recv response"),
1580            Some(Envelope::ResponseOk(_))
1581        ));
1582        sender
1583            .send_envelope(&Envelope::Request(Request {
1584                request_id: RequestId::new(12),
1585                instance_id,
1586                method_id: MethodId::new(2),
1587                payload: Value::Nil,
1588            }))
1589            .await
1590            .expect("send failing request");
1591        assert!(matches!(
1592            receiver.recv_envelope().await.expect("recv error"),
1593            Some(Envelope::ResponseError(_))
1594        ));
1595        sender
1596            .send_envelope(&Envelope::Goodbye(Goodbye {
1597                reason_code: 0,
1598                message: Some("done".to_string()),
1599            }))
1600            .await
1601            .expect("send goodbye");
1602        drop(sender);
1603        drop(receiver);
1604        server_task.await.expect("server task").expect("serve");
1605
1606        let snapshot = recorder.snapshot();
1607        assert_eq!(snapshot.connections_started, 1);
1608        assert_eq!(snapshot.connections_ended_successfully, 1);
1609        assert_eq!(snapshot.handshakes_completed, 1);
1610        assert_eq!(snapshot.requests_started, 2);
1611        assert_eq!(snapshot.requests_completed, 1);
1612        assert_eq!(snapshot.requests_failed, 1);
1613        assert_eq!(snapshot.requests_slow, 1);
1614        assert!(snapshot.request_elapsed_total > Duration::ZERO);
1615    }
1616
1617    async fn run_handshake(server: RpcServer, options: Options) -> Result<Envelope, RuntimeError> {
1618        let (client_stream, server_stream) = duplex(4096);
1619        let server_connection = IpcConnection::from_stream(server_stream, FrameConfig::default());
1620        let server_task =
1621            tokio::spawn(async move { server.serve_connection(server_connection).await });
1622
1623        let client_connection = IpcConnection::from_stream(client_stream, FrameConfig::default());
1624        let (sender, mut receiver) = client_connection.split();
1625        sender
1626            .send_envelope(&hello_envelope(options))
1627            .await
1628            .expect("send hello");
1629
1630        let envelope = receiver.recv_envelope().await;
1631        drop(sender);
1632        drop(receiver);
1633        let server_result = server_task.await.expect("server task");
1634        match envelope.expect("recv hello ack") {
1635            Some(envelope) => Ok(envelope),
1636            None => Err(server_result.expect_err("server should return handshake error")),
1637        }
1638    }
1639
1640    fn auth_option(token: &str) -> (String, Value) {
1641        (
1642            DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
1643            Value::from(token),
1644        )
1645    }
1646
1647    async fn send_hello(sender: &RpcSender) {
1648        sender
1649            .send_envelope(&hello_envelope(Vec::new()))
1650            .await
1651            .expect("send hello");
1652    }
1653
1654    fn hello_envelope(options: Options) -> Envelope {
1655        Envelope::Hello(Hello {
1656            protocol_version: RUNTIME_PROTOCOL_VERSION,
1657            role: Role::Client,
1658            capability_bits: CapabilityFlags::empty(),
1659            max_message_size: 16 * 1024 * 1024,
1660            options,
1661        })
1662    }
1663
1664    struct MetricsTestHandler;
1665
1666    impl RpcServiceHandler for MetricsTestHandler {
1667        fn call(&self, _: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
1668            Box::pin(async move {
1669                match method_id.get() {
1670                    1 => Ok(payload),
1671                    _ => Err(RuntimeError::runtime(
1672                        RuntimeErrorCode::InternalRuntimeError,
1673                        "test failure",
1674                    )),
1675                }
1676            })
1677        }
1678    }
1679}