1#![allow(unused_imports)]
2
3#[cfg(feature = "auth-middle")]
4use tonic::service::interceptor::InterceptedService;
5use tonic::{
6 codegen::{Body, Bytes, StdError},
7 service::interceptor::InterceptorLayer,
8 transport::{Channel, Endpoint},
9};
10#[cfg(feature = "auth-middle")]
11use tower::{ServiceBuilder, util::Either};
12
13use crate::{
14 client::{OpenFgaClient, OpenFgaServiceClient},
15 error::{Error, Result},
16 generated::{
17 ConsistencyPreference, CreateStoreRequest, ListStoresRequest, ReadRequest,
18 ReadRequestTupleKey, Store, Tuple,
19 },
20};
21
22#[cfg(feature = "auth-middle")]
23pub type BasicOpenFgaServiceClient = OpenFgaServiceClient<BasicAuthLayer>;
28
29#[cfg(feature = "auth-middle")]
30impl BasicOpenFgaServiceClient {
31 pub fn new_unauthenticated(endpoint: impl Into<url::Url>) -> Result<Self> {
36 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
37 let channel = endpoint.connect_lazy();
38 let intercepted = InterceptedService::new(channel, NoOpInterceptor);
39 let service = Either::Right(intercepted);
40 Ok(BasicOpenFgaServiceClient::new(service))
41 }
42
43 pub fn new_with_basic_auth(endpoint: impl Into<url::Url>, token: &str) -> Result<Self> {
49 let authorizer = middle::BearerTokenAuthorizer::new(token).map_err(|e| {
50 tracing::error!("Could not construct OpenFGA client. Invalid token: {e}");
51 Error::InvalidToken {
52 reason: e.to_string(),
53 }
54 })?;
55 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
56 let channel = endpoint.connect_lazy();
57 let intercepted = InterceptedService::new(channel, authorizer);
58 let service = Either::Left(Either::Right(intercepted));
59 Ok(BasicOpenFgaServiceClient::new(service))
60 }
61
62 pub async fn new_with_client_credentials(
68 endpoint: impl Into<url::Url>,
69 client_id: &str,
70 client_secret: &str,
71 token_endpoint: impl Into<url::Url>,
72 scopes: &[&str],
73 ) -> Result<Self> {
74 let builder = middle::BasicClientCredentialAuthorizer::basic_builder(
75 client_id,
76 client_secret,
77 token_endpoint.into(),
78 );
79 let authorizer = if scopes.is_empty() {
80 builder
81 } else {
82 builder.add_scopes(scopes)
83 }
84 .build()
85 .await
86 .map_err(|e| {
87 tracing::error!("Could not construct OpenFGA client. Failed to fetch or refresh Client Credentials: {e}");
88 Error::CredentialRefreshError(e)
89 })?;
90 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
91 let channel = endpoint.connect_lazy();
92 let intercepted = InterceptedService::new(channel, authorizer);
93 let service = Either::Left(Either::Left(intercepted));
94 Ok(BasicOpenFgaServiceClient::new(service))
95 }
96}
97
98impl<T> OpenFgaServiceClient<T>
99where
100 T: tonic::client::GrpcService<tonic::body::Body>,
101 T::Error: Into<StdError>,
102 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
103 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
104 T: Clone,
105{
106 pub fn into_client(self, store_id: &str, authorization_model_id: &str) -> OpenFgaClient<T> {
108 OpenFgaClient::new(self, store_id, authorization_model_id)
109 }
110
111 pub async fn get_store_by_name(&mut self, store_name: &str) -> Result<Option<Store>> {
118 let stores = self
119 .list_stores(ListStoresRequest {
120 page_size: Some(2),
121 continuation_token: String::new(),
122 name: store_name.to_string(),
123 })
124 .await
125 .map_err(|e| {
126 tracing::error!("Failed to list stores in OpenFGA: {e}");
127 Error::RequestFailed(Box::new(e))
128 })?
129 .into_inner();
130 let num_stores = stores.stores.len();
131
132 match stores.stores.first() {
133 Some(store) => {
134 if num_stores > 1 {
135 tracing::error!("Multiple stores with the name `{}` found", store_name);
136 Err(Error::AmbiguousStoreName(store_name.to_string()))
137 } else {
138 Ok(Some(store.clone()))
139 }
140 }
141 None => Ok(None),
142 }
143 }
144
145 pub async fn get_or_create_store(&mut self, store_name: &str) -> Result<Store> {
152 let store = self.get_store_by_name(store_name).await?;
153 match store {
154 None => {
155 tracing::debug!("OpenFGA Store {} not found. Creating it.", store_name);
156 let store = self
157 .create_store(CreateStoreRequest {
158 name: store_name.to_owned(),
159 })
160 .await
161 .map_err(|e| {
162 tracing::error!("Failed to create store in OpenFGA: {e}");
163 Error::RequestFailed(Box::new(e))
164 })?
165 .into_inner();
166 Ok(Store {
167 id: store.id,
168 name: store.name,
169 created_at: store.created_at,
170 updated_at: store.updated_at,
171 deleted_at: None,
172 })
173 }
174 Some(store) => Ok(store),
175 }
176 }
177
178 pub async fn read_all_pages(
184 &mut self,
185 store_id: &str,
186 tuple: Option<impl Into<ReadRequestTupleKey>>,
187 consistency: impl Into<ConsistencyPreference>,
188 page_size: i32,
189 max_pages: u32,
190 ) -> Result<Vec<Tuple>> {
191 let mut continuation_token = String::new();
192 let tuple = tuple.map(Into::into);
193 let mut tuples = Vec::new();
194 let mut count = 0;
195 let consistency = consistency.into();
196
197 loop {
198 let read_request = ReadRequest {
199 store_id: store_id.to_owned(),
200 tuple_key: tuple.clone(),
201 page_size: Some(page_size),
202 continuation_token: continuation_token.clone(),
203 consistency: consistency.into(),
204 };
205 let response = self
206 .read(read_request.clone())
207 .await
208 .map_err(|e| {
209 tracing::error!(
210 "Failed to read from OpenFGA: {e}. Request: {:?}",
211 read_request
212 );
213 Error::RequestFailed(Box::new(e))
214 })?
215 .into_inner();
216 tuples.extend(response.tuples);
217 continuation_token.clone_from(&response.continuation_token);
218 if continuation_token.is_empty() || count > max_pages {
219 if count > max_pages {
220 return Err(Error::TooManyPages { max_pages, tuple });
221 }
222 break;
223 }
224 count += 1;
225 }
226
227 Ok(tuples)
228 }
229}
230
231#[cfg(feature = "auth-middle")]
232pub type BasicAuthLayer = tower::util::Either<
233 tower::util::Either<
234 InterceptedService<Channel, middle::BasicClientCredentialAuthorizer>,
235 InterceptedService<Channel, middle::BearerTokenAuthorizer>,
236 >,
237 InterceptedService<Channel, NoOpInterceptor>,
238>;
239
240#[cfg(feature = "auth-middle")]
241#[derive(Clone, Copy, Debug)]
242pub struct NoOpInterceptor;
243
244#[cfg(feature = "auth-middle")]
245impl tonic::service::Interceptor for NoOpInterceptor {
246 fn call(
247 &mut self,
248 request: tonic::Request<()>,
249 ) -> std::result::Result<tonic::Request<()>, tonic::Status> {
250 Ok(request)
251 }
252}
253
254#[cfg(feature = "auth-middle")]
255fn get_tonic_endpoint_logged(endpoint: &url::Url) -> Result<Endpoint> {
256 let ep = Endpoint::new(endpoint.to_string()).map_err(|e| {
257 tracing::error!("Could not construct OpenFGA client. Invalid endpoint `{endpoint}`: {e}");
258 Error::InvalidEndpoint(endpoint.to_string())
259 })?;
260
261 if endpoint.scheme() == "https" {
263 #[cfg(feature = "tls-rustls")]
264 {
265 use tonic::transport::ClientTlsConfig;
266 let tls_config = ClientTlsConfig::new().with_enabled_roots();
267 return ep.tls_config(tls_config).map_err(|e| {
268 tracing::error!(
269 "Could not configure TLS for OpenFGA client endpoint `{endpoint}`: {e}"
270 );
271 Error::TlsConfigurationFailed {
272 endpoint: endpoint.to_string(),
273 reason: e.to_string(),
274 }
275 });
276 }
277 #[cfg(not(feature = "tls-rustls"))]
278 {
279 return Err(Error::TlsConfigurationFailed {
280 endpoint: endpoint.to_string(),
281 reason: "HTTPS endpoint requires the `tls-rustls` feature to be enabled"
282 .to_string(),
283 });
284 }
285 }
286
287 Ok(ep)
288}
289
290#[cfg(test)]
291pub(crate) mod test {
292 use needs_env_var::needs_env_var;
293
294 #[cfg(feature = "auth-middle")]
296 mod openfga {
297 use std::collections::{HashMap, HashSet};
298
299 use super::super::*;
300 use crate::{
301 client::{
302 TupleKey, WriteAuthorizationModelRequest, WriteAuthorizationModelResponse,
303 WriteRequest, WriteRequestWrites,
304 },
305 generated::AuthorizationModel,
306 };
307
308 fn get_basic_client() -> BasicOpenFgaServiceClient {
309 let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
310 BasicOpenFgaServiceClient::new_unauthenticated(url::Url::parse(&endpoint).unwrap())
311 .expect("Client can be created")
312 }
313
314 async fn new_store() -> (BasicOpenFgaServiceClient, Store) {
315 let mut client = get_basic_client();
316 let store_name = format!("store-{}", uuid::Uuid::now_v7());
317 let store = client
318 .get_or_create_store(&store_name)
319 .await
320 .expect("Store can be created");
321 (client, store)
322 }
323
324 async fn create_entitlements_model(
325 client: &mut BasicOpenFgaServiceClient,
326 store: &Store,
327 ) -> WriteAuthorizationModelResponse {
328 let schema = include_str!("../tests/sample-store/entitlements/schema.json");
329 let model: AuthorizationModel =
330 serde_json::from_str(schema).expect("Schema can be deserialized");
331 let auth_model = client
332 .write_authorization_model(WriteAuthorizationModelRequest {
333 store_id: store.id.clone(),
334 type_definitions: model.type_definitions,
335 schema_version: model.schema_version,
336 conditions: model.conditions,
337 })
338 .await
339 .expect("Auth model can be written");
340
341 auth_model.into_inner()
342 }
343
344 #[tokio::test]
345 async fn test_get_store_by_name_many() {
346 let mut client = get_basic_client();
347
348 let mut stores = HashMap::new();
349 for _i in 0..201 {
350 let store_name = format!("store-{}", uuid::Uuid::now_v7());
351 let r = client
352 .get_or_create_store(&store_name)
353 .await
354 .expect("Store can be created");
355 assert_eq!(store_name, r.name);
356 stores.insert(store_name, r.id);
357 }
358
359 for (store_name, store_id) in stores {
360 let store = client
361 .get_store_by_name(&store_name)
362 .await
363 .expect("Store can be fetched")
364 .expect("Store exists");
365 assert_eq!(store_id, store.id);
366 }
367 }
368
369 #[tokio::test]
370 async fn test_get_store_by_name_non_existant() {
371 let mut client = get_basic_client();
372 let store = client
373 .get_store_by_name("non-existent-store")
374 .await
375 .unwrap();
376 assert!(store.is_none());
377 }
378
379 #[tokio::test]
380 async fn test_get_or_create_store() {
381 let mut client = get_basic_client();
382 let store_name = format!("store-{}", uuid::Uuid::now_v7());
383 let store = client.get_or_create_store(&store_name).await.unwrap();
384 let store2 = client.get_or_create_store(&store_name).await.unwrap();
385 assert_eq!(store.id, store2.id);
386 }
387
388 #[tokio::test]
389 async fn test_read_all_pages_many() {
390 let (mut client, store) = new_store().await;
391 let auth_model = create_entitlements_model(&mut client, &store).await;
392 let object = "organization:org-1";
393
394 let users = (0..501)
395 .map(|i| format!("user:u-{i}"))
396 .collect::<Vec<String>>();
397
398 for user in &users {
399 client
400 .write(WriteRequest {
401 authorization_model_id: auth_model.authorization_model_id.clone(),
402 store_id: store.id.clone(),
403 writes: Some(WriteRequestWrites {
404 on_duplicate: String::new(),
405 tuple_keys: vec![TupleKey {
406 user: user.clone(),
407 relation: "member".to_string(),
408 object: object.to_string(),
409 condition: None,
410 }],
411 }),
412 deletes: None,
413 })
414 .await
415 .expect("Write can be done");
416 }
417
418 let tuples = client
419 .read_all_pages(
420 &store.id,
421 Some(ReadRequestTupleKey {
422 user: String::new(),
423 relation: "member".to_string(),
424 object: object.to_string(),
425 }),
426 ConsistencyPreference::HigherConsistency,
427 100,
428 6,
429 )
430 .await
431 .expect("Read can be done");
432
433 assert_eq!(tuples.len(), 501);
434 assert_eq!(
435 tuples
436 .iter()
437 .map(|t| t.key.clone().unwrap().user)
438 .collect::<HashSet<String>>(),
439 HashSet::from_iter(users)
440 );
441 }
442
443 #[tokio::test]
444 async fn test_real_all_pages_empty() {
445 let (mut client, store) = new_store().await;
446 let tuples = client
447 .read_all_pages(
448 &store.id,
449 Some(ReadRequestTupleKey {
450 user: String::new(),
451 relation: "member".to_string(),
452 object: "organization:org-1".to_string(),
453 }),
454 ConsistencyPreference::HigherConsistency,
455 100,
456 5,
457 )
458 .await
459 .expect("Read can be done");
460
461 assert!(tuples.is_empty());
462 }
463 }
464}