use crate::error::{ServiceError, ServiceResult};
use axum::body::Body;
use axum::response::{IntoResponse, Response};
use connectrpc::{ConnectError, ErrorCode};
use serde::{Deserialize, Serialize};
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context as TaskContext, Poll},
};
use tower::{Layer, Service};
use super::AuthContext;
#[derive(Debug, Clone)]
pub struct KetoConfig {
pub grpc_endpoint: String,
pub write_grpc_endpoint: String,
}
impl Default for KetoConfig {
fn default() -> Self {
Self {
grpc_endpoint: "http://localhost:4466".to_string(),
write_grpc_endpoint: "http://localhost:4467".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct KetoClient {
config: KetoConfig,
}
impl KetoClient {
pub fn new(config: KetoConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(KetoConfig::default())
}
pub fn config(&self) -> &KetoConfig {
&self.config
}
fn channel(&self) -> tonic::transport::Channel {
tonic::transport::Channel::from_shared(self.config.grpc_endpoint.clone())
.expect("invalid keto grpc_endpoint URI")
.connect_lazy()
}
fn write_channel(&self) -> tonic::transport::Channel {
tonic::transport::Channel::from_shared(self.config.write_grpc_endpoint.clone())
.expect("invalid keto write_grpc_endpoint URI")
.connect_lazy()
}
pub async fn check_permission(
&self,
namespace: &str,
object: &str,
relation: &str,
subject: &str,
) -> ServiceResult<bool> {
use super::keto_proto::{
RelationTuple, Subject, check_service_client::CheckServiceClient,
subject::Ref as SubjectRef,
};
let mut client = CheckServiceClient::new(self.channel());
let tuple = RelationTuple {
namespace: namespace.to_string(),
object: object.to_string(),
relation: relation.to_string(),
subject: Some(Subject {
r#ref: Some(SubjectRef::Id(subject.to_string())),
}),
};
#[allow(deprecated)]
let request = tonic::Request::new(super::keto_proto::CheckRequest {
namespace: String::new(),
object: String::new(),
relation: String::new(),
subject: None,
tuple: Some(tuple),
latest: false,
snaptoken: String::new(),
max_depth: 0,
});
match client.check(request).await {
Ok(resp) => Ok(resp.into_inner().allowed),
Err(status) if status.code() == tonic::Code::PermissionDenied => Ok(false),
Err(status) => Err(ServiceError::Internal(format!(
"keto check_permission failed: {}",
status
))),
}
}
pub async fn grant(
&self,
namespace: &str,
object: &str,
relation: &str,
subject: &str,
) -> ServiceResult<()> {
use super::keto_proto::{
RelationTuple, RelationTupleDelta, Subject, TransactRelationTuplesRequest,
relation_tuple_delta::Action, subject::Ref as SubjectRef,
write_service_client::WriteServiceClient,
};
let mut client = WriteServiceClient::new(self.write_channel());
let tuple = RelationTuple {
namespace: namespace.to_string(),
object: object.to_string(),
relation: relation.to_string(),
subject: Some(Subject {
r#ref: Some(SubjectRef::Id(subject.to_string())),
}),
};
let delta = RelationTupleDelta {
action: Action::Insert as i32,
relation_tuple: Some(tuple),
};
let request = tonic::Request::new(TransactRelationTuplesRequest {
relation_tuple_deltas: vec![delta],
});
client
.transact_relation_tuples(request)
.await
.map(|_| ())
.map_err(|status| ServiceError::Internal(format!("keto grant failed: {}", status)))
}
pub async fn get_roles(&self, _subject: &str) -> ServiceResult<Vec<String>> {
Ok(vec!["user".to_string()])
}
pub async fn get_permissions(&self, _subject: &str) -> ServiceResult<Vec<Permission>> {
Ok(vec![])
}
}
impl Default for KetoClient {
fn default() -> Self {
Self::with_defaults()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Permission {
pub resource: String,
pub action: String,
pub allowed: bool,
}
fn unauthorized(message: &str) -> Response {
ConnectError::new(ErrorCode::Unauthenticated, message).into_response()
}
fn forbidden(message: &str) -> Response {
ConnectError::new(ErrorCode::PermissionDenied, message).into_response()
}
fn internal(message: impl std::fmt::Display) -> Response {
ConnectError::new(ErrorCode::Internal, message.to_string()).into_response()
}
#[derive(Clone)]
pub struct KetoLayer {
client: Arc<KetoClient>,
namespace: Arc<String>,
relation: Arc<String>,
skip_paths: Arc<Vec<String>>,
object_extractor: Arc<dyn Fn(&http::Request<()>) -> String + Send + Sync>,
}
impl std::fmt::Debug for KetoLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KetoLayer")
.field("namespace", &self.namespace)
.field("relation", &self.relation)
.field("skip_paths", &self.skip_paths)
.finish()
}
}
impl KetoLayer {
pub fn new(
client: KetoClient,
namespace: impl Into<String>,
relation: impl Into<String>,
) -> Self {
Self {
client: Arc::new(client),
namespace: Arc::new(namespace.into()),
relation: Arc::new(relation.into()),
skip_paths: Arc::new(vec![]),
object_extractor: Arc::new(|req| req.uri().path().to_string()),
}
}
pub fn from_config(
config: KetoConfig,
namespace: impl Into<String>,
relation: impl Into<String>,
) -> Self {
Self::new(KetoClient::new(config), namespace, relation)
}
pub fn skip_path(mut self, prefix: impl Into<String>) -> Self {
Arc::make_mut(&mut self.skip_paths).push(prefix.into());
self
}
pub fn with_object_extractor(
mut self,
f: Arc<dyn Fn(&http::Request<()>) -> String + Send + Sync>,
) -> Self {
self.object_extractor = f;
self
}
}
impl<S> Layer<S> for KetoLayer {
type Service = KetoService<S>;
fn layer(&self, inner: S) -> Self::Service {
KetoService {
inner,
client: Arc::clone(&self.client),
namespace: Arc::clone(&self.namespace),
relation: Arc::clone(&self.relation),
skip_paths: Arc::clone(&self.skip_paths),
object_extractor: Arc::clone(&self.object_extractor),
}
}
}
#[derive(Clone)]
pub struct KetoService<S> {
inner: S,
client: Arc<KetoClient>,
namespace: Arc<String>,
relation: Arc<String>,
skip_paths: Arc<Vec<String>>,
object_extractor: Arc<dyn Fn(&http::Request<()>) -> String + Send + Sync>,
}
impl<S> Service<http::Request<Body>> for KetoService<S>
where
S: Service<http::Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<Body>) -> Self::Future {
let path = req.uri().path().to_string();
for prefix in self.skip_paths.iter() {
if path.starts_with(prefix.as_str()) {
let fut = self.inner.call(req);
return Box::pin(async move { fut.await });
}
}
let auth_ctx = req.extensions().get::<AuthContext>().cloned();
let subject = match auth_ctx {
None
| Some(AuthContext {
is_authenticated: false,
..
}) => {
let resp = unauthorized("unauthenticated");
return Box::pin(async move { Ok(resp) });
}
Some(ctx) => ctx.subject.unwrap_or_default(),
};
let (parts, body) = req.into_parts();
let unit_req = http::Request::from_parts(parts.clone(), ());
let object = (self.object_extractor)(&unit_req);
let req = http::Request::from_parts(parts, body);
let client = Arc::clone(&self.client);
let namespace = Arc::clone(&self.namespace);
let relation = Arc::clone(&self.relation);
let mut inner = self.inner.clone();
Box::pin(async move {
match client
.check_permission(&namespace, &object, &relation, &subject)
.await
{
Ok(true) => inner.call(req).await,
Ok(false) => Ok(forbidden("permission denied")),
Err(e) => Ok(internal(e)),
}
})
}
}
pub fn extract_subject(headers: &http::HeaderMap) -> Option<String> {
headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| {
if v.starts_with("Bearer ") {
Some(v[7..].to_string())
} else {
None
}
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keto_config_default() {
let config = KetoConfig::default();
assert_eq!(config.grpc_endpoint, "http://localhost:4466");
assert_eq!(config.write_grpc_endpoint, "http://localhost:4467");
}
#[test]
fn test_keto_client_new() {
let config = KetoConfig {
grpc_endpoint: "http://keto:4466".to_string(),
write_grpc_endpoint: "http://keto:4467".to_string(),
};
let client = KetoClient::new(config);
assert_eq!(client.config.grpc_endpoint, "http://keto:4466");
assert_eq!(client.config.write_grpc_endpoint, "http://keto:4467");
}
#[test]
fn test_keto_client_with_defaults() {
let client = KetoClient::with_defaults();
assert_eq!(client.config().grpc_endpoint, "http://localhost:4466");
}
#[test]
fn test_keto_layer_skip_path_builder() {
let layer = KetoLayer::new(KetoClient::with_defaults(), "ns", "read")
.skip_path("/health")
.skip_path("/metrics");
assert_eq!(layer.skip_paths.len(), 2);
assert_eq!(layer.skip_paths[0], "/health");
assert_eq!(layer.skip_paths[1], "/metrics");
}
#[test]
fn test_keto_layer_namespace_relation() {
let layer = KetoLayer::new(KetoClient::with_defaults(), "docs", "write");
assert_eq!(layer.namespace.as_str(), "docs");
assert_eq!(layer.relation.as_str(), "write");
}
#[test]
fn test_extract_subject_bearer() {
let mut headers = http::HeaderMap::new();
headers.insert(
"Authorization",
http::HeaderValue::from_static("Bearer mytoken"),
);
assert_eq!(extract_subject(&headers), Some("mytoken".to_string()));
}
#[test]
fn test_extract_subject_non_bearer() {
let mut headers = http::HeaderMap::new();
headers.insert(
"Authorization",
http::HeaderValue::from_static("Basic credentials"),
);
assert_eq!(extract_subject(&headers), None);
}
#[test]
fn test_extract_subject_missing() {
let headers = http::HeaderMap::new();
assert_eq!(extract_subject(&headers), None);
}
use axum::body::Body;
use axum::response::{IntoResponse, Response};
use tower::{ServiceBuilder, ServiceExt};
fn ok_service() -> impl Service<
http::Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Future<Output = Result<Response, std::convert::Infallible>>,
> + Clone {
tower::service_fn(|_req: http::Request<Body>| async {
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(http::StatusCode::OK)
.body(Body::empty())
.unwrap()
.into_response(),
)
})
}
#[tokio::test]
async fn test_keto_layer_skip_path_forwards() {
let layer = KetoLayer::new(
KetoClient::new(KetoConfig {
grpc_endpoint: "http://127.0.0.1:1".to_string(), write_grpc_endpoint: "http://127.0.0.1:1".to_string(),
}),
"ns",
"read",
)
.skip_path("/health");
let mut svc = ServiceBuilder::new().layer(layer).service(ok_service());
let req = http::Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::OK);
}
#[tokio::test]
async fn test_keto_layer_missing_auth_context_returns_401() {
let layer = KetoLayer::new(KetoClient::with_defaults(), "ns", "read");
let mut svc = ServiceBuilder::new().layer(layer).service(ok_service());
let req = http::Request::builder()
.uri("/protected")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_keto_layer_unauthenticated_context_returns_401() {
let layer = KetoLayer::new(KetoClient::with_defaults(), "ns", "read");
let mut svc = ServiceBuilder::new().layer(layer).service(ok_service());
let mut req = http::Request::builder()
.uri("/protected")
.body(Body::empty())
.unwrap();
req.extensions_mut().insert(AuthContext::unauthenticated());
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);
}
}