Skip to main content

shaperail_runtime/grpc/
server.rs

1//! gRPC server builder (M16).
2//!
3//! Builds a Tonic gRPC server with dynamic resource services, JWT auth
4//! interceptor, server reflection, and health check.
5
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9use http_body_util::BodyExt;
10use prost::bytes::Bytes;
11use shaperail_core::{GrpcConfig, ResourceDefinition};
12use tokio::task::JoinHandle;
13use tonic::server::NamedService;
14use tonic::transport::Server;
15use tonic::Status;
16
17use super::service;
18use crate::auth::extractor::AuthenticatedUser;
19use crate::auth::jwt::JwtConfig;
20use crate::handlers::crud::AppState;
21
22/// Handle to a running gRPC server — can be used to await or abort.
23pub struct GrpcServerHandle {
24    pub handle: JoinHandle<Result<(), tonic::transport::Error>>,
25    pub addr: SocketAddr,
26}
27
28/// Dynamic gRPC service that routes to resource handlers based on path.
29#[derive(Clone)]
30pub struct ShaperailGrpcService {
31    state: Arc<AppState>,
32    resources: Vec<ResourceDefinition>,
33    jwt_config: Option<Arc<JwtConfig>>,
34}
35
36impl ShaperailGrpcService {
37    pub fn new(
38        state: Arc<AppState>,
39        resources: Vec<ResourceDefinition>,
40        jwt_config: Option<Arc<JwtConfig>>,
41    ) -> Self {
42        Self {
43            state,
44            resources,
45            jwt_config,
46        }
47    }
48
49    /// Parse a gRPC path like `/shaperail.v1.users.UserService/GetUser`
50    /// into (resource_name, method_name).
51    pub fn parse_grpc_path(path: &str) -> Option<(String, String)> {
52        let path = path.strip_prefix('/')?;
53        let (service_part, method) = path.split_once('/')?;
54        let parts: Vec<&str> = service_part.split('.').collect();
55        if parts.len() >= 4 && parts[0] == "shaperail" {
56            let resource_name = parts[2].to_string();
57            Some((resource_name, method.to_string()))
58        } else {
59            None
60        }
61    }
62
63    /// Handle a unary or server-streaming gRPC call.
64    async fn handle_request(
65        &self,
66        resource_name: &str,
67        method_name: &str,
68        user: Option<&AuthenticatedUser>,
69        body: &[u8],
70    ) -> Result<GrpcResponse, Status> {
71        let resource = self
72            .resources
73            .iter()
74            .find(|r| r.resource == resource_name)
75            .ok_or_else(|| Status::not_found(format!("Unknown resource: {resource_name}")))?;
76
77        if method_name.starts_with("Get") {
78            let data = service::handle_get(self.state.clone(), resource, user, body).await?;
79            Ok(GrpcResponse::Unary(data))
80        } else if method_name.starts_with("Stream") {
81            let items =
82                service::handle_stream_list(self.state.clone(), resource, user, body).await?;
83            Ok(GrpcResponse::Stream(items))
84        } else if method_name.starts_with("List") {
85            let data = service::handle_list(self.state.clone(), resource, user, body).await?;
86            Ok(GrpcResponse::Unary(data))
87        } else if method_name.starts_with("Create") {
88            let data = service::handle_create(self.state.clone(), resource, user, body).await?;
89            Ok(GrpcResponse::Unary(data))
90        } else if method_name.starts_with("Update") {
91            Err(Status::unimplemented("Update not yet implemented"))
92        } else if method_name.starts_with("Delete") {
93            let data = service::handle_delete(self.state.clone(), resource, user, body).await?;
94            Ok(GrpcResponse::Unary(data))
95        } else {
96            Err(Status::unimplemented(format!(
97                "Unknown method: {method_name}"
98            )))
99        }
100    }
101}
102
103enum GrpcResponse {
104    Unary(Bytes),
105    Stream(Vec<Bytes>),
106}
107
108/// The tonic body type used in 0.12.
109type TonicBody = tonic::body::BoxBody;
110
111/// Wrapper implementing tonic's Service trait for dynamic dispatch.
112#[derive(Clone)]
113struct ShaperailGrpcServiceServer {
114    inner: ShaperailGrpcService,
115}
116
117impl NamedService for ShaperailGrpcServiceServer {
118    const NAME: &'static str = "shaperail";
119}
120
121impl tower::Service<http::Request<TonicBody>> for ShaperailGrpcServiceServer {
122    type Response = http::Response<TonicBody>;
123    type Error = std::convert::Infallible;
124    type Future = std::pin::Pin<
125        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
126    >;
127
128    fn poll_ready(
129        &mut self,
130        _cx: &mut std::task::Context<'_>,
131    ) -> std::task::Poll<Result<(), Self::Error>> {
132        std::task::Poll::Ready(Ok(()))
133    }
134
135    fn call(&mut self, req: http::Request<TonicBody>) -> Self::Future {
136        let inner = self.inner.clone();
137
138        Box::pin(async move {
139            let path = req.uri().path().to_string();
140
141            // Extract auth from headers
142            let user = extract_user_from_headers(req.headers(), inner.jwt_config.as_deref());
143
144            // Collect body bytes
145            let body_bytes = collect_body(req.into_body()).await;
146
147            // Strip gRPC framing: 1 byte compression + 4 bytes length
148            let message_data = if body_bytes.len() >= 5 {
149                &body_bytes[5..]
150            } else {
151                &body_bytes[..]
152            };
153
154            // Parse path and dispatch
155            let (resource_name, method_name) = match ShaperailGrpcService::parse_grpc_path(&path) {
156                Some(v) => v,
157                None => {
158                    return Ok(grpc_error_response(
159                        tonic::Code::Unimplemented,
160                        &format!("Unknown path: {path}"),
161                    ));
162                }
163            };
164
165            match inner
166                .handle_request(&resource_name, &method_name, user.as_ref(), message_data)
167                .await
168            {
169                Ok(GrpcResponse::Unary(data)) => Ok(grpc_data_response(&data)),
170                Ok(GrpcResponse::Stream(items)) => {
171                    let mut combined = Vec::new();
172                    for item in &items {
173                        let len = item.len() as u32;
174                        combined.push(0u8);
175                        combined.extend_from_slice(&len.to_be_bytes());
176                        combined.extend_from_slice(item);
177                    }
178                    Ok(grpc_data_response(&combined))
179                }
180                Err(status) => Ok(grpc_error_response(status.code(), status.message())),
181            }
182        })
183    }
184}
185
186/// Extract a user from HTTP headers (for JWT auth via gRPC metadata).
187fn extract_user_from_headers(
188    headers: &http::HeaderMap,
189    jwt_config: Option<&JwtConfig>,
190) -> Option<AuthenticatedUser> {
191    let auth_str = headers.get("authorization")?.to_str().ok()?;
192    let token = auth_str.strip_prefix("Bearer ")?;
193    let jwt = jwt_config?;
194    let claims = jwt.decode(token).ok()?;
195    if claims.token_type != "access" {
196        return None;
197    }
198    Some(AuthenticatedUser {
199        id: claims.sub,
200        role: claims.role,
201        tenant_id: None,
202    })
203}
204
205/// Collect body bytes from a tonic BoxBody.
206async fn collect_body(body: TonicBody) -> Bytes {
207    use http_body_util::BodyExt;
208    match body.collect().await {
209        Ok(collected) => collected.to_bytes(),
210        Err(_) => Bytes::new(),
211    }
212}
213
214/// Build a successful gRPC response with data.
215fn grpc_data_response(data: &[u8]) -> http::Response<TonicBody> {
216    // gRPC frame: 0 (no compression) + 4 byte big-endian length + data
217    let mut frame = Vec::with_capacity(5 + data.len());
218    frame.push(0u8);
219    let len = data.len() as u32;
220    frame.extend_from_slice(&len.to_be_bytes());
221    frame.extend_from_slice(data);
222
223    let body = http_body_util::Full::new(Bytes::from(frame))
224        .map_err(|never: std::convert::Infallible| match never {});
225    let boxed = TonicBody::new(body);
226
227    http::Response::builder()
228        .status(200)
229        .header("content-type", "application/grpc")
230        .header("grpc-status", "0")
231        .body(boxed)
232        .unwrap_or_else(|_| empty_grpc_response(13, "Internal error"))
233}
234
235/// Build a gRPC error response.
236fn grpc_error_response(code: tonic::Code, message: &str) -> http::Response<TonicBody> {
237    empty_grpc_response(code as i32, message)
238}
239
240/// Build an empty gRPC response with status and message headers.
241fn empty_grpc_response(code: i32, message: &str) -> http::Response<TonicBody> {
242    let body = http_body_util::Full::new(Bytes::new())
243        .map_err(|never: std::convert::Infallible| match never {});
244    let boxed = TonicBody::new(body);
245
246    http::Response::builder()
247        .status(200)
248        .header("content-type", "application/grpc")
249        .header("grpc-status", code.to_string())
250        .header("grpc-message", message)
251        .body(boxed)
252        .unwrap_or_else(|_| {
253            // Last resort fallback
254            let fb = http_body_util::Full::new(Bytes::new())
255                .map_err(|never: std::convert::Infallible| match never {});
256            http::Response::new(TonicBody::new(fb))
257        })
258}
259
260/// Build and start the gRPC server.
261///
262/// Returns a `GrpcServerHandle` that can be awaited or aborted.
263/// The server runs on a separate port from the HTTP REST/GraphQL server.
264pub async fn build_grpc_server(
265    state: Arc<AppState>,
266    resources: Vec<ResourceDefinition>,
267    jwt_config: Option<Arc<JwtConfig>>,
268    grpc_config: Option<&GrpcConfig>,
269) -> Result<GrpcServerHandle, Box<dyn std::error::Error + Send + Sync>> {
270    let port = grpc_config.map(|c| c.port).unwrap_or(50051);
271    let reflection_enabled = grpc_config.map(|c| c.reflection).unwrap_or(true);
272
273    let addr: SocketAddr = format!("0.0.0.0:{port}").parse()?;
274
275    let svc = ShaperailGrpcService::new(state, resources.clone(), jwt_config);
276    let grpc_service = ShaperailGrpcServiceServer { inner: svc };
277
278    // Health service
279    let (mut health_reporter, health_service) = tonic_health::server::health_reporter();
280    health_reporter
281        .set_serving::<ShaperailGrpcServiceServer>()
282        .await;
283
284    for resource in &resources {
285        let pascal = to_pascal_case(&to_singular(&resource.resource));
286        let service_name = format!(
287            "shaperail.v{}.{}.{}Service",
288            resource.version, resource.resource, pascal
289        );
290        health_reporter
291            .set_service_status(&service_name, tonic_health::ServingStatus::Serving)
292            .await;
293    }
294
295    let mut builder = Server::builder();
296
297    let handle = if reflection_enabled {
298        let reflection_service = tonic_reflection::server::Builder::configure()
299            .build_v1()
300            .map_err(|e| format!("Failed to build reflection service: {e}"))?;
301
302        let router = builder
303            .add_service(health_service)
304            .add_service(reflection_service)
305            .add_service(grpc_service);
306
307        tokio::spawn(async move { router.serve(addr).await })
308    } else {
309        let router = builder
310            .add_service(health_service)
311            .add_service(grpc_service);
312
313        tokio::spawn(async move { router.serve(addr).await })
314    };
315
316    tracing::info!("gRPC server listening on {addr}");
317
318    Ok(GrpcServerHandle { handle, addr })
319}
320
321fn to_pascal_case(s: &str) -> String {
322    s.split('_')
323        .map(|part| {
324            let mut chars = part.chars();
325            match chars.next() {
326                Some(c) => {
327                    let upper: String = c.to_uppercase().collect();
328                    upper + chars.as_str()
329                }
330                None => String::new(),
331            }
332        })
333        .collect()
334}
335
336fn to_singular(s: &str) -> String {
337    const EXCEPTIONS: &[&str] = &["status", "bus", "alias", "canvas"];
338    if EXCEPTIONS.iter().any(|e| s.ends_with(e)) {
339        return s.to_string();
340    }
341    if let Some(stripped) = s.strip_suffix("ies") {
342        format!("{stripped}y")
343    } else if s.ends_with("ses") || s.ends_with("xes") || s.ends_with("zes") {
344        s[..s.len() - 2].to_string()
345    } else if let Some(stripped) = s.strip_suffix('s') {
346        if stripped.ends_with('s') {
347            s.to_string()
348        } else {
349            stripped.to_string()
350        }
351    } else {
352        s.to_string()
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn parse_grpc_path_valid() {
362        let result =
363            ShaperailGrpcService::parse_grpc_path("/shaperail.v1.users.UserService/GetUser");
364        assert_eq!(result, Some(("users".to_string(), "GetUser".to_string())));
365    }
366
367    #[test]
368    fn parse_grpc_path_list() {
369        let result =
370            ShaperailGrpcService::parse_grpc_path("/shaperail.v1.orders.OrderService/ListOrders");
371        assert_eq!(
372            result,
373            Some(("orders".to_string(), "ListOrders".to_string()))
374        );
375    }
376
377    #[test]
378    fn parse_grpc_path_invalid() {
379        assert!(ShaperailGrpcService::parse_grpc_path("/invalid").is_none());
380        assert!(ShaperailGrpcService::parse_grpc_path("").is_none());
381    }
382
383    #[test]
384    fn parse_grpc_path_stream() {
385        let result =
386            ShaperailGrpcService::parse_grpc_path("/shaperail.v1.users.UserService/StreamUsers");
387        assert_eq!(
388            result,
389            Some(("users".to_string(), "StreamUsers".to_string()))
390        );
391    }
392
393    #[test]
394    fn pascal_and_singular() {
395        assert_eq!(to_pascal_case("user"), "User");
396        assert_eq!(to_pascal_case("blog_post"), "BlogPost");
397        assert_eq!(to_singular("users"), "user");
398        assert_eq!(to_singular("categories"), "category");
399    }
400
401    #[test]
402    fn extract_user_no_header() {
403        let headers = http::HeaderMap::new();
404        assert!(extract_user_from_headers(&headers, None).is_none());
405    }
406
407    #[test]
408    fn extract_user_valid_token() {
409        let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
410        let token = jwt.encode_access("user-1", "admin").unwrap();
411
412        let mut headers = http::HeaderMap::new();
413        headers.insert(
414            "authorization",
415            http::HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
416        );
417
418        let user = extract_user_from_headers(&headers, Some(&jwt));
419        assert!(user.is_some());
420        let user = user.unwrap();
421        assert_eq!(user.id, "user-1");
422        assert_eq!(user.role, "admin");
423    }
424
425    #[test]
426    fn extract_user_invalid_token() {
427        let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
428
429        let mut headers = http::HeaderMap::new();
430        headers.insert(
431            "authorization",
432            http::HeaderValue::from_str("Bearer invalid.token.here").unwrap(),
433        );
434
435        assert!(extract_user_from_headers(&headers, Some(&jwt)).is_none());
436    }
437}