Skip to main content

spikard_http/grpc/
mod.rs

1//! gRPC runtime support for Spikard
2//!
3//! This module provides gRPC server infrastructure using Tonic, enabling
4//! Spikard to handle both HTTP/1.1 REST requests and HTTP/2 gRPC requests.
5//!
6//! # Architecture
7//!
8//! The gRPC support follows the same language-agnostic pattern as the HTTP handler:
9//!
10//! 1. **GrpcHandler trait**: Language-agnostic interface for handling gRPC requests
11//! 2. **Service bridge**: Converts between Tonic's types and our internal representation
12//! 3. **Streaming support**: Utilities for handling streaming RPCs
13//! 4. **Server integration**: Multiplexes HTTP/1.1 and HTTP/2 traffic
14//!
15//! # Example
16//!
17//! ```ignore
18//! use spikard_http::grpc::{GrpcHandler, GrpcRequestData, GrpcResponseData};
19//! use std::sync::Arc;
20//!
21//! // Implement GrpcHandler for your language binding
22//! struct MyGrpcHandler;
23//!
24//! impl GrpcHandler for MyGrpcHandler {
25//!     fn call(&self, request: GrpcRequestData) -> Pin<Box<dyn Future<Output = GrpcHandlerResult> + Send>> {
26//!         Box::pin(async move {
27//!             // Handle the gRPC request
28//!             Ok(GrpcResponseData {
29//!                 payload: bytes::Bytes::from("response"),
30//!                 metadata: tonic::metadata::MetadataMap::new(),
31//!             })
32//!         })
33//!     }
34//!
35//!     fn service_name(&self) -> &str {
36//!         "mypackage.MyService"
37//!     }
38//! }
39//!
40//! // Register with the server
41//! let handler = Arc::new(MyGrpcHandler);
42//! let config = GrpcConfig::default();
43//! ```
44
45pub(crate) mod framing;
46pub(crate) mod handler;
47pub(crate) mod service;
48pub(crate) mod streaming;
49
50// Re-export types used from integration tests and binding crates
51pub use handler::{GrpcHandler, GrpcHandlerResult, GrpcRequestData, GrpcResponseData};
52// Internal implementation details only
53pub(crate) use handler::RpcMode;
54pub(crate) use service::{GenericGrpcService, parse_grpc_path};
55
56use serde::{Deserialize, Serialize};
57use std::collections::HashMap;
58#[cfg(test)]
59use std::collections::HashSet;
60use std::sync::Arc;
61
62/// Configuration for gRPC support
63///
64/// Controls how the server handles gRPC requests, including compression,
65/// timeouts, and protocol settings.
66///
67/// # Stream Limits
68///
69/// This configuration enforces message-level size limits but delegates
70/// concurrent stream limiting to the HTTP/2 transport layer:
71///
72/// - **Message Size Limits**: The `max_message_size` field is enforced per
73///   individual message (request or response) in both unary and streaming RPCs.
74///   When a single message exceeds this limit, the request is rejected with
75///   `PAYLOAD_TOO_LARGE` (HTTP 413).
76///
77/// - **Concurrent Stream Limits**: The `max_concurrent_streams` is an advisory
78///   configuration passed to the HTTP/2 layer for connection-level stream
79///   negotiation. The HTTP/2 transport automatically enforces this limit and
80///   returns GOAWAY frames when exceeded. Applications should not rely on
81///   custom enforcement of this limit.
82///
83/// - **Stream Response Size Limits**: The `max_stream_response_bytes` field caps the
84///   total encoded bytes emitted across a server-streaming or bidi-streaming response.
85///   When the cumulative size exceeds the limit, the stream is terminated with
86///   `tonic::Status::resource_exhausted`. Defaults to `None` (unbounded).
87///
88/// # Example
89///
90/// ```ignore
91/// let mut config = GrpcConfig::default();
92/// config.max_message_size = 10 * 1024 * 1024; // 10MB per message
93/// config.max_concurrent_streams = 50; // Advised to HTTP/2 layer
94/// ```
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct GrpcConfig {
97    /// Enable gRPC support
98    #[serde(default = "default_true")]
99    pub enabled: bool,
100
101    /// Maximum message size in bytes (for both sending and receiving)
102    ///
103    /// This limit applies to individual messages in both unary and streaming RPCs.
104    /// When a single message exceeds this size, the request is rejected with HTTP 413
105    /// (Payload Too Large).
106    ///
107    /// Default: 4MB (4194304 bytes)
108    ///
109    /// # Note
110    /// This limit does NOT apply to the total response size in streaming RPCs.
111    /// For multi-message streams, the total response can exceed this limit as long
112    /// as each individual message stays within the limit.
113    #[serde(default = "default_max_message_size")]
114    pub max_message_size: usize,
115
116    /// Enable gzip compression for gRPC messages
117    #[serde(default = "default_true")]
118    pub enable_compression: bool,
119
120    /// Timeout for gRPC requests in seconds (None = no timeout)
121    #[serde(default)]
122    pub request_timeout: Option<u64>,
123
124    /// Maximum number of concurrent streams per connection (HTTP/2 advisory)
125    ///
126    /// This value is communicated to HTTP/2 clients as the server's flow control limit.
127    /// The HTTP/2 transport layer enforces this limit automatically via SETTINGS frames
128    /// and GOAWAY responses. Applications should NOT implement custom enforcement.
129    ///
130    /// Default: 100 streams per connection
131    ///
132    /// # Stream Limiting Strategy
133    /// - **Per Connection**: This limit applies per HTTP/2 connection, not globally
134    /// - **Transport Enforcement**: HTTP/2 handles all stream limiting; applications
135    ///   need not implement custom checks
136    /// - **Streaming Requests**: In server streaming or bidi streaming, each logical
137    ///   RPC consumes one stream slot. Message ordering within a stream follows
138    ///   HTTP/2 frame ordering.
139    #[serde(default = "default_max_concurrent_streams")]
140    pub max_concurrent_streams: u32,
141
142    /// Enable HTTP/2 keepalive
143    #[serde(default = "default_true")]
144    pub enable_keepalive: bool,
145
146    /// HTTP/2 keepalive interval in seconds
147    #[serde(default = "default_keepalive_interval")]
148    pub keepalive_interval: u64,
149
150    /// HTTP/2 keepalive timeout in seconds
151    #[serde(default = "default_keepalive_timeout")]
152    pub keepalive_timeout: u64,
153
154    /// Total byte cap across an entire streaming response.
155    ///
156    /// When `Some(n)`, the streaming adapter aborts the stream with
157    /// `tonic::Status::resource_exhausted` once the cumulative encoded message
158    /// bytes exceed `n`. The stream yields the error item and then terminates.
159    ///
160    /// Per-message cap remains `max_message_size`. This limit applies to
161    /// server-streaming and bidirectional-streaming RPCs only; unary RPCs are
162    /// governed solely by `max_message_size`.
163    ///
164    /// Default: `None` (unbounded total response size).
165    #[serde(default)]
166    pub max_stream_response_bytes: Option<usize>,
167}
168
169impl Default for GrpcConfig {
170    fn default() -> Self {
171        Self {
172            enabled: true,
173            max_message_size: default_max_message_size(),
174            enable_compression: true,
175            request_timeout: None,
176            max_concurrent_streams: default_max_concurrent_streams(),
177            enable_keepalive: true,
178            keepalive_interval: default_keepalive_interval(),
179            keepalive_timeout: default_keepalive_timeout(),
180            max_stream_response_bytes: None,
181        }
182    }
183}
184
185const fn default_true() -> bool {
186    true
187}
188
189const fn default_max_message_size() -> usize {
190    4 * 1024 * 1024 // 4MB
191}
192
193const fn default_max_concurrent_streams() -> u32 {
194    100
195}
196
197const fn default_keepalive_interval() -> u64 {
198    75 // seconds
199}
200
201const fn default_keepalive_timeout() -> u64 {
202    20 // seconds
203}
204
205/// Registry for gRPC handlers
206///
207/// Maps service and method names to their handlers and RPC modes. Used by the
208/// server to route incoming gRPC requests to the appropriate handler method based
209/// on the parsed gRPC path.
210///
211/// # Example
212///
213/// ```ignore
214/// use spikard_http::grpc::{GrpcRegistry, RpcMode};
215/// use std::sync::Arc;
216///
217/// let mut registry = GrpcRegistry::new();
218/// registry.register("mypackage.UserService", "GetUser", Arc::new(user_handler), RpcMode::Unary);
219/// registry.register(
220///     "mypackage.StreamService",
221///     "StreamUsers",
222///     Arc::new(stream_handler),
223///     RpcMode::ServerStreaming,
224/// );
225/// ```
226type GrpcHandlerEntry = (Arc<dyn GrpcHandler>, RpcMode);
227const WILDCARD_METHOD: &str = "*";
228
229#[derive(Clone)]
230pub(crate) struct GrpcRegistry {
231    handlers: Arc<HashMap<(String, String), GrpcHandlerEntry>>,
232}
233
234impl GrpcRegistry {
235    /// Create a new empty gRPC handler registry
236    pub fn new() -> Self {
237        Self {
238            handlers: Arc::new(HashMap::new()),
239        }
240    }
241
242    /// Get a handler and its RPC mode by service and method name
243    ///
244    /// Returns both the handler and the RPC mode it was registered with. Exact
245    /// method matches take precedence over service-level wildcard handlers.
246    pub fn get(&self, service_name: &str, method_name: &str) -> Option<(Arc<dyn GrpcHandler>, RpcMode)> {
247        self.handlers
248            .get(&(service_name.to_owned(), method_name.to_owned()))
249            .or_else(|| {
250                self.handlers
251                    .get(&(service_name.to_owned(), WILDCARD_METHOD.to_owned()))
252            })
253            .cloned()
254    }
255
256    /// Check if a service has any registered handlers.
257    pub fn contains_service(&self, service_name: &str) -> bool {
258        self.handlers
259            .keys()
260            .any(|(registered_service, _)| registered_service == service_name)
261    }
262
263    /// Check if the registry is empty
264    pub fn is_empty(&self) -> bool {
265        self.handlers.is_empty()
266    }
267}
268
269#[cfg(test)]
270impl GrpcRegistry {
271    /// Register a gRPC handler for a specific service method (test helper)
272    pub fn register(
273        &mut self,
274        service_name: impl Into<String>,
275        method_name: impl Into<String>,
276        handler: Arc<dyn GrpcHandler>,
277        rpc_mode: RpcMode,
278    ) {
279        let handlers = Arc::make_mut(&mut self.handlers);
280        handlers.insert((service_name.into(), method_name.into()), (handler, rpc_mode));
281    }
282
283    /// Register a gRPC handler for an entire service (test helper)
284    pub fn register_service(
285        &mut self,
286        service_name: impl Into<String>,
287        handler: Arc<dyn GrpcHandler>,
288        rpc_mode: RpcMode,
289    ) {
290        self.register(service_name, WILDCARD_METHOD, handler, rpc_mode);
291    }
292
293    /// Get all registered service names (test helper)
294    pub fn service_names(&self) -> Vec<String> {
295        self.handlers
296            .keys()
297            .map(|(service_name, _)| service_name.clone())
298            .collect::<HashSet<_>>()
299            .into_iter()
300            .collect()
301    }
302
303    /// Get all explicitly registered method names for a service (test helper)
304    pub fn method_names(&self, service_name: &str) -> Vec<String> {
305        self.handlers
306            .keys()
307            .filter(|(registered_service, method_name)| {
308                registered_service == service_name && method_name.as_str() != WILDCARD_METHOD
309            })
310            .map(|(_, method_name)| method_name.clone())
311            .collect()
312    }
313
314    /// Check if a specific service method is registered (test helper)
315    pub fn contains(&self, service_name: &str, method_name: &str) -> bool {
316        self.handlers
317            .contains_key(&(service_name.to_owned(), method_name.to_owned()))
318    }
319
320    /// Get the number of registered services (test helper)
321    pub fn len(&self) -> usize {
322        self.handlers.len()
323    }
324}
325
326impl Default for GrpcRegistry {
327    fn default() -> Self {
328        Self::new()
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use crate::grpc::handler::{GrpcHandler, GrpcHandlerResult, GrpcRequestData};
336    use std::future::Future;
337    use std::pin::Pin;
338
339    struct TestHandler;
340
341    impl GrpcHandler for TestHandler {
342        fn call(&self, _request: GrpcRequestData) -> Pin<Box<dyn Future<Output = GrpcHandlerResult> + Send>> {
343            Box::pin(async {
344                Ok(GrpcResponseData {
345                    payload: bytes::Bytes::new(),
346                    metadata: tonic::metadata::MetadataMap::new(),
347                })
348            })
349        }
350
351        fn service_name(&self) -> &'static str {
352            // Since we can't return a reference to self.0 with 'static lifetime,
353            // we need to use a workaround. In real usage, service names should be static.
354            "test.Service"
355        }
356    }
357
358    #[test]
359    fn test_grpc_config_default() {
360        let config = GrpcConfig::default();
361        assert!(config.enabled);
362        assert_eq!(config.max_message_size, 4 * 1024 * 1024);
363        assert!(config.enable_compression);
364        assert!(config.request_timeout.is_none());
365        assert_eq!(config.max_concurrent_streams, 100);
366        assert!(config.enable_keepalive);
367        assert_eq!(config.keepalive_interval, 75);
368        assert_eq!(config.keepalive_timeout, 20);
369        assert!(config.max_stream_response_bytes.is_none());
370    }
371
372    #[test]
373    fn test_grpc_config_serialization() {
374        let config = GrpcConfig::default();
375        let json = serde_json::to_string(&config).unwrap();
376        let deserialized: GrpcConfig = serde_json::from_str(&json).unwrap();
377
378        assert_eq!(config.enabled, deserialized.enabled);
379        assert_eq!(config.max_message_size, deserialized.max_message_size);
380        assert_eq!(config.enable_compression, deserialized.enable_compression);
381    }
382
383    #[test]
384    fn test_grpc_registry_new() {
385        let registry = GrpcRegistry::new();
386        assert!(registry.is_empty());
387        assert_eq!(registry.len(), 0);
388    }
389
390    #[test]
391    fn test_grpc_registry_register() {
392        let mut registry = GrpcRegistry::new();
393        let handler = Arc::new(TestHandler);
394
395        registry.register("test.Service", "TestMethod", handler, RpcMode::Unary);
396
397        assert!(!registry.is_empty());
398        assert_eq!(registry.len(), 1);
399        assert!(registry.contains("test.Service", "TestMethod"));
400    }
401
402    #[test]
403    fn test_grpc_registry_get() {
404        let mut registry = GrpcRegistry::new();
405        let handler = Arc::new(TestHandler);
406
407        registry.register("test.Service", "TestMethod", handler, RpcMode::Unary);
408
409        let retrieved = registry.get("test.Service", "TestMethod");
410        assert!(retrieved.is_some());
411        let (handler, rpc_mode) = retrieved.unwrap();
412        assert_eq!(handler.service_name(), "test.Service");
413        assert_eq!(rpc_mode, RpcMode::Unary);
414    }
415
416    #[test]
417    fn test_grpc_registry_get_nonexistent() {
418        let registry = GrpcRegistry::new();
419        let result = registry.get("nonexistent.Service", "MissingMethod");
420        assert!(result.is_none());
421    }
422
423    #[test]
424    fn test_grpc_registry_service_names() {
425        let mut registry = GrpcRegistry::new();
426
427        registry.register("service1", "Method1", Arc::new(TestHandler), RpcMode::Unary);
428        registry.register("service2", "Method2", Arc::new(TestHandler), RpcMode::ServerStreaming);
429        registry.register("service3", "Method3", Arc::new(TestHandler), RpcMode::Unary);
430
431        let mut names = registry.service_names();
432        names.sort();
433
434        assert_eq!(names, vec!["service1", "service2", "service3"]);
435    }
436
437    #[test]
438    fn test_grpc_registry_contains() {
439        let mut registry = GrpcRegistry::new();
440        registry.register("test.Service", "TestMethod", Arc::new(TestHandler), RpcMode::Unary);
441
442        assert!(registry.contains("test.Service", "TestMethod"));
443        assert!(!registry.contains("other.Service", "TestMethod"));
444    }
445
446    #[test]
447    fn test_grpc_registry_multiple_services() {
448        let mut registry = GrpcRegistry::new();
449
450        registry.register("user.Service", "GetUser", Arc::new(TestHandler), RpcMode::Unary);
451        registry.register(
452            "post.Service",
453            "ListPosts",
454            Arc::new(TestHandler),
455            RpcMode::ServerStreaming,
456        );
457
458        assert_eq!(registry.len(), 2);
459        assert!(registry.contains("user.Service", "GetUser"));
460        assert!(registry.contains("post.Service", "ListPosts"));
461    }
462
463    #[test]
464    fn test_grpc_registry_clone() {
465        let mut registry = GrpcRegistry::new();
466        registry.register("test.Service", "TestMethod", Arc::new(TestHandler), RpcMode::Unary);
467
468        let cloned = registry.clone();
469
470        assert_eq!(cloned.len(), 1);
471        assert!(cloned.contains("test.Service", "TestMethod"));
472    }
473
474    #[test]
475    fn test_grpc_registry_default() {
476        let registry = GrpcRegistry::default();
477        assert!(registry.is_empty());
478    }
479
480    #[test]
481    fn test_grpc_registry_rpc_mode_storage() {
482        let mut registry = GrpcRegistry::new();
483
484        registry.register("unary.Service", "UnaryMethod", Arc::new(TestHandler), RpcMode::Unary);
485        registry.register(
486            "server_stream.Service",
487            "StreamMethod",
488            Arc::new(TestHandler),
489            RpcMode::ServerStreaming,
490        );
491        registry.register(
492            "client_stream.Service",
493            "UploadMethod",
494            Arc::new(TestHandler),
495            RpcMode::ClientStreaming,
496        );
497        registry.register(
498            "bidi.Service",
499            "ChatMethod",
500            Arc::new(TestHandler),
501            RpcMode::BidirectionalStreaming,
502        );
503
504        let (_, mode) = registry.get("unary.Service", "UnaryMethod").unwrap();
505        assert_eq!(mode, RpcMode::Unary);
506
507        let (_, mode) = registry.get("server_stream.Service", "StreamMethod").unwrap();
508        assert_eq!(mode, RpcMode::ServerStreaming);
509
510        let (_, mode) = registry.get("client_stream.Service", "UploadMethod").unwrap();
511        assert_eq!(mode, RpcMode::ClientStreaming);
512
513        let (_, mode) = registry.get("bidi.Service", "ChatMethod").unwrap();
514        assert_eq!(mode, RpcMode::BidirectionalStreaming);
515    }
516
517    #[test]
518    fn test_grpc_registry_service_fallback() {
519        let mut registry = GrpcRegistry::new();
520        registry.register_service("test.Service", Arc::new(TestHandler), RpcMode::Unary);
521
522        assert!(registry.contains_service("test.Service"));
523        assert!(registry.get("test.Service", "AnyMethod").is_some());
524        assert!(registry.method_names("test.Service").is_empty());
525    }
526
527    #[test]
528    fn test_grpc_registry_prefers_method_specific_handler() {
529        struct MethodSpecificHandler;
530
531        impl GrpcHandler for MethodSpecificHandler {
532            fn call(&self, _request: GrpcRequestData) -> Pin<Box<dyn Future<Output = GrpcHandlerResult> + Send>> {
533                Box::pin(async {
534                    Ok(GrpcResponseData {
535                        payload: bytes::Bytes::from("method-specific"),
536                        metadata: tonic::metadata::MetadataMap::new(),
537                    })
538                })
539            }
540
541            fn service_name(&self) -> &str {
542                "test.Service"
543            }
544        }
545
546        let mut registry = GrpcRegistry::new();
547        registry.register_service("test.Service", Arc::new(TestHandler), RpcMode::Unary);
548        registry.register(
549            "test.Service",
550            "GetThing",
551            Arc::new(MethodSpecificHandler),
552            RpcMode::ServerStreaming,
553        );
554
555        let (_, mode) = registry.get("test.Service", "GetThing").unwrap();
556        assert_eq!(mode, RpcMode::ServerStreaming);
557        let (_, fallback_mode) = registry.get("test.Service", "OtherThing").unwrap();
558        assert_eq!(fallback_mode, RpcMode::Unary);
559    }
560}