pub mod handler;
pub mod streaming;
use std::{convert::Infallible, sync::Arc};
use fraiseql_core::{
db::traits::DatabaseAdapter,
schema::CompiledSchema,
security::{OidcValidator, SecurityContext},
};
use fraiseql_error::FraiseQLError;
use handler::{RpcDispatchTable, build_dispatch_table};
use prost::Message as _;
use prost_reflect::DescriptorPool;
use tonic::{body::Body as TonicBody, server::NamedService};
use tracing::{Instrument as _, debug, info, info_span, warn};
use crate::middleware::RateLimiter;
pub struct GrpcServices<A: DatabaseAdapter> {
pub service: DynamicGrpcService<A>,
pub reflection_descriptor_bytes: Option<Vec<u8>>,
pub service_name: String,
}
pub struct DynamicGrpcService<A: DatabaseAdapter> {
adapter: Arc<A>,
schema: Arc<CompiledSchema>,
dispatch: Arc<RpcDispatchTable>,
pool: Arc<DescriptorPool>,
service_name: Arc<str>,
oidc_validator: Option<Arc<OidcValidator>>,
rate_limiter: Option<Arc<RateLimiter>>,
}
impl<A: DatabaseAdapter> Clone for DynamicGrpcService<A> {
fn clone(&self) -> Self {
Self {
adapter: Arc::clone(&self.adapter),
schema: Arc::clone(&self.schema),
dispatch: Arc::clone(&self.dispatch),
pool: Arc::clone(&self.pool),
service_name: Arc::clone(&self.service_name),
oidc_validator: self.oidc_validator.as_ref().map(Arc::clone),
rate_limiter: self.rate_limiter.as_ref().map(Arc::clone),
}
}
}
impl<A: DatabaseAdapter> NamedService for DynamicGrpcService<A> {
const NAME: &'static str = "fraiseql.v1.FraiseQLService";
}
impl<A: DatabaseAdapter + Clone + Send + Sync + 'static> DynamicGrpcService<A> {
async fn handle_request(
&self,
method: &str,
req: http::Request<TonicBody>,
) -> http::Response<TonicBody> {
use http_body_util::BodyExt as _;
let Some(op) = self.dispatch.get(method) else {
return grpc_error_response(
tonic::Code::Unimplemented,
&format!("Method not found: {method}"),
);
};
let auth_header = req
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.map(String::from);
let request_id = req
.headers()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("grpc")
.to_string();
let client_ip = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
.map(|s| s.trim().to_string())
.or_else(|| {
req.headers().get("x-real-ip").and_then(|v| v.to_str().ok()).map(String::from)
})
.unwrap_or_else(|| "unknown".to_string());
let security_context: Option<SecurityContext> =
match self.authenticate(auth_header, request_id).await {
Ok(ctx) => ctx,
Err(resp) => return resp,
};
if let Some(ref ctx) = security_context {
tracing::Span::current().record("user_id", ctx.user_id.as_str());
}
if let Some(ref limiter) = self.rate_limiter {
let result = if let Some(ref ctx) = security_context {
limiter.check_user_limit(&ctx.user_id).await
} else {
limiter.check_ip_limit(&client_ip).await
};
if !result.allowed {
let user_id = security_context.as_ref().map(|c| c.user_id.as_str());
warn!(
ip = %client_ip,
user_id = ?user_id,
retry_after_secs = result.retry_after_secs,
method = %method,
"gRPC rate limit exceeded"
);
return grpc_error_response(tonic::Code::ResourceExhausted, "Rate limit exceeded");
}
}
let body_bytes: bytes::Bytes = match req.into_body().collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
return grpc_error_response(
tonic::Code::Internal,
&format!("Failed to read request body: {e}"),
);
},
};
if body_bytes.len() < 5 {
return grpc_error_response(tonic::Code::InvalidArgument, "Request body too short");
}
let msg_bytes = &body_bytes[5..];
let Some(service_desc) = self.pool.get_service_by_name(&self.service_name) else {
return grpc_error_response(tonic::Code::Internal, "Service descriptor not found");
};
let method_name = method.rsplit('/').next().unwrap_or(method);
let Some(method_desc) = service_desc.methods().find(|m| m.name() == method_name) else {
return grpc_error_response(
tonic::Code::Unimplemented,
&format!("Method not found: {method_name}"),
);
};
let request_desc = method_desc.input();
let request_msg = match prost_reflect::DynamicMessage::decode(request_desc, msg_bytes) {
Ok(m) => m,
Err(e) => {
return grpc_error_response(
tonic::Code::InvalidArgument,
&format!("Failed to decode request: {e}"),
);
},
};
if let handler::RpcKind::ServerStream {
view_name,
columns,
row_descriptor,
} = &op.kind
{
let Some(type_def) = self.schema.find_type(&op.type_name) else {
return grpc_error_response(
tonic::Code::Internal,
&format!("Type '{}' not found in schema", op.type_name),
);
};
let batch_size = self.schema.grpc_config.as_ref().map_or(500, |c| c.stream_batch_size);
debug!(method = %method, batch_size, "Starting gRPC server-streaming response");
let body_stream = streaming::build_streaming_body(
Arc::clone(&self.adapter),
view_name.clone(),
columns.clone(),
row_descriptor.clone(),
type_def,
&request_msg,
security_context.as_ref(),
batch_size,
);
let body = http_body_util::StreamBody::new(body_stream);
let mut response = http::Response::new(TonicBody::new(body));
response
.headers_mut()
.insert("content-type", http::HeaderValue::from_static("application/grpc"));
return response;
}
let response_msg = match &op.kind {
handler::RpcKind::Query {
view_name,
returns_list,
columns,
row_descriptor,
} => {
let Some(type_def) = self.schema.find_type(&op.type_name) else {
return grpc_error_response(
tonic::Code::Internal,
&format!("Type '{}' not found in schema", op.type_name),
);
};
let rows = match handler::execute_grpc_query(
self.adapter.as_ref(),
view_name,
columns,
*returns_list,
&request_msg,
type_def,
security_context.as_ref(),
)
.await
{
Ok(rows) => rows,
Err(FraiseQLError::Validation { message, .. }) => {
return grpc_error_response(tonic::Code::InvalidArgument, &message);
},
Err(FraiseQLError::Unsupported { message }) => {
return grpc_error_response(tonic::Code::Unimplemented, &message);
},
Err(e) => return grpc_error_response(tonic::Code::Internal, &e.to_string()),
};
debug!(method = %method, row_count = rows.len(), "gRPC query returned results");
handler::encode_response(
rows,
columns,
*returns_list,
row_descriptor,
&op.response_descriptor,
)
},
handler::RpcKind::ServerStream { .. } => {
unreachable!("ServerStream handled above");
},
handler::RpcKind::Mutation { function_name } => {
let result = match handler::execute_grpc_mutation(
self.adapter.as_ref(),
function_name,
&request_msg,
)
.await
{
Ok(r) => r,
Err(FraiseQLError::Validation { message, .. }) => {
return grpc_error_response(tonic::Code::InvalidArgument, &message);
},
Err(FraiseQLError::Unsupported { message }) => {
return grpc_error_response(tonic::Code::Unimplemented, &message);
},
Err(e) => return grpc_error_response(tonic::Code::Internal, &e.to_string()),
};
debug!(method = %method, success = result.success, "gRPC mutation completed");
handler::encode_mutation_response(&result, &op.response_descriptor)
},
};
let response_bytes = response_msg.encode_to_vec();
let mut framed = Vec::with_capacity(5 + response_bytes.len());
framed.push(0); framed.extend_from_slice(
&(u32::try_from(response_bytes.len()).unwrap_or(u32::MAX)).to_be_bytes(),
);
framed.extend_from_slice(&response_bytes);
let mut response = http::Response::new(TonicBody::new(axum::body::Body::from(framed)));
response
.headers_mut()
.insert("content-type", http::HeaderValue::from_static("application/grpc"));
response
.headers_mut()
.insert("grpc-status", http::HeaderValue::from_static("0"));
response
}
}
impl<A: DatabaseAdapter + Clone + Send + Sync + 'static> DynamicGrpcService<A> {
async fn authenticate(
&self,
auth_header: Option<String>,
request_id: String,
) -> std::result::Result<Option<SecurityContext>, http::Response<TonicBody>> {
let Some(validator) = self.oidc_validator.as_ref() else {
return Ok(None); };
let token = match auth_header.as_deref() {
Some(h) if h.starts_with("Bearer ") => h[7..].to_string(),
Some(_) => {
debug!("gRPC request has invalid Authorization header format");
return Err(grpc_error_response(
tonic::Code::Unauthenticated,
"Invalid Authorization header format",
));
},
None => {
if validator.is_required() {
debug!("gRPC request missing required Authorization header");
return Err(grpc_error_response(
tonic::Code::Unauthenticated,
"Authentication required",
));
}
return Ok(None);
},
};
match validator.validate_token(&token).await {
Ok(user) => {
debug!(user_id = %user.user_id, "gRPC user authenticated");
Ok(Some(SecurityContext::from_user(&user, request_id)))
},
Err(e) => {
warn!(error = %e, "gRPC token validation failed");
Err(grpc_error_response(tonic::Code::Unauthenticated, "Invalid or expired token"))
},
}
}
}
fn grpc_error_response(code: tonic::Code, message: &str) -> http::Response<TonicBody> {
let mut response = http::Response::new(TonicBody::empty());
response
.headers_mut()
.insert("content-type", http::HeaderValue::from_static("application/grpc"));
response
.headers_mut()
.insert("grpc-status", http::HeaderValue::from(code as i32));
if let Ok(msg) = http::HeaderValue::from_str(message) {
response.headers_mut().insert("grpc-message", msg);
}
response
}
impl<A: DatabaseAdapter + Clone + Send + Sync + 'static> tower::Service<http::Request<TonicBody>>
for DynamicGrpcService<A>
{
type Error = Infallible;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
type Response = http::Response<TonicBody>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: http::Request<TonicBody>) -> Self::Future {
let svc = self.clone();
let method = req.uri().path().to_string();
Box::pin(async move {
let span = info_span!(
"grpc_request",
method = %method,
grpc.status = tracing::field::Empty,
user_id = tracing::field::Empty,
);
let response = svc.handle_request(&method, req).instrument(span.clone()).await;
let grpc_status = response
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
span.record("grpc.status", grpc_status);
Ok(response)
})
}
}
pub fn build_grpc_service<A: DatabaseAdapter + Clone + Send + Sync + 'static>(
schema: Arc<CompiledSchema>,
adapter: Arc<A>,
oidc_validator: Option<Arc<OidcValidator>>,
rate_limiter: Option<Arc<RateLimiter>>,
) -> Result<Option<GrpcServices<A>>, FraiseQLError> {
let grpc_config = match schema.grpc_config.as_ref() {
Some(cfg) if cfg.enabled => cfg,
_ => return Ok(None),
};
let descriptor_path = &grpc_config.descriptor_path;
let descriptor_bytes = std::fs::read(descriptor_path).map_err(|e| {
FraiseQLError::validation(format!(
"Failed to read gRPC descriptor file '{descriptor_path}': {e}"
))
})?;
let pool = DescriptorPool::decode(descriptor_bytes.as_slice()).map_err(|e| {
FraiseQLError::validation(format!(
"Failed to decode gRPC descriptor file '{descriptor_path}': {e}"
))
})?;
let service_name =
pool.services().next().map(|s| s.full_name().to_string()).ok_or_else(|| {
FraiseQLError::validation("No gRPC service found in descriptor pool".to_string())
})?;
info!(
service = %service_name,
descriptor_path = %descriptor_path,
"Building gRPC dispatch table"
);
let dispatch = build_dispatch_table(&schema, &service_name, &pool)?;
info!(
service = %service_name,
rpc_count = dispatch.len(),
"gRPC dispatch table built"
);
for (method, op) in &dispatch {
match &op.kind {
handler::RpcKind::Query {
view_name,
columns,
returns_list,
..
} => {
debug!(
method = %method,
view = %view_name,
columns = columns.len(),
list = returns_list,
"Registered gRPC query RPC"
);
},
handler::RpcKind::ServerStream {
view_name, columns, ..
} => {
debug!(
method = %method,
view = %view_name,
columns = columns.len(),
"Registered gRPC server-streaming RPC"
);
},
handler::RpcKind::Mutation { function_name } => {
debug!(
method = %method,
function = %function_name,
"Registered gRPC mutation RPC"
);
},
}
}
if oidc_validator.is_some() {
info!("gRPC transport: OIDC authentication enabled");
}
if rate_limiter.is_some() {
info!("gRPC transport: rate limiting enabled");
}
let reflection_descriptor_bytes = if grpc_config.reflection {
info!("gRPC server reflection enabled");
Some(descriptor_bytes)
} else {
None
};
let service = DynamicGrpcService {
adapter,
schema,
dispatch: Arc::new(dispatch),
pool: Arc::new(pool),
service_name: service_name.clone().into(),
oidc_validator,
rate_limiter,
};
Ok(Some(GrpcServices {
service,
reflection_descriptor_bytes,
service_name,
}))
}