Skip to main content

clawdb_server/grpc/
mod.rs

1pub mod service;
2
3use std::{
4    future::Future,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use anyhow::{Context as _, Result};
11use http::{Request, Response, StatusCode};
12use tokio::net::TcpListener;
13use tokio_util::sync::CancellationToken;
14use tonic::{
15    body::empty_body,
16    transport::{Identity, Server, ServerTlsConfig},
17};
18use tower::{Layer, Service, ServiceBuilder};
19
20use crate::{
21    grpc::service::{
22        proto::{claw_db_service_server::ClawDbServiceServer, FILE_DESCRIPTOR_SET},
23        ClawDbServiceImpl,
24    },
25    state::AppState,
26};
27
28#[derive(Clone)]
29struct GrpcRateLimitLayer {
30    state: Arc<AppState>,
31}
32
33impl<S> Layer<S> for GrpcRateLimitLayer {
34    type Service = GrpcRateLimitService<S>;
35
36    fn layer(&self, inner: S) -> Self::Service {
37        GrpcRateLimitService {
38            inner,
39            state: self.state.clone(),
40        }
41    }
42}
43
44#[derive(Clone)]
45struct GrpcRateLimitService<S> {
46    inner: S,
47    state: Arc<AppState>,
48}
49
50impl<S, B> Service<Request<B>> for GrpcRateLimitService<S>
51where
52    S: Service<Request<B>, Response = Response<tonic::body::BoxBody>> + Send + Clone + 'static,
53    S::Future: Send + 'static,
54    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
55    B: Send + 'static,
56{
57    type Response = Response<tonic::body::BoxBody>;
58    type Error = S::Error;
59    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
60
61    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62        self.inner.poll_ready(cx)
63    }
64
65    fn call(&mut self, request: Request<B>) -> Self::Future {
66        let token = request
67            .headers()
68            .get("x-claw-session")
69            .and_then(|value| value.to_str().ok())
70            .unwrap_or("anonymous")
71            .to_string();
72
73        if let Err(not_until) = self.state.grpc_limiter.check_key(&token) {
74            let mut response = Response::new(empty_body());
75            *response.status_mut() = StatusCode::OK;
76            response.headers_mut().insert(
77                http::header::CONTENT_TYPE,
78                http::HeaderValue::from_static("application/grpc"),
79            );
80            response.headers_mut().insert(
81                http::HeaderName::from_static("grpc-status"),
82                http::HeaderValue::from_static("8"),
83            );
84            response.headers_mut().insert(
85                http::HeaderName::from_static("grpc-message"),
86                http::HeaderValue::from_static("rate limit exceeded"),
87            );
88            if let Ok(value) =
89                http::HeaderValue::from_str(&AppState::retry_after_seconds(&not_until).to_string())
90            {
91                response
92                    .headers_mut()
93                    .insert(http::HeaderName::from_static("retry-delay"), value);
94            }
95            return Box::pin(async move { Ok(response) });
96        }
97
98        Box::pin(self.inner.call(request))
99    }
100}
101
102pub async fn serve(
103    listener: TcpListener,
104    state: Arc<AppState>,
105    shutdown: CancellationToken,
106) -> Result<()> {
107    let reflection = tonic_reflection::server::Builder::configure()
108        .register_encoded_file_descriptor_set(FILE_DESCRIPTOR_SET)
109        .build_v1()
110        .context("failed to build gRPC reflection service")?;
111
112    let cert_path = std::env::var("CLAW_TLS_CERT_PATH").ok();
113    let key_path = std::env::var("CLAW_TLS_KEY_PATH").ok();
114
115    let builder = match (cert_path, key_path) {
116        (Some(cert_path), Some(key_path)) => {
117            let cert = tokio::fs::read(cert_path)
118                .await
119                .context("failed to read TLS certificate")?;
120            let key = tokio::fs::read(key_path)
121                .await
122                .context("failed to read TLS key")?;
123            let identity = Identity::from_pem(cert, key);
124            Server::builder()
125                .tls_config(ServerTlsConfig::new().identity(identity))
126                .context("failed to configure gRPC TLS")?
127        }
128        _ => {
129            tracing::warn!(
130                "TLS not configured — use CLAW_TLS_CERT_PATH and CLAW_TLS_KEY_PATH for production."
131            );
132            Server::builder()
133        }
134    };
135
136    builder
137        .layer(ServiceBuilder::new().layer(GrpcRateLimitLayer {
138            state: state.clone(),
139        }))
140        .add_service(reflection)
141        .add_service(ClawDbServiceServer::new(ClawDbServiceImpl::new(state)))
142        .serve_with_incoming_shutdown(
143            tokio_stream::wrappers::TcpListenerStream::new(listener),
144            async move {
145                shutdown.cancelled().await;
146            },
147        )
148        .await
149        .context("gRPC server failed")
150}