use crate::error::{ServiceError, ServiceResult};
use axum::body::Body;
use axum::response::{IntoResponse, Response};
use connectrpc::{ConnectError, ErrorCode};
use serde::{Deserialize, Serialize};
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context as TaskContext, Poll},
};
use tower::{Layer, Service};
use super::AuthContext;
pub type ObjectExtractor = Arc<dyn Fn(&http::Request<()>) -> String + Send + Sync>;
#[derive(Debug, Clone)]
pub struct KetoConfig {
pub grpc_endpoint: String,
pub write_grpc_endpoint: String,
}
impl Default for KetoConfig {
fn default() -> Self {
Self {
grpc_endpoint: "http://localhost:4466".to_string(),
write_grpc_endpoint: "http://localhost:4467".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct KetoClient {
config: KetoConfig,
}
impl KetoClient {
pub fn new(config: KetoConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(KetoConfig::default())
}
pub fn config(&self) -> &KetoConfig {
&self.config
}
fn channel(&self) -> tonic::transport::Channel {
tonic::transport::Channel::from_shared(self.config.grpc_endpoint.clone())
.expect("invalid keto grpc_endpoint URI")
.connect_lazy()
}
fn write_channel(&self) -> tonic::transport::Channel {
tonic::transport::Channel::from_shared(self.config.write_grpc_endpoint.clone())
.expect("invalid keto write_grpc_endpoint URI")
.connect_lazy()
}
pub async fn check_permission(
&self,
namespace: &str,
object: &str,
relation: &str,
subject: &str,
) -> ServiceResult<bool> {
use super::keto_proto::{
RelationTuple, check_service_client::CheckServiceClient,
};
let mut client = CheckServiceClient::new(self.channel());
let tuple = RelationTuple {
namespace: namespace.to_string(),
object: object.to_string(),
relation: relation.to_string(),
subject: Some(parse_subject(subject)),
};
let request = tonic::Request::new(super::keto_proto::CheckRequest {
tuple: Some(tuple),
..Default::default()
});
match client.check(request).await {
Ok(resp) => Ok(resp.into_inner().allowed),
Err(status) if status.code() == tonic::Code::PermissionDenied => Ok(false),
Err(status) => Err(ServiceError::Internal(format!(
"keto check_permission failed: {}",
status
))),
}
}
pub async fn grant(
&self,
namespace: &str,
object: &str,
relation: &str,
subject: &str,
) -> ServiceResult<()> {
use super::keto_proto::{
RelationTuple, RelationTupleDelta, TransactRelationTuplesRequest,
relation_tuple_delta::Action, write_service_client::WriteServiceClient,
};
let mut client = WriteServiceClient::new(self.write_channel());
let tuple = RelationTuple {
namespace: namespace.to_string(),
object: object.to_string(),
relation: relation.to_string(),
subject: Some(parse_subject(subject)),
};
let delta = RelationTupleDelta {
action: Action::Insert as i32,
relation_tuple: Some(tuple),
};
let request = tonic::Request::new(TransactRelationTuplesRequest {
relation_tuple_deltas: vec![delta],
});
client
.transact_relation_tuples(request)
.await
.map(|_| ())
.map_err(|status| ServiceError::Internal(format!("keto grant failed: {}", status)))
}
pub async fn list_relation_tuples(
&self,
namespace: &str,
object: Option<&str>,
relation: Option<&str>,
subject: Option<&str>,
) -> ServiceResult<Vec<RelationTuple>> {
use super::keto_proto::{
read_service_client::ReadServiceClient, ListRelationTuplesRequest, RelationQuery,
Subject,
};
let mut client = ReadServiceClient::new(self.channel());
let relation_query = Some(RelationQuery {
namespace: Some(namespace.to_string()),
object: object.map(ToString::to_string),
relation: relation.map(ToString::to_string),
subject: subject.map(|id| Subject {
r#ref: Some(super::keto_proto::subject::Ref::Id(id.to_string())),
}),
});
#[allow(deprecated)]
let request = tonic::Request::new(ListRelationTuplesRequest {
query: None,
relation_query,
snaptoken: String::new(),
page_size: 0,
page_token: String::new(),
});
let response = client
.list_relation_tuples(request)
.await
.map_err(|status| ServiceError::Internal(format!("keto list_relation_tuples failed: {}", status)))?;
Ok(response.into_inner().relation_tuples.into_iter().map(Into::into).collect())
}
pub async fn delete_relation_tuples(
&self,
namespace: &str,
object: Option<&str>,
relation: Option<&str>,
subject: Option<&str>,
) -> ServiceResult<()> {
use super::keto_proto::{
write_service_client::WriteServiceClient, DeleteRelationTuplesRequest, RelationQuery,
Subject,
};
let mut client = WriteServiceClient::new(self.write_channel());
let relation_query = Some(RelationQuery {
namespace: Some(namespace.to_string()),
object: object.map(ToString::to_string),
relation: relation.map(ToString::to_string),
subject: subject.map(|id| Subject {
r#ref: Some(super::keto_proto::subject::Ref::Id(id.to_string())),
}),
});
#[allow(deprecated)]
let request = tonic::Request::new(DeleteRelationTuplesRequest {
query: None,
relation_query,
});
client
.delete_relation_tuples(request)
.await
.map(|_| ())
.map_err(|status| ServiceError::Internal(format!("keto delete_relation_tuples failed: {}", status)))
}
pub async fn get_roles(&self, _subject: &str) -> ServiceResult<Vec<String>> {
Ok(vec!["user".to_string()])
}
pub async fn get_permissions(&self, _subject: &str) -> ServiceResult<Vec<Permission>> {
Ok(vec![])
}
}
fn parse_subject(subject: &str) -> super::keto_proto::Subject {
use super::keto_proto::{SubjectSet, subject::Ref as SubjectRef};
let r#ref = if let Some((namespace, rest)) = subject.split_once(':') {
let (object, relation) = rest.split_once('#').unwrap_or((rest, ""));
SubjectRef::Set(SubjectSet {
namespace: namespace.to_string(),
object: object.to_string(),
relation: relation.to_string(),
})
} else {
SubjectRef::Id(subject.to_string())
};
super::keto_proto::Subject { r#ref: Some(r#ref) }
}
impl Default for KetoClient {
fn default() -> Self {
Self::with_defaults()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Permission {
pub resource: String,
pub action: String,
pub allowed: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Subject {
Id(String),
Set {
namespace: String,
object: String,
relation: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelationTuple {
pub namespace: String,
pub object: String,
pub relation: String,
pub subject: Subject,
}
impl From<super::keto_proto::RelationTuple> for RelationTuple {
fn from(proto: super::keto_proto::RelationTuple) -> Self {
let subject = proto
.subject
.and_then(|s| s.r#ref)
.map_or(Subject::Id(String::new()), |r| match r {
super::keto_proto::subject::Ref::Id(id) => Subject::Id(id),
super::keto_proto::subject::Ref::Set(set) => Subject::Set {
namespace: set.namespace,
object: set.object,
relation: set.relation,
},
});
Self {
namespace: proto.namespace,
object: proto.object,
relation: proto.relation,
subject,
}
}
}
fn unauthorized(message: &str) -> Response {
ConnectError::new(ErrorCode::Unauthenticated, message).into_response()
}
fn forbidden(message: &str) -> Response {
ConnectError::new(ErrorCode::PermissionDenied, message).into_response()
}
fn internal(message: impl std::fmt::Display) -> Response {
ConnectError::new(ErrorCode::Internal, message.to_string()).into_response()
}
#[derive(Clone)]
pub struct KetoLayer {
client: Arc<KetoClient>,
namespace: Arc<String>,
relation: Arc<String>,
skip_paths: Arc<Vec<String>>,
object_extractor: ObjectExtractor,
}
impl std::fmt::Debug for KetoLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KetoLayer")
.field("namespace", &self.namespace)
.field("relation", &self.relation)
.field("skip_paths", &self.skip_paths)
.finish()
}
}
impl KetoLayer {
pub fn new(
client: KetoClient,
namespace: impl Into<String>,
relation: impl Into<String>,
) -> Self {
Self {
client: Arc::new(client),
namespace: Arc::new(namespace.into()),
relation: Arc::new(relation.into()),
skip_paths: Arc::new(vec![]),
object_extractor: Arc::new(|req| req.uri().path().to_string()),
}
}
pub fn from_config(
config: KetoConfig,
namespace: impl Into<String>,
relation: impl Into<String>,
) -> Self {
Self::new(KetoClient::new(config), namespace, relation)
}
pub fn skip_path(mut self, prefix: impl Into<String>) -> Self {
Arc::make_mut(&mut self.skip_paths).push(prefix.into());
self
}
pub fn with_object_extractor(
mut self,
f: ObjectExtractor,
) -> Self {
self.object_extractor = f;
self
}
}
impl<S> Layer<S> for KetoLayer {
type Service = KetoService<S>;
fn layer(&self, inner: S) -> Self::Service {
KetoService {
inner,
client: Arc::clone(&self.client),
namespace: Arc::clone(&self.namespace),
relation: Arc::clone(&self.relation),
skip_paths: Arc::clone(&self.skip_paths),
object_extractor: Arc::clone(&self.object_extractor),
}
}
}
#[derive(Clone)]
pub struct KetoService<S> {
inner: S,
client: Arc<KetoClient>,
namespace: Arc<String>,
relation: Arc<String>,
skip_paths: Arc<Vec<String>>,
object_extractor: ObjectExtractor,
}
impl<S> Service<http::Request<Body>> for KetoService<S>
where
S: Service<http::Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<Body>) -> Self::Future {
let path = req.uri().path().to_string();
for prefix in self.skip_paths.iter() {
if path.starts_with(prefix.as_str()) {
return Box::pin(self.inner.call(req));
}
}
let auth_ctx = req.extensions().get::<AuthContext>().cloned();
let subject = match auth_ctx {
None
| Some(AuthContext {
is_authenticated: false,
..
}) => {
let resp = unauthorized("unauthenticated");
return Box::pin(async move { Ok(resp) });
}
Some(ctx) => ctx.subject.unwrap_or_default(),
};
let (parts, body) = req.into_parts();
let unit_req = http::Request::from_parts(parts.clone(), ());
let object = (self.object_extractor)(&unit_req);
let req = http::Request::from_parts(parts, body);
let client = Arc::clone(&self.client);
let namespace = Arc::clone(&self.namespace);
let relation = Arc::clone(&self.relation);
let mut inner = self.inner.clone();
Box::pin(async move {
match client
.check_permission(&namespace, &object, &relation, &subject)
.await
{
Ok(true) => inner.call(req).await,
Ok(false) => Ok(forbidden("permission denied")),
Err(e) => Ok(internal(e)),
}
})
}
}
pub fn extract_subject(headers: &http::HeaderMap) -> Option<String> {
headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer ").map(str::to_string))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keto_config_default() {
let config = KetoConfig::default();
assert_eq!(config.grpc_endpoint, "http://localhost:4466");
assert_eq!(config.write_grpc_endpoint, "http://localhost:4467");
}
#[test]
fn test_keto_client_new() {
let config = KetoConfig {
grpc_endpoint: "http://keto:4466".to_string(),
write_grpc_endpoint: "http://keto:4467".to_string(),
};
let client = KetoClient::new(config);
assert_eq!(client.config.grpc_endpoint, "http://keto:4466");
assert_eq!(client.config.write_grpc_endpoint, "http://keto:4467");
}
#[test]
fn test_keto_client_with_defaults() {
let client = KetoClient::with_defaults();
assert_eq!(client.config().grpc_endpoint, "http://localhost:4466");
}
#[test]
fn test_keto_layer_skip_path_builder() {
let layer = KetoLayer::new(KetoClient::with_defaults(), "ns", "read")
.skip_path("/health")
.skip_path("/metrics");
assert_eq!(layer.skip_paths.len(), 2);
assert_eq!(layer.skip_paths[0], "/health");
assert_eq!(layer.skip_paths[1], "/metrics");
}
#[test]
fn test_keto_layer_namespace_relation() {
let layer = KetoLayer::new(KetoClient::with_defaults(), "docs", "write");
assert_eq!(layer.namespace.as_str(), "docs");
assert_eq!(layer.relation.as_str(), "write");
}
#[test]
fn test_extract_subject_bearer() {
let mut headers = http::HeaderMap::new();
headers.insert(
"Authorization",
http::HeaderValue::from_static("Bearer mytoken"),
);
assert_eq!(extract_subject(&headers), Some("mytoken".to_string()));
}
#[test]
fn test_extract_subject_non_bearer() {
let mut headers = http::HeaderMap::new();
headers.insert(
"Authorization",
http::HeaderValue::from_static("Basic credentials"),
);
assert_eq!(extract_subject(&headers), None);
}
#[test]
fn test_extract_subject_missing() {
let headers = http::HeaderMap::new();
assert_eq!(extract_subject(&headers), None);
}
use axum::body::Body;
use axum::response::{IntoResponse, Response};
use tower::{ServiceBuilder, ServiceExt};
fn ok_service() -> impl Service<
http::Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Future<Output = Result<Response, std::convert::Infallible>>,
> + Clone {
tower::service_fn(|_req: http::Request<Body>| async {
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(http::StatusCode::OK)
.body(Body::empty())
.unwrap()
.into_response(),
)
})
}
#[tokio::test]
async fn test_keto_layer_skip_path_forwards() {
let layer = KetoLayer::new(
KetoClient::new(KetoConfig {
grpc_endpoint: "http://127.0.0.1:1".to_string(), write_grpc_endpoint: "http://127.0.0.1:1".to_string(),
}),
"ns",
"read",
)
.skip_path("/health");
let mut svc = ServiceBuilder::new().layer(layer).service(ok_service());
let req = http::Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::OK);
}
#[tokio::test]
async fn test_keto_layer_missing_auth_context_returns_401() {
let layer = KetoLayer::new(KetoClient::with_defaults(), "ns", "read");
let mut svc = ServiceBuilder::new().layer(layer).service(ok_service());
let req = http::Request::builder()
.uri("/protected")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_keto_layer_unauthenticated_context_returns_401() {
let layer = KetoLayer::new(KetoClient::with_defaults(), "ns", "read");
let mut svc = ServiceBuilder::new().layer(layer).service(ok_service());
let mut req = http::Request::builder()
.uri("/protected")
.body(Body::empty())
.unwrap();
req.extensions_mut().insert(AuthContext::unauthenticated());
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);
}
#[test]
fn test_parse_subject_plain_id() {
use crate::middleware::auth::keto_proto::subject::Ref as SubjectRef;
let subject = parse_subject("user-123");
let r#ref = subject.r#ref.expect("subject ref is present");
assert!(matches!(r#ref, SubjectRef::Id(id) if id == "user-123"));
}
#[test]
fn test_parse_subject_set_without_relation() {
use crate::middleware::auth::keto_proto::subject::Ref as SubjectRef;
let subject = parse_subject("groups:admins");
let r#ref = subject.r#ref.expect("subject ref is present");
match r#ref {
SubjectRef::Set(set) => {
assert_eq!(set.namespace, "groups");
assert_eq!(set.object, "admins");
assert_eq!(set.relation, "");
}
_ => panic!("expected SubjectSet"),
}
}
#[test]
fn test_parse_subject_set_with_relation() {
use crate::middleware::auth::keto_proto::subject::Ref as SubjectRef;
let subject = parse_subject("groups:admins#members");
let r#ref = subject.r#ref.expect("subject ref is present");
match r#ref {
SubjectRef::Set(set) => {
assert_eq!(set.namespace, "groups");
assert_eq!(set.object, "admins");
assert_eq!(set.relation, "members");
}
_ => panic!("expected SubjectSet"),
}
}
}