use std::sync::Arc;
use tonic::{Request, Response, Status};
use crate::ids::RequestId;
use crate::middleware::{PasetoAuth, TokenValidator};
#[cfg(feature = "jwt")]
use crate::middleware::JwtAuth;
pub fn request_id_interceptor<T>(mut req: Request<T>) -> Result<Request<T>, Status> {
let request_id = req
.metadata()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| RequestId::new().to_string());
req.metadata_mut().insert(
"x-request-id",
request_id
.parse()
.map_err(|_| Status::internal("Failed to parse request ID"))?,
);
req.extensions_mut().insert(RequestIdExtension(request_id));
Ok(req)
}
#[derive(Clone, Debug)]
pub struct RequestIdExtension(pub String);
pub fn add_request_id_to_response<B>(
response: &mut Response<B>,
request_id: &str,
) -> Result<(), Status> {
response.metadata_mut().insert(
"x-request-id",
request_id
.parse()
.map_err(|_| Status::internal("Failed to add request ID to response"))?,
);
Ok(())
}
pub fn token_auth_interceptor<V: TokenValidator + 'static>(
validator: Arc<V>,
) -> impl Fn(Request<()>) -> Result<Request<()>, Status> + Clone {
move |mut req: Request<()>| {
let token = req
.metadata()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.ok_or_else(|| Status::unauthenticated("Missing or invalid authorization token"))?;
let claims = validator
.validate_token(token)
.map_err(|e| Status::unauthenticated(format!("Invalid token: {}", e)))?;
tracing::debug!(
sub = %claims.sub,
roles = ?claims.roles,
"gRPC request authenticated"
);
req.extensions_mut().insert(claims);
Ok(req)
}
}
pub fn paseto_auth_interceptor(
paseto_auth: Arc<PasetoAuth>,
) -> impl Fn(Request<()>) -> Result<Request<()>, Status> + Clone {
token_auth_interceptor(paseto_auth)
}
#[cfg(feature = "jwt")]
pub fn jwt_auth_interceptor(
jwt_auth: Arc<JwtAuth>,
) -> impl Fn(Request<()>) -> Result<Request<()>, Status> + Clone {
token_auth_interceptor(jwt_auth)
}
#[deprecated(
since = "0.2.0",
note = "Use paseto_auth_interceptor or jwt_auth_interceptor instead"
)]
pub fn auth_interceptor<T>(req: Request<T>) -> Result<Request<T>, Status> {
let _token = req
.metadata()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.ok_or_else(|| Status::unauthenticated("Missing authentication token"))?;
Ok(req)
}
pub fn tracing_interceptor<T>(req: Request<T>) -> Result<Request<T>, Status> {
let request_id = req
.metadata()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
tracing::info!(
request_id = %request_id,
"gRPC request received"
);
Ok(req)
}