1#![allow(unused_imports)]
2
3use tonic::{
4 codegen::{Body, Bytes, StdError},
5 service::interceptor::InterceptorLayer,
6 transport::{Channel, Endpoint},
7};
8#[cfg(feature = "auth-middle")]
9use tower::{util::Either, ServiceBuilder};
10
11use crate::{
12 client::{OpenFgaClient, OpenFgaServiceClient},
13 error::{Error, Result},
14 generated::{
15 ConsistencyPreference, CreateStoreRequest, ListStoresRequest, ReadRequest,
16 ReadRequestTupleKey, Store, Tuple,
17 },
18};
19
20#[cfg(feature = "auth-middle")]
21pub type BasicOpenFgaServiceClient = OpenFgaServiceClient<BasicAuthLayer>;
26
27#[cfg(feature = "auth-middle")]
28impl BasicOpenFgaServiceClient {
29 pub fn new_unauthenticated(endpoint: impl Into<url::Url>) -> Result<Self> {
34 let either_or_option: EitherOrOption = None;
35 let auth_layer = tower::util::option_layer(either_or_option);
36 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
37 let c = ServiceBuilder::new()
38 .layer(auth_layer)
39 .service(endpoint.connect_lazy());
40 Ok(BasicOpenFgaServiceClient::new(c))
41 }
42
43 pub fn new_with_basic_auth(endpoint: impl Into<url::Url>, token: &str) -> Result<Self> {
49 let either_or_option: EitherOrOption =
50 Some(tower::util::Either::Right(tonic::service::interceptor(
51 middle::BearerTokenAuthorizer::new(token).map_err(|e| {
52 tracing::error!("Could not construct OpenFGA client. Invalid token: {e}");
53 Error::InvalidToken {
54 reason: e.to_string(),
55 }
56 })?,
57 )));
58 let auth_layer = tower::util::option_layer(either_or_option);
59 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
60 let c = ServiceBuilder::new()
61 .layer(auth_layer)
62 .service(endpoint.connect_lazy());
63 Ok(BasicOpenFgaServiceClient::new(c))
64 }
65
66 pub async fn new_with_client_credentials(
72 endpoint: impl Into<url::Url>,
73 client_id: &str,
74 client_secret: &str,
75 token_endpoint: impl Into<url::Url>,
76 scopes: &[&str],
77 ) -> Result<Self> {
78 let either_or_option: EitherOrOption =
79 Some(tower::util::Either::Left(tonic::service::interceptor(
80 {
81 let builder = middle::BasicClientCredentialAuthorizer::basic_builder(
82 client_id,
83 client_secret,
84 token_endpoint.into(),
85 );
86 if scopes.is_empty() {
87 builder
88 } else {
89 builder.add_scopes(scopes)
90 }
91 }
92 .build()
93 .await.map_err(|e| {
94 tracing::error!("Could not construct OpenFGA client. Failed to fetch or refresh Client Credentials: {e}");
95 Error::CredentialRefreshError(e)
96 })?,
97 )));
98 let auth_layer = tower::util::option_layer(either_or_option);
99 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
100 let c = ServiceBuilder::new()
101 .layer(auth_layer)
102 .service(endpoint.connect_lazy());
103 Ok(BasicOpenFgaServiceClient::new(c))
104 }
105}
106
107impl<T> OpenFgaServiceClient<T>
108where
109 T: tonic::client::GrpcService<tonic::body::BoxBody>,
110 T::Error: Into<StdError>,
111 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
112 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
113 T: Clone,
114{
115 pub fn into_client(self, store_id: &str, authorization_model_id: &str) -> OpenFgaClient<T> {
117 OpenFgaClient::new(self, store_id, authorization_model_id)
118 }
119
120 pub async fn get_store_by_name(&mut self, store_name: &str) -> Result<Option<Store>> {
127 let stores = self
128 .list_stores(ListStoresRequest {
129 page_size: Some(2),
130 continuation_token: String::new(),
131 name: store_name.to_string(),
132 })
133 .await
134 .map_err(|e| {
135 tracing::error!("Failed to list stores in OpenFGA: {e}");
136 Error::RequestFailed(e)
137 })?
138 .into_inner();
139 let num_stores = stores.stores.len();
140
141 match stores.stores.first() {
142 Some(store) => {
143 if num_stores > 1 {
144 tracing::error!("Multiple stores with the name `{}` found", store_name);
145 Err(Error::AmbiguousStoreName(store_name.to_string()))
146 } else {
147 Ok(Some(store.clone()))
148 }
149 }
150 None => Ok(None),
151 }
152 }
153
154 pub async fn get_or_create_store(&mut self, store_name: &str) -> Result<Store> {
161 let store = self.get_store_by_name(store_name).await?;
162 match store {
163 None => {
164 tracing::debug!("OpenFGA Store {} not found. Creating it.", store_name);
165 let store = self
166 .create_store(CreateStoreRequest {
167 name: store_name.to_owned(),
168 })
169 .await
170 .map_err(|e| {
171 tracing::error!("Failed to create store in OpenFGA: {e}");
172 Error::RequestFailed(e)
173 })?
174 .into_inner();
175 Ok(Store {
176 id: store.id,
177 name: store.name,
178 created_at: store.created_at,
179 updated_at: store.updated_at,
180 deleted_at: None,
181 })
182 }
183 Some(store) => Ok(store),
184 }
185 }
186
187 pub async fn read_all_pages(
193 &mut self,
194 store_id: &str,
195 tuple: impl Into<ReadRequestTupleKey>,
196 consistency: impl Into<ConsistencyPreference>,
197 page_size: i32,
198 max_pages: u32,
199 ) -> Result<Vec<Tuple>> {
200 let mut continuation_token = String::new();
201 let mut tuples = Vec::new();
202 let mut count = 0;
203 let tuple = tuple.into();
204 let consistency = consistency.into();
205
206 loop {
207 let read_request = ReadRequest {
208 store_id: store_id.to_owned(),
209 tuple_key: Some(tuple.clone()),
210 page_size: Some(page_size),
211 continuation_token: continuation_token.clone(),
212 consistency: consistency.into(),
213 };
214 let response = self
215 .read(read_request.clone())
216 .await
217 .map_err(|e| {
218 tracing::error!(
219 "Failed to read from OpenFGA: {e}. Request: {:?}",
220 read_request
221 );
222 Error::RequestFailed(e)
223 })?
224 .into_inner();
225 tuples.extend(response.tuples);
226 continuation_token.clone_from(&response.continuation_token);
227 if continuation_token.is_empty() || count > max_pages {
228 if count > max_pages {
229 return Err(Error::TooManyPages { max_pages, tuple });
230 }
231 break;
232 }
233 count += 1;
234 }
235
236 Ok(tuples)
237 }
238}
239
240#[cfg(feature = "auth-middle")]
241pub type BasicAuthLayer = tower::util::Either<
242 tower::util::Either<
243 tonic::service::interceptor::InterceptedService<
244 Channel,
245 middle::BasicClientCredentialAuthorizer,
246 >,
247 tonic::service::interceptor::InterceptedService<Channel, middle::BearerTokenAuthorizer>,
248 >,
249 Channel,
250>;
251
252#[cfg(feature = "auth-middle")]
253type EitherOrOption = Option<
254 Either<
255 InterceptorLayer<middle::BasicClientCredentialAuthorizer>,
256 InterceptorLayer<middle::BearerTokenAuthorizer>,
257 >,
258>;
259
260#[cfg(feature = "auth-middle")]
261fn get_tonic_endpoint_logged(endpoint: &url::Url) -> Result<Endpoint> {
262 Endpoint::new(endpoint.to_string()).map_err(|e| {
263 tracing::error!("Could not construct OpenFGA client. Invalid endpoint `{endpoint}`: {e}");
264 Error::InvalidEndpoint(endpoint.to_string())
265 })
266}
267
268#[cfg(test)]
269pub(crate) mod test {
270 use needs_env_var::needs_env_var;
271
272 #[cfg(feature = "auth-middle")]
274 mod openfga {
275 use std::collections::{HashMap, HashSet};
276
277 use super::super::*;
278 use crate::{
279 client::{
280 TupleKey, WriteAuthorizationModelRequest, WriteAuthorizationModelResponse,
281 WriteRequest, WriteRequestWrites,
282 },
283 generated::AuthorizationModel,
284 };
285
286 fn get_basic_client() -> BasicOpenFgaServiceClient {
287 let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
288 BasicOpenFgaServiceClient::new_unauthenticated(url::Url::parse(&endpoint).unwrap())
289 .expect("Client can be created")
290 }
291
292 async fn new_store() -> (BasicOpenFgaServiceClient, Store) {
293 let mut client = get_basic_client();
294 let store_name = format!("store-{}", uuid::Uuid::now_v7());
295 let store = client
296 .get_or_create_store(&store_name)
297 .await
298 .expect("Store can be created");
299 (client, store)
300 }
301
302 async fn create_entitlements_model(
303 client: &mut BasicOpenFgaServiceClient,
304 store: &Store,
305 ) -> WriteAuthorizationModelResponse {
306 let schema = include_str!("../tests/sample-store/entitlements/schema.json");
307 let model: AuthorizationModel =
308 serde_json::from_str(schema).expect("Schema can be deserialized");
309 let auth_model = client
310 .write_authorization_model(WriteAuthorizationModelRequest {
311 store_id: store.id.clone(),
312 type_definitions: model.type_definitions,
313 schema_version: model.schema_version,
314 conditions: model.conditions,
315 })
316 .await
317 .expect("Auth model can be written");
318
319 auth_model.into_inner()
320 }
321
322 #[tokio::test]
323 async fn test_get_store_by_name_many() {
324 let mut client = get_basic_client();
325
326 let mut stores = HashMap::new();
327 for _i in 0..201 {
328 let store_name = format!("store-{}", uuid::Uuid::now_v7());
329 let r = client
330 .get_or_create_store(&store_name)
331 .await
332 .expect("Store can be created");
333 assert_eq!(store_name, r.name);
334 stores.insert(store_name, r.id);
335 }
336
337 for (store_name, store_id) in stores {
338 let store = client
339 .get_store_by_name(&store_name)
340 .await
341 .expect("Store can be fetched")
342 .expect("Store exists");
343 assert_eq!(store_id, store.id);
344 }
345 }
346
347 #[tokio::test]
348 async fn test_get_store_by_name_non_existant() {
349 let mut client = get_basic_client();
350 let store = client
351 .get_store_by_name("non-existent-store")
352 .await
353 .unwrap();
354 assert!(store.is_none());
355 }
356
357 #[tokio::test]
358 async fn test_get_or_create_store() {
359 let mut client = get_basic_client();
360 let store_name = format!("store-{}", uuid::Uuid::now_v7());
361 let store = client.get_or_create_store(&store_name).await.unwrap();
362 let store2 = client.get_or_create_store(&store_name).await.unwrap();
363 assert_eq!(store.id, store2.id);
364 }
365
366 #[tokio::test]
367 async fn test_read_all_pages_many() {
368 let (mut client, store) = new_store().await;
369 let auth_model = create_entitlements_model(&mut client, &store).await;
370 let object = "organization:org-1";
371
372 let users = (0..501)
373 .map(|i| format!("user:u-{i}"))
374 .collect::<Vec<String>>();
375
376 for user in &users {
377 client
378 .write(WriteRequest {
379 authorization_model_id: auth_model.authorization_model_id.clone(),
380 store_id: store.id.clone(),
381 writes: Some(WriteRequestWrites {
382 tuple_keys: vec![TupleKey {
383 user: user.to_string(),
384 relation: "member".to_string(),
385 object: object.to_string(),
386 condition: None,
387 }],
388 }),
389 deletes: None,
390 })
391 .await
392 .expect("Write can be done");
393 }
394
395 let tuples = client
396 .read_all_pages(
397 &store.id,
398 ReadRequestTupleKey {
399 user: String::new(),
400 relation: "member".to_string(),
401 object: object.to_string(),
402 },
403 ConsistencyPreference::HigherConsistency,
404 100,
405 6,
406 )
407 .await
408 .expect("Read can be done");
409
410 assert_eq!(tuples.len(), 501);
411 assert_eq!(
412 tuples
413 .iter()
414 .map(|t| t.key.clone().unwrap().user)
415 .collect::<HashSet<String>>(),
416 HashSet::from_iter(users)
417 );
418 }
419
420 #[tokio::test]
421 async fn test_real_all_pages_empty() {
422 let (mut client, store) = new_store().await;
423 let tuples = client
424 .read_all_pages(
425 &store.id,
426 ReadRequestTupleKey {
427 user: String::new(),
428 relation: "member".to_string(),
429 object: "organization:org-1".to_string(),
430 },
431 ConsistencyPreference::HigherConsistency,
432 100,
433 5,
434 )
435 .await
436 .expect("Read can be done");
437
438 assert!(tuples.is_empty());
439 }
440 }
441}