1use std::net::SocketAddr;
7use std::sync::Arc;
8
9use http_body_util::BodyExt;
10use prost::bytes::Bytes;
11use shaperail_core::{GrpcConfig, ResourceDefinition};
12use tokio::task::JoinHandle;
13use tonic::server::NamedService;
14use tonic::transport::Server;
15use tonic::Status;
16
17use super::service;
18use crate::auth::extractor::AuthenticatedUser;
19use crate::auth::jwt::JwtConfig;
20use crate::handlers::crud::AppState;
21
22pub struct GrpcServerHandle {
24 pub handle: JoinHandle<Result<(), tonic::transport::Error>>,
25 pub addr: SocketAddr,
26}
27
28#[derive(Clone)]
30pub struct ShaperailGrpcService {
31 state: Arc<AppState>,
32 resources: Vec<ResourceDefinition>,
33 jwt_config: Option<Arc<JwtConfig>>,
34}
35
36impl ShaperailGrpcService {
37 pub fn new(
38 state: Arc<AppState>,
39 resources: Vec<ResourceDefinition>,
40 jwt_config: Option<Arc<JwtConfig>>,
41 ) -> Self {
42 Self {
43 state,
44 resources,
45 jwt_config,
46 }
47 }
48
49 pub fn parse_grpc_path(path: &str) -> Option<(String, String)> {
52 let path = path.strip_prefix('/')?;
53 let (service_part, method) = path.split_once('/')?;
54 let parts: Vec<&str> = service_part.split('.').collect();
55 if parts.len() >= 4 && parts[0] == "shaperail" {
56 let resource_name = parts[2].to_string();
57 Some((resource_name, method.to_string()))
58 } else {
59 None
60 }
61 }
62
63 async fn handle_request(
65 &self,
66 resource_name: &str,
67 method_name: &str,
68 user: Option<&AuthenticatedUser>,
69 body: &[u8],
70 ) -> Result<GrpcResponse, Status> {
71 let resource = self
72 .resources
73 .iter()
74 .find(|r| r.resource == resource_name)
75 .ok_or_else(|| Status::not_found(format!("Unknown resource: {resource_name}")))?;
76
77 if method_name.starts_with("Get") {
78 let data = service::handle_get(self.state.clone(), resource, user, body).await?;
79 Ok(GrpcResponse::Unary(data))
80 } else if method_name.starts_with("Stream") {
81 let items =
82 service::handle_stream_list(self.state.clone(), resource, user, body).await?;
83 Ok(GrpcResponse::Stream(items))
84 } else if method_name.starts_with("List") {
85 let data = service::handle_list(self.state.clone(), resource, user, body).await?;
86 Ok(GrpcResponse::Unary(data))
87 } else if method_name.starts_with("Create") {
88 let data = service::handle_create(self.state.clone(), resource, user, body).await?;
89 Ok(GrpcResponse::Unary(data))
90 } else if method_name.starts_with("Update") {
91 Err(Status::unimplemented("Update not yet implemented"))
92 } else if method_name.starts_with("Delete") {
93 let data = service::handle_delete(self.state.clone(), resource, user, body).await?;
94 Ok(GrpcResponse::Unary(data))
95 } else {
96 Err(Status::unimplemented(format!(
97 "Unknown method: {method_name}"
98 )))
99 }
100 }
101}
102
103enum GrpcResponse {
104 Unary(Bytes),
105 Stream(Vec<Bytes>),
106}
107
108type TonicBody = tonic::body::BoxBody;
110
111#[derive(Clone)]
113struct ShaperailGrpcServiceServer {
114 inner: ShaperailGrpcService,
115}
116
117impl NamedService for ShaperailGrpcServiceServer {
118 const NAME: &'static str = "shaperail";
119}
120
121impl tower::Service<http::Request<TonicBody>> for ShaperailGrpcServiceServer {
122 type Response = http::Response<TonicBody>;
123 type Error = std::convert::Infallible;
124 type Future = std::pin::Pin<
125 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
126 >;
127
128 fn poll_ready(
129 &mut self,
130 _cx: &mut std::task::Context<'_>,
131 ) -> std::task::Poll<Result<(), Self::Error>> {
132 std::task::Poll::Ready(Ok(()))
133 }
134
135 fn call(&mut self, req: http::Request<TonicBody>) -> Self::Future {
136 let inner = self.inner.clone();
137
138 Box::pin(async move {
139 let path = req.uri().path().to_string();
140
141 let user = extract_user_from_headers(req.headers(), inner.jwt_config.as_deref());
143
144 let body_bytes = collect_body(req.into_body()).await;
146
147 let message_data = if body_bytes.len() >= 5 {
149 &body_bytes[5..]
150 } else {
151 &body_bytes[..]
152 };
153
154 let (resource_name, method_name) = match ShaperailGrpcService::parse_grpc_path(&path) {
156 Some(v) => v,
157 None => {
158 return Ok(grpc_error_response(
159 tonic::Code::Unimplemented,
160 &format!("Unknown path: {path}"),
161 ));
162 }
163 };
164
165 match inner
166 .handle_request(&resource_name, &method_name, user.as_ref(), message_data)
167 .await
168 {
169 Ok(GrpcResponse::Unary(data)) => Ok(grpc_data_response(&data)),
170 Ok(GrpcResponse::Stream(items)) => {
171 let mut combined = Vec::new();
172 for item in &items {
173 let len = item.len() as u32;
174 combined.push(0u8);
175 combined.extend_from_slice(&len.to_be_bytes());
176 combined.extend_from_slice(item);
177 }
178 Ok(grpc_data_response(&combined))
179 }
180 Err(status) => Ok(grpc_error_response(status.code(), status.message())),
181 }
182 })
183 }
184}
185
186fn extract_user_from_headers(
188 headers: &http::HeaderMap,
189 jwt_config: Option<&JwtConfig>,
190) -> Option<AuthenticatedUser> {
191 let auth_str = headers.get("authorization")?.to_str().ok()?;
192 let token = auth_str.strip_prefix("Bearer ")?;
193 let jwt = jwt_config?;
194 let claims = jwt.decode(token).ok()?;
195 if claims.token_type != "access" {
196 return None;
197 }
198 Some(AuthenticatedUser {
199 id: claims.sub,
200 role: claims.role,
201 })
202}
203
204async fn collect_body(body: TonicBody) -> Bytes {
206 use http_body_util::BodyExt;
207 match body.collect().await {
208 Ok(collected) => collected.to_bytes(),
209 Err(_) => Bytes::new(),
210 }
211}
212
213fn grpc_data_response(data: &[u8]) -> http::Response<TonicBody> {
215 let mut frame = Vec::with_capacity(5 + data.len());
217 frame.push(0u8);
218 let len = data.len() as u32;
219 frame.extend_from_slice(&len.to_be_bytes());
220 frame.extend_from_slice(data);
221
222 let body = http_body_util::Full::new(Bytes::from(frame))
223 .map_err(|never: std::convert::Infallible| match never {});
224 let boxed = TonicBody::new(body);
225
226 http::Response::builder()
227 .status(200)
228 .header("content-type", "application/grpc")
229 .header("grpc-status", "0")
230 .body(boxed)
231 .unwrap_or_else(|_| empty_grpc_response(13, "Internal error"))
232}
233
234fn grpc_error_response(code: tonic::Code, message: &str) -> http::Response<TonicBody> {
236 empty_grpc_response(code as i32, message)
237}
238
239fn empty_grpc_response(code: i32, message: &str) -> http::Response<TonicBody> {
241 let body = http_body_util::Full::new(Bytes::new())
242 .map_err(|never: std::convert::Infallible| match never {});
243 let boxed = TonicBody::new(body);
244
245 http::Response::builder()
246 .status(200)
247 .header("content-type", "application/grpc")
248 .header("grpc-status", code.to_string())
249 .header("grpc-message", message)
250 .body(boxed)
251 .unwrap_or_else(|_| {
252 let fb = http_body_util::Full::new(Bytes::new())
254 .map_err(|never: std::convert::Infallible| match never {});
255 http::Response::new(TonicBody::new(fb))
256 })
257}
258
259pub async fn build_grpc_server(
264 state: Arc<AppState>,
265 resources: Vec<ResourceDefinition>,
266 jwt_config: Option<Arc<JwtConfig>>,
267 grpc_config: Option<&GrpcConfig>,
268) -> Result<GrpcServerHandle, Box<dyn std::error::Error + Send + Sync>> {
269 let port = grpc_config.map(|c| c.port).unwrap_or(50051);
270 let reflection_enabled = grpc_config.map(|c| c.reflection).unwrap_or(true);
271
272 let addr: SocketAddr = format!("0.0.0.0:{port}").parse()?;
273
274 let svc = ShaperailGrpcService::new(state, resources.clone(), jwt_config);
275 let grpc_service = ShaperailGrpcServiceServer { inner: svc };
276
277 let (mut health_reporter, health_service) = tonic_health::server::health_reporter();
279 health_reporter
280 .set_serving::<ShaperailGrpcServiceServer>()
281 .await;
282
283 for resource in &resources {
284 let pascal = to_pascal_case(&to_singular(&resource.resource));
285 let service_name = format!(
286 "shaperail.v{}.{}.{}Service",
287 resource.version, resource.resource, pascal
288 );
289 health_reporter
290 .set_service_status(&service_name, tonic_health::ServingStatus::Serving)
291 .await;
292 }
293
294 let mut builder = Server::builder();
295
296 let handle = if reflection_enabled {
297 let reflection_service = tonic_reflection::server::Builder::configure()
298 .build_v1()
299 .map_err(|e| format!("Failed to build reflection service: {e}"))?;
300
301 let router = builder
302 .add_service(health_service)
303 .add_service(reflection_service)
304 .add_service(grpc_service);
305
306 tokio::spawn(async move { router.serve(addr).await })
307 } else {
308 let router = builder
309 .add_service(health_service)
310 .add_service(grpc_service);
311
312 tokio::spawn(async move { router.serve(addr).await })
313 };
314
315 tracing::info!("gRPC server listening on {addr}");
316
317 Ok(GrpcServerHandle { handle, addr })
318}
319
320fn to_pascal_case(s: &str) -> String {
321 s.split('_')
322 .map(|part| {
323 let mut chars = part.chars();
324 match chars.next() {
325 Some(c) => {
326 let upper: String = c.to_uppercase().collect();
327 upper + chars.as_str()
328 }
329 None => String::new(),
330 }
331 })
332 .collect()
333}
334
335fn to_singular(s: &str) -> String {
336 const EXCEPTIONS: &[&str] = &["status", "bus", "alias", "canvas"];
337 if EXCEPTIONS.iter().any(|e| s.ends_with(e)) {
338 return s.to_string();
339 }
340 if let Some(stripped) = s.strip_suffix("ies") {
341 format!("{stripped}y")
342 } else if s.ends_with("ses") || s.ends_with("xes") || s.ends_with("zes") {
343 s[..s.len() - 2].to_string()
344 } else if let Some(stripped) = s.strip_suffix('s') {
345 if stripped.ends_with('s') {
346 s.to_string()
347 } else {
348 stripped.to_string()
349 }
350 } else {
351 s.to_string()
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn parse_grpc_path_valid() {
361 let result =
362 ShaperailGrpcService::parse_grpc_path("/shaperail.v1.users.UserService/GetUser");
363 assert_eq!(result, Some(("users".to_string(), "GetUser".to_string())));
364 }
365
366 #[test]
367 fn parse_grpc_path_list() {
368 let result =
369 ShaperailGrpcService::parse_grpc_path("/shaperail.v1.orders.OrderService/ListOrders");
370 assert_eq!(
371 result,
372 Some(("orders".to_string(), "ListOrders".to_string()))
373 );
374 }
375
376 #[test]
377 fn parse_grpc_path_invalid() {
378 assert!(ShaperailGrpcService::parse_grpc_path("/invalid").is_none());
379 assert!(ShaperailGrpcService::parse_grpc_path("").is_none());
380 }
381
382 #[test]
383 fn parse_grpc_path_stream() {
384 let result =
385 ShaperailGrpcService::parse_grpc_path("/shaperail.v1.users.UserService/StreamUsers");
386 assert_eq!(
387 result,
388 Some(("users".to_string(), "StreamUsers".to_string()))
389 );
390 }
391
392 #[test]
393 fn pascal_and_singular() {
394 assert_eq!(to_pascal_case("user"), "User");
395 assert_eq!(to_pascal_case("blog_post"), "BlogPost");
396 assert_eq!(to_singular("users"), "user");
397 assert_eq!(to_singular("categories"), "category");
398 }
399
400 #[test]
401 fn extract_user_no_header() {
402 let headers = http::HeaderMap::new();
403 assert!(extract_user_from_headers(&headers, None).is_none());
404 }
405
406 #[test]
407 fn extract_user_valid_token() {
408 let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
409 let token = jwt.encode_access("user-1", "admin").unwrap();
410
411 let mut headers = http::HeaderMap::new();
412 headers.insert(
413 "authorization",
414 http::HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
415 );
416
417 let user = extract_user_from_headers(&headers, Some(&jwt));
418 assert!(user.is_some());
419 let user = user.unwrap();
420 assert_eq!(user.id, "user-1");
421 assert_eq!(user.role, "admin");
422 }
423
424 #[test]
425 fn extract_user_invalid_token() {
426 let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
427
428 let mut headers = http::HeaderMap::new();
429 headers.insert(
430 "authorization",
431 http::HeaderValue::from_str("Bearer invalid.token.here").unwrap(),
432 );
433
434 assert!(extract_user_from_headers(&headers, Some(&jwt)).is_none());
435 }
436}