Skip to main content

spikard_http/grpc/
mod.rs

1//! gRPC runtime support for Spikard
2//!
3//! This module provides gRPC server infrastructure using Tonic, enabling
4//! Spikard to handle both HTTP/1.1 REST requests and HTTP/2 gRPC requests.
5//!
6//! # Architecture
7//!
8//! The gRPC support follows the same language-agnostic pattern as the HTTP handler:
9//!
10//! 1. **GrpcHandler trait**: Language-agnostic interface for handling gRPC requests
11//! 2. **Service bridge**: Converts between Tonic's types and our internal representation
12//! 3. **Streaming support**: Utilities for handling streaming RPCs
13//! 4. **Server integration**: Multiplexes HTTP/1.1 and HTTP/2 traffic
14//!
15//! # Example
16//!
17//! ```ignore
18//! use spikard_http::grpc::{GrpcHandler, GrpcRequestData, GrpcResponseData};
19//! use std::sync::Arc;
20//!
21//! // Implement GrpcHandler for your language binding
22//! struct MyGrpcHandler;
23//!
24//! impl GrpcHandler for MyGrpcHandler {
25//!     fn call(&self, request: GrpcRequestData) -> Pin<Box<dyn Future<Output = GrpcHandlerResult> + Send>> {
26//!         Box::pin(async move {
27//!             // Handle the gRPC request
28//!             Ok(GrpcResponseData {
29//!                 payload: bytes::Bytes::from("response"),
30//!                 metadata: tonic::metadata::MetadataMap::new(),
31//!             })
32//!         })
33//!     }
34//!
35//!     fn service_name(&self) -> &str {
36//!         "mypackage.MyService"
37//!     }
38//! }
39//!
40//! // Register with the server
41//! let handler = Arc::new(MyGrpcHandler);
42//! let config = GrpcConfig::default();
43//! ```
44
45pub mod framing;
46pub mod handler;
47pub mod service;
48pub mod streaming;
49
50// Re-export main types
51pub use framing::parse_grpc_client_stream;
52pub use handler::{GrpcHandler, GrpcHandlerResult, GrpcRequestData, GrpcResponseData, RpcMode};
53pub use service::{GenericGrpcService, copy_metadata, is_grpc_request, parse_grpc_path};
54pub use streaming::{MessageStream, StreamingRequest, StreamingResponse};
55
56use serde::{Deserialize, Serialize};
57use std::collections::{HashMap, HashSet};
58use std::sync::Arc;
59
60/// Configuration for gRPC support
61///
62/// Controls how the server handles gRPC requests, including compression,
63/// timeouts, and protocol settings.
64///
65/// # Stream Limits
66///
67/// This configuration enforces message-level size limits but delegates
68/// concurrent stream limiting to the HTTP/2 transport layer:
69///
70/// - **Message Size Limits**: The `max_message_size` field is enforced per
71///   individual message (request or response) in both unary and streaming RPCs.
72///   When a single message exceeds this limit, the request is rejected with
73///   `PAYLOAD_TOO_LARGE` (HTTP 413).
74///
75/// - **Concurrent Stream Limits**: The `max_concurrent_streams` is an advisory
76///   configuration passed to the HTTP/2 layer for connection-level stream
77///   negotiation. The HTTP/2 transport automatically enforces this limit and
78///   returns GOAWAY frames when exceeded. Applications should not rely on
79///   custom enforcement of this limit.
80///
81/// - **Stream Length Limits**: There is currently no built-in limit on the
82///   total number of messages in a stream. Handlers should implement their own
83///   message counting if needed. Future versions may add a `max_stream_response_bytes`
84///   field to limit total response size per stream.
85///
86/// # Example
87///
88/// ```ignore
89/// let mut config = GrpcConfig::default();
90/// config.max_message_size = 10 * 1024 * 1024; // 10MB per message
91/// config.max_concurrent_streams = 50; // Advised to HTTP/2 layer
92/// ```
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct GrpcConfig {
95    /// Enable gRPC support
96    #[serde(default = "default_true")]
97    pub enabled: bool,
98
99    /// Maximum message size in bytes (for both sending and receiving)
100    ///
101    /// This limit applies to individual messages in both unary and streaming RPCs.
102    /// When a single message exceeds this size, the request is rejected with HTTP 413
103    /// (Payload Too Large).
104    ///
105    /// Default: 4MB (4194304 bytes)
106    ///
107    /// # Note
108    /// This limit does NOT apply to the total response size in streaming RPCs.
109    /// For multi-message streams, the total response can exceed this limit as long
110    /// as each individual message stays within the limit.
111    #[serde(default = "default_max_message_size")]
112    pub max_message_size: usize,
113
114    /// Enable gzip compression for gRPC messages
115    #[serde(default = "default_true")]
116    pub enable_compression: bool,
117
118    /// Timeout for gRPC requests in seconds (None = no timeout)
119    #[serde(default)]
120    pub request_timeout: Option<u64>,
121
122    /// Maximum number of concurrent streams per connection (HTTP/2 advisory)
123    ///
124    /// This value is communicated to HTTP/2 clients as the server's flow control limit.
125    /// The HTTP/2 transport layer enforces this limit automatically via SETTINGS frames
126    /// and GOAWAY responses. Applications should NOT implement custom enforcement.
127    ///
128    /// Default: 100 streams per connection
129    ///
130    /// # Stream Limiting Strategy
131    /// - **Per Connection**: This limit applies per HTTP/2 connection, not globally
132    /// - **Transport Enforcement**: HTTP/2 handles all stream limiting; applications
133    ///   need not implement custom checks
134    /// - **Streaming Requests**: In server streaming or bidi streaming, each logical
135    ///   RPC consumes one stream slot. Message ordering within a stream follows
136    ///   HTTP/2 frame ordering.
137    ///
138    /// # Future Enhancement
139    /// A future `max_stream_response_bytes` field may be added to limit the total
140    /// response size in streaming RPCs (separate from per-message limits).
141    #[serde(default = "default_max_concurrent_streams")]
142    pub max_concurrent_streams: u32,
143
144    /// Enable HTTP/2 keepalive
145    #[serde(default = "default_true")]
146    pub enable_keepalive: bool,
147
148    /// HTTP/2 keepalive interval in seconds
149    #[serde(default = "default_keepalive_interval")]
150    pub keepalive_interval: u64,
151
152    /// HTTP/2 keepalive timeout in seconds
153    #[serde(default = "default_keepalive_timeout")]
154    pub keepalive_timeout: u64,
155    // Future extension point:
156    // pub max_stream_response_bytes: Option<usize>,  // Total bytes per streaming response
157}
158
159impl Default for GrpcConfig {
160    fn default() -> Self {
161        Self {
162            enabled: true,
163            max_message_size: default_max_message_size(),
164            enable_compression: true,
165            request_timeout: None,
166            max_concurrent_streams: default_max_concurrent_streams(),
167            enable_keepalive: true,
168            keepalive_interval: default_keepalive_interval(),
169            keepalive_timeout: default_keepalive_timeout(),
170        }
171    }
172}
173
174const fn default_true() -> bool {
175    true
176}
177
178const fn default_max_message_size() -> usize {
179    4 * 1024 * 1024 // 4MB
180}
181
182const fn default_max_concurrent_streams() -> u32 {
183    100
184}
185
186const fn default_keepalive_interval() -> u64 {
187    75 // seconds
188}
189
190const fn default_keepalive_timeout() -> u64 {
191    20 // seconds
192}
193
194/// Registry for gRPC handlers
195///
196/// Maps service and method names to their handlers and RPC modes. Used by the
197/// server to route incoming gRPC requests to the appropriate handler method based
198/// on the parsed gRPC path.
199///
200/// # Example
201///
202/// ```ignore
203/// use spikard_http::grpc::{GrpcRegistry, RpcMode};
204/// use std::sync::Arc;
205///
206/// let mut registry = GrpcRegistry::new();
207/// registry.register("mypackage.UserService", "GetUser", Arc::new(user_handler), RpcMode::Unary);
208/// registry.register(
209///     "mypackage.StreamService",
210///     "StreamUsers",
211///     Arc::new(stream_handler),
212///     RpcMode::ServerStreaming,
213/// );
214/// ```
215type GrpcHandlerEntry = (Arc<dyn GrpcHandler>, RpcMode);
216const WILDCARD_METHOD: &str = "*";
217
218#[derive(Clone)]
219pub struct GrpcRegistry {
220    handlers: Arc<HashMap<(String, String), GrpcHandlerEntry>>,
221}
222
223impl GrpcRegistry {
224    /// Create a new empty gRPC handler registry
225    pub fn new() -> Self {
226        Self {
227            handlers: Arc::new(HashMap::new()),
228        }
229    }
230
231    /// Register a gRPC handler for a specific service method
232    ///
233    /// # Arguments
234    ///
235    /// * `service_name` - Fully qualified service name (e.g., "mypackage.MyService")
236    /// * `method_name` - Method name (e.g., "GetUser")
237    /// * `handler` - Handler implementation for this service
238    /// * `rpc_mode` - The RPC mode this handler supports (Unary, ServerStreaming, etc.)
239    pub fn register(
240        &mut self,
241        service_name: impl Into<String>,
242        method_name: impl Into<String>,
243        handler: Arc<dyn GrpcHandler>,
244        rpc_mode: RpcMode,
245    ) {
246        let handlers = Arc::make_mut(&mut self.handlers);
247        handlers.insert((service_name.into(), method_name.into()), (handler, rpc_mode));
248    }
249
250    /// Register a gRPC handler for an entire service.
251    ///
252    /// This is a service-level fallback for bindings that route methods inside a
253    /// single handler object. Method-specific registrations take precedence over
254    /// these wildcard entries during request dispatch.
255    pub fn register_service(
256        &mut self,
257        service_name: impl Into<String>,
258        handler: Arc<dyn GrpcHandler>,
259        rpc_mode: RpcMode,
260    ) {
261        self.register(service_name, WILDCARD_METHOD, handler, rpc_mode);
262    }
263
264    /// Get a handler and its RPC mode by service and method name
265    ///
266    /// Returns both the handler and the RPC mode it was registered with. Exact
267    /// method matches take precedence over service-level wildcard handlers.
268    pub fn get(&self, service_name: &str, method_name: &str) -> Option<(Arc<dyn GrpcHandler>, RpcMode)> {
269        self.handlers
270            .get(&(service_name.to_owned(), method_name.to_owned()))
271            .or_else(|| {
272                self.handlers
273                    .get(&(service_name.to_owned(), WILDCARD_METHOD.to_owned()))
274            })
275            .cloned()
276    }
277
278    /// Get all registered service names
279    pub fn service_names(&self) -> Vec<String> {
280        self.handlers
281            .keys()
282            .map(|(service_name, _)| service_name.clone())
283            .collect::<HashSet<_>>()
284            .into_iter()
285            .collect()
286    }
287
288    /// Get all explicitly registered method names for a service.
289    pub fn method_names(&self, service_name: &str) -> Vec<String> {
290        self.handlers
291            .keys()
292            .filter(|(registered_service, method_name)| {
293                registered_service == service_name && method_name.as_str() != WILDCARD_METHOD
294            })
295            .map(|(_, method_name)| method_name.clone())
296            .collect()
297    }
298
299    /// Check if a specific service method is registered.
300    pub fn contains(&self, service_name: &str, method_name: &str) -> bool {
301        self.handlers
302            .contains_key(&(service_name.to_owned(), method_name.to_owned()))
303    }
304
305    /// Check if a service has any registered handlers.
306    pub fn contains_service(&self, service_name: &str) -> bool {
307        self.handlers
308            .keys()
309            .any(|(registered_service, _)| registered_service == service_name)
310    }
311
312    /// Get the number of registered services
313    pub fn len(&self) -> usize {
314        self.handlers.len()
315    }
316
317    /// Check if the registry is empty
318    pub fn is_empty(&self) -> bool {
319        self.handlers.is_empty()
320    }
321}
322
323impl Default for GrpcRegistry {
324    fn default() -> Self {
325        Self::new()
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use crate::grpc::handler::{GrpcHandler, GrpcHandlerResult, GrpcRequestData};
333    use std::future::Future;
334    use std::pin::Pin;
335
336    struct TestHandler;
337
338    impl GrpcHandler for TestHandler {
339        fn call(&self, _request: GrpcRequestData) -> Pin<Box<dyn Future<Output = GrpcHandlerResult> + Send>> {
340            Box::pin(async {
341                Ok(GrpcResponseData {
342                    payload: bytes::Bytes::new(),
343                    metadata: tonic::metadata::MetadataMap::new(),
344                })
345            })
346        }
347
348        fn service_name(&self) -> &'static str {
349            // Since we can't return a reference to self.0 with 'static lifetime,
350            // we need to use a workaround. In real usage, service names should be static.
351            "test.Service"
352        }
353    }
354
355    #[test]
356    fn test_grpc_config_default() {
357        let config = GrpcConfig::default();
358        assert!(config.enabled);
359        assert_eq!(config.max_message_size, 4 * 1024 * 1024);
360        assert!(config.enable_compression);
361        assert!(config.request_timeout.is_none());
362        assert_eq!(config.max_concurrent_streams, 100);
363        assert!(config.enable_keepalive);
364        assert_eq!(config.keepalive_interval, 75);
365        assert_eq!(config.keepalive_timeout, 20);
366    }
367
368    #[test]
369    fn test_grpc_config_serialization() {
370        let config = GrpcConfig::default();
371        let json = serde_json::to_string(&config).unwrap();
372        let deserialized: GrpcConfig = serde_json::from_str(&json).unwrap();
373
374        assert_eq!(config.enabled, deserialized.enabled);
375        assert_eq!(config.max_message_size, deserialized.max_message_size);
376        assert_eq!(config.enable_compression, deserialized.enable_compression);
377    }
378
379    #[test]
380    fn test_grpc_registry_new() {
381        let registry = GrpcRegistry::new();
382        assert!(registry.is_empty());
383        assert_eq!(registry.len(), 0);
384    }
385
386    #[test]
387    fn test_grpc_registry_register() {
388        let mut registry = GrpcRegistry::new();
389        let handler = Arc::new(TestHandler);
390
391        registry.register("test.Service", "TestMethod", handler, RpcMode::Unary);
392
393        assert!(!registry.is_empty());
394        assert_eq!(registry.len(), 1);
395        assert!(registry.contains("test.Service", "TestMethod"));
396    }
397
398    #[test]
399    fn test_grpc_registry_get() {
400        let mut registry = GrpcRegistry::new();
401        let handler = Arc::new(TestHandler);
402
403        registry.register("test.Service", "TestMethod", handler, RpcMode::Unary);
404
405        let retrieved = registry.get("test.Service", "TestMethod");
406        assert!(retrieved.is_some());
407        let (handler, rpc_mode) = retrieved.unwrap();
408        assert_eq!(handler.service_name(), "test.Service");
409        assert_eq!(rpc_mode, RpcMode::Unary);
410    }
411
412    #[test]
413    fn test_grpc_registry_get_nonexistent() {
414        let registry = GrpcRegistry::new();
415        let result = registry.get("nonexistent.Service", "MissingMethod");
416        assert!(result.is_none());
417    }
418
419    #[test]
420    fn test_grpc_registry_service_names() {
421        let mut registry = GrpcRegistry::new();
422
423        registry.register("service1", "Method1", Arc::new(TestHandler), RpcMode::Unary);
424        registry.register("service2", "Method2", Arc::new(TestHandler), RpcMode::ServerStreaming);
425        registry.register("service3", "Method3", Arc::new(TestHandler), RpcMode::Unary);
426
427        let mut names = registry.service_names();
428        names.sort();
429
430        assert_eq!(names, vec!["service1", "service2", "service3"]);
431    }
432
433    #[test]
434    fn test_grpc_registry_contains() {
435        let mut registry = GrpcRegistry::new();
436        registry.register("test.Service", "TestMethod", Arc::new(TestHandler), RpcMode::Unary);
437
438        assert!(registry.contains("test.Service", "TestMethod"));
439        assert!(!registry.contains("other.Service", "TestMethod"));
440    }
441
442    #[test]
443    fn test_grpc_registry_multiple_services() {
444        let mut registry = GrpcRegistry::new();
445
446        registry.register("user.Service", "GetUser", Arc::new(TestHandler), RpcMode::Unary);
447        registry.register(
448            "post.Service",
449            "ListPosts",
450            Arc::new(TestHandler),
451            RpcMode::ServerStreaming,
452        );
453
454        assert_eq!(registry.len(), 2);
455        assert!(registry.contains("user.Service", "GetUser"));
456        assert!(registry.contains("post.Service", "ListPosts"));
457    }
458
459    #[test]
460    fn test_grpc_registry_clone() {
461        let mut registry = GrpcRegistry::new();
462        registry.register("test.Service", "TestMethod", Arc::new(TestHandler), RpcMode::Unary);
463
464        let cloned = registry.clone();
465
466        assert_eq!(cloned.len(), 1);
467        assert!(cloned.contains("test.Service", "TestMethod"));
468    }
469
470    #[test]
471    fn test_grpc_registry_default() {
472        let registry = GrpcRegistry::default();
473        assert!(registry.is_empty());
474    }
475
476    #[test]
477    fn test_grpc_registry_rpc_mode_storage() {
478        let mut registry = GrpcRegistry::new();
479
480        registry.register("unary.Service", "UnaryMethod", Arc::new(TestHandler), RpcMode::Unary);
481        registry.register(
482            "server_stream.Service",
483            "StreamMethod",
484            Arc::new(TestHandler),
485            RpcMode::ServerStreaming,
486        );
487        registry.register(
488            "client_stream.Service",
489            "UploadMethod",
490            Arc::new(TestHandler),
491            RpcMode::ClientStreaming,
492        );
493        registry.register(
494            "bidi.Service",
495            "ChatMethod",
496            Arc::new(TestHandler),
497            RpcMode::BidirectionalStreaming,
498        );
499
500        let (_, mode) = registry.get("unary.Service", "UnaryMethod").unwrap();
501        assert_eq!(mode, RpcMode::Unary);
502
503        let (_, mode) = registry.get("server_stream.Service", "StreamMethod").unwrap();
504        assert_eq!(mode, RpcMode::ServerStreaming);
505
506        let (_, mode) = registry.get("client_stream.Service", "UploadMethod").unwrap();
507        assert_eq!(mode, RpcMode::ClientStreaming);
508
509        let (_, mode) = registry.get("bidi.Service", "ChatMethod").unwrap();
510        assert_eq!(mode, RpcMode::BidirectionalStreaming);
511    }
512
513    #[test]
514    fn test_grpc_registry_service_fallback() {
515        let mut registry = GrpcRegistry::new();
516        registry.register_service("test.Service", Arc::new(TestHandler), RpcMode::Unary);
517
518        assert!(registry.contains_service("test.Service"));
519        assert!(registry.get("test.Service", "AnyMethod").is_some());
520        assert!(registry.method_names("test.Service").is_empty());
521    }
522
523    #[test]
524    fn test_grpc_registry_prefers_method_specific_handler() {
525        struct MethodSpecificHandler;
526
527        impl GrpcHandler for MethodSpecificHandler {
528            fn call(&self, _request: GrpcRequestData) -> Pin<Box<dyn Future<Output = GrpcHandlerResult> + Send>> {
529                Box::pin(async {
530                    Ok(GrpcResponseData {
531                        payload: bytes::Bytes::from("method-specific"),
532                        metadata: tonic::metadata::MetadataMap::new(),
533                    })
534                })
535            }
536
537            fn service_name(&self) -> &str {
538                "test.Service"
539            }
540        }
541
542        let mut registry = GrpcRegistry::new();
543        registry.register_service("test.Service", Arc::new(TestHandler), RpcMode::Unary);
544        registry.register(
545            "test.Service",
546            "GetThing",
547            Arc::new(MethodSpecificHandler),
548            RpcMode::ServerStreaming,
549        );
550
551        let (_, mode) = registry.get("test.Service", "GetThing").unwrap();
552        assert_eq!(mode, RpcMode::ServerStreaming);
553        let (_, fallback_mode) = registry.get("test.Service", "OtherThing").unwrap();
554        assert_eq!(fallback_mode, RpcMode::Unary);
555    }
556}