1use std::net::SocketAddr;
7use std::sync::Arc;
8
9use http_body_util::BodyExt;
10use prost::bytes::Bytes;
11use shaperail_core::{GrpcConfig, ResourceDefinition};
12use tokio::task::JoinHandle;
13use tonic::server::NamedService;
14use tonic::transport::Server;
15use tonic::Status;
16
17use super::service;
18use crate::auth::extractor::AuthenticatedUser;
19use crate::auth::jwt::JwtConfig;
20use crate::handlers::crud::AppState;
21
22pub struct GrpcServerHandle {
24 pub handle: JoinHandle<Result<(), tonic::transport::Error>>,
25 pub addr: SocketAddr,
26}
27
28#[derive(Clone)]
30pub struct ShaperailGrpcService {
31 state: Arc<AppState>,
32 resources: Vec<ResourceDefinition>,
33 jwt_config: Option<Arc<JwtConfig>>,
34}
35
36impl ShaperailGrpcService {
37 pub fn new(
38 state: Arc<AppState>,
39 resources: Vec<ResourceDefinition>,
40 jwt_config: Option<Arc<JwtConfig>>,
41 ) -> Self {
42 Self {
43 state,
44 resources,
45 jwt_config,
46 }
47 }
48
49 pub fn parse_grpc_path(path: &str) -> Option<(String, String)> {
52 let path = path.strip_prefix('/')?;
53 let (service_part, method) = path.split_once('/')?;
54 let parts: Vec<&str> = service_part.split('.').collect();
55 if parts.len() >= 4 && parts[0] == "shaperail" {
56 let resource_name = parts[2].to_string();
57 Some((resource_name, method.to_string()))
58 } else {
59 None
60 }
61 }
62
63 async fn handle_request(
65 &self,
66 resource_name: &str,
67 method_name: &str,
68 user: Option<&AuthenticatedUser>,
69 body: &[u8],
70 ) -> Result<GrpcResponse, Status> {
71 let resource = self
72 .resources
73 .iter()
74 .find(|r| r.resource == resource_name)
75 .ok_or_else(|| Status::not_found(format!("Unknown resource: {resource_name}")))?;
76
77 if method_name.starts_with("Get") {
78 let data = service::handle_get(self.state.clone(), resource, user, body).await?;
79 Ok(GrpcResponse::Unary(data))
80 } else if method_name.starts_with("Stream") {
81 let items =
82 service::handle_stream_list(self.state.clone(), resource, user, body).await?;
83 Ok(GrpcResponse::Stream(items))
84 } else if method_name.starts_with("List") {
85 let data = service::handle_list(self.state.clone(), resource, user, body).await?;
86 Ok(GrpcResponse::Unary(data))
87 } else if method_name.starts_with("Create") {
88 let data = service::handle_create(self.state.clone(), resource, user, body).await?;
89 Ok(GrpcResponse::Unary(data))
90 } else if method_name.starts_with("Update") {
91 Err(Status::unimplemented("Update not yet implemented"))
92 } else if method_name.starts_with("Delete") {
93 let data = service::handle_delete(self.state.clone(), resource, user, body).await?;
94 Ok(GrpcResponse::Unary(data))
95 } else {
96 Err(Status::unimplemented(format!(
97 "Unknown method: {method_name}"
98 )))
99 }
100 }
101}
102
103enum GrpcResponse {
104 Unary(Bytes),
105 Stream(Vec<Bytes>),
106}
107
108type TonicBody = tonic::body::BoxBody;
110
111#[derive(Clone)]
113struct ShaperailGrpcServiceServer {
114 inner: ShaperailGrpcService,
115}
116
117impl NamedService for ShaperailGrpcServiceServer {
118 const NAME: &'static str = "shaperail";
119}
120
121impl tower::Service<http::Request<TonicBody>> for ShaperailGrpcServiceServer {
122 type Response = http::Response<TonicBody>;
123 type Error = std::convert::Infallible;
124 type Future = std::pin::Pin<
125 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
126 >;
127
128 fn poll_ready(
129 &mut self,
130 _cx: &mut std::task::Context<'_>,
131 ) -> std::task::Poll<Result<(), Self::Error>> {
132 std::task::Poll::Ready(Ok(()))
133 }
134
135 fn call(&mut self, req: http::Request<TonicBody>) -> Self::Future {
136 let inner = self.inner.clone();
137
138 Box::pin(async move {
139 let path = req.uri().path().to_string();
140
141 let user = extract_user_from_headers(req.headers(), inner.jwt_config.as_deref());
143
144 let body_bytes = collect_body(req.into_body()).await;
146
147 let message_data = if body_bytes.len() >= 5 {
149 &body_bytes[5..]
150 } else {
151 &body_bytes[..]
152 };
153
154 let (resource_name, method_name) = match ShaperailGrpcService::parse_grpc_path(&path) {
156 Some(v) => v,
157 None => {
158 return Ok(grpc_error_response(
159 tonic::Code::Unimplemented,
160 &format!("Unknown path: {path}"),
161 ));
162 }
163 };
164
165 match inner
166 .handle_request(&resource_name, &method_name, user.as_ref(), message_data)
167 .await
168 {
169 Ok(GrpcResponse::Unary(data)) => Ok(grpc_data_response(&data)),
170 Ok(GrpcResponse::Stream(items)) => {
171 let mut combined = Vec::new();
172 for item in &items {
173 let len = item.len() as u32;
174 combined.push(0u8);
175 combined.extend_from_slice(&len.to_be_bytes());
176 combined.extend_from_slice(item);
177 }
178 Ok(grpc_data_response(&combined))
179 }
180 Err(status) => Ok(grpc_error_response(status.code(), status.message())),
181 }
182 })
183 }
184}
185
186fn extract_user_from_headers(
188 headers: &http::HeaderMap,
189 jwt_config: Option<&JwtConfig>,
190) -> Option<AuthenticatedUser> {
191 let auth_str = headers.get("authorization")?.to_str().ok()?;
192 let token = auth_str.strip_prefix("Bearer ")?;
193 let jwt = jwt_config?;
194 let claims = jwt.decode(token).ok()?;
195 if claims.token_type != "access" {
196 return None;
197 }
198 Some(AuthenticatedUser {
199 id: claims.sub,
200 role: claims.role,
201 tenant_id: None,
202 })
203}
204
205async fn collect_body(body: TonicBody) -> Bytes {
207 use http_body_util::BodyExt;
208 match body.collect().await {
209 Ok(collected) => collected.to_bytes(),
210 Err(_) => Bytes::new(),
211 }
212}
213
214fn grpc_data_response(data: &[u8]) -> http::Response<TonicBody> {
216 let mut frame = Vec::with_capacity(5 + data.len());
218 frame.push(0u8);
219 let len = data.len() as u32;
220 frame.extend_from_slice(&len.to_be_bytes());
221 frame.extend_from_slice(data);
222
223 let body = http_body_util::Full::new(Bytes::from(frame))
224 .map_err(|never: std::convert::Infallible| match never {});
225 let boxed = TonicBody::new(body);
226
227 http::Response::builder()
228 .status(200)
229 .header("content-type", "application/grpc")
230 .header("grpc-status", "0")
231 .body(boxed)
232 .unwrap_or_else(|_| empty_grpc_response(13, "Internal error"))
233}
234
235fn grpc_error_response(code: tonic::Code, message: &str) -> http::Response<TonicBody> {
237 empty_grpc_response(code as i32, message)
238}
239
240fn empty_grpc_response(code: i32, message: &str) -> http::Response<TonicBody> {
242 let body = http_body_util::Full::new(Bytes::new())
243 .map_err(|never: std::convert::Infallible| match never {});
244 let boxed = TonicBody::new(body);
245
246 http::Response::builder()
247 .status(200)
248 .header("content-type", "application/grpc")
249 .header("grpc-status", code.to_string())
250 .header("grpc-message", message)
251 .body(boxed)
252 .unwrap_or_else(|_| {
253 let fb = http_body_util::Full::new(Bytes::new())
255 .map_err(|never: std::convert::Infallible| match never {});
256 http::Response::new(TonicBody::new(fb))
257 })
258}
259
260pub async fn build_grpc_server(
265 state: Arc<AppState>,
266 resources: Vec<ResourceDefinition>,
267 jwt_config: Option<Arc<JwtConfig>>,
268 grpc_config: Option<&GrpcConfig>,
269) -> Result<GrpcServerHandle, Box<dyn std::error::Error + Send + Sync>> {
270 let port = grpc_config.map(|c| c.port).unwrap_or(50051);
271 let reflection_enabled = grpc_config.map(|c| c.reflection).unwrap_or(true);
272
273 let addr: SocketAddr = format!("0.0.0.0:{port}").parse()?;
274
275 let svc = ShaperailGrpcService::new(state, resources.clone(), jwt_config);
276 let grpc_service = ShaperailGrpcServiceServer { inner: svc };
277
278 let (mut health_reporter, health_service) = tonic_health::server::health_reporter();
280 health_reporter
281 .set_serving::<ShaperailGrpcServiceServer>()
282 .await;
283
284 for resource in &resources {
285 let pascal = to_pascal_case(&to_singular(&resource.resource));
286 let service_name = format!(
287 "shaperail.v{}.{}.{}Service",
288 resource.version, resource.resource, pascal
289 );
290 health_reporter
291 .set_service_status(&service_name, tonic_health::ServingStatus::Serving)
292 .await;
293 }
294
295 let mut builder = Server::builder();
296
297 let handle = if reflection_enabled {
298 let reflection_service = tonic_reflection::server::Builder::configure()
299 .build_v1()
300 .map_err(|e| format!("Failed to build reflection service: {e}"))?;
301
302 let router = builder
303 .add_service(health_service)
304 .add_service(reflection_service)
305 .add_service(grpc_service);
306
307 tokio::spawn(async move { router.serve(addr).await })
308 } else {
309 let router = builder
310 .add_service(health_service)
311 .add_service(grpc_service);
312
313 tokio::spawn(async move { router.serve(addr).await })
314 };
315
316 tracing::info!("gRPC server listening on {addr}");
317
318 Ok(GrpcServerHandle { handle, addr })
319}
320
321fn to_pascal_case(s: &str) -> String {
322 s.split('_')
323 .map(|part| {
324 let mut chars = part.chars();
325 match chars.next() {
326 Some(c) => {
327 let upper: String = c.to_uppercase().collect();
328 upper + chars.as_str()
329 }
330 None => String::new(),
331 }
332 })
333 .collect()
334}
335
336fn to_singular(s: &str) -> String {
337 const EXCEPTIONS: &[&str] = &["status", "bus", "alias", "canvas"];
338 if EXCEPTIONS.iter().any(|e| s.ends_with(e)) {
339 return s.to_string();
340 }
341 if let Some(stripped) = s.strip_suffix("ies") {
342 format!("{stripped}y")
343 } else if s.ends_with("ses") || s.ends_with("xes") || s.ends_with("zes") {
344 s[..s.len() - 2].to_string()
345 } else if let Some(stripped) = s.strip_suffix('s') {
346 if stripped.ends_with('s') {
347 s.to_string()
348 } else {
349 stripped.to_string()
350 }
351 } else {
352 s.to_string()
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn parse_grpc_path_valid() {
362 let result =
363 ShaperailGrpcService::parse_grpc_path("/shaperail.v1.users.UserService/GetUser");
364 assert_eq!(result, Some(("users".to_string(), "GetUser".to_string())));
365 }
366
367 #[test]
368 fn parse_grpc_path_list() {
369 let result =
370 ShaperailGrpcService::parse_grpc_path("/shaperail.v1.orders.OrderService/ListOrders");
371 assert_eq!(
372 result,
373 Some(("orders".to_string(), "ListOrders".to_string()))
374 );
375 }
376
377 #[test]
378 fn parse_grpc_path_invalid() {
379 assert!(ShaperailGrpcService::parse_grpc_path("/invalid").is_none());
380 assert!(ShaperailGrpcService::parse_grpc_path("").is_none());
381 }
382
383 #[test]
384 fn parse_grpc_path_stream() {
385 let result =
386 ShaperailGrpcService::parse_grpc_path("/shaperail.v1.users.UserService/StreamUsers");
387 assert_eq!(
388 result,
389 Some(("users".to_string(), "StreamUsers".to_string()))
390 );
391 }
392
393 #[test]
394 fn pascal_and_singular() {
395 assert_eq!(to_pascal_case("user"), "User");
396 assert_eq!(to_pascal_case("blog_post"), "BlogPost");
397 assert_eq!(to_singular("users"), "user");
398 assert_eq!(to_singular("categories"), "category");
399 }
400
401 #[test]
402 fn extract_user_no_header() {
403 let headers = http::HeaderMap::new();
404 assert!(extract_user_from_headers(&headers, None).is_none());
405 }
406
407 #[test]
408 fn extract_user_valid_token() {
409 let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
410 let token = jwt.encode_access("user-1", "admin").unwrap();
411
412 let mut headers = http::HeaderMap::new();
413 headers.insert(
414 "authorization",
415 http::HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
416 );
417
418 let user = extract_user_from_headers(&headers, Some(&jwt));
419 assert!(user.is_some());
420 let user = user.unwrap();
421 assert_eq!(user.id, "user-1");
422 assert_eq!(user.role, "admin");
423 }
424
425 #[test]
426 fn extract_user_invalid_token() {
427 let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
428
429 let mut headers = http::HeaderMap::new();
430 headers.insert(
431 "authorization",
432 http::HeaderValue::from_str("Bearer invalid.token.here").unwrap(),
433 );
434
435 assert!(extract_user_from_headers(&headers, Some(&jwt)).is_none());
436 }
437}