1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4};
5
6use async_stream::stream;
7use futures::{pin_mut, StreamExt};
8use tonic::codegen::{Body, Bytes, StdError};
9
10use crate::{
11 client::{
12 batch_check_single_result::CheckResult, BatchCheckItem, BatchCheckRequest, CheckRequest,
13 CheckRequestTupleKey, ConsistencyPreference, ContextualTupleKeys, ExpandRequest,
14 ExpandRequestTupleKey, ListObjectsRequest, ListObjectsResponse, OpenFgaServiceClient,
15 ReadRequest, ReadRequestTupleKey, ReadResponse, Tuple, TupleKey, TupleKeyWithoutCondition,
16 UsersetTree, WriteRequest, WriteRequestDeletes, WriteRequestWrites,
17 },
18 error::{Error, Result},
19};
20
21const DEFAULT_MAX_TUPLES_PER_WRITE: i32 = 100;
22
23#[derive(Clone, Debug)]
24pub struct OpenFgaClient<T> {
54 client: OpenFgaServiceClient<T>,
55 inner: Arc<ModelClientInner>,
56}
57
58#[derive(Debug, Clone)]
59struct ModelClientInner {
60 store_id: String,
61 authorization_model_id: String,
62 max_tuples_per_write: i32,
63 consistency: ConsistencyPreference,
64}
65
66#[cfg(feature = "auth-middle")]
67pub type BasicOpenFgaClient = OpenFgaClient<crate::client::BasicAuthLayer>;
72
73impl<T> OpenFgaClient<T>
74where
75 T: tonic::client::GrpcService<tonic::body::BoxBody>,
76 T::Error: Into<StdError>,
77 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
78 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
79 T: Clone,
80{
81 #[must_use]
83 pub fn new(
84 client: OpenFgaServiceClient<T>,
85 store_id: &str,
86 authorization_model_id: &str,
87 ) -> Self {
88 OpenFgaClient {
89 client,
90 inner: Arc::new(ModelClientInner {
91 store_id: store_id.to_string(),
92 authorization_model_id: authorization_model_id.to_string(),
93 max_tuples_per_write: DEFAULT_MAX_TUPLES_PER_WRITE,
94 consistency: ConsistencyPreference::MinimizeLatency,
95 }),
96 }
97 }
98
99 #[must_use]
101 pub fn set_max_tuples_per_write(mut self, max_tuples_per_write: i32) -> Self {
102 let inner = Arc::unwrap_or_clone(self.inner);
103 self.inner = Arc::new(ModelClientInner {
104 store_id: inner.store_id,
105 authorization_model_id: inner.authorization_model_id,
106 max_tuples_per_write,
107 consistency: inner.consistency,
108 });
109 self
110 }
111
112 #[must_use]
114 pub fn set_consistency(mut self, consistency: impl Into<ConsistencyPreference>) -> Self {
115 let inner = Arc::unwrap_or_clone(self.inner);
116 self.inner = Arc::new(ModelClientInner {
117 store_id: inner.store_id,
118 authorization_model_id: inner.authorization_model_id,
119 max_tuples_per_write: inner.max_tuples_per_write,
120 consistency: consistency.into(),
121 });
122 self
123 }
124
125 pub fn store_id(&self) -> &str {
127 &self.inner.store_id
128 }
129
130 pub fn authorization_model_id(&self) -> &str {
132 &self.inner.authorization_model_id
133 }
134
135 pub fn max_tuples_per_write(&self) -> i32 {
137 self.inner.max_tuples_per_write
138 }
139
140 pub fn client(&self) -> OpenFgaServiceClient<T> {
142 self.client.clone()
143 }
144
145 pub fn consistency(&self) -> ConsistencyPreference {
147 self.inner.consistency
148 }
149
150 pub async fn write(
172 &self,
173 writes: impl Into<Option<Vec<TupleKey>>>,
174 deletes: impl Into<Option<Vec<TupleKeyWithoutCondition>>>,
175 ) -> Result<()> {
176 let writes = writes.into().and_then(|w| (!w.is_empty()).then_some(w));
177 let deletes = deletes.into().and_then(|d| (!d.is_empty()).then_some(d));
178
179 if writes.is_none() && deletes.is_none() {
180 return Ok(());
181 }
182
183 let num_writes_and_deletes = i32::try_from(
184 writes
185 .as_ref()
186 .map_or(0, Vec::len)
187 .checked_add(deletes.as_ref().map_or(0, Vec::len))
188 .unwrap_or(usize::MAX),
189 )
190 .unwrap_or(i32::MAX);
191
192 if num_writes_and_deletes > self.max_tuples_per_write() {
193 tracing::error!(
194 "Too many writes and deletes in single OpenFGA transaction (actual) {} > {} (max)",
195 num_writes_and_deletes,
196 self.max_tuples_per_write()
197 );
198 return Err(Error::TooManyWrites {
199 actual: num_writes_and_deletes,
200 max: self.max_tuples_per_write(),
201 });
202 }
203
204 let write_request = WriteRequest {
205 store_id: self.store_id().to_string(),
206 writes: writes.map(|writes| WriteRequestWrites { tuple_keys: writes }),
207 deletes: deletes.map(|deletes| WriteRequestDeletes {
208 tuple_keys: deletes,
209 }),
210 authorization_model_id: self.authorization_model_id().to_string(),
211 };
212
213 self.client
214 .clone()
215 .write(write_request.clone())
216 .await
217 .map_err(|e| {
218 let write_request_debug = format!("{write_request:?}");
219 tracing::error!(
220 "Write request failed with status {e}. Request: {write_request_debug}"
221 );
222 Error::RequestFailed(e)
223 })
224 .map(|_| ())
225 }
226
227 pub async fn read(
236 &self,
237 page_size: i32,
238 tuple_key: impl Into<ReadRequestTupleKey>,
239 continuation_token: impl Into<Option<String>>,
240 ) -> Result<tonic::Response<ReadResponse>> {
241 let read_request = ReadRequest {
242 store_id: self.store_id().to_string(),
243 page_size: Some(page_size),
244 continuation_token: continuation_token.into().unwrap_or_default(),
245 tuple_key: Some(tuple_key.into()),
246 consistency: self.consistency().into(),
247 };
248 self.client
249 .clone()
250 .read(read_request.clone())
251 .await
252 .map_err(|e| {
253 let read_request_debug = format!("{read_request:?}");
254 tracing::error!(
255 "Read request failed with status {e}. Request: {read_request_debug}"
256 );
257 Error::RequestFailed(e)
258 })
259 }
260
261 pub async fn read_all_pages(
269 &self,
270 tuple: impl Into<ReadRequestTupleKey>,
271 page_size: i32,
272 max_pages: u32,
273 ) -> Result<Vec<Tuple>> {
274 let store_id = self.store_id().to_string();
275 self.client
276 .clone()
277 .read_all_pages(&store_id, tuple, self.consistency(), page_size, max_pages)
278 .await
279 }
280
281 pub async fn check(
288 &self,
289 tuple_key: impl Into<CheckRequestTupleKey>,
290 contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
291 context: impl Into<Option<prost_wkt_types::Struct>>,
292 trace: bool,
293 ) -> Result<bool> {
294 let contextual_tuples = contextual_tuples
295 .into()
296 .and_then(|c| (!c.is_empty()).then_some(c))
297 .map(|tuple_keys| ContextualTupleKeys { tuple_keys });
298
299 let check_request = CheckRequest {
300 store_id: self.store_id().to_string(),
301 tuple_key: Some(tuple_key.into()),
302 consistency: self.consistency().into(),
303 contextual_tuples,
304 authorization_model_id: self.authorization_model_id().to_string(),
305 context: context.into(),
306 trace,
307 };
308 let response = self
309 .client
310 .clone()
311 .check(check_request.clone())
312 .await
313 .map_err(|e| {
314 let check_request_debug = format!("{check_request:?}");
315 tracing::error!(
316 "Check request failed with status {e}. Request: {check_request_debug}"
317 );
318 Error::RequestFailed(e)
319 })?;
320 Ok(response.get_ref().allowed)
321 }
322
323 pub async fn batch_check<I>(
331 &self,
332 checks: impl IntoIterator<Item = I>,
333 ) -> Result<HashMap<String, CheckResult>>
334 where
335 I: Into<BatchCheckItem>,
336 {
337 let checks: Vec<BatchCheckItem> = checks.into_iter().map(Into::into).collect();
338 let request = BatchCheckRequest {
339 store_id: self.store_id().to_string(),
340 checks,
341 authorization_model_id: self.authorization_model_id().to_string(),
342 consistency: self.consistency().into(),
343 };
344
345 let response = self
346 .client
347 .clone()
348 .batch_check(request.clone())
349 .await
350 .map_err(|e| {
351 let request_debug = format!("{request:?}");
352 tracing::error!(
353 "Batch-Check request failed with status {e}. Request: {request_debug}"
354 );
355 Error::RequestFailed(e)
356 })?;
357
358 let mut map = HashMap::new();
359 for (k, v) in response.into_inner().result {
360 match v.check_result {
361 Some(v) => map.insert(k, v),
365 None => return Err(Error::ExpectedOneof),
366 };
367 }
368 Ok(map)
369 }
370
371 pub async fn expand(
378 &self,
379 tuple_key: impl Into<ExpandRequestTupleKey>,
380 contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
381 ) -> Result<Option<UsersetTree>> {
382 let expand_request = ExpandRequest {
383 store_id: self.store_id().to_string(),
384 tuple_key: Some(tuple_key.into()),
385 authorization_model_id: self.authorization_model_id().to_string(),
386 consistency: self.consistency().into(),
387 contextual_tuples: contextual_tuples
388 .into()
389 .map(|tuple_keys| ContextualTupleKeys { tuple_keys }),
390 };
391 let response = self
392 .client
393 .clone()
394 .expand(expand_request.clone())
395 .await
396 .map_err(|e| {
397 tracing::error!(
398 "Expand request failed with status {e}. Request: {expand_request:?}"
399 );
400 Error::RequestFailed(e)
401 })?;
402 Ok(response.into_inner().tree)
403 }
404
405 pub async fn check_simple(&self, tuple_key: impl Into<CheckRequestTupleKey>) -> Result<bool> {
410 self.check(tuple_key, None, None, false).await
411 }
412
413 pub async fn list_objects(
418 &self,
419 r#type: impl Into<String>,
420 relation: impl Into<String>,
421 user: impl Into<String>,
422 contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
423 context: impl Into<Option<prost_wkt_types::Struct>>,
424 ) -> Result<tonic::Response<ListObjectsResponse>> {
425 let request = ListObjectsRequest {
426 r#type: r#type.into(),
427 relation: relation.into(),
428 user: user.into(),
429 authorization_model_id: self.authorization_model_id().to_string(),
430 store_id: self.store_id().to_string(),
431 consistency: self.consistency().into(),
432 contextual_tuples: contextual_tuples
433 .into()
434 .map(|tuple_keys| ContextualTupleKeys { tuple_keys }),
435 context: context.into(),
436 };
437
438 self.client
439 .clone()
440 .list_objects(request.clone())
441 .await
442 .map_err(|e| {
443 tracing::error!(
444 "List-Objects request failed with status {e}. Request: {request:?}"
445 );
446 Error::RequestFailed(e)
447 })
448 }
449
450 pub async fn delete_relations_to_object(&self, object: &str) -> Result<()> {
462 loop {
463 self.delete_relations_to_object_inner(object)
464 .await
465 .inspect_err(|e| {
466 tracing::error!("Failed to delete relations to object {object}: {e}");
467 })?;
468
469 if self.exists_relation_to(object).await? {
470 tracing::debug!("Some tuples for object {object} are still present after first sweep. Performing another deletion.");
471 } else {
472 tracing::debug!("Successfully deleted all relations to object {object}");
473 break Ok(());
474 }
475 }
476 }
477
478 pub async fn exists_relation_to(&self, object: &str) -> Result<bool> {
484 let tuples = self.read_relations_to_object(object, None, 1).await?;
485 Ok(!tuples.tuples.is_empty())
486 }
487
488 async fn read_relations_to_object(
489 &self,
490 object: &str,
491 continuation_token: impl Into<Option<String>>,
492 page_size: i32,
493 ) -> Result<ReadResponse> {
494 self.read(
495 page_size,
496 TupleKeyWithoutCondition {
497 user: String::new(),
498 relation: String::new(),
499 object: object.to_string(),
500 },
501 continuation_token,
502 )
503 .await
504 .map(tonic::Response::into_inner)
505 }
506
507 async fn delete_relations_to_object_inner(&self, object: &str) -> Result<()> {
511 let read_stream = stream! {
512 let mut continuation_token = None;
513 let mut seen= HashSet::new();
516 while continuation_token != Some(String::new()) {
517 let response = self.read_relations_to_object(object, continuation_token, self.max_tuples_per_write()).await?;
518 let keys = response.tuples.into_iter().filter_map(|t| t.key).filter(|k| !seen.contains(&(k.user.clone(), k.relation.clone()))).collect::<Vec<_>>();
519 tracing::debug!("Read {} keys for object {object} that are up for deletion. Continuation token: {}", keys.len(), response.continuation_token);
520 continuation_token = Some(response.continuation_token);
521 seen.extend(keys.iter().map(|k| (k.user.clone(), k.relation.clone())));
522 yield Result::Ok(keys);
523 }
524 };
525 pin_mut!(read_stream);
526 let mut read_tuples: Option<Vec<TupleKey>> = None;
527
528 let delete_tuples = |t: Option<Vec<TupleKey>>| async {
529 match t {
530 Some(tuples) => {
531 tracing::debug!(
532 "Deleting {} tuples for object {object} that we haven't seen before.",
533 tuples.len()
534 );
535 self.write(
536 None,
537 Some(
538 tuples
539 .into_iter()
540 .map(|t| TupleKeyWithoutCondition {
541 user: t.user,
542 relation: t.relation,
543 object: t.object,
544 })
545 .collect(),
546 ),
547 )
548 .await
549 }
550 None => Ok(()),
551 }
552 };
553
554 loop {
555 let next_future = read_stream.next();
556 let deletion_future = delete_tuples(read_tuples.clone());
557
558 let (tuples, delete) = futures::join!(next_future, deletion_future);
559 delete?;
560
561 if let Some(tuples) = tuples.transpose()? {
562 read_tuples = (!tuples.is_empty()).then_some(tuples);
563 } else {
564 break Ok(());
565 }
566 }
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use needs_env_var::needs_env_var;
573
574 #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
575 mod openfga {
576 use tracing_test::traced_test;
577
578 use super::super::*;
579 use crate::{
580 client::{AuthorizationModel, Store},
581 migration::test::openfga::service_client_with_store,
582 };
583
584 async fn write_custom_roles_model(
585 client: &OpenFgaServiceClient<tonic::transport::Channel>,
586 store: &Store,
587 ) -> String {
588 let model: AuthorizationModel = serde_json::from_str(include_str!(
589 "../tests/sample-store/custom-roles/schema.json"
590 ))
591 .unwrap();
592 client
593 .clone()
594 .write_authorization_model(model.into_write_request(store.id.clone()))
595 .await
596 .unwrap()
597 .into_inner()
598 .authorization_model_id
599 }
600
601 async fn get_client_with_custom_roles_model() -> OpenFgaClient<tonic::transport::Channel> {
602 let (service_client, store) = service_client_with_store().await;
603 let auth_model_id = write_custom_roles_model(&service_client, &store).await;
604 let client = OpenFgaClient::new(service_client, &store.id, auth_model_id.as_str());
605 client
606 }
607
608 #[tokio::test]
609 #[traced_test]
610 async fn test_delete_relations_to_object() {
611 let client = get_client_with_custom_roles_model().await;
612 let object = "team:team1";
613
614 assert!(!client.exists_relation_to(object).await.unwrap());
615
616 client
617 .write(
618 vec![TupleKey {
619 user: "user:user1".to_string(),
620 relation: "member".to_string(),
621 object: object.to_string(),
622 condition: None,
623 }],
624 None,
625 )
626 .await
627 .unwrap();
628 assert!(client.exists_relation_to(object).await.unwrap());
629 client.delete_relations_to_object(object).await.unwrap();
630 assert!(!client.exists_relation_to(object).await.unwrap());
631 }
632
633 #[tokio::test]
634 #[traced_test]
635 async fn test_delete_relations_to_object_usersets() {
636 let client = get_client_with_custom_roles_model().await;
637 let object: &str = "role:admin";
638
639 assert!(!client.exists_relation_to(object).await.unwrap());
640
641 client
642 .write(
643 vec![TupleKey {
644 user: "team:team1#member".to_string(),
645 relation: "assignee".to_string(),
646 object: object.to_string(),
647 condition: None,
648 }],
649 None,
650 )
651 .await
652 .unwrap();
653 assert!(client.exists_relation_to(object).await.unwrap());
654 client.delete_relations_to_object(object).await.unwrap();
655 assert!(!client.exists_relation_to(object).await.unwrap());
656 }
657
658 #[tokio::test]
659 #[traced_test]
660 async fn test_delete_relations_to_object_empty() {
661 let client = get_client_with_custom_roles_model().await;
662 let object = "team:team1";
663
664 assert!(!client.exists_relation_to(object).await.unwrap());
665 client.delete_relations_to_object(object).await.unwrap();
666 assert!(!client.exists_relation_to(object).await.unwrap());
667 }
668
669 #[tokio::test]
670 #[traced_test]
671 async fn test_delete_relations_to_object_many() {
672 let client = get_client_with_custom_roles_model().await;
673 let object = "org:org1";
674
675 assert!(!client.exists_relation_to(object).await.unwrap());
676
677 for i in 0..502 {
678 client
679 .write(
680 vec![
681 TupleKey {
682 user: format!("user:user{i}"),
683 relation: "member".to_string(),
684 object: object.to_string(),
685 condition: None,
686 },
687 TupleKey {
688 user: format!("role:role{i}#assignee"),
689 relation: "role_assigner".to_string(),
690 object: object.to_string(),
691 condition: None,
692 },
693 ],
694 None,
695 )
696 .await
697 .unwrap();
698 }
699
700 let object_2 = "org:org2";
702 client
703 .write(
704 vec![TupleKey {
705 user: "user:user1".to_string(),
706 relation: "owner".to_string(),
707 object: object_2.to_string(),
708 condition: None,
709 }],
710 None,
711 )
712 .await
713 .unwrap();
714
715 assert!(client.exists_relation_to(object).await.unwrap());
716 assert!(client.exists_relation_to(object_2).await.unwrap());
717
718 client.delete_relations_to_object(object).await.unwrap();
719
720 assert!(!client.exists_relation_to(object).await.unwrap());
721 assert!(client.exists_relation_to(object_2).await.unwrap());
722 assert!(client
723 .check_simple(TupleKeyWithoutCondition {
724 user: "user:user1".to_string(),
725 relation: "role_assigner".to_string(),
726 object: object_2.to_string(),
727 })
728 .await
729 .unwrap());
730 }
731 }
732}