clawdb_server/grpc/
mod.rs1pub mod service;
2
3use std::{
4 future::Future,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use anyhow::{Context as _, Result};
11use http::{Request, Response, StatusCode};
12use tokio::net::TcpListener;
13use tokio_util::sync::CancellationToken;
14use tonic::{
15 body::empty_body,
16 transport::{Identity, Server, ServerTlsConfig},
17};
18use tower::{Layer, Service, ServiceBuilder};
19
20use crate::{
21 grpc::service::{
22 proto::{claw_db_service_server::ClawDbServiceServer, FILE_DESCRIPTOR_SET},
23 ClawDbServiceImpl,
24 },
25 state::AppState,
26};
27
28#[derive(Clone)]
29struct GrpcRateLimitLayer {
30 state: Arc<AppState>,
31}
32
33impl<S> Layer<S> for GrpcRateLimitLayer {
34 type Service = GrpcRateLimitService<S>;
35
36 fn layer(&self, inner: S) -> Self::Service {
37 GrpcRateLimitService {
38 inner,
39 state: self.state.clone(),
40 }
41 }
42}
43
44#[derive(Clone)]
45struct GrpcRateLimitService<S> {
46 inner: S,
47 state: Arc<AppState>,
48}
49
50impl<S, B> Service<Request<B>> for GrpcRateLimitService<S>
51where
52 S: Service<Request<B>, Response = Response<tonic::body::BoxBody>> + Send + Clone + 'static,
53 S::Future: Send + 'static,
54 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
55 B: Send + 'static,
56{
57 type Response = Response<tonic::body::BoxBody>;
58 type Error = S::Error;
59 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
60
61 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62 self.inner.poll_ready(cx)
63 }
64
65 fn call(&mut self, request: Request<B>) -> Self::Future {
66 let token = request
67 .headers()
68 .get("x-claw-session")
69 .and_then(|value| value.to_str().ok())
70 .unwrap_or("anonymous")
71 .to_string();
72
73 if let Err(not_until) = self.state.grpc_limiter.check_key(&token) {
74 let mut response = Response::new(empty_body());
75 *response.status_mut() = StatusCode::OK;
76 response.headers_mut().insert(
77 http::header::CONTENT_TYPE,
78 http::HeaderValue::from_static("application/grpc"),
79 );
80 response.headers_mut().insert(
81 http::HeaderName::from_static("grpc-status"),
82 http::HeaderValue::from_static("8"),
83 );
84 response.headers_mut().insert(
85 http::HeaderName::from_static("grpc-message"),
86 http::HeaderValue::from_static("rate limit exceeded"),
87 );
88 if let Ok(value) =
89 http::HeaderValue::from_str(&AppState::retry_after_seconds(¬_until).to_string())
90 {
91 response
92 .headers_mut()
93 .insert(http::HeaderName::from_static("retry-delay"), value);
94 }
95 return Box::pin(async move { Ok(response) });
96 }
97
98 Box::pin(self.inner.call(request))
99 }
100}
101
102pub async fn serve(
103 listener: TcpListener,
104 state: Arc<AppState>,
105 shutdown: CancellationToken,
106) -> Result<()> {
107 let reflection = tonic_reflection::server::Builder::configure()
108 .register_encoded_file_descriptor_set(FILE_DESCRIPTOR_SET)
109 .build_v1()
110 .context("failed to build gRPC reflection service")?;
111
112 let cert_path = std::env::var("CLAW_TLS_CERT_PATH").ok();
113 let key_path = std::env::var("CLAW_TLS_KEY_PATH").ok();
114
115 let builder = match (cert_path, key_path) {
116 (Some(cert_path), Some(key_path)) => {
117 let cert = tokio::fs::read(cert_path)
118 .await
119 .context("failed to read TLS certificate")?;
120 let key = tokio::fs::read(key_path)
121 .await
122 .context("failed to read TLS key")?;
123 let identity = Identity::from_pem(cert, key);
124 Server::builder()
125 .tls_config(ServerTlsConfig::new().identity(identity))
126 .context("failed to configure gRPC TLS")?
127 }
128 _ => {
129 tracing::warn!(
130 "TLS not configured — use CLAW_TLS_CERT_PATH and CLAW_TLS_KEY_PATH for production."
131 );
132 Server::builder()
133 }
134 };
135
136 builder
137 .layer(ServiceBuilder::new().layer(GrpcRateLimitLayer {
138 state: state.clone(),
139 }))
140 .add_service(reflection)
141 .add_service(ClawDbServiceServer::new(ClawDbServiceImpl::new(state)))
142 .serve_with_incoming_shutdown(
143 tokio_stream::wrappers::TcpListenerStream::new(listener),
144 async move {
145 shutdown.cancelled().await;
146 },
147 )
148 .await
149 .context("gRPC server failed")
150}