1use std::{
2 collections::{HashMap, HashSet},
3 future::Future,
4 hash::Hash,
5 pin::Pin,
6 str::FromStr,
7 sync::Arc,
8};
9
10use futures_timer::Delay;
11use tonic::codegen::{Body, Bytes, StdError};
12
13use crate::{
14 client::{
15 AuthorizationModel, ConsistencyPreference, OpenFgaServiceClient, ReadRequestTupleKey,
16 Store, Tuple, TupleKey, WriteAuthorizationModelResponse, WriteRequest, WriteRequestWrites,
17 },
18 error::{Error, Result},
19};
20
21const DEFAULT_PAGE_SIZE: i32 = 100;
22const MAX_PAGES: u32 = 1000;
23
24#[derive(Debug, Copy, Clone, PartialEq, Hash, Eq)]
25pub struct AuthorizationModelVersion {
26 major: u32,
27 minor: u32,
28}
29
30#[derive(Debug, Clone, PartialEq)]
31struct VersionedAuthorizationModel {
32 model: AuthorizationModel,
33 version: AuthorizationModelVersion,
34}
35
36#[derive(Debug, Clone)]
60pub struct TupleModelManager<T, S>
61where
62 T: tonic::client::GrpcService<tonic::body::Body>,
63 T::Error: Into<StdError>,
64 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
65 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
66{
67 client: OpenFgaServiceClient<T>,
68 store_name: String,
69 model_prefix: String,
70 migrations: HashMap<AuthorizationModelVersion, Migration<T, S>>,
71}
72
73#[derive(Clone)]
74struct Migration<T, S> {
75 model: VersionedAuthorizationModel,
79 pre_migration_fn: Option<BoxedMigrationFn<T, S>>,
80 post_migration_fn: Option<BoxedMigrationFn<T, S>>,
81}
82
83pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
85
86pub type MigrationFn<T, S> = fn(
99 OpenFgaServiceClient<T>,
100 Option<String>,
101 Option<String>,
102 S,
103) -> BoxFuture<'static, std::result::Result<(), StdError>>;
104
105type DynMigrationFn<T, S> = dyn Fn(
107 OpenFgaServiceClient<T>,
108 Option<String>,
109 Option<String>,
110 S,
111) -> BoxFuture<'static, std::result::Result<(), StdError>>;
112
113type BoxedMigrationFn<T, S> = Arc<DynMigrationFn<T, S>>;
115
116fn box_migration_fn<T, S, F, Fut>(f: F) -> BoxedMigrationFn<T, S>
118where
119 F: Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> Fut + Send + 'static,
120 Fut: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
121{
122 Arc::new(
123 move |client, prev_auth_model_id, active_auth_model_id, state| {
124 Box::pin(f(client, prev_auth_model_id, active_auth_model_id, state))
125 },
126 )
127}
128
129impl<T, S> TupleModelManager<T, S>
130where
131 T: tonic::client::GrpcService<tonic::body::Body>,
132 T: Clone,
133 T::Error: Into<StdError>,
134 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
135 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
136 S: Clone,
137{
138 const AUTH_MODEL_ID_TYPE: &'static str = "auth_model_id";
139 const MODEL_VERSION_EXISTS_RELATION: &'static str = "exists";
140 const MODEL_VERSION_TYPE: &'static str = "model_version";
141 const MODEL_VERSION_OPENFGA_ID_RELATION: &'static str = "openfga_id";
142
143 pub fn new(client: OpenFgaServiceClient<T>, store_name: &str, model_prefix: &str) -> Self {
149 TupleModelManager {
150 client,
151 model_prefix: model_prefix.to_string(),
152 store_name: store_name.to_string(),
153 migrations: HashMap::new(),
154 }
155 }
156
157 #[must_use]
163 pub fn add_model<FutPre, FutPost>(
164 mut self,
165 model: AuthorizationModel,
166 version: AuthorizationModelVersion,
167 pre_migration_fn: Option<
168 impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPre
169 + Send
170 + 'static,
171 >,
172 post_migration_fn: Option<
173 impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPost
174 + Send
175 + 'static,
176 >,
177 ) -> Self
178 where
179 FutPre: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
180 FutPost: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
181 {
182 let migration = Migration {
183 model: VersionedAuthorizationModel::new(model, version),
184 pre_migration_fn: pre_migration_fn.map(box_migration_fn),
185 post_migration_fn: post_migration_fn.map(box_migration_fn),
186 };
187 self.migrations.insert(migration.model.version(), migration);
188 self
189 }
190
191 #[allow(clippy::too_many_lines)]
206 pub async fn migrate(&mut self, state: S) -> Result<()> {
207 let span = tracing::span!(
208 tracing::Level::INFO,
209 "Running OpenFGA Migrations",
210 store_name = self.store_name,
211 model_prefix = self.model_prefix
212 );
213 let _enter = span.enter();
214
215 if self.migrations.is_empty() {
216 tracing::info!("No Migrations have been added. Nothing to do.");
217 return Ok(());
218 }
219
220 let store = self.client.get_or_create_store(&self.store_name).await?;
221 let existing_models = self.get_existing_versions().await?;
222 let max_existing_model = existing_models.iter().max().copied();
223 let mut curr_model_id = if let Some(version) = max_existing_model {
224 Some(self.require_authorization_model_id(version).await?)
225 } else {
226 None
227 };
228 let mut prev_model_id = None;
231
232 if let Some(max_existing_model) = max_existing_model {
233 tracing::info!(
234 "Currently the highest existing Model Version is: {}",
235 max_existing_model
236 );
237 } else {
238 tracing::info!("No model found in OpenFGA store");
239 }
240
241 let ordered_migrations = self.migrations_to_perform(max_existing_model);
242
243 let mut client = self.client.clone();
244 for migration in ordered_migrations {
245 tracing::info!("Migrating to model version: {}", migration.model.version());
246
247 if let Some(pre_migration_fn) = migration.pre_migration_fn.as_ref() {
249 pre_migration_fn(
250 client.clone(),
251 prev_model_id.clone(),
252 curr_model_id.clone(),
253 state.clone(),
254 )
255 .await
256 .map_err(|e| {
257 tracing::error!("Error in OpenFGA pre-migration hook: {:?}", e);
258 Error::MigrationHookFailed {
259 version: migration.model.version().to_string(),
260 error: Arc::new(e),
261 }
262 })?;
263 }
264
265 let request = migration
267 .model
268 .model()
269 .clone()
270 .into_write_request(store.id.clone());
271 let written_model = client
272 .write_authorization_model(request)
273 .await
274 .map_err(|e| {
275 tracing::error!("Error writing model: {:?}", e);
276 Error::RequestFailed(Box::new(e))
277 })?;
278 tracing::info!(
279 "Model version {} written to OpenFGA store {} with model id {}",
280 migration.model.version(),
281 self.store_name,
282 written_model.get_ref().authorization_model_id,
283 );
284 tracing::debug!("Model written: {:?}", written_model);
285
286 prev_model_id.clone_from(&curr_model_id);
288 curr_model_id = Some(written_model.get_ref().authorization_model_id.clone());
289
290 if let Some(post_migration_fn) = migration.post_migration_fn.as_ref() {
292 post_migration_fn(
293 client.clone(),
294 prev_model_id.clone(),
295 curr_model_id.clone(),
296 state.clone(),
297 )
298 .await
299 .map_err(|e| {
300 tracing::error!("Error in OpenFGA post-migration hook: {:?}", e);
301 Error::MigrationHookFailed {
302 version: migration.model.version().to_string(),
303 error: Arc::new(e),
304 }
305 })?;
306 }
307
308 Self::mark_as_applied(
310 &mut client,
311 &self.model_prefix,
312 &store,
313 migration.model.version(),
314 written_model.into_inner(),
315 )
316 .await?;
317 }
318
319 Ok(())
320 }
321
322 pub async fn get_authorization_model_id(
329 &mut self,
330 version: AuthorizationModelVersion,
331 ) -> Result<Option<String>> {
332 let store = self
333 .client
334 .get_store_by_name(&self.store_name)
335 .await?
336 .ok_or_else(|| {
337 tracing::error!("Store with name {} not found", self.store_name);
338 Error::StoreNotFound(self.store_name.clone())
339 })?;
340
341 let applied_models = self
342 .client
343 .read_all_pages(
344 &store.id,
345 Some(ReadRequestTupleKey {
346 user: String::new(),
347 relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
348 object: Self::format_model_version_key(&self.model_prefix, version),
349 }),
350 ConsistencyPreference::HigherConsistency,
351 DEFAULT_PAGE_SIZE,
352 MAX_PAGES,
353 )
354 .await?;
355
356 let applied_models = applied_models
357 .into_iter()
358 .filter_map(|t| t.key)
359 .filter_map(|t| {
360 t.user
361 .strip_prefix(&format!("{}:", Self::AUTH_MODEL_ID_TYPE))
362 .map(ToString::to_string)
363 })
364 .collect::<Vec<_>>();
365
366 if applied_models.len() > 1 {
367 tracing::error!(
368 "Multiple authorization models with model prefix {} for version {} found.",
369 self.model_prefix,
370 version
371 );
372 return Err(Error::AmbiguousModelVersion {
373 model_prefix: self.model_prefix.clone(),
374 version: version.to_string(),
375 });
376 }
377
378 let model_id = applied_models.into_iter().next().map(|openfga_id| {
379 tracing::info!(
380 "Authorization model for version {version} found in OpenFGA store {}. Model ID: {openfga_id}",
381 self.store_name,
382 );
383 openfga_id
384 });
385
386 Ok(model_id)
387 }
388
389 async fn require_authorization_model_id(
392 &mut self,
393 version: AuthorizationModelVersion,
394 ) -> Result<String> {
395 self.get_authorization_model_id(version)
396 .await?
397 .ok_or_else(|| {
398 tracing::error!("Missing authorization model id for model version {version}");
399 Error::MissingAuthorizationModelId {
400 model_prefix: self.model_prefix.clone(),
401 version: version.to_string(),
402 }
403 })
404 }
405
406 async fn mark_as_applied(
408 client: &mut OpenFgaServiceClient<T>,
409 model_prefix: &str,
410 store: &Store,
411 version: AuthorizationModelVersion,
412 write_response: WriteAuthorizationModelResponse,
413 ) -> Result<()> {
414 let authorization_model_id = write_response.authorization_model_id;
415 let object = Self::format_model_version_key(model_prefix, version);
416
417 let write_request = WriteRequest {
418 store_id: store.id.clone(),
419 writes: Some(WriteRequestWrites {
420 on_duplicate: String::new(),
421 tuple_keys: vec![
422 TupleKey {
423 user: format!("{}:{authorization_model_id}", Self::AUTH_MODEL_ID_TYPE),
424 relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
425 object: object.clone(),
426 condition: None,
427 },
428 TupleKey {
429 user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE),
430 relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
431 object,
432 condition: None,
433 },
434 ],
435 }),
436 deletes: None,
437 authorization_model_id: authorization_model_id.clone(),
438 };
439
440 let max_retries = 5;
442 let retry_delay = std::time::Duration::from_secs(1);
443
444 for attempt in 0..max_retries {
445 match client.write(write_request.clone()).await {
446 Ok(_) => {
447 if attempt > 0 {
448 tracing::info!(
449 "Successfully marked model {version} as applied after {attempt} retries",
450 );
451 }
452 return Ok(());
453 }
454 Err(e) => {
455 if attempt == max_retries {
456 tracing::error!(
457 "Error marking model as applied after {} retries: {:?}",
458 max_retries,
459 e
460 );
461 return Err(Error::RequestFailed(Box::new(e)));
462 }
463
464 tracing::warn!(
465 "Failed to mark model as applied (attempt {}/{max_retries}), retrying in {retry_delay:?}: {e:?}",
466 attempt + 1,
467 );
468
469 Delay::new(retry_delay).await;
470 }
471 }
472 }
473
474 unreachable!();
475 }
476
477 fn ordered_migrations(&self) -> Vec<&Migration<T, S>> {
480 let mut migrations = self.migrations.values().collect::<Vec<_>>();
481 migrations.sort_unstable_by_key(|m| m.model.version());
482 migrations
483 }
484
485 fn migrations_to_perform(
487 &self,
488 max_existing_model: Option<AuthorizationModelVersion>,
489 ) -> Vec<&Migration<T, S>> {
490 let ordered_migrations = self.ordered_migrations();
491 let migrations_to_perform = ordered_migrations
492 .into_iter()
493 .filter(|m| {
494 max_existing_model.is_none_or(|max_existing| m.model.version() > max_existing)
495 })
496 .collect::<Vec<_>>();
497
498 tracing::info!(
499 "{} migrations needed in OpenFGA store {} for model-prefix {}",
500 migrations_to_perform.len(),
501 self.store_name,
502 self.model_prefix
503 );
504 migrations_to_perform
505 }
506
507 pub async fn get_existing_versions(&mut self) -> Result<Vec<AuthorizationModelVersion>> {
514 let Some(store) = self.client.get_store_by_name(&self.store_name).await? else {
515 return Ok(vec![]);
516 };
517
518 let tuples = self
519 .client
520 .read_all_pages(
521 &store.id,
522 Some(ReadRequestTupleKey {
523 user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE).to_string(),
524 relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
525 object: format!("{}:", Self::MODEL_VERSION_TYPE).to_string(),
526 }),
527 crate::client::ConsistencyPreference::HigherConsistency,
528 DEFAULT_PAGE_SIZE,
529 MAX_PAGES,
530 )
531 .await?;
532 let existing_models = Self::parse_existing_models(tuples, &self.model_prefix);
533 Ok(existing_models.into_iter().collect())
534 }
535
536 fn parse_existing_models(
537 exist_tuples: Vec<Tuple>,
538 model_prefix: &str,
539 ) -> HashSet<AuthorizationModelVersion> {
540 exist_tuples
541 .into_iter()
542 .filter_map(|t| t.key)
543 .filter_map(|t| Self::parse_model_version_from_key(&t.object, model_prefix))
544 .collect()
545 }
546
547 fn parse_model_version_from_key(
548 model: &str,
549 model_prefix: &str,
550 ) -> Option<AuthorizationModelVersion> {
551 model
552 .strip_prefix(&format!("{}:", Self::MODEL_VERSION_TYPE))
554 .and_then(|model| {
555 model
556 .strip_prefix(&format!("{model_prefix}-"))
557 .and_then(|version| AuthorizationModelVersion::from_str(version).ok())
558 })
559 }
560
561 fn format_model_version_key(model_prefix: &str, version: AuthorizationModelVersion) -> String {
562 format!("{}:{}-{}", Self::MODEL_VERSION_TYPE, model_prefix, version)
563 }
564}
565
566impl VersionedAuthorizationModel {
567 pub(crate) fn new(model: AuthorizationModel, version: AuthorizationModelVersion) -> Self {
568 VersionedAuthorizationModel { model, version }
569 }
570
571 pub(crate) fn version(&self) -> AuthorizationModelVersion {
572 self.version
573 }
574
575 pub(crate) fn model(&self) -> &AuthorizationModel {
576 &self.model
577 }
578}
579
580impl AuthorizationModelVersion {
581 #[must_use]
582 pub fn new(major: u32, minor: u32) -> Self {
583 AuthorizationModelVersion { major, minor }
584 }
585
586 #[must_use]
587 pub fn major(&self) -> u32 {
588 self.major
589 }
590
591 #[must_use]
592 pub fn minor(&self) -> u32 {
593 self.minor
594 }
595}
596
597impl PartialOrd for AuthorizationModelVersion {
599 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
600 Some(self.cmp(other))
601 }
602}
603
604impl Ord for AuthorizationModelVersion {
605 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
606 (self.major, self.minor).cmp(&(other.major, other.minor))
607 }
608}
609
610impl std::fmt::Display for AuthorizationModelVersion {
611 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
612 write!(f, "{}.{}", self.major, self.minor)
613 }
614}
615
616impl FromStr for AuthorizationModelVersion {
617 type Err = Error;
618
619 fn from_str(s: &str) -> Result<Self> {
620 let parts = s.split('.').collect::<Vec<_>>();
621 if parts.len() != 2 {
622 return Err(Error::InvalidModelVersion(s.to_string()));
623 }
624
625 let major = parts[0]
626 .parse()
627 .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
628 let minor = parts[1]
629 .parse()
630 .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
631
632 Ok(AuthorizationModelVersion::new(major, minor))
633 }
634}
635
636impl<T, S> std::fmt::Debug for Migration<T, S> {
637 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638 f.debug_struct("Migration")
639 .field("model", &self.model)
640 .field("pre_migration_fn", &"...")
641 .field("post_migration_fn", &"...")
642 .finish()
643 }
644}
645
646#[cfg(test)]
647pub(crate) mod test {
648 use std::sync::Mutex;
649
650 use needs_env_var::needs_env_var;
651 use pretty_assertions::assert_eq;
652
653 use super::*;
654
655 type ChannelTupleManager = TupleModelManager<tonic::transport::Channel, ()>;
656
657 #[test]
658 fn test_ordering() {
659 let versioned_1_0 = AuthorizationModelVersion::new(1, 0);
660 let versioned_1_1 = AuthorizationModelVersion::new(1, 1);
661 let versioned_2_0 = AuthorizationModelVersion::new(2, 0);
662 let versioned_2_1 = AuthorizationModelVersion::new(2, 1);
663 let versioned_2_2 = AuthorizationModelVersion::new(2, 2);
664
665 assert!(versioned_1_0 < versioned_1_1);
666 assert!(versioned_1_1 < versioned_2_0);
667 assert!(versioned_2_0 < versioned_2_1);
668 assert!(versioned_2_1 < versioned_2_2);
669 }
670
671 #[test]
672 fn test_auth_model_version_str() {
673 let version = AuthorizationModelVersion::new(1, 0);
674 assert_eq!(version.to_string(), "1.0");
675 assert_eq!("1.0".parse::<AuthorizationModelVersion>().unwrap(), version);
676
677 let version = AuthorizationModelVersion::new(10, 2);
678 assert_eq!(version.to_string(), "10.2");
679 assert_eq!(
680 "10.2".parse::<AuthorizationModelVersion>().unwrap(),
681 version
682 );
683 }
684
685 #[test]
686 fn test_parse_model_version_from_key() {
687 let model_prefix = "test";
688 let model_version = AuthorizationModelVersion::new(1, 0);
689 let key = format!("model_version:{model_prefix}-{model_version}");
690 assert_eq!(
691 ChannelTupleManager::parse_model_version_from_key(&key, model_prefix),
692 Some(model_version)
693 );
694
695 assert!(
697 ChannelTupleManager::parse_model_version_from_key("model_version:1.0", model_prefix)
698 .is_none()
699 );
700
701 assert!(
703 ChannelTupleManager::parse_model_version_from_key(
704 "model_version:foo-1.0",
705 model_prefix
706 )
707 .is_none()
708 );
709
710 assert_eq!(
712 ChannelTupleManager::parse_model_version_from_key(
713 "model_version:other-model-10.200",
714 "other-model"
715 ),
716 Some(AuthorizationModelVersion::new(10, 200))
717 );
718 }
719
720 #[test]
721 fn test_format_model_version_key() {
722 let model_prefix = "test";
723 let model_version = AuthorizationModelVersion::new(1, 0);
724 let key = ChannelTupleManager::format_model_version_key(model_prefix, model_version);
725 assert_eq!(key, "model_version:test-1.0");
726 let parsed = ChannelTupleManager::parse_model_version_from_key(&key, model_prefix).unwrap();
727 assert_eq!(parsed, model_version);
728 }
729
730 #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
731 pub(crate) mod openfga {
732 use std::str::FromStr;
733
734 use pretty_assertions::assert_eq;
735
736 use super::*;
737 use crate::client::{OpenFgaServiceClient, ReadAuthorizationModelRequest};
738
739 pub(crate) async fn get_service_client() -> OpenFgaServiceClient<tonic::transport::Channel>
740 {
741 let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
742 let endpoint = tonic::transport::Endpoint::from_str(&endpoint).unwrap();
743 OpenFgaServiceClient::connect(endpoint)
744 .await
745 .expect("Client can be created")
746 }
747
748 pub(crate) async fn service_client_with_store()
749 -> (OpenFgaServiceClient<tonic::transport::Channel>, Store) {
750 let mut client = get_service_client().await;
751 let store_name = format!("test-{}", uuid::Uuid::now_v7());
752 let store = client.get_or_create_store(&store_name).await.unwrap();
753 (client, store)
754 }
755
756 #[tokio::test]
757 async fn test_get_existing_versions_nonexistent_store() {
758 let client = get_service_client().await;
759 let mut manager: TupleModelManager<_, ()> =
760 TupleModelManager::new(client, "nonexistent", "test");
761
762 let versions = manager.get_existing_versions().await.unwrap();
763 assert!(versions.is_empty());
764 }
765
766 #[tokio::test]
767 async fn test_get_existing_versions_nonexistent_auth_model() {
768 let mut client = get_service_client().await;
769 let store_name = format!("test-{}", uuid::Uuid::now_v7());
770 let _store = client.get_or_create_store(&store_name).await.unwrap();
771 let mut manager: TupleModelManager<_, ()> =
772 TupleModelManager::new(client, &store_name, "test");
773 let versions = manager.get_existing_versions().await.unwrap();
774 assert!(versions.is_empty());
775 }
776
777 #[tokio::test]
778 async fn test_get_authorization_model_id() {
779 let (mut client, store) = service_client_with_store().await;
780 let model_prefix = "test";
781 let version = AuthorizationModelVersion::new(1, 0);
782
783 let mut manager: TupleModelManager<_, ()> =
784 TupleModelManager::new(client.clone(), &store.name, model_prefix);
785
786 assert_eq!(
788 manager.get_authorization_model_id(version).await.unwrap(),
789 None
790 );
791
792 let model: AuthorizationModel =
794 serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
795 .unwrap();
796 client
797 .write_authorization_model(model.into_write_request(store.id.clone()))
798 .await
799 .unwrap();
800
801 client
803 .write(WriteRequest {
804 store_id: store.id.clone(),
805 writes: Some(WriteRequestWrites {
806 on_duplicate: String::new(),
807 tuple_keys: vec![
808 TupleKey {
809 user: "auth_model_id:111111".to_string(),
810 relation: "openfga_id".to_string(),
811 object: "model_version:test-1.0".to_string(),
812 condition: None,
813 },
814 TupleKey {
815 user: "auth_model_id:*".to_string(),
816 relation: "exists".to_string(),
817 object: "model_version:test-1.0".to_string(),
818 condition: None,
819 },
820 TupleKey {
822 user: "auth_model_id:*".to_string(),
823 relation: "exists".to_string(),
824 object: "model_version:test2-1.0".to_string(),
825 condition: None,
826 },
827 ],
828 }),
829 deletes: None,
830 authorization_model_id: String::new(),
831 })
832 .await
833 .unwrap();
834
835 assert_eq!(
836 manager.get_authorization_model_id(version).await.unwrap(),
837 Some("111111".to_string())
838 );
839 }
840
841 #[tokio::test]
842 async fn test_model_manager() {
843 let store_name = format!("test-{}", uuid::Uuid::now_v7());
844 let mut client = get_service_client().await;
845
846 let model_1_0: AuthorizationModel =
847 serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
848 .unwrap();
849
850 let version_1_0 = AuthorizationModelVersion::new(1, 0);
851
852 let migration_state = MigrationState::default();
853 let mut manager = TupleModelManager::new(client.clone(), &store_name, "test-model")
854 .add_model(
855 model_1_0.clone(),
856 version_1_0,
857 Some(v1_pre_migration_fn),
858 None::<MigrationFn<_, _>>,
859 );
860 manager.migrate(migration_state.clone()).await.unwrap();
861 assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
863 manager.migrate(migration_state.clone()).await.unwrap();
864 assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
866
867 let auth_model_id = manager
869 .get_authorization_model_id(version_1_0)
870 .await
871 .unwrap()
872 .unwrap();
873 let mut auth_model =
874 get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
875 auth_model.id = model_1_0.id.clone();
876 assert_eq!(
877 serde_json::to_value(&model_1_0).unwrap(),
878 serde_json::to_value(auth_model).unwrap()
879 );
880
881 let model_1_1: AuthorizationModel =
883 serde_json::from_str(include_str!("../tests/model-manager/v1.1/schema.json"))
884 .unwrap();
885 let version_1_1 = AuthorizationModelVersion::new(1, 1);
886 let mut manager = manager.add_model(
887 model_1_1.clone(),
888 version_1_1,
889 None::<MigrationFn<_, _>>,
890 Some(v2_post_migration_fn),
891 );
892 manager.migrate(migration_state.clone()).await.unwrap();
893 manager.migrate(migration_state.clone()).await.unwrap();
894 manager.migrate(migration_state.clone()).await.unwrap();
895
896 assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
898 assert_eq!(*migration_state.counter_2.lock().unwrap(), 1);
900
901 let auth_model_id = manager
903 .get_authorization_model_id(version_1_1)
904 .await
905 .unwrap()
906 .unwrap();
907 let mut auth_model =
908 get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
909 auth_model.id = model_1_1.id.clone();
910 assert_eq!(
911 serde_json::to_value(&model_1_1).unwrap(),
912 serde_json::to_value(auth_model).unwrap()
913 );
914 }
915
916 async fn get_auth_model_by_id(
917 client: &mut OpenFgaServiceClient<tonic::transport::Channel>,
918 store_name: &str,
919 auth_model_id: &str,
920 ) -> AuthorizationModel {
921 client
922 .read_authorization_model(ReadAuthorizationModelRequest {
923 store_id: client
924 .clone()
925 .get_store_by_name(store_name)
926 .await
927 .unwrap()
928 .unwrap()
929 .id,
930 id: auth_model_id.to_string(),
931 })
932 .await
933 .unwrap()
934 .into_inner()
935 .authorization_model
936 .unwrap()
937 }
938 }
939
940 #[derive(Default, Clone)]
941 struct MigrationState {
942 counter_1: Arc<Mutex<i32>>,
943 counter_2: Arc<Mutex<i32>>,
944 }
945
946 #[allow(clippy::unused_async)]
947 async fn v1_pre_migration_fn(
948 client: OpenFgaServiceClient<tonic::transport::Channel>,
949 _prev_model: Option<String>,
950 _curr_model: Option<String>,
951 state: MigrationState,
952 ) -> std::result::Result<(), StdError> {
953 let _ = client;
954 let mut counter = state.counter_1.lock().unwrap();
956 *counter += 1;
957 if *counter == 2 {
958 return Err(Box::new(Error::RequestFailed(Box::new(
959 tonic::Status::new(tonic::Code::Internal, "Test"),
960 ))));
961 }
962 Ok(())
963 }
964
965 #[allow(clippy::unused_async)]
966 async fn v2_post_migration_fn(
967 client: OpenFgaServiceClient<tonic::transport::Channel>,
968 _prev_model: Option<String>,
969 _curr_model: Option<String>,
970 state: MigrationState,
971 ) -> std::result::Result<(), StdError> {
972 let _ = client;
973 let mut counter = state.counter_2.lock().unwrap();
975 *counter += 1;
976 if *counter == 2 {
977 return Err(Box::new(Error::RequestFailed(Box::new(
978 tonic::Status::new(tonic::Code::Internal, "Test"),
979 ))));
980 }
981 Ok(())
982 }
983}