1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4};
5
6use async_stream::stream;
7use futures::{pin_mut, StreamExt};
8use tonic::codegen::{Body, Bytes, StdError};
9
10use crate::{
11 client::{
12 batch_check_single_result::CheckResult, BatchCheckItem, BatchCheckRequest, CheckRequest,
13 CheckRequestTupleKey, ConsistencyPreference, ContextualTupleKeys, ExpandRequest,
14 ExpandRequestTupleKey, ListObjectsRequest, ListObjectsResponse, OpenFgaServiceClient,
15 ReadRequest, ReadRequestTupleKey, ReadResponse, Tuple, TupleKey, TupleKeyWithoutCondition,
16 UsersetTree, WriteRequest, WriteRequestDeletes, WriteRequestWrites,
17 },
18 error::{Error, Result},
19};
20
21const DEFAULT_MAX_TUPLES_PER_WRITE: i32 = 100;
22
23#[derive(Clone, Debug)]
24pub struct OpenFgaClient<T> {
54 client: OpenFgaServiceClient<T>,
55 inner: Arc<ModelClientInner>,
56}
57
58#[derive(Debug, Clone)]
59struct ModelClientInner {
60 store_id: String,
61 authorization_model_id: String,
62 max_tuples_per_write: i32,
63 consistency: ConsistencyPreference,
64}
65
66#[cfg(feature = "auth-middle")]
67pub type BasicOpenFgaClient = OpenFgaClient<crate::client::BasicAuthLayer>;
72
73impl<T> OpenFgaClient<T>
74where
75 T: tonic::client::GrpcService<tonic::body::BoxBody>,
76 T::Error: Into<StdError>,
77 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
78 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
79 T: Clone,
80{
81 #[must_use]
83 pub fn new(
84 client: OpenFgaServiceClient<T>,
85 store_id: &str,
86 authorization_model_id: &str,
87 ) -> Self {
88 OpenFgaClient {
89 client,
90 inner: Arc::new(ModelClientInner {
91 store_id: store_id.to_string(),
92 authorization_model_id: authorization_model_id.to_string(),
93 max_tuples_per_write: DEFAULT_MAX_TUPLES_PER_WRITE,
94 consistency: ConsistencyPreference::MinimizeLatency,
95 }),
96 }
97 }
98
99 #[must_use]
101 pub fn set_max_tuples_per_write(mut self, max_tuples_per_write: i32) -> Self {
102 let inner = Arc::unwrap_or_clone(self.inner);
103 self.inner = Arc::new(ModelClientInner {
104 store_id: inner.store_id,
105 authorization_model_id: inner.authorization_model_id,
106 max_tuples_per_write,
107 consistency: inner.consistency,
108 });
109 self
110 }
111
112 #[must_use]
114 pub fn set_consistency(mut self, consistency: impl Into<ConsistencyPreference>) -> Self {
115 let inner = Arc::unwrap_or_clone(self.inner);
116 self.inner = Arc::new(ModelClientInner {
117 store_id: inner.store_id,
118 authorization_model_id: inner.authorization_model_id,
119 max_tuples_per_write: inner.max_tuples_per_write,
120 consistency: consistency.into(),
121 });
122 self
123 }
124
125 pub fn store_id(&self) -> &str {
127 &self.inner.store_id
128 }
129
130 pub fn authorization_model_id(&self) -> &str {
132 &self.inner.authorization_model_id
133 }
134
135 pub fn max_tuples_per_write(&self) -> i32 {
137 self.inner.max_tuples_per_write
138 }
139
140 pub fn client(&self) -> OpenFgaServiceClient<T> {
142 self.client.clone()
143 }
144
145 pub fn consistency(&self) -> ConsistencyPreference {
147 self.inner.consistency
148 }
149
150 pub async fn write(
172 &self,
173 writes: impl Into<Option<Vec<TupleKey>>>,
174 deletes: impl Into<Option<Vec<TupleKeyWithoutCondition>>>,
175 ) -> Result<()> {
176 let writes = writes.into().and_then(|w| (!w.is_empty()).then_some(w));
177 let deletes = deletes.into().and_then(|d| (!d.is_empty()).then_some(d));
178
179 if writes.is_none() && deletes.is_none() {
180 return Ok(());
181 }
182
183 let num_writes_and_deletes = i32::try_from(
184 #[allow(clippy::manual_saturating_arithmetic)]
185 writes
186 .as_ref()
187 .map_or(0, Vec::len)
188 .checked_add(deletes.as_ref().map_or(0, Vec::len))
189 .unwrap_or(usize::MAX),
190 )
191 .unwrap_or(i32::MAX);
192
193 if num_writes_and_deletes > self.max_tuples_per_write() {
194 tracing::error!(
195 "Too many writes and deletes in single OpenFGA transaction (actual) {} > {} (max)",
196 num_writes_and_deletes,
197 self.max_tuples_per_write()
198 );
199 return Err(Error::TooManyWrites {
200 actual: num_writes_and_deletes,
201 max: self.max_tuples_per_write(),
202 });
203 }
204
205 let write_request = WriteRequest {
206 store_id: self.store_id().to_string(),
207 writes: writes.map(|writes| WriteRequestWrites { tuple_keys: writes }),
208 deletes: deletes.map(|deletes| WriteRequestDeletes {
209 tuple_keys: deletes,
210 }),
211 authorization_model_id: self.authorization_model_id().to_string(),
212 };
213
214 self.client
215 .clone()
216 .write(write_request.clone())
217 .await
218 .map_err(|e| {
219 let write_request_debug = format!("{write_request:?}");
220 tracing::error!(
221 "Write request failed with status {e}. Request: {write_request_debug}"
222 );
223 Error::RequestFailed(Box::new(e))
224 })
225 .map(|_| ())
226 }
227
228 pub async fn read(
237 &self,
238 page_size: i32,
239 tuple_key: impl Into<ReadRequestTupleKey>,
240 continuation_token: impl Into<Option<String>>,
241 ) -> Result<tonic::Response<ReadResponse>> {
242 let read_request = ReadRequest {
243 store_id: self.store_id().to_string(),
244 page_size: Some(page_size),
245 continuation_token: continuation_token.into().unwrap_or_default(),
246 tuple_key: Some(tuple_key.into()),
247 consistency: self.consistency().into(),
248 };
249 self.client
250 .clone()
251 .read(read_request.clone())
252 .await
253 .map_err(|e| {
254 let read_request_debug = format!("{read_request:?}");
255 tracing::error!(
256 "Read request failed with status {e}. Request: {read_request_debug}"
257 );
258 Error::RequestFailed(Box::new(e))
259 })
260 }
261
262 pub async fn read_all_pages(
270 &self,
271 tuple: Option<impl Into<ReadRequestTupleKey>>,
272 page_size: i32,
273 max_pages: u32,
274 ) -> Result<Vec<Tuple>> {
275 let store_id = self.store_id().to_string();
276 self.client
277 .clone()
278 .read_all_pages(&store_id, tuple, self.consistency(), page_size, max_pages)
279 .await
280 }
281
282 pub async fn check(
289 &self,
290 tuple_key: impl Into<CheckRequestTupleKey>,
291 contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
292 context: impl Into<Option<prost_wkt_types::Struct>>,
293 trace: bool,
294 ) -> Result<bool> {
295 let contextual_tuples = contextual_tuples
296 .into()
297 .and_then(|c| (!c.is_empty()).then_some(c))
298 .map(|tuple_keys| ContextualTupleKeys { tuple_keys });
299
300 let check_request = CheckRequest {
301 store_id: self.store_id().to_string(),
302 tuple_key: Some(tuple_key.into()),
303 consistency: self.consistency().into(),
304 contextual_tuples,
305 authorization_model_id: self.authorization_model_id().to_string(),
306 context: context.into(),
307 trace,
308 };
309 let response = self
310 .client
311 .clone()
312 .check(check_request.clone())
313 .await
314 .map_err(|e| {
315 let check_request_debug = format!("{check_request:?}");
316 tracing::error!(
317 "Check request failed with status {e}. Request: {check_request_debug}"
318 );
319 Error::RequestFailed(Box::new(e))
320 })?;
321 Ok(response.get_ref().allowed)
322 }
323
324 pub async fn batch_check<I>(
332 &self,
333 checks: impl IntoIterator<Item = I>,
334 ) -> Result<HashMap<String, CheckResult>>
335 where
336 I: Into<BatchCheckItem>,
337 {
338 let checks: Vec<BatchCheckItem> = checks.into_iter().map(Into::into).collect();
339 let request = BatchCheckRequest {
340 store_id: self.store_id().to_string(),
341 checks,
342 authorization_model_id: self.authorization_model_id().to_string(),
343 consistency: self.consistency().into(),
344 };
345
346 let response = self
347 .client
348 .clone()
349 .batch_check(request.clone())
350 .await
351 .map_err(|e| {
352 let request_debug = format!("{request:?}");
353 tracing::error!(
354 "Batch-Check request failed with status {e}. Request: {request_debug}"
355 );
356 Error::RequestFailed(Box::new(e))
357 })?;
358
359 let mut map = HashMap::new();
360 for (k, v) in response.into_inner().result {
361 match v.check_result {
362 Some(v) => map.insert(k, v),
366 None => return Err(Error::ExpectedOneof),
367 };
368 }
369 Ok(map)
370 }
371
372 pub async fn expand(
379 &self,
380 tuple_key: impl Into<ExpandRequestTupleKey>,
381 contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
382 ) -> Result<Option<UsersetTree>> {
383 let expand_request = ExpandRequest {
384 store_id: self.store_id().to_string(),
385 tuple_key: Some(tuple_key.into()),
386 authorization_model_id: self.authorization_model_id().to_string(),
387 consistency: self.consistency().into(),
388 contextual_tuples: contextual_tuples
389 .into()
390 .map(|tuple_keys| ContextualTupleKeys { tuple_keys }),
391 };
392 let response = self
393 .client
394 .clone()
395 .expand(expand_request.clone())
396 .await
397 .map_err(|e| {
398 tracing::error!(
399 "Expand request failed with status {e}. Request: {expand_request:?}"
400 );
401 Error::RequestFailed(Box::new(e))
402 })?;
403 Ok(response.into_inner().tree)
404 }
405
406 pub async fn check_simple(&self, tuple_key: impl Into<CheckRequestTupleKey>) -> Result<bool> {
411 self.check(tuple_key, None, None, false).await
412 }
413
414 pub async fn list_objects(
419 &self,
420 r#type: impl Into<String>,
421 relation: impl Into<String>,
422 user: impl Into<String>,
423 contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
424 context: impl Into<Option<prost_wkt_types::Struct>>,
425 ) -> Result<tonic::Response<ListObjectsResponse>> {
426 let request = ListObjectsRequest {
427 r#type: r#type.into(),
428 relation: relation.into(),
429 user: user.into(),
430 authorization_model_id: self.authorization_model_id().to_string(),
431 store_id: self.store_id().to_string(),
432 consistency: self.consistency().into(),
433 contextual_tuples: contextual_tuples
434 .into()
435 .map(|tuple_keys| ContextualTupleKeys { tuple_keys }),
436 context: context.into(),
437 };
438
439 self.client
440 .clone()
441 .list_objects(request.clone())
442 .await
443 .map_err(|e| {
444 tracing::error!(
445 "List-Objects request failed with status {e}. Request: {request:?}"
446 );
447 Error::RequestFailed(Box::new(e))
448 })
449 }
450
451 pub async fn delete_relations_to_object(&self, object: &str) -> Result<()> {
463 loop {
464 self.delete_relations_to_object_inner(object)
465 .await
466 .inspect_err(|e| {
467 tracing::error!("Failed to delete relations to object {object}: {e}");
468 })?;
469
470 if self.exists_relation_to(object).await? {
471 tracing::debug!("Some tuples for object {object} are still present after first sweep. Performing another deletion.");
472 } else {
473 tracing::debug!("Successfully deleted all relations to object {object}");
474 break Ok(());
475 }
476 }
477 }
478
479 pub async fn exists_relation_to(&self, object: &str) -> Result<bool> {
485 let tuples = self.read_relations_to_object(object, None, 1).await?;
486 Ok(!tuples.tuples.is_empty())
487 }
488
489 async fn read_relations_to_object(
490 &self,
491 object: &str,
492 continuation_token: impl Into<Option<String>>,
493 page_size: i32,
494 ) -> Result<ReadResponse> {
495 self.read(
496 page_size,
497 TupleKeyWithoutCondition {
498 user: String::new(),
499 relation: String::new(),
500 object: object.to_string(),
501 },
502 continuation_token,
503 )
504 .await
505 .map(tonic::Response::into_inner)
506 }
507
508 async fn delete_relations_to_object_inner(&self, object: &str) -> Result<()> {
512 let read_stream = stream! {
513 let mut continuation_token = None;
514 let mut seen= HashSet::new();
517 while continuation_token != Some(String::new()) {
518 let response = self.read_relations_to_object(object, continuation_token, self.max_tuples_per_write()).await?;
519 let keys = response.tuples.into_iter().filter_map(|t| t.key).filter(|k| !seen.contains(&(k.user.clone(), k.relation.clone()))).collect::<Vec<_>>();
520 tracing::debug!("Read {} keys for object {object} that are up for deletion. Continuation token: {}", keys.len(), response.continuation_token);
521 continuation_token = Some(response.continuation_token);
522 seen.extend(keys.iter().map(|k| (k.user.clone(), k.relation.clone())));
523 yield Result::Ok(keys);
524 }
525 };
526 pin_mut!(read_stream);
527 let mut read_tuples: Option<Vec<TupleKey>> = None;
528
529 let delete_tuples = |t: Option<Vec<TupleKey>>| async {
530 match t {
531 Some(tuples) => {
532 tracing::debug!(
533 "Deleting {} tuples for object {object} that we haven't seen before.",
534 tuples.len()
535 );
536 self.write(
537 None,
538 Some(
539 tuples
540 .into_iter()
541 .map(|t| TupleKeyWithoutCondition {
542 user: t.user,
543 relation: t.relation,
544 object: t.object,
545 })
546 .collect(),
547 ),
548 )
549 .await
550 }
551 None => Ok(()),
552 }
553 };
554
555 loop {
556 let next_future = read_stream.next();
557 let deletion_future = delete_tuples(read_tuples.clone());
558
559 let (tuples, delete) = futures::join!(next_future, deletion_future);
560 delete?;
561
562 if let Some(tuples) = tuples.transpose()? {
563 read_tuples = (!tuples.is_empty()).then_some(tuples);
564 } else {
565 break Ok(());
566 }
567 }
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use needs_env_var::needs_env_var;
574
575 #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
576 mod openfga {
577 use tracing_test::traced_test;
578
579 use super::super::*;
580 use crate::{
581 client::{AuthorizationModel, Store},
582 migration::test::openfga::service_client_with_store,
583 };
584
585 async fn write_custom_roles_model(
586 client: &OpenFgaServiceClient<tonic::transport::Channel>,
587 store: &Store,
588 ) -> String {
589 let model: AuthorizationModel = serde_json::from_str(include_str!(
590 "../tests/sample-store/custom-roles/schema.json"
591 ))
592 .unwrap();
593 client
594 .clone()
595 .write_authorization_model(model.into_write_request(store.id.clone()))
596 .await
597 .unwrap()
598 .into_inner()
599 .authorization_model_id
600 }
601
602 async fn get_client_with_custom_roles_model() -> OpenFgaClient<tonic::transport::Channel> {
603 let (service_client, store) = service_client_with_store().await;
604 let auth_model_id = write_custom_roles_model(&service_client, &store).await;
605 let client = OpenFgaClient::new(service_client, &store.id, auth_model_id.as_str());
606 client
607 }
608
609 #[tokio::test]
611 #[traced_test]
612 async fn test_read_all_pages_empty_tuple() {
613 let client = get_client_with_custom_roles_model().await;
614
615 let loop_count = 100;
616 let tuples_per_loop = 3;
617 for i in 0..loop_count {
618 client
621 .write(
622 vec![
623 TupleKey {
624 user: format!("user:user{i}"),
625 relation: "member".to_string(),
626 object: "team:team1".to_string(),
627 condition: None,
628 },
629 TupleKey {
630 user: format!("role:role{i}#assignee"),
631 relation: "role_assigner".to_string(),
632 object: "org:org1".to_string(),
633 condition: None,
634 },
635 TupleKey {
636 user: format!("org:org{i}"),
637 relation: "org".to_string(),
638 object: "asset-category:ac{i}".to_string(),
639 condition: None,
640 },
641 ],
642 None,
643 )
644 .await
645 .unwrap();
646 }
647
648 let tuples = client
649 .read_all_pages(None::<ReadRequestTupleKey>, 50, u32::MAX)
650 .await
651 .unwrap();
652 assert_eq!(tuples.len(), loop_count * tuples_per_loop);
653 }
654
655 #[tokio::test]
656 #[traced_test]
657 async fn test_delete_relations_to_object() {
658 let client = get_client_with_custom_roles_model().await;
659 let object = "team:team1";
660
661 assert!(!client.exists_relation_to(object).await.unwrap());
662
663 client
664 .write(
665 vec![TupleKey {
666 user: "user:user1".to_string(),
667 relation: "member".to_string(),
668 object: object.to_string(),
669 condition: None,
670 }],
671 None,
672 )
673 .await
674 .unwrap();
675 assert!(client.exists_relation_to(object).await.unwrap());
676 client.delete_relations_to_object(object).await.unwrap();
677 assert!(!client.exists_relation_to(object).await.unwrap());
678 }
679
680 #[tokio::test]
681 #[traced_test]
682 async fn test_delete_relations_to_object_usersets() {
683 let client = get_client_with_custom_roles_model().await;
684 let object: &str = "role:admin";
685
686 assert!(!client.exists_relation_to(object).await.unwrap());
687
688 client
689 .write(
690 vec![TupleKey {
691 user: "team:team1#member".to_string(),
692 relation: "assignee".to_string(),
693 object: object.to_string(),
694 condition: None,
695 }],
696 None,
697 )
698 .await
699 .unwrap();
700 assert!(client.exists_relation_to(object).await.unwrap());
701 client.delete_relations_to_object(object).await.unwrap();
702 assert!(!client.exists_relation_to(object).await.unwrap());
703 }
704
705 #[tokio::test]
706 #[traced_test]
707 async fn test_delete_relations_to_object_empty() {
708 let client = get_client_with_custom_roles_model().await;
709 let object = "team:team1";
710
711 assert!(!client.exists_relation_to(object).await.unwrap());
712 client.delete_relations_to_object(object).await.unwrap();
713 assert!(!client.exists_relation_to(object).await.unwrap());
714 }
715
716 #[tokio::test]
717 #[traced_test]
718 async fn test_delete_relations_to_object_many() {
719 let client = get_client_with_custom_roles_model().await;
720 let object = "org:org1";
721
722 assert!(!client.exists_relation_to(object).await.unwrap());
723
724 for i in 0..502 {
725 client
726 .write(
727 vec![
728 TupleKey {
729 user: format!("user:user{i}"),
730 relation: "member".to_string(),
731 object: object.to_string(),
732 condition: None,
733 },
734 TupleKey {
735 user: format!("role:role{i}#assignee"),
736 relation: "role_assigner".to_string(),
737 object: object.to_string(),
738 condition: None,
739 },
740 ],
741 None,
742 )
743 .await
744 .unwrap();
745 }
746
747 let object_2 = "org:org2";
749 client
750 .write(
751 vec![TupleKey {
752 user: "user:user1".to_string(),
753 relation: "owner".to_string(),
754 object: object_2.to_string(),
755 condition: None,
756 }],
757 None,
758 )
759 .await
760 .unwrap();
761
762 assert!(client.exists_relation_to(object).await.unwrap());
763 assert!(client.exists_relation_to(object_2).await.unwrap());
764
765 client.delete_relations_to_object(object).await.unwrap();
766
767 assert!(!client.exists_relation_to(object).await.unwrap());
768 assert!(client.exists_relation_to(object_2).await.unwrap());
769 assert!(client
770 .check_simple(TupleKeyWithoutCondition {
771 user: "user:user1".to_string(),
772 relation: "role_assigner".to_string(),
773 object: object_2.to_string(),
774 })
775 .await
776 .unwrap());
777 }
778 }
779}