openfga_client/
client_ext.rs

1#![allow(unused_imports)]
2
3use tonic::{
4    codegen::{Body, Bytes, StdError},
5    service::interceptor::InterceptorLayer,
6    transport::{Channel, Endpoint},
7};
8#[cfg(feature = "auth-middle")]
9use tower::{util::Either, ServiceBuilder};
10
11use crate::{
12    client::{OpenFgaClient, OpenFgaServiceClient},
13    error::{Error, Result},
14    generated::{
15        ConsistencyPreference, CreateStoreRequest, ListStoresRequest, ReadRequest,
16        ReadRequestTupleKey, Store, Tuple,
17    },
18};
19
20#[cfg(feature = "auth-middle")]
21/// Specialization of the [`OpenFgaServiceClient`] that includes optional
22/// authentication with pre-shared keys (Bearer tokens) or client credentials.
23/// For more fine-granular control, you can construct [`OpenFgaServiceClient`] directly
24/// using interceptors for Authentication.
25pub type BasicOpenFgaServiceClient = OpenFgaServiceClient<BasicAuthLayer>;
26
27#[cfg(feature = "auth-middle")]
28impl BasicOpenFgaServiceClient {
29    /// Create a new client without authentication.
30    ///
31    /// # Errors
32    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
33    pub fn new_unauthenticated(endpoint: impl Into<url::Url>) -> Result<Self> {
34        let either_or_option: EitherOrOption = None;
35        let auth_layer = tower::util::option_layer(either_or_option);
36        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
37        let c = ServiceBuilder::new()
38            .layer(auth_layer)
39            .service(endpoint.connect_lazy());
40        Ok(BasicOpenFgaServiceClient::new(c))
41    }
42
43    /// Create a new client without authentication.
44    ///
45    /// # Errors
46    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
47    /// * [`Error::InvalidToken`] if the token is not valid ASCII.
48    pub fn new_with_basic_auth(endpoint: impl Into<url::Url>, token: &str) -> Result<Self> {
49        let either_or_option: EitherOrOption =
50            Some(tower::util::Either::Right(tonic::service::interceptor(
51                middle::BearerTokenAuthorizer::new(token).map_err(|e| {
52                    tracing::error!("Could not construct OpenFGA client. Invalid token: {e}");
53                    Error::InvalidToken {
54                        reason: e.to_string(),
55                    }
56                })?,
57            )));
58        let auth_layer = tower::util::option_layer(either_or_option);
59        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
60        let c = ServiceBuilder::new()
61            .layer(auth_layer)
62            .service(endpoint.connect_lazy());
63        Ok(BasicOpenFgaServiceClient::new(c))
64    }
65
66    /// Create a new client using client credentials.
67    ///
68    /// # Errors
69    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
70    /// * [`Error::CredentialRefreshError`] if the client credentials could not be exchanged for a token.
71    pub async fn new_with_client_credentials(
72        endpoint: impl Into<url::Url>,
73        client_id: &str,
74        client_secret: &str,
75        token_endpoint: impl Into<url::Url>,
76        scopes: &[&str],
77    ) -> Result<Self> {
78        let either_or_option: EitherOrOption =
79            Some(tower::util::Either::Left(tonic::service::interceptor(
80                {
81                let builder = middle::BasicClientCredentialAuthorizer::basic_builder(
82                    client_id,
83                    client_secret,
84                    token_endpoint.into(),
85                );
86                if scopes.is_empty() {
87                    builder
88                } else {
89                    builder.add_scopes(scopes)
90                }
91            }
92                .build()
93                .await.map_err(|e| {
94                    tracing::error!("Could not construct OpenFGA client. Failed to fetch or refresh Client Credentials: {e}");
95                    Error::CredentialRefreshError(e)
96                })?,
97            )));
98        let auth_layer = tower::util::option_layer(either_or_option);
99        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
100        let c = ServiceBuilder::new()
101            .layer(auth_layer)
102            .service(endpoint.connect_lazy());
103        Ok(BasicOpenFgaServiceClient::new(c))
104    }
105}
106
107impl<T> OpenFgaServiceClient<T>
108where
109    T: tonic::client::GrpcService<tonic::body::BoxBody>,
110    T::Error: Into<StdError>,
111    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
112    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
113    T: Clone,
114{
115    /// Transform this service client into a higher-level [`OpenFgaClient`].
116    pub fn into_client(self, store_id: &str, authorization_model_id: &str) -> OpenFgaClient<T> {
117        OpenFgaClient::new(self, store_id, authorization_model_id)
118    }
119
120    /// Fetch a store by name.
121    /// If no store is found, returns `Ok(None)`.
122    ///
123    /// # Errors
124    /// * [`Error::AmbiguousStoreName`] if multiple stores with the same name are found.
125    /// * [`Error::RequestFailed`] if the request to OpenFGA fails.
126    pub async fn get_store_by_name(&mut self, store_name: &str) -> Result<Option<Store>> {
127        let stores = self
128            .list_stores(ListStoresRequest {
129                page_size: Some(2),
130                continuation_token: String::new(),
131                name: store_name.to_string(),
132            })
133            .await
134            .map_err(|e| {
135                tracing::error!("Failed to list stores in OpenFGA: {e}");
136                Error::RequestFailed(e)
137            })?
138            .into_inner();
139        let num_stores = stores.stores.len();
140
141        match stores.stores.first() {
142            Some(store) => {
143                if num_stores > 1 {
144                    tracing::error!("Multiple stores with the name `{}` found", store_name);
145                    Err(Error::AmbiguousStoreName(store_name.to_string()))
146                } else {
147                    Ok(Some(store.clone()))
148                }
149            }
150            None => Ok(None),
151        }
152    }
153
154    /// Get a store by name or create it if it doesn't exist.
155    /// Returns information about the store, including its ID.
156    ///
157    /// # Errors
158    /// * [`Error::RequestFailed`] If a request to OpenFGA fails.
159    /// * [`Error::AmbiguousStoreName`] If multiple stores with the same name are found.
160    pub async fn get_or_create_store(&mut self, store_name: &str) -> Result<Store> {
161        let store = self.get_store_by_name(store_name).await?;
162        match store {
163            None => {
164                tracing::debug!("OpenFGA Store {} not found. Creating it.", store_name);
165                let store = self
166                    .create_store(CreateStoreRequest {
167                        name: store_name.to_owned(),
168                    })
169                    .await
170                    .map_err(|e| {
171                        tracing::error!("Failed to create store in OpenFGA: {e}");
172                        Error::RequestFailed(e)
173                    })?
174                    .into_inner();
175                Ok(Store {
176                    id: store.id,
177                    name: store.name,
178                    created_at: store.created_at,
179                    updated_at: store.updated_at,
180                    deleted_at: None,
181                })
182            }
183            Some(store) => Ok(store),
184        }
185    }
186
187    /// Wrapper around [`Self::read`] that reads all pages of the result, handling pagination.
188    ///
189    /// # Errors
190    /// * [`Error::RequestFailed`] If a request to OpenFGA fails.
191    /// * [`Error::TooManyPages`] If the number of pages read exceeds `max_pages`.
192    pub async fn read_all_pages(
193        &mut self,
194        store_id: &str,
195        tuple: impl Into<ReadRequestTupleKey>,
196        consistency: impl Into<ConsistencyPreference>,
197        page_size: i32,
198        max_pages: u32,
199    ) -> Result<Vec<Tuple>> {
200        let mut continuation_token = String::new();
201        let mut tuples = Vec::new();
202        let mut count = 0;
203        let tuple = tuple.into();
204        let consistency = consistency.into();
205
206        loop {
207            let read_request = ReadRequest {
208                store_id: store_id.to_owned(),
209                tuple_key: Some(tuple.clone()),
210                page_size: Some(page_size),
211                continuation_token: continuation_token.clone(),
212                consistency: consistency.into(),
213            };
214            let response = self
215                .read(read_request.clone())
216                .await
217                .map_err(|e| {
218                    tracing::error!(
219                        "Failed to read from OpenFGA: {e}. Request: {:?}",
220                        read_request
221                    );
222                    Error::RequestFailed(e)
223                })?
224                .into_inner();
225            tuples.extend(response.tuples);
226            continuation_token.clone_from(&response.continuation_token);
227            if continuation_token.is_empty() || count > max_pages {
228                if count > max_pages {
229                    return Err(Error::TooManyPages { max_pages, tuple });
230                }
231                break;
232            }
233            count += 1;
234        }
235
236        Ok(tuples)
237    }
238}
239
240#[cfg(feature = "auth-middle")]
241pub type BasicAuthLayer = tower::util::Either<
242    tower::util::Either<
243        tonic::service::interceptor::InterceptedService<
244            Channel,
245            middle::BasicClientCredentialAuthorizer,
246        >,
247        tonic::service::interceptor::InterceptedService<Channel, middle::BearerTokenAuthorizer>,
248    >,
249    Channel,
250>;
251
252#[cfg(feature = "auth-middle")]
253type EitherOrOption = Option<
254    Either<
255        InterceptorLayer<middle::BasicClientCredentialAuthorizer>,
256        InterceptorLayer<middle::BearerTokenAuthorizer>,
257    >,
258>;
259
260#[cfg(feature = "auth-middle")]
261fn get_tonic_endpoint_logged(endpoint: &url::Url) -> Result<Endpoint> {
262    Endpoint::new(endpoint.to_string()).map_err(|e| {
263        tracing::error!("Could not construct OpenFGA client. Invalid endpoint `{endpoint}`: {e}");
264        Error::InvalidEndpoint(endpoint.to_string())
265    })
266}
267
268#[cfg(test)]
269pub(crate) mod test {
270    use needs_env_var::needs_env_var;
271
272    // #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
273    #[cfg(feature = "auth-middle")]
274    mod openfga {
275        use std::collections::HashSet;
276
277        use super::super::*;
278        use crate::{
279            client::{
280                TupleKey, WriteAuthorizationModelRequest, WriteAuthorizationModelResponse,
281                WriteRequest, WriteRequestWrites,
282            },
283            generated::AuthorizationModel,
284        };
285
286        fn get_basic_client() -> BasicOpenFgaServiceClient {
287            let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
288            BasicOpenFgaServiceClient::new_unauthenticated(url::Url::parse(&endpoint).unwrap())
289                .expect("Client can be created")
290        }
291
292        async fn new_store() -> (BasicOpenFgaServiceClient, Store) {
293            let mut client = get_basic_client();
294            let store_name = format!("store-{}", uuid::Uuid::now_v7());
295            let store = client
296                .get_or_create_store(&store_name)
297                .await
298                .expect("Store can be created");
299            (client, store)
300        }
301
302        async fn create_entitlements_model(
303            client: &mut BasicOpenFgaServiceClient,
304            store: &Store,
305        ) -> WriteAuthorizationModelResponse {
306            let schema = include_str!("../tests/sample-store/entitlements/schema.json");
307            let model: AuthorizationModel =
308                serde_json::from_str(schema).expect("Schema can be deserialized");
309            let auth_model = client
310                .write_authorization_model(WriteAuthorizationModelRequest {
311                    store_id: store.id.clone(),
312                    type_definitions: model.type_definitions,
313                    schema_version: model.schema_version,
314                    conditions: model.conditions,
315                })
316                .await
317                .expect("Auth model can be written");
318
319            auth_model.into_inner()
320        }
321
322        #[tokio::test]
323        async fn test_get_store_by_name_non_existant() {
324            let mut client = get_basic_client();
325            let store = client
326                .get_store_by_name("non-existent-store")
327                .await
328                .unwrap();
329            assert!(store.is_none());
330        }
331
332        #[tokio::test]
333        async fn test_get_or_create_store() {
334            let mut client = get_basic_client();
335            let store_name = format!("store-{}", uuid::Uuid::now_v7());
336            let store = client.get_or_create_store(&store_name).await.unwrap();
337            let store2 = client.get_or_create_store(&store_name).await.unwrap();
338            assert_eq!(store.id, store2.id);
339        }
340
341        #[tokio::test]
342        async fn test_read_all_pages_many() {
343            let (mut client, store) = new_store().await;
344            let auth_model = create_entitlements_model(&mut client, &store).await;
345            let object = "organization:org-1";
346
347            let users = (0..501)
348                .map(|i| format!("user:u-{i}"))
349                .collect::<Vec<String>>();
350
351            for user in &users {
352                client
353                    .write(WriteRequest {
354                        authorization_model_id: auth_model.authorization_model_id.clone(),
355                        store_id: store.id.clone(),
356                        writes: Some(WriteRequestWrites {
357                            tuple_keys: vec![TupleKey {
358                                user: user.to_string(),
359                                relation: "member".to_string(),
360                                object: object.to_string(),
361                                condition: None,
362                            }],
363                        }),
364                        deletes: None,
365                    })
366                    .await
367                    .expect("Write can be done");
368            }
369
370            let tuples = client
371                .read_all_pages(
372                    &store.id,
373                    ReadRequestTupleKey {
374                        user: String::new(),
375                        relation: "member".to_string(),
376                        object: object.to_string(),
377                    },
378                    ConsistencyPreference::HigherConsistency,
379                    100,
380                    6,
381                )
382                .await
383                .expect("Read can be done");
384
385            assert_eq!(tuples.len(), 501);
386            assert_eq!(
387                tuples
388                    .iter()
389                    .map(|t| t.key.clone().unwrap().user)
390                    .collect::<HashSet<String>>(),
391                HashSet::from_iter(users)
392            );
393        }
394
395        #[tokio::test]
396        async fn test_real_all_pages_empty() {
397            let (mut client, store) = new_store().await;
398            let tuples = client
399                .read_all_pages(
400                    &store.id,
401                    ReadRequestTupleKey {
402                        user: String::new(),
403                        relation: "member".to_string(),
404                        object: "organization:org-1".to_string(),
405                    },
406                    ConsistencyPreference::HigherConsistency,
407                    100,
408                    5,
409                )
410                .await
411                .expect("Read can be done");
412
413            assert!(tuples.is_empty());
414        }
415    }
416}