1#![allow(unused_imports)]
2
3use tonic::{
4 codegen::{Body, Bytes, StdError},
5 service::interceptor::InterceptorLayer,
6 transport::{Channel, Endpoint},
7};
8#[cfg(feature = "auth-middle")]
9use tower::{util::Either, ServiceBuilder};
10
11use crate::{
12 client::{OpenFgaClient, OpenFgaServiceClient},
13 error::{Error, Result},
14 generated::{
15 ConsistencyPreference, CreateStoreRequest, ListStoresRequest, ReadRequest,
16 ReadRequestTupleKey, Store, Tuple,
17 },
18};
19
20#[cfg(feature = "auth-middle")]
21pub type BasicOpenFgaServiceClient = OpenFgaServiceClient<BasicAuthLayer>;
26
27#[cfg(feature = "auth-middle")]
28impl BasicOpenFgaServiceClient {
29 pub fn new_unauthenticated(endpoint: impl Into<url::Url>) -> Result<Self> {
34 let either_or_option: EitherOrOption = None;
35 let auth_layer = tower::util::option_layer(either_or_option);
36 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
37 let c = ServiceBuilder::new()
38 .layer(auth_layer)
39 .service(endpoint.connect_lazy());
40 Ok(BasicOpenFgaServiceClient::new(c))
41 }
42
43 pub fn new_with_basic_auth(endpoint: impl Into<url::Url>, token: &str) -> Result<Self> {
49 let either_or_option: EitherOrOption =
50 Some(tower::util::Either::Right(tonic::service::interceptor(
51 middle::BearerTokenAuthorizer::new(token).map_err(|e| {
52 tracing::error!("Could not construct OpenFGA client. Invalid token: {e}");
53 Error::InvalidToken {
54 reason: e.to_string(),
55 }
56 })?,
57 )));
58 let auth_layer = tower::util::option_layer(either_or_option);
59 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
60 let c = ServiceBuilder::new()
61 .layer(auth_layer)
62 .service(endpoint.connect_lazy());
63 Ok(BasicOpenFgaServiceClient::new(c))
64 }
65
66 pub async fn new_with_client_credentials(
72 endpoint: impl Into<url::Url>,
73 client_id: &str,
74 client_secret: &str,
75 token_endpoint: impl Into<url::Url>,
76 scopes: &[&str],
77 ) -> Result<Self> {
78 let either_or_option: EitherOrOption =
79 Some(tower::util::Either::Left(tonic::service::interceptor(
80 {
81 let builder = middle::BasicClientCredentialAuthorizer::basic_builder(
82 client_id,
83 client_secret,
84 token_endpoint.into(),
85 );
86 if scopes.is_empty() {
87 builder
88 } else {
89 builder.add_scopes(scopes)
90 }
91 }
92 .build()
93 .await.map_err(|e| {
94 tracing::error!("Could not construct OpenFGA client. Failed to fetch or refresh Client Credentials: {e}");
95 Error::CredentialRefreshError(e)
96 })?,
97 )));
98 let auth_layer = tower::util::option_layer(either_or_option);
99 let endpoint = get_tonic_endpoint_logged(&endpoint.into())?;
100 let c = ServiceBuilder::new()
101 .layer(auth_layer)
102 .service(endpoint.connect_lazy());
103 Ok(BasicOpenFgaServiceClient::new(c))
104 }
105}
106
107impl<T> OpenFgaServiceClient<T>
108where
109 T: tonic::client::GrpcService<tonic::body::BoxBody>,
110 T::Error: Into<StdError>,
111 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
112 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
113 T: Clone,
114{
115 pub fn into_client(self, store_id: &str, authorization_model_id: &str) -> OpenFgaClient<T> {
117 OpenFgaClient::new(self, store_id, authorization_model_id)
118 }
119
120 pub async fn get_store_by_name(&mut self, store_name: &str) -> Result<Option<Store>> {
127 let stores = self
128 .list_stores(ListStoresRequest {
129 page_size: Some(2),
130 continuation_token: String::new(),
131 name: store_name.to_string(),
132 })
133 .await
134 .map_err(|e| {
135 tracing::error!("Failed to list stores in OpenFGA: {e}");
136 Error::RequestFailed(e)
137 })?
138 .into_inner();
139 let num_stores = stores.stores.len();
140
141 match stores.stores.first() {
142 Some(store) => {
143 if num_stores > 1 {
144 tracing::error!("Multiple stores with the name `{}` found", store_name);
145 Err(Error::AmbiguousStoreName(store_name.to_string()))
146 } else {
147 Ok(Some(store.clone()))
148 }
149 }
150 None => Ok(None),
151 }
152 }
153
154 pub async fn get_or_create_store(&mut self, store_name: &str) -> Result<Store> {
161 let store = self.get_store_by_name(store_name).await?;
162 match store {
163 None => {
164 tracing::debug!("OpenFGA Store {} not found. Creating it.", store_name);
165 let store = self
166 .create_store(CreateStoreRequest {
167 name: store_name.to_owned(),
168 })
169 .await
170 .map_err(|e| {
171 tracing::error!("Failed to create store in OpenFGA: {e}");
172 Error::RequestFailed(e)
173 })?
174 .into_inner();
175 Ok(Store {
176 id: store.id,
177 name: store.name,
178 created_at: store.created_at,
179 updated_at: store.updated_at,
180 deleted_at: None,
181 })
182 }
183 Some(store) => Ok(store),
184 }
185 }
186
187 pub async fn read_all_pages(
193 &mut self,
194 store_id: &str,
195 tuple: impl Into<ReadRequestTupleKey>,
196 consistency: impl Into<ConsistencyPreference>,
197 page_size: i32,
198 max_pages: u32,
199 ) -> Result<Vec<Tuple>> {
200 let mut continuation_token = String::new();
201 let mut tuples = Vec::new();
202 let mut count = 0;
203 let tuple = tuple.into();
204 let consistency = consistency.into();
205
206 loop {
207 let read_request = ReadRequest {
208 store_id: store_id.to_owned(),
209 tuple_key: Some(tuple.clone()),
210 page_size: Some(page_size),
211 continuation_token: continuation_token.clone(),
212 consistency: consistency.into(),
213 };
214 let response = self
215 .read(read_request.clone())
216 .await
217 .map_err(|e| {
218 tracing::error!(
219 "Failed to read from OpenFGA: {e}. Request: {:?}",
220 read_request
221 );
222 Error::RequestFailed(e)
223 })?
224 .into_inner();
225 tuples.extend(response.tuples);
226 continuation_token.clone_from(&response.continuation_token);
227 if continuation_token.is_empty() || count > max_pages {
228 if count > max_pages {
229 return Err(Error::TooManyPages { max_pages, tuple });
230 }
231 break;
232 }
233 count += 1;
234 }
235
236 Ok(tuples)
237 }
238}
239
240#[cfg(feature = "auth-middle")]
241pub type BasicAuthLayer = tower::util::Either<
242 tower::util::Either<
243 tonic::service::interceptor::InterceptedService<
244 Channel,
245 middle::BasicClientCredentialAuthorizer,
246 >,
247 tonic::service::interceptor::InterceptedService<Channel, middle::BearerTokenAuthorizer>,
248 >,
249 Channel,
250>;
251
252#[cfg(feature = "auth-middle")]
253type EitherOrOption = Option<
254 Either<
255 InterceptorLayer<middle::BasicClientCredentialAuthorizer>,
256 InterceptorLayer<middle::BearerTokenAuthorizer>,
257 >,
258>;
259
260#[cfg(feature = "auth-middle")]
261fn get_tonic_endpoint_logged(endpoint: &url::Url) -> Result<Endpoint> {
262 Endpoint::new(endpoint.to_string()).map_err(|e| {
263 tracing::error!("Could not construct OpenFGA client. Invalid endpoint `{endpoint}`: {e}");
264 Error::InvalidEndpoint(endpoint.to_string())
265 })
266}
267
268#[cfg(test)]
269pub(crate) mod test {
270 use needs_env_var::needs_env_var;
271
272 #[cfg(feature = "auth-middle")]
274 mod openfga {
275 use std::collections::HashSet;
276
277 use super::super::*;
278 use crate::{
279 client::{
280 TupleKey, WriteAuthorizationModelRequest, WriteAuthorizationModelResponse,
281 WriteRequest, WriteRequestWrites,
282 },
283 generated::AuthorizationModel,
284 };
285
286 fn get_basic_client() -> BasicOpenFgaServiceClient {
287 let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
288 BasicOpenFgaServiceClient::new_unauthenticated(url::Url::parse(&endpoint).unwrap())
289 .expect("Client can be created")
290 }
291
292 async fn new_store() -> (BasicOpenFgaServiceClient, Store) {
293 let mut client = get_basic_client();
294 let store_name = format!("store-{}", uuid::Uuid::now_v7());
295 let store = client
296 .get_or_create_store(&store_name)
297 .await
298 .expect("Store can be created");
299 (client, store)
300 }
301
302 async fn create_entitlements_model(
303 client: &mut BasicOpenFgaServiceClient,
304 store: &Store,
305 ) -> WriteAuthorizationModelResponse {
306 let schema = include_str!("../tests/sample-store/entitlements/schema.json");
307 let model: AuthorizationModel =
308 serde_json::from_str(schema).expect("Schema can be deserialized");
309 let auth_model = client
310 .write_authorization_model(WriteAuthorizationModelRequest {
311 store_id: store.id.clone(),
312 type_definitions: model.type_definitions,
313 schema_version: model.schema_version,
314 conditions: model.conditions,
315 })
316 .await
317 .expect("Auth model can be written");
318
319 auth_model.into_inner()
320 }
321
322 #[tokio::test]
323 async fn test_get_store_by_name_non_existant() {
324 let mut client = get_basic_client();
325 let store = client
326 .get_store_by_name("non-existent-store")
327 .await
328 .unwrap();
329 assert!(store.is_none());
330 }
331
332 #[tokio::test]
333 async fn test_get_or_create_store() {
334 let mut client = get_basic_client();
335 let store_name = format!("store-{}", uuid::Uuid::now_v7());
336 let store = client.get_or_create_store(&store_name).await.unwrap();
337 let store2 = client.get_or_create_store(&store_name).await.unwrap();
338 assert_eq!(store.id, store2.id);
339 }
340
341 #[tokio::test]
342 async fn test_read_all_pages_many() {
343 let (mut client, store) = new_store().await;
344 let auth_model = create_entitlements_model(&mut client, &store).await;
345 let object = "organization:org-1";
346
347 let users = (0..501)
348 .map(|i| format!("user:u-{i}"))
349 .collect::<Vec<String>>();
350
351 for user in &users {
352 client
353 .write(WriteRequest {
354 authorization_model_id: auth_model.authorization_model_id.clone(),
355 store_id: store.id.clone(),
356 writes: Some(WriteRequestWrites {
357 tuple_keys: vec![TupleKey {
358 user: user.to_string(),
359 relation: "member".to_string(),
360 object: object.to_string(),
361 condition: None,
362 }],
363 }),
364 deletes: None,
365 })
366 .await
367 .expect("Write can be done");
368 }
369
370 let tuples = client
371 .read_all_pages(
372 &store.id,
373 ReadRequestTupleKey {
374 user: String::new(),
375 relation: "member".to_string(),
376 object: object.to_string(),
377 },
378 ConsistencyPreference::HigherConsistency,
379 100,
380 6,
381 )
382 .await
383 .expect("Read can be done");
384
385 assert_eq!(tuples.len(), 501);
386 assert_eq!(
387 tuples
388 .iter()
389 .map(|t| t.key.clone().unwrap().user)
390 .collect::<HashSet<String>>(),
391 HashSet::from_iter(users)
392 );
393 }
394
395 #[tokio::test]
396 async fn test_real_all_pages_empty() {
397 let (mut client, store) = new_store().await;
398 let tuples = client
399 .read_all_pages(
400 &store.id,
401 ReadRequestTupleKey {
402 user: String::new(),
403 relation: "member".to_string(),
404 object: "organization:org-1".to_string(),
405 },
406 ConsistencyPreference::HigherConsistency,
407 100,
408 5,
409 )
410 .await
411 .expect("Read can be done");
412
413 assert!(tuples.is_empty());
414 }
415 }
416}