use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use async_stream::stream;
use futures::{StreamExt, pin_mut};
use tonic::codegen::{Body, Bytes, StdError};
use crate::{
client::{
BatchCheckItem, BatchCheckRequest, CheckRequest, CheckRequestTupleKey,
ConsistencyPreference, ContextualTupleKeys, ExpandRequest, ExpandRequestTupleKey,
ListObjectsRequest, ListObjectsResponse, OpenFgaServiceClient, ReadRequest,
ReadRequestTupleKey, ReadResponse, Tuple, TupleKey, TupleKeyWithoutCondition, UsersetTree,
WriteRequest, WriteRequestDeletes, WriteRequestWrites,
batch_check_single_result::CheckResult,
},
error::{Error, Result},
};
const DEFAULT_MAX_TUPLES_PER_WRITE: i32 = 100;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ConflictBehavior {
#[default]
Fail,
Ignore,
}
impl ConflictBehavior {
fn as_str(&self) -> &str {
match self {
ConflictBehavior::Fail => "",
ConflictBehavior::Ignore => "ignore",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct WriteOptions {
pub on_duplicate: ConflictBehavior,
pub on_missing: ConflictBehavior,
}
impl WriteOptions {
#[must_use]
pub fn new_idempotent() -> Self {
Self {
on_duplicate: ConflictBehavior::Ignore,
on_missing: ConflictBehavior::Ignore,
}
}
}
impl Default for WriteOptions {
fn default() -> Self {
Self {
on_duplicate: ConflictBehavior::Fail,
on_missing: ConflictBehavior::Fail,
}
}
}
#[derive(Clone, Debug)]
pub struct OpenFgaClient<T> {
client: OpenFgaServiceClient<T>,
inner: Arc<ModelClientInner>,
}
#[derive(Debug, Clone)]
struct ModelClientInner {
store_id: String,
authorization_model_id: String,
max_tuples_per_write: i32,
consistency: ConsistencyPreference,
}
#[cfg(feature = "auth-middle")]
pub type BasicOpenFgaClient = OpenFgaClient<crate::client::BasicAuthLayer>;
impl<T> OpenFgaClient<T>
where
T: tonic::client::GrpcService<tonic::body::Body>,
T::Error: Into<StdError>,
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
T: Clone,
{
#[must_use]
pub fn new(
client: OpenFgaServiceClient<T>,
store_id: &str,
authorization_model_id: &str,
) -> Self {
OpenFgaClient {
client,
inner: Arc::new(ModelClientInner {
store_id: store_id.to_string(),
authorization_model_id: authorization_model_id.to_string(),
max_tuples_per_write: DEFAULT_MAX_TUPLES_PER_WRITE,
consistency: ConsistencyPreference::MinimizeLatency,
}),
}
}
#[must_use]
pub fn set_max_tuples_per_write(mut self, max_tuples_per_write: i32) -> Self {
let inner = Arc::unwrap_or_clone(self.inner);
self.inner = Arc::new(ModelClientInner {
store_id: inner.store_id,
authorization_model_id: inner.authorization_model_id,
max_tuples_per_write,
consistency: inner.consistency,
});
self
}
#[must_use]
pub fn set_consistency(mut self, consistency: impl Into<ConsistencyPreference>) -> Self {
let inner = Arc::unwrap_or_clone(self.inner);
self.inner = Arc::new(ModelClientInner {
store_id: inner.store_id,
authorization_model_id: inner.authorization_model_id,
max_tuples_per_write: inner.max_tuples_per_write,
consistency: consistency.into(),
});
self
}
pub fn store_id(&self) -> &str {
&self.inner.store_id
}
pub fn authorization_model_id(&self) -> &str {
&self.inner.authorization_model_id
}
pub fn max_tuples_per_write(&self) -> i32 {
self.inner.max_tuples_per_write
}
pub fn client(&self) -> OpenFgaServiceClient<T> {
self.client.clone()
}
pub fn consistency(&self) -> ConsistencyPreference {
self.inner.consistency
}
pub async fn write(
&self,
writes: impl Into<Option<Vec<TupleKey>>>,
deletes: impl Into<Option<Vec<TupleKeyWithoutCondition>>>,
) -> Result<()> {
self.write_with_options(writes, deletes, WriteOptions::default())
.await
}
pub async fn write_with_options(
&self,
writes: impl Into<Option<Vec<TupleKey>>>,
deletes: impl Into<Option<Vec<TupleKeyWithoutCondition>>>,
options: WriteOptions,
) -> Result<()> {
let writes = writes.into().and_then(|w| (!w.is_empty()).then_some(w));
let deletes = deletes.into().and_then(|d| (!d.is_empty()).then_some(d));
if writes.is_none() && deletes.is_none() {
return Ok(());
}
let num_writes_and_deletes = i32::try_from(
#[allow(clippy::manual_saturating_arithmetic)]
writes
.as_ref()
.map_or(0, Vec::len)
.checked_add(deletes.as_ref().map_or(0, Vec::len))
.unwrap_or(usize::MAX),
)
.unwrap_or(i32::MAX);
if num_writes_and_deletes > self.max_tuples_per_write() {
tracing::error!(
"Too many writes and deletes in single OpenFGA transaction (actual) {} > {} (max)",
num_writes_and_deletes,
self.max_tuples_per_write()
);
return Err(Error::TooManyWrites {
actual: num_writes_and_deletes,
max: self.max_tuples_per_write(),
});
}
let write_request = WriteRequest {
store_id: self.store_id().to_string(),
writes: writes.map(|writes| WriteRequestWrites {
tuple_keys: writes,
on_duplicate: options.on_duplicate.as_str().to_string(),
}),
deletes: deletes.map(|deletes| WriteRequestDeletes {
on_missing: options.on_missing.as_str().to_string(),
tuple_keys: deletes,
}),
authorization_model_id: self.authorization_model_id().to_string(),
};
self.client
.clone()
.write(write_request.clone())
.await
.map_err(|e| {
let write_request_debug = format!("{write_request:?}");
tracing::error!(
"Write request failed with status {e}. Request: {write_request_debug}"
);
Error::RequestFailed(Box::new(e))
})
.map(|_| ())
}
pub async fn read(
&self,
page_size: i32,
tuple_key: impl Into<Option<ReadRequestTupleKey>>,
continuation_token: impl Into<Option<String>>,
) -> Result<tonic::Response<ReadResponse>> {
let read_request = ReadRequest {
store_id: self.store_id().to_string(),
page_size: Some(page_size),
continuation_token: continuation_token.into().unwrap_or_default(),
tuple_key: tuple_key.into(),
consistency: self.consistency().into(),
};
self.client
.clone()
.read(read_request.clone())
.await
.map_err(|e| {
let read_request_debug = format!("{read_request:?}");
tracing::error!(
"Read request failed with status {e}. Request: {read_request_debug}"
);
Error::RequestFailed(Box::new(e))
})
}
pub async fn read_all_pages(
&self,
tuple: Option<impl Into<ReadRequestTupleKey>>,
page_size: i32,
max_pages: u32,
) -> Result<Vec<Tuple>> {
let store_id = self.store_id().to_string();
self.client
.clone()
.read_all_pages(&store_id, tuple, self.consistency(), page_size, max_pages)
.await
}
pub async fn check(
&self,
tuple_key: impl Into<CheckRequestTupleKey>,
contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
context: impl Into<Option<prost_wkt_types::Struct>>,
trace: bool,
) -> Result<bool> {
let contextual_tuples = contextual_tuples
.into()
.and_then(|c| (!c.is_empty()).then_some(c))
.map(|tuple_keys| ContextualTupleKeys { tuple_keys });
let check_request = CheckRequest {
store_id: self.store_id().to_string(),
tuple_key: Some(tuple_key.into()),
consistency: self.consistency().into(),
contextual_tuples,
authorization_model_id: self.authorization_model_id().to_string(),
context: context.into(),
trace,
};
let response = self
.client
.clone()
.check(check_request.clone())
.await
.map_err(|e| {
let check_request_debug = format!("{check_request:?}");
tracing::error!(
"Check request failed with status {e}. Request: {check_request_debug}"
);
Error::RequestFailed(Box::new(e))
})?;
Ok(response.get_ref().allowed)
}
pub async fn batch_check<I>(
&self,
checks: impl IntoIterator<Item = I>,
) -> Result<HashMap<String, CheckResult>>
where
I: Into<BatchCheckItem>,
{
let checks: Vec<BatchCheckItem> = checks.into_iter().map(Into::into).collect();
let request = BatchCheckRequest {
store_id: self.store_id().to_string(),
checks,
authorization_model_id: self.authorization_model_id().to_string(),
consistency: self.consistency().into(),
};
let response = self
.client
.clone()
.batch_check(request.clone())
.await
.map_err(|e| {
let request_debug = format!("{request:?}");
tracing::error!(
"Batch-Check request failed with status {e}. Request: {request_debug}"
);
Error::RequestFailed(Box::new(e))
})?;
let mut map = HashMap::new();
for (k, v) in response.into_inner().result {
match v.check_result {
Some(v) => map.insert(k, v),
None => return Err(Error::ExpectedOneof),
};
}
Ok(map)
}
pub async fn expand(
&self,
tuple_key: impl Into<ExpandRequestTupleKey>,
contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
) -> Result<Option<UsersetTree>> {
let expand_request = ExpandRequest {
store_id: self.store_id().to_string(),
tuple_key: Some(tuple_key.into()),
authorization_model_id: self.authorization_model_id().to_string(),
consistency: self.consistency().into(),
contextual_tuples: contextual_tuples
.into()
.map(|tuple_keys| ContextualTupleKeys { tuple_keys }),
};
let response = self
.client
.clone()
.expand(expand_request.clone())
.await
.map_err(|e| {
tracing::error!(
"Expand request failed with status {e}. Request: {expand_request:?}"
);
Error::RequestFailed(Box::new(e))
})?;
Ok(response.into_inner().tree)
}
pub async fn check_simple(&self, tuple_key: impl Into<CheckRequestTupleKey>) -> Result<bool> {
self.check(tuple_key, None, None, false).await
}
pub async fn list_objects(
&self,
r#type: impl Into<String>,
relation: impl Into<String>,
user: impl Into<String>,
contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
context: impl Into<Option<prost_wkt_types::Struct>>,
) -> Result<tonic::Response<ListObjectsResponse>> {
let request = ListObjectsRequest {
r#type: r#type.into(),
relation: relation.into(),
user: user.into(),
authorization_model_id: self.authorization_model_id().to_string(),
store_id: self.store_id().to_string(),
consistency: self.consistency().into(),
contextual_tuples: contextual_tuples
.into()
.map(|tuple_keys| ContextualTupleKeys { tuple_keys }),
context: context.into(),
};
self.client
.clone()
.list_objects(request.clone())
.await
.map_err(|e| {
tracing::error!(
"List-Objects request failed with status {e}. Request: {request:?}"
);
Error::RequestFailed(Box::new(e))
})
}
pub async fn delete_relations_to_object(&self, object: &str) -> Result<()> {
loop {
self.delete_relations_to_object_inner(object)
.await
.inspect_err(|e| {
tracing::error!("Failed to delete relations to object {object}: {e}");
})?;
if self.exists_relation_to(object).await? {
tracing::debug!(
"Some tuples for object {object} are still present after first sweep. Performing another deletion."
);
} else {
tracing::debug!("Successfully deleted all relations to object {object}");
break Ok(());
}
}
}
pub async fn exists_relation_to(&self, object: &str) -> Result<bool> {
let tuples = self.read_relations_to_object(object, None, 1).await?;
Ok(!tuples.tuples.is_empty())
}
async fn read_relations_to_object(
&self,
object: &str,
continuation_token: impl Into<Option<String>>,
page_size: i32,
) -> Result<ReadResponse> {
self.read(
page_size,
TupleKeyWithoutCondition {
user: String::new(),
relation: String::new(),
object: object.to_string(),
},
continuation_token,
)
.await
.map(tonic::Response::into_inner)
}
async fn delete_relations_to_object_inner(&self, object: &str) -> Result<()> {
let read_stream = stream! {
let mut continuation_token = None;
let mut seen= HashSet::new();
while continuation_token != Some(String::new()) {
let response = self.read_relations_to_object(object, continuation_token, self.max_tuples_per_write()).await?;
let keys = response.tuples.into_iter().filter_map(|t| t.key).filter(|k| !seen.contains(&(k.user.clone(), k.relation.clone()))).collect::<Vec<_>>();
tracing::debug!("Read {} keys for object {object} that are up for deletion. Continuation token: {}", keys.len(), response.continuation_token);
continuation_token = Some(response.continuation_token);
seen.extend(keys.iter().map(|k| (k.user.clone(), k.relation.clone())));
yield Result::Ok(keys);
}
};
pin_mut!(read_stream);
let mut read_tuples: Option<Vec<TupleKey>> = None;
let delete_tuples = |t: Option<Vec<TupleKey>>| async {
match t {
Some(tuples) => {
tracing::debug!(
"Deleting {} tuples for object {object} that we haven't seen before.",
tuples.len()
);
self.write(
None,
Some(
tuples
.into_iter()
.map(|t| TupleKeyWithoutCondition {
user: t.user,
relation: t.relation,
object: t.object,
})
.collect(),
),
)
.await
}
None => Ok(()),
}
};
loop {
let next_future = read_stream.next();
let deletion_future = delete_tuples(read_tuples.clone());
let (tuples, delete) = futures::join!(next_future, deletion_future);
delete?;
if let Some(tuples) = tuples.transpose()? {
read_tuples = (!tuples.is_empty()).then_some(tuples);
} else {
break Ok(());
}
}
}
}
#[cfg(test)]
mod tests {
use needs_env_var::needs_env_var;
#[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
mod openfga {
use tracing_test::traced_test;
use super::super::*;
use crate::{
client::{AuthorizationModel, Store},
migration::test::openfga::service_client_with_store,
};
async fn write_custom_roles_model(
client: &OpenFgaServiceClient<tonic::transport::Channel>,
store: &Store,
) -> String {
let model: AuthorizationModel = serde_json::from_str(include_str!(
"../tests/sample-store/custom-roles/schema.json"
))
.unwrap();
client
.clone()
.write_authorization_model(model.into_write_request(store.id.clone()))
.await
.unwrap()
.into_inner()
.authorization_model_id
}
async fn get_client_with_custom_roles_model() -> OpenFgaClient<tonic::transport::Channel> {
let (service_client, store) = service_client_with_store().await;
let auth_model_id = write_custom_roles_model(&service_client, &store).await;
OpenFgaClient::new(service_client, &store.id, auth_model_id.as_str())
}
#[tokio::test]
#[traced_test]
async fn test_read_single_page_unfiltered() {
let client = get_client_with_custom_roles_model().await;
let total = 75;
for i in 0..total {
client
.write(
vec![TupleKey {
user: format!("user:user{i}"),
relation: "member".to_string(),
object: "team:team1".to_string(),
condition: None,
}],
None,
)
.await
.unwrap();
}
let resp = client
.read(100, None, None::<String>)
.await
.expect("read with None tuple_key must succeed");
let inner = resp.into_inner();
assert_eq!(inner.tuples.len(), total);
assert!(
inner.continuation_token.is_empty(),
"continuation token must be empty when all results fit in one page"
);
let resp = client
.read(50, None, None::<String>)
.await
.expect("read with None tuple_key must succeed");
let inner = resp.into_inner();
assert_eq!(inner.tuples.len(), 50);
assert!(
!inner.continuation_token.is_empty(),
"continuation token must be set when more pages are available"
);
let resp = client
.read(50, None, Some(inner.continuation_token))
.await
.expect("read with continuation token must succeed");
let inner = resp.into_inner();
assert_eq!(inner.tuples.len(), total - 50);
assert!(inner.continuation_token.is_empty());
}
#[tokio::test]
#[traced_test]
async fn test_read_single_page_filtered_backward_compat() {
let client = get_client_with_custom_roles_model().await;
client
.write(
vec![
TupleKey {
user: "user:alice".to_string(),
relation: "member".to_string(),
object: "team:team1".to_string(),
condition: None,
},
TupleKey {
user: "user:bob".to_string(),
relation: "member".to_string(),
object: "team:team2".to_string(),
condition: None,
},
],
None,
)
.await
.unwrap();
let resp = client
.read(
100,
ReadRequestTupleKey {
user: String::new(),
relation: "member".to_string(),
object: "team:team1".to_string(),
},
None::<String>,
)
.await
.unwrap();
let inner = resp.into_inner();
assert_eq!(inner.tuples.len(), 1);
assert_eq!(inner.tuples[0].key.as_ref().unwrap().user, "user:alice");
}
#[tokio::test]
#[traced_test]
async fn test_read_all_pages_empty_tuple() {
let client = get_client_with_custom_roles_model().await;
let loop_count = 100;
let tuples_per_loop = 3;
for i in 0..loop_count {
client
.write(
vec![
TupleKey {
user: format!("user:user{i}"),
relation: "member".to_string(),
object: "team:team1".to_string(),
condition: None,
},
TupleKey {
user: format!("role:role{i}#assignee"),
relation: "role_assigner".to_string(),
object: "org:org1".to_string(),
condition: None,
},
TupleKey {
user: format!("org:org{i}"),
relation: "org".to_string(),
object: "asset-category:ac{i}".to_string(),
condition: None,
},
],
None,
)
.await
.unwrap();
}
let tuples = client
.read_all_pages(None::<ReadRequestTupleKey>, 50, u32::MAX)
.await
.unwrap();
assert_eq!(tuples.len(), loop_count * tuples_per_loop);
}
#[tokio::test]
#[traced_test]
async fn test_delete_relations_to_object() {
let client = get_client_with_custom_roles_model().await;
let object = "team:team1";
assert!(!client.exists_relation_to(object).await.unwrap());
client
.write(
vec![TupleKey {
user: "user:user1".to_string(),
relation: "member".to_string(),
object: object.to_string(),
condition: None,
}],
None,
)
.await
.unwrap();
assert!(client.exists_relation_to(object).await.unwrap());
client.delete_relations_to_object(object).await.unwrap();
assert!(!client.exists_relation_to(object).await.unwrap());
}
#[tokio::test]
#[traced_test]
async fn test_delete_relations_to_object_usersets() {
let client = get_client_with_custom_roles_model().await;
let object: &str = "role:admin";
assert!(!client.exists_relation_to(object).await.unwrap());
client
.write(
vec![TupleKey {
user: "team:team1#member".to_string(),
relation: "assignee".to_string(),
object: object.to_string(),
condition: None,
}],
None,
)
.await
.unwrap();
assert!(client.exists_relation_to(object).await.unwrap());
client.delete_relations_to_object(object).await.unwrap();
assert!(!client.exists_relation_to(object).await.unwrap());
}
#[tokio::test]
#[traced_test]
async fn test_delete_relations_to_object_empty() {
let client = get_client_with_custom_roles_model().await;
let object = "team:team1";
assert!(!client.exists_relation_to(object).await.unwrap());
client.delete_relations_to_object(object).await.unwrap();
assert!(!client.exists_relation_to(object).await.unwrap());
}
#[tokio::test]
#[traced_test]
async fn test_delete_relations_to_object_many() {
let client = get_client_with_custom_roles_model().await;
let object = "org:org1";
assert!(!client.exists_relation_to(object).await.unwrap());
for i in 0..502 {
client
.write(
vec![
TupleKey {
user: format!("user:user{i}"),
relation: "member".to_string(),
object: object.to_string(),
condition: None,
},
TupleKey {
user: format!("role:role{i}#assignee"),
relation: "role_assigner".to_string(),
object: object.to_string(),
condition: None,
},
],
None,
)
.await
.unwrap();
}
let object_2 = "org:org2";
client
.write(
vec![TupleKey {
user: "user:user1".to_string(),
relation: "owner".to_string(),
object: object_2.to_string(),
condition: None,
}],
None,
)
.await
.unwrap();
assert!(client.exists_relation_to(object).await.unwrap());
assert!(client.exists_relation_to(object_2).await.unwrap());
client.delete_relations_to_object(object).await.unwrap();
assert!(!client.exists_relation_to(object).await.unwrap());
assert!(client.exists_relation_to(object_2).await.unwrap());
assert!(
client
.check_simple(TupleKeyWithoutCondition {
user: "user:user1".to_string(),
relation: "role_assigner".to_string(),
object: object_2.to_string(),
})
.await
.unwrap()
);
}
#[tokio::test]
#[traced_test]
async fn test_write_with_options_ignore_duplicate() {
let client = get_client_with_custom_roles_model().await;
let tuple = TupleKey {
user: "user:user1".to_string(),
relation: "member".to_string(),
object: "team:team1".to_string(),
condition: None,
};
client
.write_with_options(vec![tuple.clone()], None, WriteOptions::default())
.await
.unwrap();
let result = client
.write_with_options(vec![tuple.clone()], None, WriteOptions::default())
.await;
assert!(result.is_err());
let options = WriteOptions {
on_duplicate: ConflictBehavior::Ignore,
on_missing: ConflictBehavior::Fail,
};
client
.write_with_options(vec![tuple], None, options)
.await
.unwrap();
}
#[tokio::test]
#[traced_test]
async fn test_write_with_options_ignore_missing() {
let client = get_client_with_custom_roles_model().await;
let tuple_key = TupleKeyWithoutCondition {
user: "user:user1".to_string(),
relation: "member".to_string(),
object: "team:team1".to_string(),
};
let result = client
.write_with_options(None, vec![tuple_key.clone()], WriteOptions::default())
.await;
assert!(result.is_err());
let options = WriteOptions {
on_duplicate: ConflictBehavior::Fail,
on_missing: ConflictBehavior::Ignore,
};
client
.write_with_options(None, vec![tuple_key], options)
.await
.unwrap();
}
#[tokio::test]
#[traced_test]
async fn test_write_with_options_idempotent() {
let client = get_client_with_custom_roles_model().await;
let tuple = TupleKey {
user: "user:user1".to_string(),
relation: "member".to_string(),
object: "team:team1".to_string(),
condition: None,
};
let options = WriteOptions::new_idempotent();
client
.write_with_options(vec![tuple.clone()], None, options)
.await
.unwrap();
client
.write_with_options(vec![tuple], None, options)
.await
.unwrap();
let tuple_key = TupleKeyWithoutCondition {
user: "user:nonexistent".to_string(),
relation: "member".to_string(),
object: "team:team1".to_string(),
};
client
.write_with_options(None, vec![tuple_key], options)
.await
.unwrap();
}
#[tokio::test]
#[traced_test]
#[allow(clippy::similar_names)]
async fn test_write_with_options_mixed_operations() {
let client = get_client_with_custom_roles_model().await;
let tuple1 = TupleKey {
user: "user:user1".to_string(),
relation: "member".to_string(),
object: "team:team1".to_string(),
condition: None,
};
client.write(vec![tuple1.clone()], None).await.unwrap();
let tuple2 = TupleKey {
user: "user:user2".to_string(),
relation: "member".to_string(),
object: "team:team1".to_string(),
condition: None,
};
let delete_key = TupleKeyWithoutCondition {
user: tuple1.user,
relation: tuple1.relation,
object: tuple1.object,
};
client
.write_with_options(vec![tuple2], vec![delete_key], WriteOptions::default())
.await
.unwrap();
let tuples = client
.read_all_pages(
Some(TupleKeyWithoutCondition {
user: String::new(),
relation: "member".to_string(),
object: "team:team1".to_string(),
}),
10,
10,
)
.await
.unwrap();
assert_eq!(tuples.len(), 1);
assert_eq!(tuples[0].key.as_ref().unwrap().user, "user:user2");
}
#[tokio::test]
#[traced_test]
async fn test_write_with_options_empty_operations() {
let client = get_client_with_custom_roles_model().await;
let result = client
.write_with_options(
None::<Vec<TupleKey>>,
None::<Vec<TupleKeyWithoutCondition>>,
WriteOptions::default(),
)
.await;
assert!(result.is_ok());
}
}
}