openfga_client/
client_ext.rs

1#![allow(unused_imports)]
2
3#[cfg(feature = "auth-middle")]
4use tonic::service::interceptor::InterceptedService;
5use tonic::{
6    codegen::{Body, Bytes, StdError},
7    service::interceptor::InterceptorLayer,
8    transport::{Channel, Endpoint},
9};
10#[cfg(feature = "auth-middle")]
11use tower::{ServiceBuilder, util::Either};
12
13use crate::{
14    client::{OpenFgaClient, OpenFgaServiceClient},
15    error::{Error, Result},
16    generated::{
17        ConsistencyPreference, CreateStoreRequest, ListStoresRequest, ReadRequest,
18        ReadRequestTupleKey, Store, Tuple,
19    },
20};
21
22#[cfg(feature = "auth-middle")]
23/// Specialization of the [`OpenFgaServiceClient`] that includes optional
24/// authentication with pre-shared keys (Bearer tokens) or client credentials.
25/// For more fine-granular control, you can construct [`OpenFgaServiceClient`] directly
26/// using interceptors for Authentication.
27pub type BasicOpenFgaServiceClient = OpenFgaServiceClient<BasicAuthLayer>;
28
29#[cfg(feature = "auth-middle")]
30impl BasicOpenFgaServiceClient {
31    /// Create a new client without authentication.
32    ///
33    /// # Errors
34    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
35    pub fn new_unauthenticated(endpoint: impl Into<url::Url>) -> Result<Self> {
36        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
37        let channel = endpoint.connect_lazy();
38        let intercepted = InterceptedService::new(channel, NoOpInterceptor);
39        let service = Either::Right(intercepted);
40        Ok(BasicOpenFgaServiceClient::new(service))
41    }
42
43    /// Create a new client without authentication.
44    ///
45    /// # Errors
46    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
47    /// * [`Error::InvalidToken`] if the token is not valid ASCII.
48    pub fn new_with_basic_auth(endpoint: impl Into<url::Url>, token: &str) -> Result<Self> {
49        let authorizer = middle::BearerTokenAuthorizer::new(token).map_err(|e| {
50            tracing::error!("Could not construct OpenFGA client. Invalid token: {e}");
51            Error::InvalidToken {
52                reason: e.to_string(),
53            }
54        })?;
55        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
56        let channel = endpoint.connect_lazy();
57        let intercepted = InterceptedService::new(channel, authorizer);
58        let service = Either::Left(Either::Right(intercepted));
59        Ok(BasicOpenFgaServiceClient::new(service))
60    }
61
62    /// Create a new client using client credentials.
63    ///
64    /// # Errors
65    /// * [`Error::InvalidEndpoint`] if the endpoint is not a valid URL.
66    /// * [`Error::CredentialRefreshError`] if the client credentials could not be exchanged for a token.
67    pub async fn new_with_client_credentials(
68        endpoint: impl Into<url::Url>,
69        client_id: &str,
70        client_secret: &str,
71        token_endpoint: impl Into<url::Url>,
72        scopes: &[&str],
73    ) -> Result<Self> {
74        let builder = middle::BasicClientCredentialAuthorizer::basic_builder(
75            client_id,
76            client_secret,
77            token_endpoint.into(),
78        );
79        let authorizer = if scopes.is_empty() {
80            builder
81        } else {
82            builder.add_scopes(scopes)
83        }
84        .build()
85        .await
86        .map_err(|e| {
87            tracing::error!("Could not construct OpenFGA client. Failed to fetch or refresh Client Credentials: {e}");
88            Error::CredentialRefreshError(e)
89        })?;
90        let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
91        let channel = endpoint.connect_lazy();
92        let intercepted = InterceptedService::new(channel, authorizer);
93        let service = Either::Left(Either::Left(intercepted));
94        Ok(BasicOpenFgaServiceClient::new(service))
95    }
96}
97
98impl<T> OpenFgaServiceClient<T>
99where
100    T: tonic::client::GrpcService<tonic::body::Body>,
101    T::Error: Into<StdError>,
102    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
103    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
104    T: Clone,
105{
106    /// Transform this service client into a higher-level [`OpenFgaClient`].
107    pub fn into_client(self, store_id: &str, authorization_model_id: &str) -> OpenFgaClient<T> {
108        OpenFgaClient::new(self, store_id, authorization_model_id)
109    }
110
111    /// Fetch a store by name.
112    /// If no store is found, returns `Ok(None)`.
113    ///
114    /// # Errors
115    /// * [`Error::AmbiguousStoreName`] if multiple stores with the same name are found.
116    /// * [`Error::RequestFailed`] if the request to OpenFGA fails.
117    pub async fn get_store_by_name(&mut self, store_name: &str) -> Result<Option<Store>> {
118        let stores = self
119            .list_stores(ListStoresRequest {
120                page_size: Some(2),
121                continuation_token: String::new(),
122                name: store_name.to_string(),
123            })
124            .await
125            .map_err(|e| {
126                tracing::error!("Failed to list stores in OpenFGA: {e}");
127                Error::RequestFailed(Box::new(e))
128            })?
129            .into_inner();
130        let num_stores = stores.stores.len();
131
132        match stores.stores.first() {
133            Some(store) => {
134                if num_stores > 1 {
135                    tracing::error!("Multiple stores with the name `{}` found", store_name);
136                    Err(Error::AmbiguousStoreName(store_name.to_string()))
137                } else {
138                    Ok(Some(store.clone()))
139                }
140            }
141            None => Ok(None),
142        }
143    }
144
145    /// Get a store by name or create it if it doesn't exist.
146    /// Returns information about the store, including its ID.
147    ///
148    /// # Errors
149    /// * [`Error::RequestFailed`] If a request to OpenFGA fails.
150    /// * [`Error::AmbiguousStoreName`] If multiple stores with the same name are found.
151    pub async fn get_or_create_store(&mut self, store_name: &str) -> Result<Store> {
152        let store = self.get_store_by_name(store_name).await?;
153        match store {
154            None => {
155                tracing::debug!("OpenFGA Store {} not found. Creating it.", store_name);
156                let store = self
157                    .create_store(CreateStoreRequest {
158                        name: store_name.to_owned(),
159                    })
160                    .await
161                    .map_err(|e| {
162                        tracing::error!("Failed to create store in OpenFGA: {e}");
163                        Error::RequestFailed(Box::new(e))
164                    })?
165                    .into_inner();
166                Ok(Store {
167                    id: store.id,
168                    name: store.name,
169                    created_at: store.created_at,
170                    updated_at: store.updated_at,
171                    deleted_at: None,
172                })
173            }
174            Some(store) => Ok(store),
175        }
176    }
177
178    /// Wrapper around [`Self::read`] that reads all pages of the result, handling pagination.
179    ///
180    /// # Errors
181    /// * [`Error::RequestFailed`] If a request to OpenFGA fails.
182    /// * [`Error::TooManyPages`] If the number of pages read exceeds `max_pages`.
183    pub async fn read_all_pages(
184        &mut self,
185        store_id: &str,
186        tuple: Option<impl Into<ReadRequestTupleKey>>,
187        consistency: impl Into<ConsistencyPreference>,
188        page_size: i32,
189        max_pages: u32,
190    ) -> Result<Vec<Tuple>> {
191        let mut continuation_token = String::new();
192        let tuple = tuple.map(Into::into);
193        let mut tuples = Vec::new();
194        let mut count = 0;
195        let consistency = consistency.into();
196
197        loop {
198            let read_request = ReadRequest {
199                store_id: store_id.to_owned(),
200                tuple_key: tuple.clone(),
201                page_size: Some(page_size),
202                continuation_token: continuation_token.clone(),
203                consistency: consistency.into(),
204            };
205            let response = self
206                .read(read_request.clone())
207                .await
208                .map_err(|e| {
209                    tracing::error!(
210                        "Failed to read from OpenFGA: {e}. Request: {:?}",
211                        read_request
212                    );
213                    Error::RequestFailed(Box::new(e))
214                })?
215                .into_inner();
216            tuples.extend(response.tuples);
217            continuation_token.clone_from(&response.continuation_token);
218            if continuation_token.is_empty() || count > max_pages {
219                if count > max_pages {
220                    return Err(Error::TooManyPages { max_pages, tuple });
221                }
222                break;
223            }
224            count += 1;
225        }
226
227        Ok(tuples)
228    }
229}
230
231#[cfg(feature = "auth-middle")]
232pub type BasicAuthLayer = tower::util::Either<
233    tower::util::Either<
234        InterceptedService<Channel, middle::BasicClientCredentialAuthorizer>,
235        InterceptedService<Channel, middle::BearerTokenAuthorizer>,
236    >,
237    InterceptedService<Channel, NoOpInterceptor>,
238>;
239
240#[cfg(feature = "auth-middle")]
241#[derive(Clone, Copy, Debug)]
242pub struct NoOpInterceptor;
243
244#[cfg(feature = "auth-middle")]
245impl tonic::service::Interceptor for NoOpInterceptor {
246    fn call(
247        &mut self,
248        request: tonic::Request<()>,
249    ) -> std::result::Result<tonic::Request<()>, tonic::Status> {
250        Ok(request)
251    }
252}
253
254#[cfg(feature = "auth-middle")]
255fn get_tonic_endpoint_logged(endpoint: &url::Url) -> Result<Endpoint> {
256    let ep = Endpoint::new(endpoint.to_string()).map_err(|e| {
257        tracing::error!("Could not construct OpenFGA client. Invalid endpoint `{endpoint}`: {e}");
258        Error::InvalidEndpoint(endpoint.to_string())
259    })?;
260
261    // Configure TLS if the endpoint uses HTTPS
262    if endpoint.scheme() == "https" {
263        #[cfg(feature = "tls-rustls")]
264        {
265            use tonic::transport::ClientTlsConfig;
266            let tls_config = ClientTlsConfig::new().with_enabled_roots();
267            return ep.tls_config(tls_config).map_err(|e| {
268                tracing::error!(
269                    "Could not configure TLS for OpenFGA client endpoint `{endpoint}`: {e}"
270                );
271                Error::TlsConfigurationFailed {
272                    endpoint: endpoint.to_string(),
273                    reason: e.to_string(),
274                }
275            });
276        }
277        #[cfg(not(feature = "tls-rustls"))]
278        {
279            return Err(Error::TlsConfigurationFailed {
280                endpoint: endpoint.to_string(),
281                reason: "HTTPS endpoint requires the `tls-rustls` feature to be enabled"
282                    .to_string(),
283            });
284        }
285    }
286
287    Ok(ep)
288}
289
290#[cfg(test)]
291pub(crate) mod test {
292    use needs_env_var::needs_env_var;
293
294    // #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
295    #[cfg(feature = "auth-middle")]
296    mod openfga {
297        use std::collections::{HashMap, HashSet};
298
299        use super::super::*;
300        use crate::{
301            client::{
302                TupleKey, WriteAuthorizationModelRequest, WriteAuthorizationModelResponse,
303                WriteRequest, WriteRequestWrites,
304            },
305            generated::AuthorizationModel,
306        };
307
308        fn get_basic_client() -> BasicOpenFgaServiceClient {
309            let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
310            BasicOpenFgaServiceClient::new_unauthenticated(url::Url::parse(&endpoint).unwrap())
311                .expect("Client can be created")
312        }
313
314        async fn new_store() -> (BasicOpenFgaServiceClient, Store) {
315            let mut client = get_basic_client();
316            let store_name = format!("store-{}", uuid::Uuid::now_v7());
317            let store = client
318                .get_or_create_store(&store_name)
319                .await
320                .expect("Store can be created");
321            (client, store)
322        }
323
324        async fn create_entitlements_model(
325            client: &mut BasicOpenFgaServiceClient,
326            store: &Store,
327        ) -> WriteAuthorizationModelResponse {
328            let schema = include_str!("../tests/sample-store/entitlements/schema.json");
329            let model: AuthorizationModel =
330                serde_json::from_str(schema).expect("Schema can be deserialized");
331            let auth_model = client
332                .write_authorization_model(WriteAuthorizationModelRequest {
333                    store_id: store.id.clone(),
334                    type_definitions: model.type_definitions,
335                    schema_version: model.schema_version,
336                    conditions: model.conditions,
337                })
338                .await
339                .expect("Auth model can be written");
340
341            auth_model.into_inner()
342        }
343
344        #[tokio::test]
345        async fn test_get_store_by_name_many() {
346            let mut client = get_basic_client();
347
348            let mut stores = HashMap::new();
349            for _i in 0..201 {
350                let store_name = format!("store-{}", uuid::Uuid::now_v7());
351                let r = client
352                    .get_or_create_store(&store_name)
353                    .await
354                    .expect("Store can be created");
355                assert_eq!(store_name, r.name);
356                stores.insert(store_name, r.id);
357            }
358
359            for (store_name, store_id) in stores {
360                let store = client
361                    .get_store_by_name(&store_name)
362                    .await
363                    .expect("Store can be fetched")
364                    .expect("Store exists");
365                assert_eq!(store_id, store.id);
366            }
367        }
368
369        #[tokio::test]
370        async fn test_get_store_by_name_non_existant() {
371            let mut client = get_basic_client();
372            let store = client
373                .get_store_by_name("non-existent-store")
374                .await
375                .unwrap();
376            assert!(store.is_none());
377        }
378
379        #[tokio::test]
380        async fn test_get_or_create_store() {
381            let mut client = get_basic_client();
382            let store_name = format!("store-{}", uuid::Uuid::now_v7());
383            let store = client.get_or_create_store(&store_name).await.unwrap();
384            let store2 = client.get_or_create_store(&store_name).await.unwrap();
385            assert_eq!(store.id, store2.id);
386        }
387
388        #[tokio::test]
389        async fn test_read_all_pages_many() {
390            let (mut client, store) = new_store().await;
391            let auth_model = create_entitlements_model(&mut client, &store).await;
392            let object = "organization:org-1";
393
394            let users = (0..501)
395                .map(|i| format!("user:u-{i}"))
396                .collect::<Vec<String>>();
397
398            for user in &users {
399                client
400                    .write(WriteRequest {
401                        authorization_model_id: auth_model.authorization_model_id.clone(),
402                        store_id: store.id.clone(),
403                        writes: Some(WriteRequestWrites {
404                            on_duplicate: String::new(),
405                            tuple_keys: vec![TupleKey {
406                                user: user.clone(),
407                                relation: "member".to_string(),
408                                object: object.to_string(),
409                                condition: None,
410                            }],
411                        }),
412                        deletes: None,
413                    })
414                    .await
415                    .expect("Write can be done");
416            }
417
418            let tuples = client
419                .read_all_pages(
420                    &store.id,
421                    Some(ReadRequestTupleKey {
422                        user: String::new(),
423                        relation: "member".to_string(),
424                        object: object.to_string(),
425                    }),
426                    ConsistencyPreference::HigherConsistency,
427                    100,
428                    6,
429                )
430                .await
431                .expect("Read can be done");
432
433            assert_eq!(tuples.len(), 501);
434            assert_eq!(
435                tuples
436                    .iter()
437                    .map(|t| t.key.clone().unwrap().user)
438                    .collect::<HashSet<String>>(),
439                HashSet::from_iter(users)
440            );
441        }
442
443        #[tokio::test]
444        async fn test_real_all_pages_empty() {
445            let (mut client, store) = new_store().await;
446            let tuples = client
447                .read_all_pages(
448                    &store.id,
449                    Some(ReadRequestTupleKey {
450                        user: String::new(),
451                        relation: "member".to_string(),
452                        object: "organization:org-1".to_string(),
453                    }),
454                    ConsistencyPreference::HigherConsistency,
455                    100,
456                    5,
457                )
458                .await
459                .expect("Read can be done");
460
461            assert!(tuples.is_empty());
462        }
463    }
464}