1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4};
5
6use async_stream::stream;
7use futures::{pin_mut, StreamExt};
8use tonic::codegen::{Body, Bytes, StdError};
9
10use crate::{
11 client::{
12 batch_check_single_result::CheckResult, BatchCheckItem, BatchCheckRequest, CheckRequest,
13 CheckRequestTupleKey, ConsistencyPreference, ContextualTupleKeys, ExpandRequest,
14 ExpandRequestTupleKey, ListObjectsRequest, ListObjectsResponse, OpenFgaServiceClient,
15 ReadRequest, ReadRequestTupleKey, ReadResponse, Tuple, TupleKey, TupleKeyWithoutCondition,
16 UsersetTree, WriteRequest, WriteRequestDeletes, WriteRequestWrites,
17 },
18 error::{Error, Result},
19};
20
21const DEFAULT_MAX_TUPLES_PER_WRITE: i32 = 100;
22
23#[derive(Clone, Debug)]
24pub struct OpenFgaClient<T> {
54 client: OpenFgaServiceClient<T>,
55 inner: Arc<ModelClientInner>,
56}
57
58#[derive(Debug, Clone)]
59struct ModelClientInner {
60 store_id: String,
61 authorization_model_id: String,
62 max_tuples_per_write: i32,
63 consistency: ConsistencyPreference,
64}
65
66#[cfg(feature = "auth-middle")]
67pub type BasicOpenFgaClient = OpenFgaClient<crate::client::BasicAuthLayer>;
72
73impl<T> OpenFgaClient<T>
74where
75 T: tonic::client::GrpcService<tonic::body::BoxBody>,
76 T::Error: Into<StdError>,
77 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
78 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
79 T: Clone,
80{
81 #[must_use]
83 pub fn new(
84 client: OpenFgaServiceClient<T>,
85 store_id: &str,
86 authorization_model_id: &str,
87 ) -> Self {
88 OpenFgaClient {
89 client,
90 inner: Arc::new(ModelClientInner {
91 store_id: store_id.to_string(),
92 authorization_model_id: authorization_model_id.to_string(),
93 max_tuples_per_write: DEFAULT_MAX_TUPLES_PER_WRITE,
94 consistency: ConsistencyPreference::MinimizeLatency,
95 }),
96 }
97 }
98
99 #[must_use]
101 pub fn set_max_tuples_per_write(mut self, max_tuples_per_write: i32) -> Self {
102 let inner = Arc::unwrap_or_clone(self.inner);
103 self.inner = Arc::new(ModelClientInner {
104 store_id: inner.store_id,
105 authorization_model_id: inner.authorization_model_id,
106 max_tuples_per_write,
107 consistency: inner.consistency,
108 });
109 self
110 }
111
112 #[must_use]
114 pub fn set_consistency(mut self, consistency: impl Into<ConsistencyPreference>) -> Self {
115 let inner = Arc::unwrap_or_clone(self.inner);
116 self.inner = Arc::new(ModelClientInner {
117 store_id: inner.store_id,
118 authorization_model_id: inner.authorization_model_id,
119 max_tuples_per_write: inner.max_tuples_per_write,
120 consistency: consistency.into(),
121 });
122 self
123 }
124
125 pub fn store_id(&self) -> &str {
127 &self.inner.store_id
128 }
129
130 pub fn authorization_model_id(&self) -> &str {
132 &self.inner.authorization_model_id
133 }
134
135 pub fn max_tuples_per_write(&self) -> i32 {
137 self.inner.max_tuples_per_write
138 }
139
140 pub fn client(&self) -> OpenFgaServiceClient<T> {
142 self.client.clone()
143 }
144
145 pub fn consistency(&self) -> ConsistencyPreference {
147 self.inner.consistency
148 }
149
150 pub async fn write(
172 &self,
173 writes: impl Into<Option<Vec<TupleKey>>>,
174 deletes: impl Into<Option<Vec<TupleKeyWithoutCondition>>>,
175 ) -> Result<()> {
176 let writes = writes.into().and_then(|w| (!w.is_empty()).then_some(w));
177 let deletes = deletes.into().and_then(|d| (!d.is_empty()).then_some(d));
178
179 if writes.is_none() && deletes.is_none() {
180 return Ok(());
181 }
182
183 let num_writes_and_deletes = i32::try_from(
184 writes
185 .as_ref()
186 .map_or(0, Vec::len)
187 .checked_add(deletes.as_ref().map_or(0, Vec::len))
188 .unwrap_or(usize::MAX),
189 )
190 .unwrap_or(i32::MAX);
191
192 if num_writes_and_deletes > self.max_tuples_per_write() {
193 tracing::error!(
194 "Too many writes and deletes in single OpenFGA transaction (actual) {} > {} (max)",
195 num_writes_and_deletes,
196 self.max_tuples_per_write()
197 );
198 return Err(Error::TooManyWrites {
199 actual: num_writes_and_deletes,
200 max: self.max_tuples_per_write(),
201 });
202 }
203
204 let write_request = WriteRequest {
205 store_id: self.store_id().to_string(),
206 writes: writes.map(|writes| WriteRequestWrites { tuple_keys: writes }),
207 deletes: deletes.map(|deletes| WriteRequestDeletes {
208 tuple_keys: deletes,
209 }),
210 authorization_model_id: self.authorization_model_id().to_string(),
211 };
212
213 self.client
214 .clone()
215 .write(write_request.clone())
216 .await
217 .map_err(|e| {
218 let write_request_debug = format!("{write_request:?}");
219 tracing::error!(
220 "Write request failed with status {e}. Request: {write_request_debug}"
221 );
222 Error::RequestFailed(e)
223 })
224 .map(|_| ())
225 }
226
227 pub async fn read(
236 &self,
237 page_size: i32,
238 tuple_key: impl Into<ReadRequestTupleKey>,
239 continuation_token: impl Into<Option<String>>,
240 ) -> Result<tonic::Response<ReadResponse>> {
241 let read_request = ReadRequest {
242 store_id: self.store_id().to_string(),
243 page_size: Some(page_size),
244 continuation_token: continuation_token.into().unwrap_or_default(),
245 tuple_key: Some(tuple_key.into()),
246 consistency: self.consistency().into(),
247 };
248 self.client
249 .clone()
250 .read(read_request.clone())
251 .await
252 .map_err(|e| {
253 let read_request_debug = format!("{read_request:?}");
254 tracing::error!(
255 "Read request failed with status {e}. Request: {read_request_debug}"
256 );
257 Error::RequestFailed(e)
258 })
259 }
260
261 pub async fn read_all_pages(
269 &self,
270 tuple: impl Into<ReadRequestTupleKey>,
271 page_size: i32,
272 max_pages: u32,
273 ) -> Result<Vec<Tuple>> {
274 let store_id = self.store_id().to_string();
275 self.client
276 .clone()
277 .read_all_pages(&store_id, tuple, self.consistency(), page_size, max_pages)
278 .await
279 }
280
281 pub async fn check(
288 &self,
289 tuple_key: impl Into<CheckRequestTupleKey>,
290 contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
291 context: impl Into<Option<prost_wkt_types::Struct>>,
292 trace: bool,
293 ) -> Result<bool> {
294 let contextual_tuples = contextual_tuples
295 .into()
296 .and_then(|c| (!c.is_empty()).then_some(c))
297 .map(|tuple_keys| ContextualTupleKeys { tuple_keys });
298
299 let check_request = CheckRequest {
300 store_id: self.store_id().to_string(),
301 tuple_key: Some(tuple_key.into()),
302 consistency: self.consistency().into(),
303 contextual_tuples,
304 authorization_model_id: self.authorization_model_id().to_string(),
305 context: context.into(),
306 trace,
307 };
308 let response = self
309 .client
310 .clone()
311 .check(check_request.clone())
312 .await
313 .map_err(|e| {
314 let check_request_debug = format!("{check_request:?}");
315 tracing::error!(
316 "Check request failed with status {e}. Request: {check_request_debug}"
317 );
318 Error::RequestFailed(e)
319 })?;
320 Ok(response.get_ref().allowed)
321 }
322
323 pub async fn batch_check<I>(
329 &self,
330 checks: impl IntoIterator<Item = I>,
331 ) -> Result<HashMap<String, Option<CheckResult>>>
332 where
333 I: Into<BatchCheckItem>,
334 {
335 let checks: Vec<BatchCheckItem> = checks.into_iter().map(Into::into).collect();
336 let request = BatchCheckRequest {
337 store_id: self.store_id().to_string(),
338 checks,
339 authorization_model_id: self.authorization_model_id().to_string(),
340 consistency: self.consistency().into(),
341 };
342 let response = self
343 .client
344 .clone()
345 .batch_check(request.clone())
346 .await
347 .map_err(|e| {
348 let request_debug = format!("{request:?}");
349 tracing::error!(
350 "Batch-Check request failed with status {e}. Request: {request_debug}"
351 );
352 Error::RequestFailed(e)
353 })?;
354
355 Ok(response
356 .into_inner()
357 .result
358 .into_iter()
359 .map(|(k, v)| (k, v.check_result))
360 .collect())
361 }
362
363 pub async fn expand(
370 &self,
371 tuple_key: impl Into<ExpandRequestTupleKey>,
372 contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
373 ) -> Result<Option<UsersetTree>> {
374 let expand_request = ExpandRequest {
375 store_id: self.store_id().to_string(),
376 tuple_key: Some(tuple_key.into()),
377 authorization_model_id: self.authorization_model_id().to_string(),
378 consistency: self.consistency().into(),
379 contextual_tuples: contextual_tuples
380 .into()
381 .map(|tuple_keys| ContextualTupleKeys { tuple_keys }),
382 };
383 let response = self
384 .client
385 .clone()
386 .expand(expand_request.clone())
387 .await
388 .map_err(|e| {
389 tracing::error!(
390 "Expand request failed with status {e}. Request: {expand_request:?}"
391 );
392 Error::RequestFailed(e)
393 })?;
394 Ok(response.into_inner().tree)
395 }
396
397 pub async fn check_simple(&self, tuple_key: impl Into<CheckRequestTupleKey>) -> Result<bool> {
402 self.check(tuple_key, None, None, false).await
403 }
404
405 pub async fn list_objects(
410 &self,
411 r#type: impl Into<String>,
412 relation: impl Into<String>,
413 user: impl Into<String>,
414 contextual_tuples: impl Into<Option<Vec<TupleKey>>>,
415 context: impl Into<Option<prost_wkt_types::Struct>>,
416 ) -> Result<tonic::Response<ListObjectsResponse>> {
417 let request = ListObjectsRequest {
418 r#type: r#type.into(),
419 relation: relation.into(),
420 user: user.into(),
421 authorization_model_id: self.authorization_model_id().to_string(),
422 store_id: self.store_id().to_string(),
423 consistency: self.consistency().into(),
424 contextual_tuples: contextual_tuples
425 .into()
426 .map(|tuple_keys| ContextualTupleKeys { tuple_keys }),
427 context: context.into(),
428 };
429
430 self.client
431 .clone()
432 .list_objects(request.clone())
433 .await
434 .map_err(|e| {
435 tracing::error!(
436 "List-Objects request failed with status {e}. Request: {request:?}"
437 );
438 Error::RequestFailed(e)
439 })
440 }
441
442 pub async fn delete_relations_to_object(&self, object: &str) -> Result<()> {
454 loop {
455 self.delete_relations_to_object_inner(object)
456 .await
457 .inspect_err(|e| {
458 tracing::error!("Failed to delete relations to object {object}: {e}");
459 })?;
460
461 if self.exists_relation_to(object).await? {
462 tracing::debug!("Some tuples for object {object} are still present after first sweep. Performing another deletion.");
463 } else {
464 tracing::debug!("Successfully deleted all relations to object {object}");
465 break Ok(());
466 }
467 }
468 }
469
470 pub async fn exists_relation_to(&self, object: &str) -> Result<bool> {
476 let tuples = self.read_relations_to_object(object, None, 1).await?;
477 Ok(!tuples.tuples.is_empty())
478 }
479
480 async fn read_relations_to_object(
481 &self,
482 object: &str,
483 continuation_token: impl Into<Option<String>>,
484 page_size: i32,
485 ) -> Result<ReadResponse> {
486 self.read(
487 page_size,
488 TupleKeyWithoutCondition {
489 user: String::new(),
490 relation: String::new(),
491 object: object.to_string(),
492 },
493 continuation_token,
494 )
495 .await
496 .map(tonic::Response::into_inner)
497 }
498
499 async fn delete_relations_to_object_inner(&self, object: &str) -> Result<()> {
503 let read_stream = stream! {
504 let mut continuation_token = None;
505 let mut seen= HashSet::new();
508 while continuation_token != Some(String::new()) {
509 let response = self.read_relations_to_object(object, continuation_token, self.max_tuples_per_write()).await?;
510 let keys = response.tuples.into_iter().filter_map(|t| t.key).filter(|k| !seen.contains(&(k.user.clone(), k.relation.clone()))).collect::<Vec<_>>();
511 tracing::debug!("Read {} keys for object {object} that are up for deletion. Continuation token: {}", keys.len(), response.continuation_token);
512 continuation_token = Some(response.continuation_token);
513 seen.extend(keys.iter().map(|k| (k.user.clone(), k.relation.clone())));
514 yield Result::Ok(keys);
515 }
516 };
517 pin_mut!(read_stream);
518 let mut read_tuples: Option<Vec<TupleKey>> = None;
519
520 let delete_tuples = |t: Option<Vec<TupleKey>>| async {
521 match t {
522 Some(tuples) => {
523 tracing::debug!(
524 "Deleting {} tuples for object {object} that we haven't seen before.",
525 tuples.len()
526 );
527 self.write(
528 None,
529 Some(
530 tuples
531 .into_iter()
532 .map(|t| TupleKeyWithoutCondition {
533 user: t.user,
534 relation: t.relation,
535 object: t.object,
536 })
537 .collect(),
538 ),
539 )
540 .await
541 }
542 None => Ok(()),
543 }
544 };
545
546 loop {
547 let next_future = read_stream.next();
548 let deletion_future = delete_tuples(read_tuples.clone());
549
550 let (tuples, delete) = futures::join!(next_future, deletion_future);
551 delete?;
552
553 if let Some(tuples) = tuples.transpose()? {
554 read_tuples = (!tuples.is_empty()).then_some(tuples);
555 } else {
556 break Ok(());
557 }
558 }
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use needs_env_var::needs_env_var;
565
566 #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
567 mod openfga {
568 use tracing_test::traced_test;
569
570 use super::super::*;
571 use crate::{
572 client::{AuthorizationModel, Store},
573 migration::test::openfga::service_client_with_store,
574 };
575
576 async fn write_custom_roles_model(
577 client: &OpenFgaServiceClient<tonic::transport::Channel>,
578 store: &Store,
579 ) -> String {
580 let model: AuthorizationModel = serde_json::from_str(include_str!(
581 "../tests/sample-store/custom-roles/schema.json"
582 ))
583 .unwrap();
584 client
585 .clone()
586 .write_authorization_model(model.into_write_request(store.id.clone()))
587 .await
588 .unwrap()
589 .into_inner()
590 .authorization_model_id
591 }
592
593 async fn get_client_with_custom_roles_model() -> OpenFgaClient<tonic::transport::Channel> {
594 let (service_client, store) = service_client_with_store().await;
595 let auth_model_id = write_custom_roles_model(&service_client, &store).await;
596 let client = OpenFgaClient::new(service_client, &store.id, auth_model_id.as_str());
597 client
598 }
599
600 #[tokio::test]
601 #[traced_test]
602 async fn test_delete_relations_to_object() {
603 let client = get_client_with_custom_roles_model().await;
604 let object = "team:team1";
605
606 assert!(!client.exists_relation_to(object).await.unwrap());
607
608 client
609 .write(
610 vec![TupleKey {
611 user: "user:user1".to_string(),
612 relation: "member".to_string(),
613 object: object.to_string(),
614 condition: None,
615 }],
616 None,
617 )
618 .await
619 .unwrap();
620 assert!(client.exists_relation_to(object).await.unwrap());
621 client.delete_relations_to_object(object).await.unwrap();
622 assert!(!client.exists_relation_to(object).await.unwrap());
623 }
624
625 #[tokio::test]
626 #[traced_test]
627 async fn test_delete_relations_to_object_usersets() {
628 let client = get_client_with_custom_roles_model().await;
629 let object: &str = "role:admin";
630
631 assert!(!client.exists_relation_to(object).await.unwrap());
632
633 client
634 .write(
635 vec![TupleKey {
636 user: "team:team1#member".to_string(),
637 relation: "assignee".to_string(),
638 object: object.to_string(),
639 condition: None,
640 }],
641 None,
642 )
643 .await
644 .unwrap();
645 assert!(client.exists_relation_to(object).await.unwrap());
646 client.delete_relations_to_object(object).await.unwrap();
647 assert!(!client.exists_relation_to(object).await.unwrap());
648 }
649
650 #[tokio::test]
651 #[traced_test]
652 async fn test_delete_relations_to_object_empty() {
653 let client = get_client_with_custom_roles_model().await;
654 let object = "team:team1";
655
656 assert!(!client.exists_relation_to(object).await.unwrap());
657 client.delete_relations_to_object(object).await.unwrap();
658 assert!(!client.exists_relation_to(object).await.unwrap());
659 }
660
661 #[tokio::test]
662 #[traced_test]
663 async fn test_delete_relations_to_object_many() {
664 let client = get_client_with_custom_roles_model().await;
665 let object = "org:org1";
666
667 assert!(!client.exists_relation_to(object).await.unwrap());
668
669 for i in 0..502 {
670 client
671 .write(
672 vec![
673 TupleKey {
674 user: format!("user:user{i}"),
675 relation: "member".to_string(),
676 object: object.to_string(),
677 condition: None,
678 },
679 TupleKey {
680 user: format!("role:role{i}#assignee"),
681 relation: "role_assigner".to_string(),
682 object: object.to_string(),
683 condition: None,
684 },
685 ],
686 None,
687 )
688 .await
689 .unwrap();
690 }
691
692 let object_2 = "org:org2";
694 client
695 .write(
696 vec![TupleKey {
697 user: "user:user1".to_string(),
698 relation: "owner".to_string(),
699 object: object_2.to_string(),
700 condition: None,
701 }],
702 None,
703 )
704 .await
705 .unwrap();
706
707 assert!(client.exists_relation_to(object).await.unwrap());
708 assert!(client.exists_relation_to(object_2).await.unwrap());
709
710 client.delete_relations_to_object(object).await.unwrap();
711
712 assert!(!client.exists_relation_to(object).await.unwrap());
713 assert!(client.exists_relation_to(object_2).await.unwrap());
714 assert!(client
715 .check_simple(TupleKeyWithoutCondition {
716 user: "user:user1".to_string(),
717 relation: "role_assigner".to_string(),
718 object: object_2.to_string(),
719 })
720 .await
721 .unwrap());
722 }
723 }
724}