openfga_client/
client_ext.rs

1#![allow(unused_imports)]
2
3use tonic::{
4    codegen::{Body, Bytes, StdError},
5    service::interceptor::InterceptorLayer,
6    transport::{Channel, Endpoint},
7};
8#[cfg(feature = "auth-middle")]
9use tower::{util::Either, ServiceBuilder};
10
11use crate::{
12    client::OpenFgaServiceClient,
13    error::{Error, Result},
14    generated::{
15        ConsistencyPreference, CreateStoreRequest, ListStoresRequest, ReadRequest,
16        ReadRequestTupleKey, Store, Tuple,
17    },
18};
19
20#[cfg(feature = "auth-middle")]
21/// Specialization of the [`OpenFgaServiceClient`] that includes optional
22/// authentication with pre-shared keys (Bearer tokens) or client credentials.
23/// For more fine-granular control, you can construct [`OpenFgaServiceClient`] directly
24/// using interceptors for Authentication.
25pub type BasicOpenFgaServiceClient = OpenFgaServiceClient<AuthLayer>;
26
27#[cfg(feature = "auth-middle")]
28impl BasicOpenFgaServiceClient {
29    /// Create a new client without authentication.
30    ///
31    /// # Errors
32    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
33    pub fn new_unauthenticated(endpoint: &str) -> Result<Self> {
34        let either_or_option: EitherOrOption = None;
35        let auth_layer = tower::util::option_layer(either_or_option);
36        let endpoint = get_tonic_endpoint_logged(endpoint)?;
37        let c = ServiceBuilder::new()
38            .layer(auth_layer)
39            .service(endpoint.connect_lazy());
40        Ok(BasicOpenFgaServiceClient::new(c))
41    }
42
43    /// Create a new client without authentication.
44    ///
45    /// # Errors
46    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
47    /// * [`Error::InvalidToken`] if the token is not valid ASCII.
48    pub fn new_with_basic_auth(endpoint: &str, token: &str) -> Result<Self> {
49        let either_or_option: EitherOrOption =
50            Some(tower::util::Either::Right(tonic::service::interceptor(
51                middle::BearerTokenAuthorizer::new(token).map_err(|e| {
52                    tracing::error!("Could not construct OpenFGA client. Invalid token: {e}");
53                    Error::InvalidToken {
54                        reason: e.to_string(),
55                    }
56                })?,
57            )));
58        let auth_layer = tower::util::option_layer(either_or_option);
59        let endpoint = get_tonic_endpoint_logged(endpoint)?;
60        let c = ServiceBuilder::new()
61            .layer(auth_layer)
62            .service(endpoint.connect_lazy());
63        Ok(BasicOpenFgaServiceClient::new(c))
64    }
65
66    /// Create a new client using client credentials.
67    ///
68    /// # Errors
69    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
70    /// * [`Error::CredentialRefreshError`] if the client credentials could not be exchanged for a token.
71    pub async fn new_with_client_credentials(
72        endpoint: &str,
73        client_id: &str,
74        client_secret: &str,
75        token_endpoint: &url::Url,
76    ) -> Result<Self> {
77        let either_or_option: EitherOrOption =
78            Some(tower::util::Either::Left(tonic::service::interceptor(
79                middle::BasicClientCredentialAuthorizer::basic_builder(
80                    client_id,
81                    client_secret,
82                    token_endpoint.clone(),
83                )
84                .build()
85                .await.map_err(|e| {
86                    tracing::error!("Could not construct OpenFGA client. Failed to fetch or refresh Client Credentials: {e}");
87                    Error::CredentialRefreshError(e)
88                })?,
89            )));
90        let auth_layer = tower::util::option_layer(either_or_option);
91        let endpoint = get_tonic_endpoint_logged(endpoint)?;
92        let c = ServiceBuilder::new()
93            .layer(auth_layer)
94            .service(endpoint.connect_lazy());
95        Ok(BasicOpenFgaServiceClient::new(c))
96    }
97}
98
99impl<T> OpenFgaServiceClient<T>
100where
101    T: tonic::client::GrpcService<tonic::body::BoxBody>,
102    T::Error: Into<StdError>,
103    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
104    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
105{
106    /// Fetch a store by name.
107    /// If no store is found, returns `Ok(None)`.
108    ///
109    /// # Errors
110    /// * [`Error::AmbiguousStoreName`] if multiple stores with the same name are found.
111    /// * [`Error::RequestFailed`] if the request to OpenFGA fails.
112    pub async fn get_store_by_name(&mut self, store_name: &str) -> Result<Option<Store>> {
113        let stores = self
114            .list_stores(ListStoresRequest {
115                page_size: Some(2),
116                continuation_token: String::new(),
117                name: store_name.to_string(),
118            })
119            .await
120            .map_err(|e| {
121                tracing::error!("Failed to list stores in OpenFGA: {e}");
122                Error::RequestFailed(e)
123            })?
124            .into_inner();
125        let num_stores = stores.stores.len();
126
127        match stores.stores.first() {
128            Some(store) => {
129                if num_stores > 1 {
130                    tracing::error!("Multiple stores with the name `{}` found", store_name);
131                    Err(Error::AmbiguousStoreName(store_name.to_string()))
132                } else {
133                    Ok(Some(store.clone()))
134                }
135            }
136            None => Ok(None),
137        }
138    }
139
140    /// Get a store by name or create it if it doesn't exist.
141    /// Returns information about the store, including its ID.
142    ///
143    /// # Errors
144    /// * [`Error::RequestFailed`] If a request to OpenFGA fails.
145    /// * [`Error::AmbiguousStoreName`] If multiple stores with the same name are found.
146    pub async fn get_or_create_store(&mut self, store_name: &str) -> Result<Store> {
147        let store = self.get_store_by_name(store_name).await?;
148        match store {
149            None => {
150                tracing::debug!("OpenFGA Store {} not found. Creating it.", store_name);
151                let store = self
152                    .create_store(CreateStoreRequest {
153                        name: store_name.to_owned(),
154                    })
155                    .await
156                    .map_err(|e| {
157                        tracing::error!("Failed to create store in OpenFGA: {e}");
158                        Error::RequestFailed(e)
159                    })?
160                    .into_inner();
161                Ok(Store {
162                    id: store.id,
163                    name: store.name,
164                    created_at: store.created_at,
165                    updated_at: store.updated_at,
166                    deleted_at: None,
167                })
168            }
169            Some(store) => Ok(store),
170        }
171    }
172
173    /// Wrapper around [`Self::read`] that reads all pages of the result, handling pagination.
174    ///
175    /// # Errors
176    /// * [`Error::RequestFailed`] If a request to OpenFGA fails.
177    /// * [`Error::TooManyPages`] If the number of pages read exceeds `max_pages`.
178    pub async fn read_all_pages(
179        &mut self,
180        store_id: &str,
181        tuple: ReadRequestTupleKey,
182        consistency: ConsistencyPreference,
183        page_size: i32,
184        max_pages: u32,
185    ) -> Result<Vec<Tuple>> {
186        let mut continuation_token = String::new();
187        let mut tuples = Vec::new();
188        let mut count = 0;
189
190        loop {
191            let read_request = ReadRequest {
192                store_id: store_id.to_owned(),
193                tuple_key: Some(tuple.clone()),
194                page_size: Some(page_size),
195                continuation_token: continuation_token.clone(),
196                consistency: consistency.into(),
197            };
198            let response = self
199                .read(read_request.clone())
200                .await
201                .map_err(|e| {
202                    tracing::error!(
203                        "Failed to read from OpenFGA: {e}. Request: {:?}",
204                        read_request
205                    );
206                    Error::RequestFailed(e)
207                })?
208                .into_inner();
209            tuples.extend(response.tuples);
210            continuation_token.clone_from(&response.continuation_token);
211            if continuation_token.is_empty() || count > max_pages {
212                if count > max_pages {
213                    return Err(Error::TooManyPages { max_pages, tuple });
214                }
215                break;
216            }
217            count += 1;
218        }
219
220        Ok(tuples)
221    }
222}
223
224#[cfg(feature = "auth-middle")]
225type AuthLayer = tower::util::Either<
226    tower::util::Either<
227        tonic::service::interceptor::InterceptedService<
228            Channel,
229            middle::BasicClientCredentialAuthorizer,
230        >,
231        tonic::service::interceptor::InterceptedService<Channel, middle::BearerTokenAuthorizer>,
232    >,
233    Channel,
234>;
235
236#[cfg(feature = "auth-middle")]
237type EitherOrOption = Option<
238    Either<
239        InterceptorLayer<middle::BasicClientCredentialAuthorizer>,
240        InterceptorLayer<middle::BearerTokenAuthorizer>,
241    >,
242>;
243
244#[cfg(feature = "auth-middle")]
245fn get_tonic_endpoint_logged(endpoint: &str) -> Result<Endpoint> {
246    Endpoint::new(endpoint.to_string()).map_err(|e| {
247        tracing::error!("Could not construct OpenFGA client. Invalid endpoint `{endpoint}`: {e}");
248        Error::InvalidEndpoint(endpoint.to_string())
249    })
250}
251
252#[cfg(test)]
253pub(crate) mod test {
254    use needs_env_var::needs_env_var;
255
256    // #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
257    #[cfg(feature = "auth-middle")]
258    mod openfga {
259        use std::collections::HashSet;
260
261        use super::super::*;
262        use crate::{
263            client::{
264                TupleKey, WriteAuthorizationModelRequest, WriteAuthorizationModelResponse,
265                WriteRequest, WriteRequestWrites,
266            },
267            generated::AuthorizationModel,
268        };
269
270        fn get_basic_client() -> BasicOpenFgaServiceClient {
271            let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
272            BasicOpenFgaServiceClient::new_unauthenticated(&endpoint)
273                .expect("Client can be created")
274        }
275
276        async fn new_store() -> (BasicOpenFgaServiceClient, Store) {
277            let mut client = get_basic_client();
278            let store_name = format!("store-{}", uuid::Uuid::now_v7());
279            let store = client
280                .get_or_create_store(&store_name)
281                .await
282                .expect("Store can be created");
283            (client, store)
284        }
285
286        async fn create_entitlements_model(
287            client: &mut BasicOpenFgaServiceClient,
288            store: &Store,
289        ) -> WriteAuthorizationModelResponse {
290            let schema = include_str!("../tests/sample-store/entitlements/schema.json");
291            let model: AuthorizationModel =
292                serde_json::from_str(schema).expect("Schema can be deserialized");
293            let auth_model = client
294                .write_authorization_model(WriteAuthorizationModelRequest {
295                    store_id: store.id.clone(),
296                    type_definitions: model.type_definitions,
297                    schema_version: model.schema_version,
298                    conditions: model.conditions,
299                })
300                .await
301                .expect("Auth model can be written");
302
303            auth_model.into_inner()
304        }
305
306        #[tokio::test]
307        async fn test_get_store_by_name_non_existant() {
308            let mut client = get_basic_client();
309            let store = client
310                .get_store_by_name("non-existent-store")
311                .await
312                .unwrap();
313            assert!(store.is_none());
314        }
315
316        #[tokio::test]
317        async fn test_get_or_create_store() {
318            let mut client = get_basic_client();
319            let store_name = format!("store-{}", uuid::Uuid::now_v7());
320            let store = client.get_or_create_store(&store_name).await.unwrap();
321            let store2 = client.get_or_create_store(&store_name).await.unwrap();
322            assert_eq!(store.id, store2.id);
323        }
324
325        #[tokio::test]
326        async fn test_read_all_pages_many() {
327            let (mut client, store) = new_store().await;
328            let auth_model = create_entitlements_model(&mut client, &store).await;
329            let object = "organization:org-1";
330
331            let users = (0..501)
332                .map(|i| format!("user:u-{i}"))
333                .collect::<Vec<String>>();
334
335            for user in &users {
336                client
337                    .write(WriteRequest {
338                        authorization_model_id: auth_model.authorization_model_id.clone(),
339                        store_id: store.id.clone(),
340                        writes: Some(WriteRequestWrites {
341                            tuple_keys: vec![TupleKey {
342                                user: user.to_string(),
343                                relation: "member".to_string(),
344                                object: object.to_string(),
345                                condition: None,
346                            }],
347                        }),
348                        deletes: None,
349                    })
350                    .await
351                    .expect("Write can be done");
352            }
353
354            let tuples = client
355                .read_all_pages(
356                    &store.id,
357                    ReadRequestTupleKey {
358                        user: String::new(),
359                        relation: "member".to_string(),
360                        object: object.to_string(),
361                    },
362                    ConsistencyPreference::HigherConsistency,
363                    100,
364                    6,
365                )
366                .await
367                .expect("Read can be done");
368
369            assert_eq!(tuples.len(), 501);
370            assert_eq!(
371                HashSet::<String>::from_iter(tuples.iter().map(|t| t.key.clone().unwrap().user)),
372                HashSet::from_iter(users)
373            );
374        }
375
376        #[tokio::test]
377        async fn test_real_all_pages_empty() {
378            let (mut client, store) = new_store().await;
379            let tuples = client
380                .read_all_pages(
381                    &store.id,
382                    ReadRequestTupleKey {
383                        user: String::new(),
384                        relation: "member".to_string(),
385                        object: "organization:org-1".to_string(),
386                    },
387                    ConsistencyPreference::HigherConsistency,
388                    100,
389                    5,
390                )
391                .await
392                .expect("Read can be done");
393
394            assert!(tuples.is_empty());
395        }
396    }
397}