1use std::{
2 collections::{HashMap, HashSet},
3 future::Future,
4 hash::Hash,
5 pin::Pin,
6 str::FromStr,
7 sync::Arc,
8};
9
10use tonic::codegen::{Body, Bytes, StdError};
11
12use crate::{
13 client::{
14 AuthorizationModel, ConsistencyPreference, OpenFgaServiceClient, ReadRequestTupleKey,
15 Store, Tuple, TupleKey, WriteAuthorizationModelResponse, WriteRequest, WriteRequestWrites,
16 },
17 error::{Error, Result},
18};
19
20const DEFAULT_PAGE_SIZE: i32 = 100;
21const MAX_PAGES: u32 = 1000;
22
23#[derive(Debug, Copy, Clone, PartialEq, Hash, Eq)]
24pub struct AuthorizationModelVersion {
25 major: u32,
26 minor: u32,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30struct VersionedAuthorizationModel {
31 model: AuthorizationModel,
32 version: AuthorizationModelVersion,
33}
34
35#[derive(Debug, Clone)]
59pub struct TupleModelManager<T, S>
60where
61 T: tonic::client::GrpcService<tonic::body::Body>,
62 T::Error: Into<StdError>,
63 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
64 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
65{
66 client: OpenFgaServiceClient<T>,
67 store_name: String,
68 model_prefix: String,
69 migrations: HashMap<AuthorizationModelVersion, Migration<T, S>>,
70}
71
72#[derive(Clone)]
73struct Migration<T, S> {
74 model: VersionedAuthorizationModel,
78 pre_migration_fn: Option<BoxedMigrationFn<T, S>>,
79 post_migration_fn: Option<BoxedMigrationFn<T, S>>,
80}
81
82pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
84
85pub type MigrationFn<T, S> = fn(
98 OpenFgaServiceClient<T>,
99 Option<String>,
100 Option<String>,
101 S,
102) -> BoxFuture<'static, std::result::Result<(), StdError>>;
103
104type DynMigrationFn<T, S> = dyn Fn(
106 OpenFgaServiceClient<T>,
107 Option<String>,
108 Option<String>,
109 S,
110) -> BoxFuture<'static, std::result::Result<(), StdError>>;
111
112type BoxedMigrationFn<T, S> = Arc<DynMigrationFn<T, S>>;
114
115fn box_migration_fn<T, S, F, Fut>(f: F) -> BoxedMigrationFn<T, S>
117where
118 F: Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> Fut + Send + 'static,
119 Fut: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
120{
121 Arc::new(
122 move |client, prev_auth_model_id, active_auth_model_id, state| {
123 Box::pin(f(client, prev_auth_model_id, active_auth_model_id, state))
124 },
125 )
126}
127
128impl<T, S> TupleModelManager<T, S>
129where
130 T: tonic::client::GrpcService<tonic::body::Body>,
131 T: Clone,
132 T::Error: Into<StdError>,
133 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
134 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
135 S: Clone,
136{
137 const AUTH_MODEL_ID_TYPE: &'static str = "auth_model_id";
138 const MODEL_VERSION_EXISTS_RELATION: &'static str = "exists";
139 const MODEL_VERSION_TYPE: &'static str = "model_version";
140 const MODEL_VERSION_OPENFGA_ID_RELATION: &'static str = "openfga_id";
141
142 pub fn new(client: OpenFgaServiceClient<T>, store_name: &str, model_prefix: &str) -> Self {
148 TupleModelManager {
149 client,
150 model_prefix: model_prefix.to_string(),
151 store_name: store_name.to_string(),
152 migrations: HashMap::new(),
153 }
154 }
155
156 #[must_use]
162 pub fn add_model<FutPre, FutPost>(
163 mut self,
164 model: AuthorizationModel,
165 version: AuthorizationModelVersion,
166 pre_migration_fn: Option<
167 impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPre
168 + Send
169 + 'static,
170 >,
171 post_migration_fn: Option<
172 impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPost
173 + Send
174 + 'static,
175 >,
176 ) -> Self
177 where
178 FutPre: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
179 FutPost: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
180 {
181 let migration = Migration {
182 model: VersionedAuthorizationModel::new(model, version),
183 pre_migration_fn: pre_migration_fn.map(box_migration_fn),
184 post_migration_fn: post_migration_fn.map(box_migration_fn),
185 };
186 self.migrations.insert(migration.model.version(), migration);
187 self
188 }
189
190 #[allow(clippy::too_many_lines)]
205 pub async fn migrate(&mut self, state: S) -> Result<()> {
206 let span = tracing::span!(
207 tracing::Level::INFO,
208 "Running OpenFGA Migrations",
209 store_name = self.store_name,
210 model_prefix = self.model_prefix
211 );
212 let _enter = span.enter();
213
214 if self.migrations.is_empty() {
215 tracing::info!("No Migrations have been added. Nothing to do.");
216 return Ok(());
217 }
218
219 let store = self.client.get_or_create_store(&self.store_name).await?;
220 let existing_models = self.get_existing_versions().await?;
221 let max_existing_model = existing_models.iter().max().copied();
222 let mut curr_model_id = if let Some(version) = max_existing_model {
223 Some(self.require_authorization_model_id(version).await?)
224 } else {
225 None
226 };
227 let mut prev_model_id = None;
230
231 if let Some(max_existing_model) = max_existing_model {
232 tracing::info!(
233 "Currently the highest existing Model Version is: {}",
234 max_existing_model
235 );
236 } else {
237 tracing::info!("No model found in OpenFGA store");
238 }
239
240 let ordered_migrations = self.migrations_to_perform(max_existing_model);
241
242 let mut client = self.client.clone();
243 for migration in ordered_migrations {
244 tracing::info!("Migrating to model version: {}", migration.model.version());
245
246 if let Some(pre_migration_fn) = migration.pre_migration_fn.as_ref() {
248 pre_migration_fn(
249 client.clone(),
250 prev_model_id.clone(),
251 curr_model_id.clone(),
252 state.clone(),
253 )
254 .await
255 .map_err(|e| {
256 tracing::error!("Error in OpenFGA pre-migration hook: {:?}", e);
257 Error::MigrationHookFailed {
258 version: migration.model.version().to_string(),
259 error: Arc::new(e),
260 }
261 })?;
262 }
263
264 let request = migration
266 .model
267 .model()
268 .clone()
269 .into_write_request(store.id.clone());
270 let written_model = client
271 .write_authorization_model(request)
272 .await
273 .map_err(|e| {
274 tracing::error!("Error writing model: {:?}", e);
275 Error::RequestFailed(Box::new(e))
276 })?;
277 tracing::debug!("Model written: {:?}", written_model);
278
279 prev_model_id.clone_from(&curr_model_id);
281 curr_model_id = Some(written_model.get_ref().authorization_model_id.clone());
282
283 if let Some(post_migration_fn) = migration.post_migration_fn.as_ref() {
285 post_migration_fn(
286 client.clone(),
287 prev_model_id.clone(),
288 curr_model_id.clone(),
289 state.clone(),
290 )
291 .await
292 .map_err(|e| {
293 tracing::error!("Error in OpenFGA post-migration hook: {:?}", e);
294 Error::MigrationHookFailed {
295 version: migration.model.version().to_string(),
296 error: Arc::new(e),
297 }
298 })?;
299 }
300
301 Self::mark_as_applied(
303 &mut client,
304 &self.model_prefix,
305 &store,
306 migration.model.version(),
307 written_model.into_inner(),
308 )
309 .await?;
310 }
311
312 Ok(())
313 }
314
315 pub async fn get_authorization_model_id(
322 &mut self,
323 version: AuthorizationModelVersion,
324 ) -> Result<Option<String>> {
325 let store = self
326 .client
327 .get_store_by_name(&self.store_name)
328 .await?
329 .ok_or_else(|| {
330 tracing::error!("Store with name {} not found", self.store_name);
331 Error::StoreNotFound(self.store_name.clone())
332 })?;
333
334 let applied_models = self
335 .client
336 .read_all_pages(
337 &store.id,
338 Some(ReadRequestTupleKey {
339 user: String::new(),
340 relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
341 object: Self::format_model_version_key(&self.model_prefix, version),
342 }),
343 ConsistencyPreference::HigherConsistency,
344 DEFAULT_PAGE_SIZE,
345 MAX_PAGES,
346 )
347 .await?;
348
349 let applied_models = applied_models
350 .into_iter()
351 .filter_map(|t| t.key)
352 .filter_map(|t| {
353 t.user
354 .strip_prefix(&format!("{}:", Self::AUTH_MODEL_ID_TYPE))
355 .map(ToString::to_string)
356 })
357 .collect::<Vec<_>>();
358
359 if applied_models.len() > 1 {
360 tracing::error!(
361 "Multiple authorization models with model prefix {} for version {} found.",
362 self.model_prefix,
363 version
364 );
365 return Err(Error::AmbiguousModelVersion {
366 model_prefix: self.model_prefix.clone(),
367 version: version.to_string(),
368 });
369 }
370
371 let model_id = applied_models.into_iter().next().map(|openfga_id| {
372 tracing::info!(
373 "Authorization model for version {version} found in OpenFGA store {}. Model ID: {openfga_id}",
374 self.store_name,
375 );
376 openfga_id
377 });
378
379 Ok(model_id)
380 }
381
382 async fn require_authorization_model_id(
385 &mut self,
386 version: AuthorizationModelVersion,
387 ) -> Result<String> {
388 self.get_authorization_model_id(version)
389 .await?
390 .ok_or_else(|| {
391 tracing::error!("Missing authorization model id for model version {version}");
392 Error::MissingAuthorizationModelId {
393 model_prefix: self.model_prefix.clone(),
394 version: version.to_string(),
395 }
396 })
397 }
398
399 async fn mark_as_applied(
401 client: &mut OpenFgaServiceClient<T>,
402 model_prefix: &str,
403 store: &Store,
404 version: AuthorizationModelVersion,
405 write_response: WriteAuthorizationModelResponse,
406 ) -> Result<()> {
407 let authorization_model_id = write_response.authorization_model_id;
408 let object = Self::format_model_version_key(model_prefix, version);
409
410 let write_request = WriteRequest {
411 store_id: store.id.clone(),
412 writes: Some(WriteRequestWrites {
413 on_duplicate: String::new(),
414 tuple_keys: vec![
415 TupleKey {
416 user: format!("{}:{authorization_model_id}", Self::AUTH_MODEL_ID_TYPE),
417 relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
418 object: object.clone(),
419 condition: None,
420 },
421 TupleKey {
422 user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE),
423 relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
424 object,
425 condition: None,
426 },
427 ],
428 }),
429 deletes: None,
430 authorization_model_id: authorization_model_id.clone(),
431 };
432 client.write(write_request.clone()).await.map_err(|e| {
433 tracing::error!("Error marking model as applied: {:?}", e);
434 Error::RequestFailed(Box::new(e))
435 })?;
436 Ok(())
437 }
438
439 fn ordered_migrations(&self) -> Vec<&Migration<T, S>> {
442 let mut migrations = self.migrations.values().collect::<Vec<_>>();
443 migrations.sort_unstable_by_key(|m| m.model.version());
444 migrations
445 }
446
447 fn migrations_to_perform(
449 &self,
450 max_existing_model: Option<AuthorizationModelVersion>,
451 ) -> Vec<&Migration<T, S>> {
452 let ordered_migrations = self.ordered_migrations();
453 let migrations_to_perform = ordered_migrations
454 .into_iter()
455 .filter(|m| {
456 max_existing_model.is_none_or(|max_existing| m.model.version() > max_existing)
457 })
458 .collect::<Vec<_>>();
459
460 tracing::info!(
461 "{} migrations needed in OpenFGA store {} for model-prefix {}",
462 migrations_to_perform.len(),
463 self.store_name,
464 self.model_prefix
465 );
466 migrations_to_perform
467 }
468
469 pub async fn get_existing_versions(&mut self) -> Result<Vec<AuthorizationModelVersion>> {
476 let Some(store) = self.client.get_store_by_name(&self.store_name).await? else {
477 return Ok(vec![]);
478 };
479
480 let tuples = self
481 .client
482 .read_all_pages(
483 &store.id,
484 Some(ReadRequestTupleKey {
485 user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE).to_string(),
486 relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
487 object: format!("{}:", Self::MODEL_VERSION_TYPE).to_string(),
488 }),
489 crate::client::ConsistencyPreference::HigherConsistency,
490 DEFAULT_PAGE_SIZE,
491 MAX_PAGES,
492 )
493 .await?;
494 let existing_models = Self::parse_existing_models(tuples, &self.model_prefix);
495 Ok(existing_models.into_iter().collect())
496 }
497
498 fn parse_existing_models(
499 exist_tuples: Vec<Tuple>,
500 model_prefix: &str,
501 ) -> HashSet<AuthorizationModelVersion> {
502 exist_tuples
503 .into_iter()
504 .filter_map(|t| t.key)
505 .filter_map(|t| Self::parse_model_version_from_key(&t.object, model_prefix))
506 .collect()
507 }
508
509 fn parse_model_version_from_key(
510 model: &str,
511 model_prefix: &str,
512 ) -> Option<AuthorizationModelVersion> {
513 model
514 .strip_prefix(&format!("{}:", Self::MODEL_VERSION_TYPE))
516 .and_then(|model| {
517 model
518 .strip_prefix(&format!("{model_prefix}-"))
519 .and_then(|version| AuthorizationModelVersion::from_str(version).ok())
520 })
521 }
522
523 fn format_model_version_key(model_prefix: &str, version: AuthorizationModelVersion) -> String {
524 format!("{}:{}-{}", Self::MODEL_VERSION_TYPE, model_prefix, version)
525 }
526}
527
528impl VersionedAuthorizationModel {
529 pub(crate) fn new(model: AuthorizationModel, version: AuthorizationModelVersion) -> Self {
530 VersionedAuthorizationModel { model, version }
531 }
532
533 pub(crate) fn version(&self) -> AuthorizationModelVersion {
534 self.version
535 }
536
537 pub(crate) fn model(&self) -> &AuthorizationModel {
538 &self.model
539 }
540}
541
542impl AuthorizationModelVersion {
543 #[must_use]
544 pub fn new(major: u32, minor: u32) -> Self {
545 AuthorizationModelVersion { major, minor }
546 }
547
548 #[must_use]
549 pub fn major(&self) -> u32 {
550 self.major
551 }
552
553 #[must_use]
554 pub fn minor(&self) -> u32 {
555 self.minor
556 }
557}
558
559impl PartialOrd for AuthorizationModelVersion {
561 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
562 Some(self.cmp(other))
563 }
564}
565
566impl Ord for AuthorizationModelVersion {
567 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
568 (self.major, self.minor).cmp(&(other.major, other.minor))
569 }
570}
571
572impl std::fmt::Display for AuthorizationModelVersion {
573 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
574 write!(f, "{}.{}", self.major, self.minor)
575 }
576}
577
578impl FromStr for AuthorizationModelVersion {
579 type Err = Error;
580
581 fn from_str(s: &str) -> Result<Self> {
582 let parts = s.split('.').collect::<Vec<_>>();
583 if parts.len() != 2 {
584 return Err(Error::InvalidModelVersion(s.to_string()));
585 }
586
587 let major = parts[0]
588 .parse()
589 .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
590 let minor = parts[1]
591 .parse()
592 .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
593
594 Ok(AuthorizationModelVersion::new(major, minor))
595 }
596}
597
598impl<T, S> std::fmt::Debug for Migration<T, S> {
599 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
600 f.debug_struct("Migration")
601 .field("model", &self.model)
602 .field("pre_migration_fn", &"...")
603 .field("post_migration_fn", &"...")
604 .finish()
605 }
606}
607
608#[cfg(test)]
609pub(crate) mod test {
610 use std::sync::Mutex;
611
612 use needs_env_var::needs_env_var;
613 use pretty_assertions::assert_eq;
614
615 use super::*;
616
617 type ChannelTupleManager = TupleModelManager<tonic::transport::Channel, ()>;
618
619 #[test]
620 fn test_ordering() {
621 let versioned_1_0 = AuthorizationModelVersion::new(1, 0);
622 let versioned_1_1 = AuthorizationModelVersion::new(1, 1);
623 let versioned_2_0 = AuthorizationModelVersion::new(2, 0);
624 let versioned_2_1 = AuthorizationModelVersion::new(2, 1);
625 let versioned_2_2 = AuthorizationModelVersion::new(2, 2);
626
627 assert!(versioned_1_0 < versioned_1_1);
628 assert!(versioned_1_1 < versioned_2_0);
629 assert!(versioned_2_0 < versioned_2_1);
630 assert!(versioned_2_1 < versioned_2_2);
631 }
632
633 #[test]
634 fn test_auth_model_version_str() {
635 let version = AuthorizationModelVersion::new(1, 0);
636 assert_eq!(version.to_string(), "1.0");
637 assert_eq!("1.0".parse::<AuthorizationModelVersion>().unwrap(), version);
638
639 let version = AuthorizationModelVersion::new(10, 2);
640 assert_eq!(version.to_string(), "10.2");
641 assert_eq!(
642 "10.2".parse::<AuthorizationModelVersion>().unwrap(),
643 version
644 );
645 }
646
647 #[test]
648 fn test_parse_model_version_from_key() {
649 let model_prefix = "test";
650 let model_version = AuthorizationModelVersion::new(1, 0);
651 let key = format!("model_version:{model_prefix}-{model_version}");
652 assert_eq!(
653 ChannelTupleManager::parse_model_version_from_key(&key, model_prefix),
654 Some(model_version)
655 );
656
657 assert!(
659 ChannelTupleManager::parse_model_version_from_key("model_version:1.0", model_prefix)
660 .is_none()
661 );
662
663 assert!(
665 ChannelTupleManager::parse_model_version_from_key(
666 "model_version:foo-1.0",
667 model_prefix
668 )
669 .is_none()
670 );
671
672 assert_eq!(
674 ChannelTupleManager::parse_model_version_from_key(
675 "model_version:other-model-10.200",
676 "other-model"
677 ),
678 Some(AuthorizationModelVersion::new(10, 200))
679 );
680 }
681
682 #[test]
683 fn test_format_model_version_key() {
684 let model_prefix = "test";
685 let model_version = AuthorizationModelVersion::new(1, 0);
686 let key = ChannelTupleManager::format_model_version_key(model_prefix, model_version);
687 assert_eq!(key, "model_version:test-1.0");
688 let parsed = ChannelTupleManager::parse_model_version_from_key(&key, model_prefix).unwrap();
689 assert_eq!(parsed, model_version);
690 }
691
692 #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
693 pub(crate) mod openfga {
694 use std::str::FromStr;
695
696 use pretty_assertions::assert_eq;
697
698 use super::*;
699 use crate::client::{OpenFgaServiceClient, ReadAuthorizationModelRequest};
700
701 pub(crate) async fn get_service_client() -> OpenFgaServiceClient<tonic::transport::Channel>
702 {
703 let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
704 let endpoint = tonic::transport::Endpoint::from_str(&endpoint).unwrap();
705 OpenFgaServiceClient::connect(endpoint)
706 .await
707 .expect("Client can be created")
708 }
709
710 pub(crate) async fn service_client_with_store()
711 -> (OpenFgaServiceClient<tonic::transport::Channel>, Store) {
712 let mut client = get_service_client().await;
713 let store_name = format!("test-{}", uuid::Uuid::now_v7());
714 let store = client.get_or_create_store(&store_name).await.unwrap();
715 (client, store)
716 }
717
718 #[tokio::test]
719 async fn test_get_existing_versions_nonexistent_store() {
720 let client = get_service_client().await;
721 let mut manager: TupleModelManager<_, ()> =
722 TupleModelManager::new(client, "nonexistent", "test");
723
724 let versions = manager.get_existing_versions().await.unwrap();
725 assert!(versions.is_empty());
726 }
727
728 #[tokio::test]
729 async fn test_get_existing_versions_nonexistent_auth_model() {
730 let mut client = get_service_client().await;
731 let store_name = format!("test-{}", uuid::Uuid::now_v7());
732 let _store = client.get_or_create_store(&store_name).await.unwrap();
733 let mut manager: TupleModelManager<_, ()> =
734 TupleModelManager::new(client, &store_name, "test");
735 let versions = manager.get_existing_versions().await.unwrap();
736 assert!(versions.is_empty());
737 }
738
739 #[tokio::test]
740 async fn test_get_authorization_model_id() {
741 let (mut client, store) = service_client_with_store().await;
742 let model_prefix = "test";
743 let version = AuthorizationModelVersion::new(1, 0);
744
745 let mut manager: TupleModelManager<_, ()> =
746 TupleModelManager::new(client.clone(), &store.name, model_prefix);
747
748 assert_eq!(
750 manager.get_authorization_model_id(version).await.unwrap(),
751 None
752 );
753
754 let model: AuthorizationModel =
756 serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
757 .unwrap();
758 client
759 .write_authorization_model(model.into_write_request(store.id.clone()))
760 .await
761 .unwrap();
762
763 client
765 .write(WriteRequest {
766 store_id: store.id.clone(),
767 writes: Some(WriteRequestWrites {
768 on_duplicate: String::new(),
769 tuple_keys: vec![
770 TupleKey {
771 user: "auth_model_id:111111".to_string(),
772 relation: "openfga_id".to_string(),
773 object: "model_version:test-1.0".to_string(),
774 condition: None,
775 },
776 TupleKey {
777 user: "auth_model_id:*".to_string(),
778 relation: "exists".to_string(),
779 object: "model_version:test-1.0".to_string(),
780 condition: None,
781 },
782 TupleKey {
784 user: "auth_model_id:*".to_string(),
785 relation: "exists".to_string(),
786 object: "model_version:test2-1.0".to_string(),
787 condition: None,
788 },
789 ],
790 }),
791 deletes: None,
792 authorization_model_id: String::new(),
793 })
794 .await
795 .unwrap();
796
797 assert_eq!(
798 manager.get_authorization_model_id(version).await.unwrap(),
799 Some("111111".to_string())
800 );
801 }
802
803 #[tokio::test]
804 async fn test_model_manager() {
805 let store_name = format!("test-{}", uuid::Uuid::now_v7());
806 let mut client = get_service_client().await;
807
808 let model_1_0: AuthorizationModel =
809 serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
810 .unwrap();
811
812 let version_1_0 = AuthorizationModelVersion::new(1, 0);
813
814 let migration_state = MigrationState::default();
815 let mut manager = TupleModelManager::new(client.clone(), &store_name, "test-model")
816 .add_model(
817 model_1_0.clone(),
818 version_1_0,
819 Some(v1_pre_migration_fn),
820 None::<MigrationFn<_, _>>,
821 );
822 manager.migrate(migration_state.clone()).await.unwrap();
823 assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
825 manager.migrate(migration_state.clone()).await.unwrap();
826 assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
828
829 let auth_model_id = manager
831 .get_authorization_model_id(version_1_0)
832 .await
833 .unwrap()
834 .unwrap();
835 let mut auth_model =
836 get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
837 auth_model.id = model_1_0.id.clone();
838 assert_eq!(
839 serde_json::to_value(&model_1_0).unwrap(),
840 serde_json::to_value(auth_model).unwrap()
841 );
842
843 let model_1_1: AuthorizationModel =
845 serde_json::from_str(include_str!("../tests/model-manager/v1.1/schema.json"))
846 .unwrap();
847 let version_1_1 = AuthorizationModelVersion::new(1, 1);
848 let mut manager = manager.add_model(
849 model_1_1.clone(),
850 version_1_1,
851 None::<MigrationFn<_, _>>,
852 Some(v2_post_migration_fn),
853 );
854 manager.migrate(migration_state.clone()).await.unwrap();
855 manager.migrate(migration_state.clone()).await.unwrap();
856 manager.migrate(migration_state.clone()).await.unwrap();
857
858 assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
860 assert_eq!(*migration_state.counter_2.lock().unwrap(), 1);
862
863 let auth_model_id = manager
865 .get_authorization_model_id(version_1_1)
866 .await
867 .unwrap()
868 .unwrap();
869 let mut auth_model =
870 get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
871 auth_model.id = model_1_1.id.clone();
872 assert_eq!(
873 serde_json::to_value(&model_1_1).unwrap(),
874 serde_json::to_value(auth_model).unwrap()
875 );
876 }
877
878 async fn get_auth_model_by_id(
879 client: &mut OpenFgaServiceClient<tonic::transport::Channel>,
880 store_name: &str,
881 auth_model_id: &str,
882 ) -> AuthorizationModel {
883 client
884 .read_authorization_model(ReadAuthorizationModelRequest {
885 store_id: client
886 .clone()
887 .get_store_by_name(store_name)
888 .await
889 .unwrap()
890 .unwrap()
891 .id,
892 id: auth_model_id.to_string(),
893 })
894 .await
895 .unwrap()
896 .into_inner()
897 .authorization_model
898 .unwrap()
899 }
900 }
901
902 #[derive(Default, Clone)]
903 struct MigrationState {
904 counter_1: Arc<Mutex<i32>>,
905 counter_2: Arc<Mutex<i32>>,
906 }
907
908 #[allow(clippy::unused_async)]
909 async fn v1_pre_migration_fn(
910 client: OpenFgaServiceClient<tonic::transport::Channel>,
911 _prev_model: Option<String>,
912 _curr_model: Option<String>,
913 state: MigrationState,
914 ) -> std::result::Result<(), StdError> {
915 let _ = client;
916 let mut counter = state.counter_1.lock().unwrap();
918 *counter += 1;
919 if *counter == 2 {
920 return Err(Box::new(Error::RequestFailed(Box::new(
921 tonic::Status::new(tonic::Code::Internal, "Test"),
922 ))));
923 }
924 Ok(())
925 }
926
927 #[allow(clippy::unused_async)]
928 async fn v2_post_migration_fn(
929 client: OpenFgaServiceClient<tonic::transport::Channel>,
930 _prev_model: Option<String>,
931 _curr_model: Option<String>,
932 state: MigrationState,
933 ) -> std::result::Result<(), StdError> {
934 let _ = client;
935 let mut counter = state.counter_2.lock().unwrap();
937 *counter += 1;
938 if *counter == 2 {
939 return Err(Box::new(Error::RequestFailed(Box::new(
940 tonic::Status::new(tonic::Code::Internal, "Test"),
941 ))));
942 }
943 Ok(())
944 }
945}