#![allow(unused_imports)]
#[cfg(feature = "auth-middle")]
use tonic::service::interceptor::InterceptedService;
use tonic::{
codegen::{Body, Bytes, StdError},
service::interceptor::InterceptorLayer,
transport::{Channel, Endpoint},
};
#[cfg(feature = "auth-middle")]
use tower::{ServiceBuilder, util::Either};
use crate::{
client::{OpenFgaClient, OpenFgaServiceClient},
error::{Error, Result},
generated::{
ConsistencyPreference, CreateStoreRequest, ListStoresRequest, ReadRequest,
ReadRequestTupleKey, Store, Tuple,
},
};
#[cfg(feature = "auth-middle")]
pub type BasicOpenFgaServiceClient = OpenFgaServiceClient<BasicAuthLayer>;
#[cfg(feature = "auth-middle")]
impl BasicOpenFgaServiceClient {
pub fn new_unauthenticated(endpoint: impl Into<url::Url>) -> Result<Self> {
let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
let channel = endpoint.connect_lazy();
let intercepted = InterceptedService::new(channel, NoOpInterceptor);
let service = Either::Right(intercepted);
Ok(BasicOpenFgaServiceClient::new(service))
}
pub fn new_with_basic_auth(endpoint: impl Into<url::Url>, token: &str) -> Result<Self> {
let authorizer = middle::BearerTokenAuthorizer::new(token).map_err(|e| {
tracing::error!("Could not construct OpenFGA client. Invalid token: {e}");
Error::InvalidToken {
reason: e.to_string(),
}
})?;
let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
let channel = endpoint.connect_lazy();
let intercepted = InterceptedService::new(channel, authorizer);
let service = Either::Left(Either::Right(intercepted));
Ok(BasicOpenFgaServiceClient::new(service))
}
pub async fn new_with_client_credentials(
endpoint: impl Into<url::Url>,
client_id: &str,
client_secret: &str,
token_endpoint: impl Into<url::Url>,
scopes: &[&str],
) -> Result<Self> {
let builder = middle::BasicClientCredentialAuthorizer::basic_builder(
client_id,
client_secret,
token_endpoint.into(),
);
let authorizer = if scopes.is_empty() {
builder
} else {
builder.add_scopes(scopes)
}
.build()
.await
.map_err(|e| {
tracing::error!("Could not construct OpenFGA client. Failed to fetch or refresh Client Credentials: {e}");
Error::CredentialRefreshError(e)
})?;
let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
let channel = endpoint.connect_lazy();
let intercepted = InterceptedService::new(channel, authorizer);
let service = Either::Left(Either::Left(intercepted));
Ok(BasicOpenFgaServiceClient::new(service))
}
}
impl<T> OpenFgaServiceClient<T>
where
T: tonic::client::GrpcService<tonic::body::Body>,
T::Error: Into<StdError>,
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
T: Clone,
{
pub fn into_client(self, store_id: &str, authorization_model_id: &str) -> OpenFgaClient<T> {
OpenFgaClient::new(self, store_id, authorization_model_id)
}
pub async fn get_store_by_name(&mut self, store_name: &str) -> Result<Option<Store>> {
let stores = self
.list_stores(ListStoresRequest {
page_size: Some(2),
continuation_token: String::new(),
name: store_name.to_string(),
})
.await
.map_err(|e| {
tracing::error!("Failed to list stores in OpenFGA: {e}");
Error::RequestFailed(Box::new(e))
})?
.into_inner();
let num_stores = stores.stores.len();
match stores.stores.first() {
Some(store) => {
if num_stores > 1 {
tracing::error!("Multiple stores with the name `{}` found", store_name);
Err(Error::AmbiguousStoreName(store_name.to_string()))
} else {
Ok(Some(store.clone()))
}
}
None => Ok(None),
}
}
pub async fn get_or_create_store(&mut self, store_name: &str) -> Result<Store> {
let store = self.get_store_by_name(store_name).await?;
match store {
None => {
tracing::debug!("OpenFGA Store {} not found. Creating it.", store_name);
let store = self
.create_store(CreateStoreRequest {
name: store_name.to_owned(),
})
.await
.map_err(|e| {
tracing::error!("Failed to create store in OpenFGA: {e}");
Error::RequestFailed(Box::new(e))
})?
.into_inner();
Ok(Store {
id: store.id,
name: store.name,
created_at: store.created_at,
updated_at: store.updated_at,
deleted_at: None,
})
}
Some(store) => Ok(store),
}
}
pub async fn read_all_pages(
&mut self,
store_id: &str,
tuple: Option<impl Into<ReadRequestTupleKey>>,
consistency: impl Into<ConsistencyPreference>,
page_size: i32,
max_pages: u32,
) -> Result<Vec<Tuple>> {
let mut continuation_token = String::new();
let tuple = tuple.map(Into::into);
let mut tuples = Vec::new();
let mut count = 0;
let consistency = consistency.into();
loop {
let read_request = ReadRequest {
store_id: store_id.to_owned(),
tuple_key: tuple.clone(),
page_size: Some(page_size),
continuation_token: continuation_token.clone(),
consistency: consistency.into(),
};
let response = self
.read(read_request.clone())
.await
.map_err(|e| {
tracing::error!(
"Failed to read from OpenFGA: {e}. Request: {:?}",
read_request
);
Error::RequestFailed(Box::new(e))
})?
.into_inner();
tuples.extend(response.tuples);
continuation_token.clone_from(&response.continuation_token);
count += 1;
if count > max_pages {
return Err(Error::TooManyPages { max_pages, tuple });
}
if continuation_token.is_empty() {
break;
}
}
Ok(tuples)
}
}
#[cfg(feature = "auth-middle")]
pub type BasicAuthLayer = tower::util::Either<
tower::util::Either<
InterceptedService<Channel, middle::BasicClientCredentialAuthorizer>,
InterceptedService<Channel, middle::BearerTokenAuthorizer>,
>,
InterceptedService<Channel, NoOpInterceptor>,
>;
#[cfg(feature = "auth-middle")]
#[derive(Clone, Copy, Debug)]
pub struct NoOpInterceptor;
#[cfg(feature = "auth-middle")]
impl tonic::service::Interceptor for NoOpInterceptor {
fn call(
&mut self,
request: tonic::Request<()>,
) -> std::result::Result<tonic::Request<()>, tonic::Status> {
Ok(request)
}
}
#[cfg(feature = "auth-middle")]
fn get_tonic_endpoint_logged(endpoint: &url::Url) -> Result<Endpoint> {
let ep = Endpoint::new(endpoint.to_string()).map_err(|e| {
tracing::error!("Could not construct OpenFGA client. Invalid endpoint `{endpoint}`: {e}");
Error::InvalidEndpoint(endpoint.to_string())
})?;
if endpoint.scheme() == "https" {
#[cfg(feature = "tls-rustls")]
{
use tonic::transport::ClientTlsConfig;
let tls_config = ClientTlsConfig::new().with_enabled_roots();
return ep.tls_config(tls_config).map_err(|e| {
tracing::error!(
"Could not configure TLS for OpenFGA client endpoint `{endpoint}`: {e}"
);
Error::TlsConfigurationFailed {
endpoint: endpoint.to_string(),
reason: e.to_string(),
}
});
}
#[cfg(not(feature = "tls-rustls"))]
{
return Err(Error::TlsConfigurationFailed {
endpoint: endpoint.to_string(),
reason: "HTTPS endpoint requires the `tls-rustls` feature to be enabled"
.to_string(),
});
}
}
Ok(ep)
}
#[cfg(test)]
pub(crate) mod test {
use needs_env_var::needs_env_var;
#[cfg(feature = "auth-middle")]
mod openfga {
use std::collections::{HashMap, HashSet};
use super::super::*;
use crate::{
client::{
TupleKey, WriteAuthorizationModelRequest, WriteAuthorizationModelResponse,
WriteRequest, WriteRequestWrites,
},
generated::AuthorizationModel,
};
fn get_basic_client() -> BasicOpenFgaServiceClient {
let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
BasicOpenFgaServiceClient::new_unauthenticated(url::Url::parse(&endpoint).unwrap())
.expect("Client can be created")
}
async fn new_store() -> (BasicOpenFgaServiceClient, Store) {
let mut client = get_basic_client();
let store_name = format!("store-{}", uuid::Uuid::now_v7());
let store = client
.get_or_create_store(&store_name)
.await
.expect("Store can be created");
(client, store)
}
async fn create_entitlements_model(
client: &mut BasicOpenFgaServiceClient,
store: &Store,
) -> WriteAuthorizationModelResponse {
let schema = include_str!("../tests/sample-store/entitlements/schema.json");
let model: AuthorizationModel =
serde_json::from_str(schema).expect("Schema can be deserialized");
let auth_model = client
.write_authorization_model(WriteAuthorizationModelRequest {
store_id: store.id.clone(),
type_definitions: model.type_definitions,
schema_version: model.schema_version,
conditions: model.conditions,
})
.await
.expect("Auth model can be written");
auth_model.into_inner()
}
#[tokio::test]
async fn test_get_store_by_name_many() {
let mut client = get_basic_client();
let mut stores = HashMap::new();
for _i in 0..201 {
let store_name = format!("store-{}", uuid::Uuid::now_v7());
let r = client
.get_or_create_store(&store_name)
.await
.expect("Store can be created");
assert_eq!(store_name, r.name);
stores.insert(store_name, r.id);
}
for (store_name, store_id) in stores {
let store = client
.get_store_by_name(&store_name)
.await
.expect("Store can be fetched")
.expect("Store exists");
assert_eq!(store_id, store.id);
}
}
#[tokio::test]
async fn test_get_store_by_name_non_existant() {
let mut client = get_basic_client();
let store = client
.get_store_by_name("non-existent-store")
.await
.unwrap();
assert!(store.is_none());
}
#[tokio::test]
async fn test_get_or_create_store() {
let mut client = get_basic_client();
let store_name = format!("store-{}", uuid::Uuid::now_v7());
let store = client.get_or_create_store(&store_name).await.unwrap();
let store2 = client.get_or_create_store(&store_name).await.unwrap();
assert_eq!(store.id, store2.id);
}
#[tokio::test]
async fn test_read_all_pages_many() {
let (mut client, store) = new_store().await;
let auth_model = create_entitlements_model(&mut client, &store).await;
let object = "organization:org-1";
let users = (0..501)
.map(|i| format!("user:u-{i}"))
.collect::<Vec<String>>();
for user in &users {
client
.write(WriteRequest {
authorization_model_id: auth_model.authorization_model_id.clone(),
store_id: store.id.clone(),
writes: Some(WriteRequestWrites {
on_duplicate: String::new(),
tuple_keys: vec![TupleKey {
user: user.clone(),
relation: "member".to_string(),
object: object.to_string(),
condition: None,
}],
}),
deletes: None,
})
.await
.expect("Write can be done");
}
let tuples = client
.read_all_pages(
&store.id,
Some(ReadRequestTupleKey {
user: String::new(),
relation: "member".to_string(),
object: object.to_string(),
}),
ConsistencyPreference::HigherConsistency,
100,
6,
)
.await
.expect("Read can be done");
assert_eq!(tuples.len(), 501);
assert_eq!(
tuples
.iter()
.map(|t| t.key.clone().unwrap().user)
.collect::<HashSet<String>>(),
HashSet::from_iter(users)
);
}
#[tokio::test]
async fn test_real_all_pages_empty() {
let (mut client, store) = new_store().await;
let tuples = client
.read_all_pages(
&store.id,
Some(ReadRequestTupleKey {
user: String::new(),
relation: "member".to_string(),
object: "organization:org-1".to_string(),
}),
ConsistencyPreference::HigherConsistency,
100,
5,
)
.await
.expect("Read can be done");
assert!(tuples.is_empty());
}
#[tokio::test]
async fn test_read_all_pages_unfiltered() {
let (mut client, store) = new_store().await;
let auth_model = create_entitlements_model(&mut client, &store).await;
let total = 250;
for i in 0..total {
client
.write(WriteRequest {
authorization_model_id: auth_model.authorization_model_id.clone(),
store_id: store.id.clone(),
writes: Some(WriteRequestWrites {
on_duplicate: String::new(),
tuple_keys: vec![TupleKey {
user: format!("user:u-{i}"),
relation: "member".to_string(),
object: format!("organization:org-{}", i % 5),
condition: None,
}],
}),
deletes: None,
})
.await
.expect("write can be done");
}
let tuples = client
.read_all_pages(
&store.id,
None::<ReadRequestTupleKey>,
ConsistencyPreference::HigherConsistency,
100,
10,
)
.await
.expect("unfiltered read_all_pages must succeed");
assert_eq!(
tuples.len(),
total,
"unfiltered read_all_pages must return every tuple in the store"
);
}
#[tokio::test]
async fn test_read_all_pages_max_pages_enforced() {
let (mut client, store) = new_store().await;
let auth_model = create_entitlements_model(&mut client, &store).await;
for i in 0..250 {
client
.write(WriteRequest {
authorization_model_id: auth_model.authorization_model_id.clone(),
store_id: store.id.clone(),
writes: Some(WriteRequestWrites {
on_duplicate: String::new(),
tuple_keys: vec![TupleKey {
user: format!("user:u-{i}"),
relation: "member".to_string(),
object: "organization:org-1".to_string(),
condition: None,
}],
}),
deletes: None,
})
.await
.expect("write can be done");
}
let result = client
.read_all_pages(
&store.id,
None::<ReadRequestTupleKey>,
ConsistencyPreference::HigherConsistency,
100,
1, )
.await;
match result {
Err(Error::TooManyPages { max_pages, .. }) => {
assert_eq!(max_pages, 1);
}
Err(other) => panic!("expected TooManyPages, got {other:?}"),
Ok(tuples) => panic!(
"expected TooManyPages error, got Ok with {} tuples",
tuples.len()
),
}
}
}
}