openfga_client/
client_ext.rs

1#![allow(unused_imports)]
2
3use tonic::{
4    codegen::{Body, Bytes, StdError},
5    service::interceptor::InterceptorLayer,
6    transport::{Channel, Endpoint},
7};
8#[cfg(feature = "auth-middle")]
9use tower::{util::Either, ServiceBuilder};
10
11use crate::{
12    client::{OpenFgaClient, OpenFgaServiceClient},
13    error::{Error, Result},
14    generated::{
15        ConsistencyPreference, CreateStoreRequest, ListStoresRequest, ReadRequest,
16        ReadRequestTupleKey, Store, Tuple,
17    },
18};
19
20#[cfg(feature = "auth-middle")]
21/// Specialization of the [`OpenFgaServiceClient`] that includes optional
22/// authentication with pre-shared keys (Bearer tokens) or client credentials.
23/// For more fine-granular control, you can construct [`OpenFgaServiceClient`] directly
24/// using interceptors for Authentication.
25pub type BasicOpenFgaServiceClient = OpenFgaServiceClient<BasicAuthLayer>;
26
27#[cfg(feature = "auth-middle")]
28impl BasicOpenFgaServiceClient {
29    /// Create a new client without authentication.
30    ///
31    /// # Errors
32    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
33    pub fn new_unauthenticated(endpoint: impl Into<url::Url>) -> Result<Self> {
34        let either_or_option: EitherOrOption = None;
35        let auth_layer = tower::util::option_layer(either_or_option);
36        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
37        let c = ServiceBuilder::new()
38            .layer(auth_layer)
39            .service(endpoint.connect_lazy());
40        Ok(BasicOpenFgaServiceClient::new(c))
41    }
42
43    /// Create a new client without authentication.
44    ///
45    /// # Errors
46    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
47    /// * [`Error::InvalidToken`] if the token is not valid ASCII.
48    pub fn new_with_basic_auth(endpoint: impl Into<url::Url>, token: &str) -> Result<Self> {
49        let either_or_option: EitherOrOption =
50            Some(tower::util::Either::Right(tonic::service::interceptor(
51                middle::BearerTokenAuthorizer::new(token).map_err(|e| {
52                    tracing::error!("Could not construct OpenFGA client. Invalid token: {e}");
53                    Error::InvalidToken {
54                        reason: e.to_string(),
55                    }
56                })?,
57            )));
58        let auth_layer = tower::util::option_layer(either_or_option);
59        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
60        let c = ServiceBuilder::new()
61            .layer(auth_layer)
62            .service(endpoint.connect_lazy());
63        Ok(BasicOpenFgaServiceClient::new(c))
64    }
65
66    /// Create a new client using client credentials.
67    ///
68    /// # Errors
69    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
70    /// * [`Error::CredentialRefreshError`] if the client credentials could not be exchanged for a token.
71    pub async fn new_with_client_credentials(
72        endpoint: impl Into<url::Url>,
73        client_id: &str,
74        client_secret: &str,
75        token_endpoint: impl Into<url::Url>,
76        scopes: &[&str],
77    ) -> Result<Self> {
78        let either_or_option: EitherOrOption =
79            Some(tower::util::Either::Left(tonic::service::interceptor(
80                {
81                let builder = middle::BasicClientCredentialAuthorizer::basic_builder(
82                    client_id,
83                    client_secret,
84                    token_endpoint.into(),
85                );
86                if scopes.is_empty() {
87                    builder
88                } else {
89                    builder.add_scopes(scopes)
90                }
91            }
92                .build()
93                .await.map_err(|e| {
94                    tracing::error!("Could not construct OpenFGA client. Failed to fetch or refresh Client Credentials: {e}");
95                    Error::CredentialRefreshError(e)
96                })?,
97            )));
98        let auth_layer = tower::util::option_layer(either_or_option);
99        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
100        let c = ServiceBuilder::new()
101            .layer(auth_layer)
102            .service(endpoint.connect_lazy());
103        Ok(BasicOpenFgaServiceClient::new(c))
104    }
105}
106
107impl<T> OpenFgaServiceClient<T>
108where
109    T: tonic::client::GrpcService<tonic::body::BoxBody>,
110    T::Error: Into<StdError>,
111    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
112    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
113    T: Clone,
114{
115    /// Transform this service client into a higher-level [`OpenFgaClient`].
116    pub fn into_client(self, store_id: &str, authorization_model_id: &str) -> OpenFgaClient<T> {
117        OpenFgaClient::new(self, store_id, authorization_model_id)
118    }
119
120    /// Fetch a store by name.
121    /// If no store is found, returns `Ok(None)`.
122    ///
123    /// # Errors
124    /// * [`Error::AmbiguousStoreName`] if multiple stores with the same name are found.
125    /// * [`Error::RequestFailed`] if the request to OpenFGA fails.
126    pub async fn get_store_by_name(&mut self, store_name: &str) -> Result<Option<Store>> {
127        let stores = self
128            .list_stores(ListStoresRequest {
129                page_size: Some(2),
130                continuation_token: String::new(),
131                name: store_name.to_string(),
132            })
133            .await
134            .map_err(|e| {
135                tracing::error!("Failed to list stores in OpenFGA: {e}");
136                Error::RequestFailed(e)
137            })?
138            .into_inner();
139        let num_stores = stores.stores.len();
140
141        match stores.stores.first() {
142            Some(store) => {
143                if num_stores > 1 {
144                    tracing::error!("Multiple stores with the name `{}` found", store_name);
145                    Err(Error::AmbiguousStoreName(store_name.to_string()))
146                } else {
147                    Ok(Some(store.clone()))
148                }
149            }
150            None => Ok(None),
151        }
152    }
153
154    /// Get a store by name or create it if it doesn't exist.
155    /// Returns information about the store, including its ID.
156    ///
157    /// # Errors
158    /// * [`Error::RequestFailed`] If a request to OpenFGA fails.
159    /// * [`Error::AmbiguousStoreName`] If multiple stores with the same name are found.
160    pub async fn get_or_create_store(&mut self, store_name: &str) -> Result<Store> {
161        let store = self.get_store_by_name(store_name).await?;
162        match store {
163            None => {
164                tracing::debug!("OpenFGA Store {} not found. Creating it.", store_name);
165                let store = self
166                    .create_store(CreateStoreRequest {
167                        name: store_name.to_owned(),
168                    })
169                    .await
170                    .map_err(|e| {
171                        tracing::error!("Failed to create store in OpenFGA: {e}");
172                        Error::RequestFailed(e)
173                    })?
174                    .into_inner();
175                Ok(Store {
176                    id: store.id,
177                    name: store.name,
178                    created_at: store.created_at,
179                    updated_at: store.updated_at,
180                    deleted_at: None,
181                })
182            }
183            Some(store) => Ok(store),
184        }
185    }
186
187    /// Wrapper around [`Self::read`] that reads all pages of the result, handling pagination.
188    ///
189    /// # Errors
190    /// * [`Error::RequestFailed`] If a request to OpenFGA fails.
191    /// * [`Error::TooManyPages`] If the number of pages read exceeds `max_pages`.
192    pub async fn read_all_pages(
193        &mut self,
194        store_id: &str,
195        tuple: impl Into<ReadRequestTupleKey>,
196        consistency: impl Into<ConsistencyPreference>,
197        page_size: i32,
198        max_pages: u32,
199    ) -> Result<Vec<Tuple>> {
200        let mut continuation_token = String::new();
201        let mut tuples = Vec::new();
202        let mut count = 0;
203        let tuple = tuple.into();
204        let consistency = consistency.into();
205
206        loop {
207            let read_request = ReadRequest {
208                store_id: store_id.to_owned(),
209                tuple_key: Some(tuple.clone()),
210                page_size: Some(page_size),
211                continuation_token: continuation_token.clone(),
212                consistency: consistency.into(),
213            };
214            let response = self
215                .read(read_request.clone())
216                .await
217                .map_err(|e| {
218                    tracing::error!(
219                        "Failed to read from OpenFGA: {e}. Request: {:?}",
220                        read_request
221                    );
222                    Error::RequestFailed(e)
223                })?
224                .into_inner();
225            tuples.extend(response.tuples);
226            continuation_token.clone_from(&response.continuation_token);
227            if continuation_token.is_empty() || count > max_pages {
228                if count > max_pages {
229                    return Err(Error::TooManyPages { max_pages, tuple });
230                }
231                break;
232            }
233            count += 1;
234        }
235
236        Ok(tuples)
237    }
238}
239
240#[cfg(feature = "auth-middle")]
241pub type BasicAuthLayer = tower::util::Either<
242    tower::util::Either<
243        tonic::service::interceptor::InterceptedService<
244            Channel,
245            middle::BasicClientCredentialAuthorizer,
246        >,
247        tonic::service::interceptor::InterceptedService<Channel, middle::BearerTokenAuthorizer>,
248    >,
249    Channel,
250>;
251
252#[cfg(feature = "auth-middle")]
253type EitherOrOption = Option<
254    Either<
255        InterceptorLayer<middle::BasicClientCredentialAuthorizer>,
256        InterceptorLayer<middle::BearerTokenAuthorizer>,
257    >,
258>;
259
260#[cfg(feature = "auth-middle")]
261fn get_tonic_endpoint_logged(endpoint: &url::Url) -> Result<Endpoint> {
262    Endpoint::new(endpoint.to_string()).map_err(|e| {
263        tracing::error!("Could not construct OpenFGA client. Invalid endpoint `{endpoint}`: {e}");
264        Error::InvalidEndpoint(endpoint.to_string())
265    })
266}
267
268#[cfg(test)]
269pub(crate) mod test {
270    use needs_env_var::needs_env_var;
271
272    // #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
273    #[cfg(feature = "auth-middle")]
274    mod openfga {
275        use std::collections::{HashMap, HashSet};
276
277        use super::super::*;
278        use crate::{
279            client::{
280                TupleKey, WriteAuthorizationModelRequest, WriteAuthorizationModelResponse,
281                WriteRequest, WriteRequestWrites,
282            },
283            generated::AuthorizationModel,
284        };
285
286        fn get_basic_client() -> BasicOpenFgaServiceClient {
287            let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
288            BasicOpenFgaServiceClient::new_unauthenticated(url::Url::parse(&endpoint).unwrap())
289                .expect("Client can be created")
290        }
291
292        async fn new_store() -> (BasicOpenFgaServiceClient, Store) {
293            let mut client = get_basic_client();
294            let store_name = format!("store-{}", uuid::Uuid::now_v7());
295            let store = client
296                .get_or_create_store(&store_name)
297                .await
298                .expect("Store can be created");
299            (client, store)
300        }
301
302        async fn create_entitlements_model(
303            client: &mut BasicOpenFgaServiceClient,
304            store: &Store,
305        ) -> WriteAuthorizationModelResponse {
306            let schema = include_str!("../tests/sample-store/entitlements/schema.json");
307            let model: AuthorizationModel =
308                serde_json::from_str(schema).expect("Schema can be deserialized");
309            let auth_model = client
310                .write_authorization_model(WriteAuthorizationModelRequest {
311                    store_id: store.id.clone(),
312                    type_definitions: model.type_definitions,
313                    schema_version: model.schema_version,
314                    conditions: model.conditions,
315                })
316                .await
317                .expect("Auth model can be written");
318
319            auth_model.into_inner()
320        }
321
322        #[tokio::test]
323        async fn test_get_store_by_name_many() {
324            let mut client = get_basic_client();
325
326            let mut stores = HashMap::new();
327            for _i in 0..201 {
328                let store_name = format!("store-{}", uuid::Uuid::now_v7());
329                let r = client
330                    .get_or_create_store(&store_name)
331                    .await
332                    .expect("Store can be created");
333                assert_eq!(store_name, r.name);
334                stores.insert(store_name, r.id);
335            }
336
337            for (store_name, store_id) in stores {
338                let store = client
339                    .get_store_by_name(&store_name)
340                    .await
341                    .expect("Store can be fetched")
342                    .expect("Store exists");
343                assert_eq!(store_id, store.id);
344            }
345        }
346
347        #[tokio::test]
348        async fn test_get_store_by_name_non_existant() {
349            let mut client = get_basic_client();
350            let store = client
351                .get_store_by_name("non-existent-store")
352                .await
353                .unwrap();
354            assert!(store.is_none());
355        }
356
357        #[tokio::test]
358        async fn test_get_or_create_store() {
359            let mut client = get_basic_client();
360            let store_name = format!("store-{}", uuid::Uuid::now_v7());
361            let store = client.get_or_create_store(&store_name).await.unwrap();
362            let store2 = client.get_or_create_store(&store_name).await.unwrap();
363            assert_eq!(store.id, store2.id);
364        }
365
366        #[tokio::test]
367        async fn test_read_all_pages_many() {
368            let (mut client, store) = new_store().await;
369            let auth_model = create_entitlements_model(&mut client, &store).await;
370            let object = "organization:org-1";
371
372            let users = (0..501)
373                .map(|i| format!("user:u-{i}"))
374                .collect::<Vec<String>>();
375
376            for user in &users {
377                client
378                    .write(WriteRequest {
379                        authorization_model_id: auth_model.authorization_model_id.clone(),
380                        store_id: store.id.clone(),
381                        writes: Some(WriteRequestWrites {
382                            tuple_keys: vec![TupleKey {
383                                user: user.to_string(),
384                                relation: "member".to_string(),
385                                object: object.to_string(),
386                                condition: None,
387                            }],
388                        }),
389                        deletes: None,
390                    })
391                    .await
392                    .expect("Write can be done");
393            }
394
395            let tuples = client
396                .read_all_pages(
397                    &store.id,
398                    ReadRequestTupleKey {
399                        user: String::new(),
400                        relation: "member".to_string(),
401                        object: object.to_string(),
402                    },
403                    ConsistencyPreference::HigherConsistency,
404                    100,
405                    6,
406                )
407                .await
408                .expect("Read can be done");
409
410            assert_eq!(tuples.len(), 501);
411            assert_eq!(
412                tuples
413                    .iter()
414                    .map(|t| t.key.clone().unwrap().user)
415                    .collect::<HashSet<String>>(),
416                HashSet::from_iter(users)
417            );
418        }
419
420        #[tokio::test]
421        async fn test_real_all_pages_empty() {
422            let (mut client, store) = new_store().await;
423            let tuples = client
424                .read_all_pages(
425                    &store.id,
426                    ReadRequestTupleKey {
427                        user: String::new(),
428                        relation: "member".to_string(),
429                        object: "organization:org-1".to_string(),
430                    },
431                    ConsistencyPreference::HigherConsistency,
432                    100,
433                    5,
434                )
435                .await
436                .expect("Read can be done");
437
438            assert!(tuples.is_empty());
439        }
440    }
441}