Skip to main content

openfga_client/
client_ext.rs

1#![allow(unused_imports)]
2
3#[cfg(feature = "auth-middle")]
4use tonic::service::interceptor::InterceptedService;
5use tonic::{
6    codegen::{Body, Bytes, StdError},
7    service::interceptor::InterceptorLayer,
8    transport::{Channel, Endpoint},
9};
10#[cfg(feature = "auth-middle")]
11use tower::{ServiceBuilder, util::Either};
12
13use crate::{
14    client::{OpenFgaClient, OpenFgaServiceClient},
15    error::{Error, Result},
16    generated::{
17        ConsistencyPreference, CreateStoreRequest, ListStoresRequest, ReadRequest,
18        ReadRequestTupleKey, Store, Tuple,
19    },
20};
21
22#[cfg(feature = "auth-middle")]
23/// Specialization of the [`OpenFgaServiceClient`] that includes optional
24/// authentication with pre-shared keys (Bearer tokens) or client credentials.
25/// For more fine-granular control, you can construct [`OpenFgaServiceClient`] directly
26/// using interceptors for Authentication.
27pub type BasicOpenFgaServiceClient = OpenFgaServiceClient<BasicAuthLayer>;
28
29#[cfg(feature = "auth-middle")]
30impl BasicOpenFgaServiceClient {
31    /// Create a new client without authentication.
32    ///
33    /// # Errors
34    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
35    pub fn new_unauthenticated(endpoint: impl Into<url::Url>) -> Result<Self> {
36        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
37        let channel = endpoint.connect_lazy();
38        let intercepted = InterceptedService::new(channel, NoOpInterceptor);
39        let service = Either::Right(intercepted);
40        Ok(BasicOpenFgaServiceClient::new(service))
41    }
42
43    /// Create a new client without authentication.
44    ///
45    /// # Errors
46    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
47    /// * [`Error::InvalidToken`] if the token is not valid ASCII.
48    pub fn new_with_basic_auth(endpoint: impl Into<url::Url>, token: &str) -> Result<Self> {
49        let authorizer = middle::BearerTokenAuthorizer::new(token).map_err(|e| {
50            tracing::error!("Could not construct OpenFGA client. Invalid token: {e}");
51            Error::InvalidToken {
52                reason: e.to_string(),
53            }
54        })?;
55        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
56        let channel = endpoint.connect_lazy();
57        let intercepted = InterceptedService::new(channel, authorizer);
58        let service = Either::Left(Either::Right(intercepted));
59        Ok(BasicOpenFgaServiceClient::new(service))
60    }
61
62    /// Create a new client using client credentials.
63    ///
64    /// # Errors
65    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
66    /// * [`Error::CredentialRefreshError`] if the client credentials could not be exchanged for a token.
67    pub async fn new_with_client_credentials(
68        endpoint: impl Into<url::Url>,
69        client_id: &str,
70        client_secret: &str,
71        token_endpoint: impl Into<url::Url>,
72        scopes: &[&str],
73    ) -> Result<Self> {
74        let builder = middle::BasicClientCredentialAuthorizer::basic_builder(
75            client_id,
76            client_secret,
77            token_endpoint.into(),
78        );
79        let authorizer = if scopes.is_empty() {
80            builder
81        } else {
82            builder.add_scopes(scopes)
83        }
84        .build()
85        .await
86        .map_err(|e| {
87            tracing::error!("Could not construct OpenFGA client. Failed to fetch or refresh Client Credentials: {e}");
88            Error::CredentialRefreshError(e)
89        })?;
90        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
91        let channel = endpoint.connect_lazy();
92        let intercepted = InterceptedService::new(channel, authorizer);
93        let service = Either::Left(Either::Left(intercepted));
94        Ok(BasicOpenFgaServiceClient::new(service))
95    }
96}
97
98impl<T> OpenFgaServiceClient<T>
99where
100    T: tonic::client::GrpcService<tonic::body::Body>,
101    T::Error: Into<StdError>,
102    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
103    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
104    T: Clone,
105{
106    /// Transform this service client into a higher-level [`OpenFgaClient`].
107    pub fn into_client(self, store_id: &str, authorization_model_id: &str) -> OpenFgaClient<T> {
108        OpenFgaClient::new(self, store_id, authorization_model_id)
109    }
110
111    /// Fetch a store by name.
112    /// If no store is found, returns `Ok(None)`.
113    ///
114    /// # Errors
115    /// * [`Error::AmbiguousStoreName`] if multiple stores with the same name are found.
116    /// * [`Error::RequestFailed`] if the request to OpenFGA fails.
117    pub async fn get_store_by_name(&mut self, store_name: &str) -> Result<Option<Store>> {
118        let stores = self
119            .list_stores(ListStoresRequest {
120                page_size: Some(2),
121                continuation_token: String::new(),
122                name: store_name.to_string(),
123            })
124            .await
125            .map_err(|e| {
126                tracing::error!("Failed to list stores in OpenFGA: {e}");
127                Error::RequestFailed(Box::new(e))
128            })?
129            .into_inner();
130        let num_stores = stores.stores.len();
131
132        match stores.stores.first() {
133            Some(store) => {
134                if num_stores > 1 {
135                    tracing::error!("Multiple stores with the name `{}` found", store_name);
136                    Err(Error::AmbiguousStoreName(store_name.to_string()))
137                } else {
138                    Ok(Some(store.clone()))
139                }
140            }
141            None => Ok(None),
142        }
143    }
144
145    /// Get a store by name or create it if it doesn't exist.
146    /// Returns information about the store, including its ID.
147    ///
148    /// # Errors
149    /// * [`Error::RequestFailed`] If a request to OpenFGA fails.
150    /// * [`Error::AmbiguousStoreName`] If multiple stores with the same name are found.
151    pub async fn get_or_create_store(&mut self, store_name: &str) -> Result<Store> {
152        let store = self.get_store_by_name(store_name).await?;
153        match store {
154            None => {
155                tracing::debug!("OpenFGA Store {} not found. Creating it.", store_name);
156                let store = self
157                    .create_store(CreateStoreRequest {
158                        name: store_name.to_owned(),
159                    })
160                    .await
161                    .map_err(|e| {
162                        tracing::error!("Failed to create store in OpenFGA: {e}");
163                        Error::RequestFailed(Box::new(e))
164                    })?
165                    .into_inner();
166                Ok(Store {
167                    id: store.id,
168                    name: store.name,
169                    created_at: store.created_at,
170                    updated_at: store.updated_at,
171                    deleted_at: None,
172                })
173            }
174            Some(store) => Ok(store),
175        }
176    }
177
178    /// Wrapper around [`Self::read`] that reads all pages of the result, handling pagination.
179    ///
180    /// `tuple` may be:
181    ///
182    /// * `Some(filter)` — returns tuples matching the filter. The OpenFGA
183    ///   server requires `filter.object` to specify at least an object type,
184    ///   AND requires either a non-empty `filter.user` or a non-empty object
185    ///   id; a bare `"type:"` prefix on its own is rejected.
186    /// * `None` — **enumerates every tuple in the store**, paginating to
187    ///   completion. This is the supported global-tuple-enumeration primitive
188    ///   and is what the OpenFGA CLI's `fga store export` uses internally.
189    ///
190    /// `page_size` is capped at 100 by the OpenFGA Read RPC (proto-level
191    /// validation, not configurable).
192    ///
193    /// # Errors
194    /// * [`Error::RequestFailed`] If a request to OpenFGA fails.
195    /// * [`Error::TooManyPages`] If the number of pages read exceeds `max_pages`.
196    pub async fn read_all_pages(
197        &mut self,
198        store_id: &str,
199        tuple: Option<impl Into<ReadRequestTupleKey>>,
200        consistency: impl Into<ConsistencyPreference>,
201        page_size: i32,
202        max_pages: u32,
203    ) -> Result<Vec<Tuple>> {
204        let mut continuation_token = String::new();
205        let tuple = tuple.map(Into::into);
206        let mut tuples = Vec::new();
207        let mut count = 0;
208        let consistency = consistency.into();
209
210        loop {
211            let read_request = ReadRequest {
212                store_id: store_id.to_owned(),
213                tuple_key: tuple.clone(),
214                page_size: Some(page_size),
215                continuation_token: continuation_token.clone(),
216                consistency: consistency.into(),
217            };
218            let response = self
219                .read(read_request.clone())
220                .await
221                .map_err(|e| {
222                    tracing::error!(
223                        "Failed to read from OpenFGA: {e}. Request: {:?}",
224                        read_request
225                    );
226                    Error::RequestFailed(Box::new(e))
227                })?
228                .into_inner();
229            tuples.extend(response.tuples);
230            continuation_token.clone_from(&response.continuation_token);
231            count += 1;
232            if count > max_pages {
233                return Err(Error::TooManyPages { max_pages, tuple });
234            }
235            if continuation_token.is_empty() {
236                break;
237            }
238        }
239
240        Ok(tuples)
241    }
242}
243
244#[cfg(feature = "auth-middle")]
245pub type BasicAuthLayer = tower::util::Either<
246    tower::util::Either<
247        InterceptedService<Channel, middle::BasicClientCredentialAuthorizer>,
248        InterceptedService<Channel, middle::BearerTokenAuthorizer>,
249    >,
250    InterceptedService<Channel, NoOpInterceptor>,
251>;
252
253#[cfg(feature = "auth-middle")]
254#[derive(Clone, Copy, Debug)]
255pub struct NoOpInterceptor;
256
257#[cfg(feature = "auth-middle")]
258impl tonic::service::Interceptor for NoOpInterceptor {
259    fn call(
260        &mut self,
261        request: tonic::Request<()>,
262    ) -> std::result::Result<tonic::Request<()>, tonic::Status> {
263        Ok(request)
264    }
265}
266
267#[cfg(feature = "auth-middle")]
268fn get_tonic_endpoint_logged(endpoint: &url::Url) -> Result<Endpoint> {
269    let ep = Endpoint::new(endpoint.to_string()).map_err(|e| {
270        tracing::error!("Could not construct OpenFGA client. Invalid endpoint `{endpoint}`: {e}");
271        Error::InvalidEndpoint(endpoint.to_string())
272    })?;
273
274    // Configure TLS if the endpoint uses HTTPS
275    if endpoint.scheme() == "https" {
276        #[cfg(feature = "tls-rustls")]
277        {
278            use tonic::transport::ClientTlsConfig;
279            let tls_config = ClientTlsConfig::new().with_enabled_roots();
280            return ep.tls_config(tls_config).map_err(|e| {
281                tracing::error!(
282                    "Could not configure TLS for OpenFGA client endpoint `{endpoint}`: {e}"
283                );
284                Error::TlsConfigurationFailed {
285                    endpoint: endpoint.to_string(),
286                    reason: e.to_string(),
287                }
288            });
289        }
290        #[cfg(not(feature = "tls-rustls"))]
291        {
292            return Err(Error::TlsConfigurationFailed {
293                endpoint: endpoint.to_string(),
294                reason: "HTTPS endpoint requires the `tls-rustls` feature to be enabled"
295                    .to_string(),
296            });
297        }
298    }
299
300    Ok(ep)
301}
302
303#[cfg(test)]
304pub(crate) mod test {
305    use needs_env_var::needs_env_var;
306
307    // #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
308    #[cfg(feature = "auth-middle")]
309    mod openfga {
310        use std::collections::{HashMap, HashSet};
311
312        use super::super::*;
313        use crate::{
314            client::{
315                TupleKey, WriteAuthorizationModelRequest, WriteAuthorizationModelResponse,
316                WriteRequest, WriteRequestWrites,
317            },
318            generated::AuthorizationModel,
319        };
320
321        fn get_basic_client() -> BasicOpenFgaServiceClient {
322            let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
323            BasicOpenFgaServiceClient::new_unauthenticated(url::Url::parse(&endpoint).unwrap())
324                .expect("Client can be created")
325        }
326
327        async fn new_store() -> (BasicOpenFgaServiceClient, Store) {
328            let mut client = get_basic_client();
329            let store_name = format!("store-{}", uuid::Uuid::now_v7());
330            let store = client
331                .get_or_create_store(&store_name)
332                .await
333                .expect("Store can be created");
334            (client, store)
335        }
336
337        async fn create_entitlements_model(
338            client: &mut BasicOpenFgaServiceClient,
339            store: &Store,
340        ) -> WriteAuthorizationModelResponse {
341            let schema = include_str!("../tests/sample-store/entitlements/schema.json");
342            let model: AuthorizationModel =
343                serde_json::from_str(schema).expect("Schema can be deserialized");
344            let auth_model = client
345                .write_authorization_model(WriteAuthorizationModelRequest {
346                    store_id: store.id.clone(),
347                    type_definitions: model.type_definitions,
348                    schema_version: model.schema_version,
349                    conditions: model.conditions,
350                })
351                .await
352                .expect("Auth model can be written");
353
354            auth_model.into_inner()
355        }
356
357        #[tokio::test]
358        async fn test_get_store_by_name_many() {
359            let mut client = get_basic_client();
360
361            let mut stores = HashMap::new();
362            for _i in 0..201 {
363                let store_name = format!("store-{}", uuid::Uuid::now_v7());
364                let r = client
365                    .get_or_create_store(&store_name)
366                    .await
367                    .expect("Store can be created");
368                assert_eq!(store_name, r.name);
369                stores.insert(store_name, r.id);
370            }
371
372            for (store_name, store_id) in stores {
373                let store = client
374                    .get_store_by_name(&store_name)
375                    .await
376                    .expect("Store can be fetched")
377                    .expect("Store exists");
378                assert_eq!(store_id, store.id);
379            }
380        }
381
382        #[tokio::test]
383        async fn test_get_store_by_name_non_existant() {
384            let mut client = get_basic_client();
385            let store = client
386                .get_store_by_name("non-existent-store")
387                .await
388                .unwrap();
389            assert!(store.is_none());
390        }
391
392        #[tokio::test]
393        async fn test_get_or_create_store() {
394            let mut client = get_basic_client();
395            let store_name = format!("store-{}", uuid::Uuid::now_v7());
396            let store = client.get_or_create_store(&store_name).await.unwrap();
397            let store2 = client.get_or_create_store(&store_name).await.unwrap();
398            assert_eq!(store.id, store2.id);
399        }
400
401        #[tokio::test]
402        async fn test_read_all_pages_many() {
403            let (mut client, store) = new_store().await;
404            let auth_model = create_entitlements_model(&mut client, &store).await;
405            let object = "organization:org-1";
406
407            let users = (0..501)
408                .map(|i| format!("user:u-{i}"))
409                .collect::<Vec<String>>();
410
411            for user in &users {
412                client
413                    .write(WriteRequest {
414                        authorization_model_id: auth_model.authorization_model_id.clone(),
415                        store_id: store.id.clone(),
416                        writes: Some(WriteRequestWrites {
417                            on_duplicate: String::new(),
418                            tuple_keys: vec![TupleKey {
419                                user: user.clone(),
420                                relation: "member".to_string(),
421                                object: object.to_string(),
422                                condition: None,
423                            }],
424                        }),
425                        deletes: None,
426                    })
427                    .await
428                    .expect("Write can be done");
429            }
430
431            let tuples = client
432                .read_all_pages(
433                    &store.id,
434                    Some(ReadRequestTupleKey {
435                        user: String::new(),
436                        relation: "member".to_string(),
437                        object: object.to_string(),
438                    }),
439                    ConsistencyPreference::HigherConsistency,
440                    100,
441                    6,
442                )
443                .await
444                .expect("Read can be done");
445
446            assert_eq!(tuples.len(), 501);
447            assert_eq!(
448                tuples
449                    .iter()
450                    .map(|t| t.key.clone().unwrap().user)
451                    .collect::<HashSet<String>>(),
452                HashSet::from_iter(users)
453            );
454        }
455
456        #[tokio::test]
457        async fn test_real_all_pages_empty() {
458            let (mut client, store) = new_store().await;
459            let tuples = client
460                .read_all_pages(
461                    &store.id,
462                    Some(ReadRequestTupleKey {
463                        user: String::new(),
464                        relation: "member".to_string(),
465                        object: "organization:org-1".to_string(),
466                    }),
467                    ConsistencyPreference::HigherConsistency,
468                    100,
469                    5,
470                )
471                .await
472                .expect("Read can be done");
473
474            assert!(tuples.is_empty());
475        }
476
477        /// Direct low-level verification that `read_all_pages` with no filter
478        /// (`tuple=None`) returns *every* tuple in the store, across multiple
479        /// pages. Mirrors the high-level test in
480        /// `model_client::tests::openfga::test_read_all_pages_empty_tuple` but
481        /// exercises the [`OpenFgaServiceClient::read_all_pages`] entry point
482        /// directly so it doesn't regress if the high-level wrapper changes.
483        #[tokio::test]
484        async fn test_read_all_pages_unfiltered() {
485            let (mut client, store) = new_store().await;
486            let auth_model = create_entitlements_model(&mut client, &store).await;
487
488            // 250 distinct (user, relation, object) tuples spread across multiple
489            // objects so no single-key filter could fetch them. With page_size=100,
490            // this forces 3 pages of pagination.
491            let total = 250;
492            for i in 0..total {
493                client
494                    .write(WriteRequest {
495                        authorization_model_id: auth_model.authorization_model_id.clone(),
496                        store_id: store.id.clone(),
497                        writes: Some(WriteRequestWrites {
498                            on_duplicate: String::new(),
499                            tuple_keys: vec![TupleKey {
500                                user: format!("user:u-{i}"),
501                                relation: "member".to_string(),
502                                object: format!("organization:org-{}", i % 5),
503                                condition: None,
504                            }],
505                        }),
506                        deletes: None,
507                    })
508                    .await
509                    .expect("write can be done");
510            }
511
512            let tuples = client
513                .read_all_pages(
514                    &store.id,
515                    None::<ReadRequestTupleKey>,
516                    ConsistencyPreference::HigherConsistency,
517                    100,
518                    10,
519                )
520                .await
521                .expect("unfiltered read_all_pages must succeed");
522
523            assert_eq!(
524                tuples.len(),
525                total,
526                "unfiltered read_all_pages must return every tuple in the store"
527            );
528        }
529
530        /// Regression test for the off-by-two pagination cap.
531        ///
532        /// `max_pages = N` is contractually documented as "the read errors if
533        /// the response would require more than N pages". The pre-fix logic
534        /// allowed up to `N + 2` pages of data to come back successfully (the
535        /// counter was checked before being incremented, so `count > max_pages`
536        /// only fired two iterations after the limit was reached).
537        ///
538        /// We write enough data to require 3 pages and ask for `max_pages = 1`
539        /// — the call must error with [`Error::TooManyPages`], not return
540        /// silently.
541        #[tokio::test]
542        async fn test_read_all_pages_max_pages_enforced() {
543            let (mut client, store) = new_store().await;
544            let auth_model = create_entitlements_model(&mut client, &store).await;
545
546            // 3 pages worth at page_size=100 → 250 tuples.
547            for i in 0..250 {
548                client
549                    .write(WriteRequest {
550                        authorization_model_id: auth_model.authorization_model_id.clone(),
551                        store_id: store.id.clone(),
552                        writes: Some(WriteRequestWrites {
553                            on_duplicate: String::new(),
554                            tuple_keys: vec![TupleKey {
555                                user: format!("user:u-{i}"),
556                                relation: "member".to_string(),
557                                object: "organization:org-1".to_string(),
558                                condition: None,
559                            }],
560                        }),
561                        deletes: None,
562                    })
563                    .await
564                    .expect("write can be done");
565            }
566
567            let result = client
568                .read_all_pages(
569                    &store.id,
570                    None::<ReadRequestTupleKey>,
571                    ConsistencyPreference::HigherConsistency,
572                    100,
573                    1, // strictly fewer than the 3 pages of data
574                )
575                .await;
576
577            match result {
578                Err(Error::TooManyPages { max_pages, .. }) => {
579                    assert_eq!(max_pages, 1);
580                }
581                Err(other) => panic!("expected TooManyPages, got {other:?}"),
582                Ok(tuples) => panic!(
583                    "expected TooManyPages error, got Ok with {} tuples",
584                    tuples.len()
585                ),
586            }
587        }
588    }
589}