1#![allow(unused_imports)]
2
3use tonic::{
4 codegen::{Body, Bytes, StdError},
5 service::interceptor::InterceptorLayer,
6 transport::{Channel, Endpoint},
7};
8#[cfg(feature = "auth-middle")]
9use tower::{util::Either, ServiceBuilder};
10
11use crate::{
12 client::OpenFgaServiceClient,
13 error::{Error, Result},
14 generated::{
15 ConsistencyPreference, CreateStoreRequest, ListStoresRequest, ReadRequest,
16 ReadRequestTupleKey, Store, Tuple,
17 },
18};
19
20#[cfg(feature = "auth-middle")]
21pub type BasicOpenFgaServiceClient = OpenFgaServiceClient<AuthLayer>;
26
27#[cfg(feature = "auth-middle")]
28impl BasicOpenFgaServiceClient {
29 pub fn new_unauthenticated(endpoint: &str) -> Result<Self> {
34 let either_or_option: EitherOrOption = None;
35 let auth_layer = tower::util::option_layer(either_or_option);
36 let endpoint = get_tonic_endpoint_logged(endpoint)?;
37 let c = ServiceBuilder::new()
38 .layer(auth_layer)
39 .service(endpoint.connect_lazy());
40 Ok(BasicOpenFgaServiceClient::new(c))
41 }
42
43 pub fn new_with_basic_auth(endpoint: &str, token: &str) -> Result<Self> {
49 let either_or_option: EitherOrOption =
50 Some(tower::util::Either::Right(tonic::service::interceptor(
51 middle::BearerTokenAuthorizer::new(token).map_err(|e| {
52 tracing::error!("Could not construct OpenFGA client. Invalid token: {e}");
53 Error::InvalidToken {
54 reason: e.to_string(),
55 }
56 })?,
57 )));
58 let auth_layer = tower::util::option_layer(either_or_option);
59 let endpoint = get_tonic_endpoint_logged(endpoint)?;
60 let c = ServiceBuilder::new()
61 .layer(auth_layer)
62 .service(endpoint.connect_lazy());
63 Ok(BasicOpenFgaServiceClient::new(c))
64 }
65
66 pub async fn new_with_client_credentials(
72 endpoint: &str,
73 client_id: &str,
74 client_secret: &str,
75 token_endpoint: &url::Url,
76 ) -> Result<Self> {
77 let either_or_option: EitherOrOption =
78 Some(tower::util::Either::Left(tonic::service::interceptor(
79 middle::BasicClientCredentialAuthorizer::basic_builder(
80 client_id,
81 client_secret,
82 token_endpoint.clone(),
83 )
84 .build()
85 .await.map_err(|e| {
86 tracing::error!("Could not construct OpenFGA client. Failed to fetch or refresh Client Credentials: {e}");
87 Error::CredentialRefreshError(e)
88 })?,
89 )));
90 let auth_layer = tower::util::option_layer(either_or_option);
91 let endpoint = get_tonic_endpoint_logged(endpoint)?;
92 let c = ServiceBuilder::new()
93 .layer(auth_layer)
94 .service(endpoint.connect_lazy());
95 Ok(BasicOpenFgaServiceClient::new(c))
96 }
97}
98
99impl<T> OpenFgaServiceClient<T>
100where
101 T: tonic::client::GrpcService<tonic::body::BoxBody>,
102 T::Error: Into<StdError>,
103 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
104 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
105{
106 pub async fn get_store_by_name(&mut self, store_name: &str) -> Result<Option<Store>> {
113 let stores = self
114 .list_stores(ListStoresRequest {
115 page_size: Some(2),
116 continuation_token: String::new(),
117 name: store_name.to_string(),
118 })
119 .await
120 .map_err(|e| {
121 tracing::error!("Failed to list stores in OpenFGA: {e}");
122 Error::RequestFailed(e)
123 })?
124 .into_inner();
125 let num_stores = stores.stores.len();
126
127 match stores.stores.first() {
128 Some(store) => {
129 if num_stores > 1 {
130 tracing::error!("Multiple stores with the name `{}` found", store_name);
131 Err(Error::AmbiguousStoreName(store_name.to_string()))
132 } else {
133 Ok(Some(store.clone()))
134 }
135 }
136 None => Ok(None),
137 }
138 }
139
140 pub async fn get_or_create_store(&mut self, store_name: &str) -> Result<Store> {
147 let store = self.get_store_by_name(store_name).await?;
148 match store {
149 None => {
150 tracing::debug!("OpenFGA Store {} not found. Creating it.", store_name);
151 let store = self
152 .create_store(CreateStoreRequest {
153 name: store_name.to_owned(),
154 })
155 .await
156 .map_err(|e| {
157 tracing::error!("Failed to create store in OpenFGA: {e}");
158 Error::RequestFailed(e)
159 })?
160 .into_inner();
161 Ok(Store {
162 id: store.id,
163 name: store.name,
164 created_at: store.created_at,
165 updated_at: store.updated_at,
166 deleted_at: None,
167 })
168 }
169 Some(store) => Ok(store),
170 }
171 }
172
173 pub async fn read_all_pages(
179 &mut self,
180 store_id: &str,
181 tuple: ReadRequestTupleKey,
182 consistency: ConsistencyPreference,
183 page_size: i32,
184 max_pages: u32,
185 ) -> Result<Vec<Tuple>> {
186 let mut continuation_token = String::new();
187 let mut tuples = Vec::new();
188 let mut count = 0;
189
190 loop {
191 let read_request = ReadRequest {
192 store_id: store_id.to_owned(),
193 tuple_key: Some(tuple.clone()),
194 page_size: Some(page_size),
195 continuation_token: continuation_token.clone(),
196 consistency: consistency.into(),
197 };
198 let response = self
199 .read(read_request.clone())
200 .await
201 .map_err(|e| {
202 tracing::error!(
203 "Failed to read from OpenFGA: {e}. Request: {:?}",
204 read_request
205 );
206 Error::RequestFailed(e)
207 })?
208 .into_inner();
209 tuples.extend(response.tuples);
210 continuation_token.clone_from(&response.continuation_token);
211 if continuation_token.is_empty() || count > max_pages {
212 if count > max_pages {
213 return Err(Error::TooManyPages { max_pages, tuple });
214 }
215 break;
216 }
217 count += 1;
218 }
219
220 Ok(tuples)
221 }
222}
223
224#[cfg(feature = "auth-middle")]
225type AuthLayer = tower::util::Either<
226 tower::util::Either<
227 tonic::service::interceptor::InterceptedService<
228 Channel,
229 middle::BasicClientCredentialAuthorizer,
230 >,
231 tonic::service::interceptor::InterceptedService<Channel, middle::BearerTokenAuthorizer>,
232 >,
233 Channel,
234>;
235
236#[cfg(feature = "auth-middle")]
237type EitherOrOption = Option<
238 Either<
239 InterceptorLayer<middle::BasicClientCredentialAuthorizer>,
240 InterceptorLayer<middle::BearerTokenAuthorizer>,
241 >,
242>;
243
244#[cfg(feature = "auth-middle")]
245fn get_tonic_endpoint_logged(endpoint: &str) -> Result<Endpoint> {
246 Endpoint::new(endpoint.to_string()).map_err(|e| {
247 tracing::error!("Could not construct OpenFGA client. Invalid endpoint `{endpoint}`: {e}");
248 Error::InvalidEndpoint(endpoint.to_string())
249 })
250}
251
252#[cfg(test)]
253pub(crate) mod test {
254 use needs_env_var::needs_env_var;
255
256 #[cfg(feature = "auth-middle")]
258 mod openfga {
259 use std::collections::HashSet;
260
261 use super::super::*;
262 use crate::{
263 client::{
264 TupleKey, WriteAuthorizationModelRequest, WriteAuthorizationModelResponse,
265 WriteRequest, WriteRequestWrites,
266 },
267 generated::AuthorizationModel,
268 };
269
270 fn get_basic_client() -> BasicOpenFgaServiceClient {
271 let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
272 BasicOpenFgaServiceClient::new_unauthenticated(&endpoint)
273 .expect("Client can be created")
274 }
275
276 async fn new_store() -> (BasicOpenFgaServiceClient, Store) {
277 let mut client = get_basic_client();
278 let store_name = format!("store-{}", uuid::Uuid::now_v7());
279 let store = client
280 .get_or_create_store(&store_name)
281 .await
282 .expect("Store can be created");
283 (client, store)
284 }
285
286 async fn create_entitlements_model(
287 client: &mut BasicOpenFgaServiceClient,
288 store: &Store,
289 ) -> WriteAuthorizationModelResponse {
290 let schema = include_str!("../tests/sample-store/entitlements/schema.json");
291 let model: AuthorizationModel =
292 serde_json::from_str(schema).expect("Schema can be deserialized");
293 let auth_model = client
294 .write_authorization_model(WriteAuthorizationModelRequest {
295 store_id: store.id.clone(),
296 type_definitions: model.type_definitions,
297 schema_version: model.schema_version,
298 conditions: model.conditions,
299 })
300 .await
301 .expect("Auth model can be written");
302
303 auth_model.into_inner()
304 }
305
306 #[tokio::test]
307 async fn test_get_store_by_name_non_existant() {
308 let mut client = get_basic_client();
309 let store = client
310 .get_store_by_name("non-existent-store")
311 .await
312 .unwrap();
313 assert!(store.is_none());
314 }
315
316 #[tokio::test]
317 async fn test_get_or_create_store() {
318 let mut client = get_basic_client();
319 let store_name = format!("store-{}", uuid::Uuid::now_v7());
320 let store = client.get_or_create_store(&store_name).await.unwrap();
321 let store2 = client.get_or_create_store(&store_name).await.unwrap();
322 assert_eq!(store.id, store2.id);
323 }
324
325 #[tokio::test]
326 async fn test_read_all_pages_many() {
327 let (mut client, store) = new_store().await;
328 let auth_model = create_entitlements_model(&mut client, &store).await;
329 let object = "organization:org-1";
330
331 let users = (0..501)
332 .map(|i| format!("user:u-{i}"))
333 .collect::<Vec<String>>();
334
335 for user in &users {
336 client
337 .write(WriteRequest {
338 authorization_model_id: auth_model.authorization_model_id.clone(),
339 store_id: store.id.clone(),
340 writes: Some(WriteRequestWrites {
341 tuple_keys: vec![TupleKey {
342 user: user.to_string(),
343 relation: "member".to_string(),
344 object: object.to_string(),
345 condition: None,
346 }],
347 }),
348 deletes: None,
349 })
350 .await
351 .expect("Write can be done");
352 }
353
354 let tuples = client
355 .read_all_pages(
356 &store.id,
357 ReadRequestTupleKey {
358 user: String::new(),
359 relation: "member".to_string(),
360 object: object.to_string(),
361 },
362 ConsistencyPreference::HigherConsistency,
363 100,
364 6,
365 )
366 .await
367 .expect("Read can be done");
368
369 assert_eq!(tuples.len(), 501);
370 assert_eq!(
371 HashSet::<String>::from_iter(tuples.iter().map(|t| t.key.clone().unwrap().user)),
372 HashSet::from_iter(users)
373 );
374 }
375
376 #[tokio::test]
377 async fn test_real_all_pages_empty() {
378 let (mut client, store) = new_store().await;
379 let tuples = client
380 .read_all_pages(
381 &store.id,
382 ReadRequestTupleKey {
383 user: String::new(),
384 relation: "member".to_string(),
385 object: "organization:org-1".to_string(),
386 },
387 ConsistencyPreference::HigherConsistency,
388 100,
389 5,
390 )
391 .await
392 .expect("Read can be done");
393
394 assert!(tuples.is_empty());
395 }
396 }
397}