1#![allow(unused_imports)]
2
3#[cfg(feature = "auth-middle")]
4use tonic::service::interceptor::InterceptedService;
5use tonic::{
6 codegen::{Body, Bytes, StdError},
7 service::interceptor::InterceptorLayer,
8 transport::{Channel, Endpoint},
9};
10#[cfg(feature = "auth-middle")]
11use tower::{ServiceBuilder, util::Either};
12
13use crate::{
14 client::{OpenFgaClient, OpenFgaServiceClient},
15 error::{Error, Result},
16 generated::{
17 ConsistencyPreference, CreateStoreRequest, ListStoresRequest, ReadRequest,
18 ReadRequestTupleKey, Store, Tuple,
19 },
20};
21
22#[cfg(feature = "auth-middle")]
23pub type BasicOpenFgaServiceClient = OpenFgaServiceClient<BasicAuthLayer>;
28
29#[cfg(feature = "auth-middle")]
30impl BasicOpenFgaServiceClient {
31 pub fn new_unauthenticated(endpoint: impl Into<url::Url>) -> Result<Self> {
36 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
37 let channel = endpoint.connect_lazy();
38 let intercepted = InterceptedService::new(channel, NoOpInterceptor);
39 let service = Either::Right(intercepted);
40 Ok(BasicOpenFgaServiceClient::new(service))
41 }
42
43 pub fn new_with_basic_auth(endpoint: impl Into<url::Url>, token: &str) -> Result<Self> {
49 let authorizer = middle::BearerTokenAuthorizer::new(token).map_err(|e| {
50 tracing::error!("Could not construct OpenFGA client. Invalid token: {e}");
51 Error::InvalidToken {
52 reason: e.to_string(),
53 }
54 })?;
55 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
56 let channel = endpoint.connect_lazy();
57 let intercepted = InterceptedService::new(channel, authorizer);
58 let service = Either::Left(Either::Right(intercepted));
59 Ok(BasicOpenFgaServiceClient::new(service))
60 }
61
62 pub async fn new_with_client_credentials(
68 endpoint: impl Into<url::Url>,
69 client_id: &str,
70 client_secret: &str,
71 token_endpoint: impl Into<url::Url>,
72 scopes: &[&str],
73 ) -> Result<Self> {
74 let builder = middle::BasicClientCredentialAuthorizer::basic_builder(
75 client_id,
76 client_secret,
77 token_endpoint.into(),
78 );
79 let authorizer = if scopes.is_empty() {
80 builder
81 } else {
82 builder.add_scopes(scopes)
83 }
84 .build()
85 .await
86 .map_err(|e| {
87 tracing::error!("Could not construct OpenFGA client. Failed to fetch or refresh Client Credentials: {e}");
88 Error::CredentialRefreshError(e)
89 })?;
90 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
91 let channel = endpoint.connect_lazy();
92 let intercepted = InterceptedService::new(channel, authorizer);
93 let service = Either::Left(Either::Left(intercepted));
94 Ok(BasicOpenFgaServiceClient::new(service))
95 }
96}
97
98impl<T> OpenFgaServiceClient<T>
99where
100 T: tonic::client::GrpcService<tonic::body::Body>,
101 T::Error: Into<StdError>,
102 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
103 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
104 T: Clone,
105{
106 pub fn into_client(self, store_id: &str, authorization_model_id: &str) -> OpenFgaClient<T> {
108 OpenFgaClient::new(self, store_id, authorization_model_id)
109 }
110
111 pub async fn get_store_by_name(&mut self, store_name: &str) -> Result<Option<Store>> {
118 let stores = self
119 .list_stores(ListStoresRequest {
120 page_size: Some(2),
121 continuation_token: String::new(),
122 name: store_name.to_string(),
123 })
124 .await
125 .map_err(|e| {
126 tracing::error!("Failed to list stores in OpenFGA: {e}");
127 Error::RequestFailed(Box::new(e))
128 })?
129 .into_inner();
130 let num_stores = stores.stores.len();
131
132 match stores.stores.first() {
133 Some(store) => {
134 if num_stores > 1 {
135 tracing::error!("Multiple stores with the name `{}` found", store_name);
136 Err(Error::AmbiguousStoreName(store_name.to_string()))
137 } else {
138 Ok(Some(store.clone()))
139 }
140 }
141 None => Ok(None),
142 }
143 }
144
145 pub async fn get_or_create_store(&mut self, store_name: &str) -> Result<Store> {
152 let store = self.get_store_by_name(store_name).await?;
153 match store {
154 None => {
155 tracing::debug!("OpenFGA Store {} not found. Creating it.", store_name);
156 let store = self
157 .create_store(CreateStoreRequest {
158 name: store_name.to_owned(),
159 })
160 .await
161 .map_err(|e| {
162 tracing::error!("Failed to create store in OpenFGA: {e}");
163 Error::RequestFailed(Box::new(e))
164 })?
165 .into_inner();
166 Ok(Store {
167 id: store.id,
168 name: store.name,
169 created_at: store.created_at,
170 updated_at: store.updated_at,
171 deleted_at: None,
172 })
173 }
174 Some(store) => Ok(store),
175 }
176 }
177
178 pub async fn read_all_pages(
197 &mut self,
198 store_id: &str,
199 tuple: Option<impl Into<ReadRequestTupleKey>>,
200 consistency: impl Into<ConsistencyPreference>,
201 page_size: i32,
202 max_pages: u32,
203 ) -> Result<Vec<Tuple>> {
204 let mut continuation_token = String::new();
205 let tuple = tuple.map(Into::into);
206 let mut tuples = Vec::new();
207 let mut count = 0;
208 let consistency = consistency.into();
209
210 loop {
211 let read_request = ReadRequest {
212 store_id: store_id.to_owned(),
213 tuple_key: tuple.clone(),
214 page_size: Some(page_size),
215 continuation_token: continuation_token.clone(),
216 consistency: consistency.into(),
217 };
218 let response = self
219 .read(read_request.clone())
220 .await
221 .map_err(|e| {
222 tracing::error!(
223 "Failed to read from OpenFGA: {e}. Request: {:?}",
224 read_request
225 );
226 Error::RequestFailed(Box::new(e))
227 })?
228 .into_inner();
229 tuples.extend(response.tuples);
230 continuation_token.clone_from(&response.continuation_token);
231 count += 1;
232 if count > max_pages {
233 return Err(Error::TooManyPages { max_pages, tuple });
234 }
235 if continuation_token.is_empty() {
236 break;
237 }
238 }
239
240 Ok(tuples)
241 }
242}
243
244#[cfg(feature = "auth-middle")]
245pub type BasicAuthLayer = tower::util::Either<
246 tower::util::Either<
247 InterceptedService<Channel, middle::BasicClientCredentialAuthorizer>,
248 InterceptedService<Channel, middle::BearerTokenAuthorizer>,
249 >,
250 InterceptedService<Channel, NoOpInterceptor>,
251>;
252
253#[cfg(feature = "auth-middle")]
254#[derive(Clone, Copy, Debug)]
255pub struct NoOpInterceptor;
256
257#[cfg(feature = "auth-middle")]
258impl tonic::service::Interceptor for NoOpInterceptor {
259 fn call(
260 &mut self,
261 request: tonic::Request<()>,
262 ) -> std::result::Result<tonic::Request<()>, tonic::Status> {
263 Ok(request)
264 }
265}
266
267#[cfg(feature = "auth-middle")]
268fn get_tonic_endpoint_logged(endpoint: &url::Url) -> Result<Endpoint> {
269 let ep = Endpoint::new(endpoint.to_string()).map_err(|e| {
270 tracing::error!("Could not construct OpenFGA client. Invalid endpoint `{endpoint}`: {e}");
271 Error::InvalidEndpoint(endpoint.to_string())
272 })?;
273
274 if endpoint.scheme() == "https" {
276 #[cfg(feature = "tls-rustls")]
277 {
278 use tonic::transport::ClientTlsConfig;
279 let tls_config = ClientTlsConfig::new().with_enabled_roots();
280 return ep.tls_config(tls_config).map_err(|e| {
281 tracing::error!(
282 "Could not configure TLS for OpenFGA client endpoint `{endpoint}`: {e}"
283 );
284 Error::TlsConfigurationFailed {
285 endpoint: endpoint.to_string(),
286 reason: e.to_string(),
287 }
288 });
289 }
290 #[cfg(not(feature = "tls-rustls"))]
291 {
292 return Err(Error::TlsConfigurationFailed {
293 endpoint: endpoint.to_string(),
294 reason: "HTTPS endpoint requires the `tls-rustls` feature to be enabled"
295 .to_string(),
296 });
297 }
298 }
299
300 Ok(ep)
301}
302
303#[cfg(test)]
304pub(crate) mod test {
305 use needs_env_var::needs_env_var;
306
307 #[cfg(feature = "auth-middle")]
309 mod openfga {
310 use std::collections::{HashMap, HashSet};
311
312 use super::super::*;
313 use crate::{
314 client::{
315 TupleKey, WriteAuthorizationModelRequest, WriteAuthorizationModelResponse,
316 WriteRequest, WriteRequestWrites,
317 },
318 generated::AuthorizationModel,
319 };
320
321 fn get_basic_client() -> BasicOpenFgaServiceClient {
322 let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
323 BasicOpenFgaServiceClient::new_unauthenticated(url::Url::parse(&endpoint).unwrap())
324 .expect("Client can be created")
325 }
326
327 async fn new_store() -> (BasicOpenFgaServiceClient, Store) {
328 let mut client = get_basic_client();
329 let store_name = format!("store-{}", uuid::Uuid::now_v7());
330 let store = client
331 .get_or_create_store(&store_name)
332 .await
333 .expect("Store can be created");
334 (client, store)
335 }
336
337 async fn create_entitlements_model(
338 client: &mut BasicOpenFgaServiceClient,
339 store: &Store,
340 ) -> WriteAuthorizationModelResponse {
341 let schema = include_str!("../tests/sample-store/entitlements/schema.json");
342 let model: AuthorizationModel =
343 serde_json::from_str(schema).expect("Schema can be deserialized");
344 let auth_model = client
345 .write_authorization_model(WriteAuthorizationModelRequest {
346 store_id: store.id.clone(),
347 type_definitions: model.type_definitions,
348 schema_version: model.schema_version,
349 conditions: model.conditions,
350 })
351 .await
352 .expect("Auth model can be written");
353
354 auth_model.into_inner()
355 }
356
357 #[tokio::test]
358 async fn test_get_store_by_name_many() {
359 let mut client = get_basic_client();
360
361 let mut stores = HashMap::new();
362 for _i in 0..201 {
363 let store_name = format!("store-{}", uuid::Uuid::now_v7());
364 let r = client
365 .get_or_create_store(&store_name)
366 .await
367 .expect("Store can be created");
368 assert_eq!(store_name, r.name);
369 stores.insert(store_name, r.id);
370 }
371
372 for (store_name, store_id) in stores {
373 let store = client
374 .get_store_by_name(&store_name)
375 .await
376 .expect("Store can be fetched")
377 .expect("Store exists");
378 assert_eq!(store_id, store.id);
379 }
380 }
381
382 #[tokio::test]
383 async fn test_get_store_by_name_non_existant() {
384 let mut client = get_basic_client();
385 let store = client
386 .get_store_by_name("non-existent-store")
387 .await
388 .unwrap();
389 assert!(store.is_none());
390 }
391
392 #[tokio::test]
393 async fn test_get_or_create_store() {
394 let mut client = get_basic_client();
395 let store_name = format!("store-{}", uuid::Uuid::now_v7());
396 let store = client.get_or_create_store(&store_name).await.unwrap();
397 let store2 = client.get_or_create_store(&store_name).await.unwrap();
398 assert_eq!(store.id, store2.id);
399 }
400
401 #[tokio::test]
402 async fn test_read_all_pages_many() {
403 let (mut client, store) = new_store().await;
404 let auth_model = create_entitlements_model(&mut client, &store).await;
405 let object = "organization:org-1";
406
407 let users = (0..501)
408 .map(|i| format!("user:u-{i}"))
409 .collect::<Vec<String>>();
410
411 for user in &users {
412 client
413 .write(WriteRequest {
414 authorization_model_id: auth_model.authorization_model_id.clone(),
415 store_id: store.id.clone(),
416 writes: Some(WriteRequestWrites {
417 on_duplicate: String::new(),
418 tuple_keys: vec![TupleKey {
419 user: user.clone(),
420 relation: "member".to_string(),
421 object: object.to_string(),
422 condition: None,
423 }],
424 }),
425 deletes: None,
426 })
427 .await
428 .expect("Write can be done");
429 }
430
431 let tuples = client
432 .read_all_pages(
433 &store.id,
434 Some(ReadRequestTupleKey {
435 user: String::new(),
436 relation: "member".to_string(),
437 object: object.to_string(),
438 }),
439 ConsistencyPreference::HigherConsistency,
440 100,
441 6,
442 )
443 .await
444 .expect("Read can be done");
445
446 assert_eq!(tuples.len(), 501);
447 assert_eq!(
448 tuples
449 .iter()
450 .map(|t| t.key.clone().unwrap().user)
451 .collect::<HashSet<String>>(),
452 HashSet::from_iter(users)
453 );
454 }
455
456 #[tokio::test]
457 async fn test_real_all_pages_empty() {
458 let (mut client, store) = new_store().await;
459 let tuples = client
460 .read_all_pages(
461 &store.id,
462 Some(ReadRequestTupleKey {
463 user: String::new(),
464 relation: "member".to_string(),
465 object: "organization:org-1".to_string(),
466 }),
467 ConsistencyPreference::HigherConsistency,
468 100,
469 5,
470 )
471 .await
472 .expect("Read can be done");
473
474 assert!(tuples.is_empty());
475 }
476
477 #[tokio::test]
484 async fn test_read_all_pages_unfiltered() {
485 let (mut client, store) = new_store().await;
486 let auth_model = create_entitlements_model(&mut client, &store).await;
487
488 let total = 250;
492 for i in 0..total {
493 client
494 .write(WriteRequest {
495 authorization_model_id: auth_model.authorization_model_id.clone(),
496 store_id: store.id.clone(),
497 writes: Some(WriteRequestWrites {
498 on_duplicate: String::new(),
499 tuple_keys: vec![TupleKey {
500 user: format!("user:u-{i}"),
501 relation: "member".to_string(),
502 object: format!("organization:org-{}", i % 5),
503 condition: None,
504 }],
505 }),
506 deletes: None,
507 })
508 .await
509 .expect("write can be done");
510 }
511
512 let tuples = client
513 .read_all_pages(
514 &store.id,
515 None::<ReadRequestTupleKey>,
516 ConsistencyPreference::HigherConsistency,
517 100,
518 10,
519 )
520 .await
521 .expect("unfiltered read_all_pages must succeed");
522
523 assert_eq!(
524 tuples.len(),
525 total,
526 "unfiltered read_all_pages must return every tuple in the store"
527 );
528 }
529
530 #[tokio::test]
542 async fn test_read_all_pages_max_pages_enforced() {
543 let (mut client, store) = new_store().await;
544 let auth_model = create_entitlements_model(&mut client, &store).await;
545
546 for i in 0..250 {
548 client
549 .write(WriteRequest {
550 authorization_model_id: auth_model.authorization_model_id.clone(),
551 store_id: store.id.clone(),
552 writes: Some(WriteRequestWrites {
553 on_duplicate: String::new(),
554 tuple_keys: vec![TupleKey {
555 user: format!("user:u-{i}"),
556 relation: "member".to_string(),
557 object: "organization:org-1".to_string(),
558 condition: None,
559 }],
560 }),
561 deletes: None,
562 })
563 .await
564 .expect("write can be done");
565 }
566
567 let result = client
568 .read_all_pages(
569 &store.id,
570 None::<ReadRequestTupleKey>,
571 ConsistencyPreference::HigherConsistency,
572 100,
573 1, )
575 .await;
576
577 match result {
578 Err(Error::TooManyPages { max_pages, .. }) => {
579 assert_eq!(max_pages, 1);
580 }
581 Err(other) => panic!("expected TooManyPages, got {other:?}"),
582 Ok(tuples) => panic!(
583 "expected TooManyPages error, got Ok with {} tuples",
584 tuples.len()
585 ),
586 }
587 }
588 }
589}