1use std::{
2 collections::{HashMap, HashSet},
3 future::Future,
4 hash::Hash,
5 pin::Pin,
6 str::FromStr,
7 sync::Arc,
8};
9
10use tonic::codegen::{Body, Bytes, StdError};
11
12use crate::{
13 client::{
14 AuthorizationModel, ConsistencyPreference, OpenFgaServiceClient, ReadRequestTupleKey,
15 Store, Tuple, TupleKey, WriteAuthorizationModelResponse, WriteRequest, WriteRequestWrites,
16 },
17 error::{Error, Result},
18};
19
20const DEFAULT_PAGE_SIZE: i32 = 100;
21const MAX_PAGES: u32 = 1000;
22
23#[derive(Debug, Copy, Clone, PartialEq, Hash, Eq)]
24pub struct AuthorizationModelVersion {
25 major: u32,
26 minor: u32,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30struct VersionedAuthorizationModel {
31 model: AuthorizationModel,
32 version: AuthorizationModelVersion,
33}
34
35#[derive(Debug, Clone)]
59pub struct TupleModelManager<T, S>
60where
61 T: tonic::client::GrpcService<tonic::body::BoxBody>,
62 T::Error: Into<StdError>,
63 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
64 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
65{
66 client: OpenFgaServiceClient<T>,
67 store_name: String,
68 model_prefix: String,
69 migrations: HashMap<AuthorizationModelVersion, Migration<T, S>>,
70}
71
72#[derive(Clone)]
73struct Migration<T, S> {
74 model: VersionedAuthorizationModel,
78 pre_migration_fn: Option<BoxedMigrationFn<T, S>>,
79 post_migration_fn: Option<BoxedMigrationFn<T, S>>,
80}
81
82pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
84
85pub type MigrationFn<T, S> = fn(
98 OpenFgaServiceClient<T>,
99 Option<String>,
100 Option<String>,
101 S,
102) -> BoxFuture<'static, std::result::Result<(), StdError>>;
103
104type DynMigrationFn<T, S> = dyn Fn(
106 OpenFgaServiceClient<T>,
107 Option<String>,
108 Option<String>,
109 S,
110) -> BoxFuture<'static, std::result::Result<(), StdError>>;
111
112type BoxedMigrationFn<T, S> = Arc<DynMigrationFn<T, S>>;
114
115fn box_migration_fn<T, S, F, Fut>(f: F) -> BoxedMigrationFn<T, S>
117where
118 F: Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> Fut + Send + 'static,
119 Fut: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
120{
121 Arc::new(
122 move |client, prev_auth_model_id, active_auth_model_id, state| {
123 Box::pin(f(client, prev_auth_model_id, active_auth_model_id, state))
124 },
125 )
126}
127
128impl<T, S> TupleModelManager<T, S>
129where
130 T: tonic::client::GrpcService<tonic::body::BoxBody>,
131 T: Clone,
132 T::Error: Into<StdError>,
133 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
134 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
135 S: Clone,
136{
137 const AUTH_MODEL_ID_TYPE: &'static str = "auth_model_id";
138 const MODEL_VERSION_EXISTS_RELATION: &'static str = "exists";
139 const MODEL_VERSION_TYPE: &'static str = "model_version";
140 const MODEL_VERSION_OPENFGA_ID_RELATION: &'static str = "openfga_id";
141
142 pub fn new(client: OpenFgaServiceClient<T>, store_name: &str, model_prefix: &str) -> Self {
148 TupleModelManager {
149 client,
150 model_prefix: model_prefix.to_string(),
151 store_name: store_name.to_string(),
152 migrations: HashMap::new(),
153 }
154 }
155
156 #[must_use]
162 pub fn add_model<FutPre, FutPost>(
163 mut self,
164 model: AuthorizationModel,
165 version: AuthorizationModelVersion,
166 pre_migration_fn: Option<
167 impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPre
168 + Send
169 + 'static,
170 >,
171 post_migration_fn: Option<
172 impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPost
173 + Send
174 + 'static,
175 >,
176 ) -> Self
177 where
178 FutPre: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
179 FutPost: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
180 {
181 let migration = Migration {
182 model: VersionedAuthorizationModel::new(model, version),
183 pre_migration_fn: pre_migration_fn.map(box_migration_fn),
184 post_migration_fn: post_migration_fn.map(box_migration_fn),
185 };
186 self.migrations.insert(migration.model.version(), migration);
187 self
188 }
189
190 #[allow(clippy::too_many_lines)]
205 pub async fn migrate(&mut self, state: S) -> Result<()> {
206 let span = tracing::span!(
207 tracing::Level::INFO,
208 "Running OpenFGA Migrations",
209 store_name = self.store_name,
210 model_prefix = self.model_prefix
211 );
212 let _enter = span.enter();
213
214 if self.migrations.is_empty() {
215 tracing::info!("No Migrations have been added. Nothing to do.");
216 return Ok(());
217 }
218
219 let store = self.client.get_or_create_store(&self.store_name).await?;
220 let existing_models = self.get_existing_versions().await?;
221 let max_existing_model = existing_models.iter().max().copied();
222 let mut curr_model_id = if let Some(version) = max_existing_model {
223 Some(self.require_authorization_model_id(version).await?)
224 } else {
225 None
226 };
227 let mut prev_model_id = None;
230
231 if let Some(max_existing_model) = max_existing_model {
232 tracing::info!(
233 "Currently the highest existing Model Version is: {}",
234 max_existing_model
235 );
236 } else {
237 tracing::info!("No model found in OpenFGA store");
238 }
239
240 let ordered_migrations = self.migrations_to_perform(max_existing_model);
241
242 let mut client = self.client.clone();
243 for migration in ordered_migrations {
244 tracing::info!("Migrating to model version: {}", migration.model.version());
245
246 if let Some(pre_migration_fn) = migration.pre_migration_fn.as_ref() {
248 pre_migration_fn(
249 client.clone(),
250 prev_model_id.clone(),
251 curr_model_id.clone(),
252 state.clone(),
253 )
254 .await
255 .map_err(|e| {
256 tracing::error!("Error in OpenFGA pre-migration hook: {:?}", e);
257 Error::MigrationHookFailed {
258 version: migration.model.version().to_string(),
259 error: Arc::new(e),
260 }
261 })?;
262 }
263
264 let request = migration
266 .model
267 .model()
268 .clone()
269 .into_write_request(store.id.clone());
270 let written_model = client
271 .write_authorization_model(request)
272 .await
273 .map_err(|e| {
274 tracing::error!("Error writing model: {:?}", e);
275 Error::RequestFailed(Box::new(e))
276 })?;
277 tracing::debug!("Model written: {:?}", written_model);
278
279 prev_model_id.clone_from(&curr_model_id);
281 curr_model_id = Some(written_model.get_ref().authorization_model_id.to_string());
282
283 if let Some(post_migration_fn) = migration.post_migration_fn.as_ref() {
285 post_migration_fn(
286 client.clone(),
287 prev_model_id.clone(),
288 curr_model_id.clone(),
289 state.clone(),
290 )
291 .await
292 .map_err(|e| {
293 tracing::error!("Error in OpenFGA post-migration hook: {:?}", e);
294 Error::MigrationHookFailed {
295 version: migration.model.version().to_string(),
296 error: Arc::new(e),
297 }
298 })?;
299 }
300
301 Self::mark_as_applied(
303 &mut client,
304 &self.model_prefix,
305 &store,
306 migration.model.version(),
307 written_model.into_inner(),
308 )
309 .await?;
310 }
311
312 Ok(())
313 }
314
315 pub async fn get_authorization_model_id(
322 &mut self,
323 version: AuthorizationModelVersion,
324 ) -> Result<Option<String>> {
325 let store = self
326 .client
327 .get_store_by_name(&self.store_name)
328 .await?
329 .ok_or_else(|| {
330 tracing::error!("Store with name {} not found", self.store_name);
331 Error::StoreNotFound(self.store_name.clone())
332 })?;
333
334 let applied_models = self
335 .client
336 .read_all_pages(
337 &store.id,
338 Some(ReadRequestTupleKey {
339 user: String::new(),
340 relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
341 object: Self::format_model_version_key(&self.model_prefix, version),
342 }),
343 ConsistencyPreference::HigherConsistency,
344 DEFAULT_PAGE_SIZE,
345 MAX_PAGES,
346 )
347 .await?;
348
349 let applied_models = applied_models
350 .into_iter()
351 .filter_map(|t| t.key)
352 .filter_map(|t| {
353 t.user
354 .strip_prefix(&format!("{}:", Self::AUTH_MODEL_ID_TYPE))
355 .map(ToString::to_string)
356 })
357 .collect::<Vec<_>>();
358
359 if applied_models.len() > 1 {
360 tracing::error!(
361 "Multiple authorization models with model prefix {} for version {} found.",
362 self.model_prefix,
363 version
364 );
365 return Err(Error::AmbiguousModelVersion {
366 model_prefix: self.model_prefix.clone(),
367 version: version.to_string(),
368 });
369 }
370
371 let model_id = applied_models.into_iter().next().map(|openfga_id| {
372 tracing::info!(
373 "Authorization model for version {version} found in OpenFGA store {}. Model ID: {openfga_id}",
374 self.store_name,
375 );
376 openfga_id
377 });
378
379 Ok(model_id)
380 }
381
382 async fn require_authorization_model_id(
385 &mut self,
386 version: AuthorizationModelVersion,
387 ) -> Result<String> {
388 self.get_authorization_model_id(version)
389 .await?
390 .ok_or_else(|| {
391 tracing::error!("Missing authorization model id for model version {version}");
392 Error::MissingAuthorizationModelId {
393 model_prefix: self.model_prefix.clone(),
394 version: version.to_string(),
395 }
396 })
397 }
398
399 async fn mark_as_applied(
401 client: &mut OpenFgaServiceClient<T>,
402 model_prefix: &str,
403 store: &Store,
404 version: AuthorizationModelVersion,
405 write_response: WriteAuthorizationModelResponse,
406 ) -> Result<()> {
407 let authorization_model_id = write_response.authorization_model_id;
408 let object = Self::format_model_version_key(model_prefix, version);
409
410 let write_request = WriteRequest {
411 store_id: store.id.clone(),
412 writes: Some(WriteRequestWrites {
413 tuple_keys: vec![
414 TupleKey {
415 user: format!("{}:{authorization_model_id}", Self::AUTH_MODEL_ID_TYPE),
416 relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
417 object: object.clone(),
418 condition: None,
419 },
420 TupleKey {
421 user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE),
422 relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
423 object,
424 condition: None,
425 },
426 ],
427 }),
428 deletes: None,
429 authorization_model_id: authorization_model_id.to_string(),
430 };
431 client.write(write_request.clone()).await.map_err(|e| {
432 tracing::error!("Error marking model as applied: {:?}", e);
433 Error::RequestFailed(Box::new(e))
434 })?;
435 Ok(())
436 }
437
438 fn ordered_migrations(&self) -> Vec<&Migration<T, S>> {
441 let mut migrations = self.migrations.values().collect::<Vec<_>>();
442 migrations.sort_unstable_by_key(|m| m.model.version());
443 migrations
444 }
445
446 fn migrations_to_perform(
448 &self,
449 max_existing_model: Option<AuthorizationModelVersion>,
450 ) -> Vec<&Migration<T, S>> {
451 let ordered_migrations = self.ordered_migrations();
452 let migrations_to_perform = ordered_migrations
453 .into_iter()
454 .filter(|m| {
455 max_existing_model.is_none_or(|max_existing| m.model.version() > max_existing)
456 })
457 .collect::<Vec<_>>();
458
459 tracing::info!(
460 "{} migrations needed in OpenFGA store {} for model-prefix {}",
461 migrations_to_perform.len(),
462 self.store_name,
463 self.model_prefix
464 );
465 migrations_to_perform
466 }
467
468 pub async fn get_existing_versions(&mut self) -> Result<Vec<AuthorizationModelVersion>> {
475 let Some(store) = self.client.get_store_by_name(&self.store_name).await? else {
476 return Ok(vec![]);
477 };
478
479 let tuples = self
480 .client
481 .read_all_pages(
482 &store.id,
483 Some(ReadRequestTupleKey {
484 user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE).to_string(),
485 relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
486 object: format!("{}:", Self::MODEL_VERSION_TYPE).to_string(),
487 }),
488 crate::client::ConsistencyPreference::HigherConsistency,
489 DEFAULT_PAGE_SIZE,
490 MAX_PAGES,
491 )
492 .await?;
493 let existing_models = Self::parse_existing_models(tuples, &self.model_prefix);
494 Ok(existing_models.into_iter().collect())
495 }
496
497 fn parse_existing_models(
498 exist_tuples: Vec<Tuple>,
499 model_prefix: &str,
500 ) -> HashSet<AuthorizationModelVersion> {
501 exist_tuples
502 .into_iter()
503 .filter_map(|t| t.key)
504 .filter_map(|t| Self::parse_model_version_from_key(&t.object, model_prefix))
505 .collect()
506 }
507
508 fn parse_model_version_from_key(
509 model: &str,
510 model_prefix: &str,
511 ) -> Option<AuthorizationModelVersion> {
512 model
513 .strip_prefix(&format!("{}:", Self::MODEL_VERSION_TYPE))
515 .and_then(|model| {
516 model
517 .strip_prefix(&format!("{model_prefix}-"))
518 .and_then(|version| AuthorizationModelVersion::from_str(version).ok())
519 })
520 }
521
522 fn format_model_version_key(model_prefix: &str, version: AuthorizationModelVersion) -> String {
523 format!("{}:{}-{}", Self::MODEL_VERSION_TYPE, model_prefix, version)
524 }
525}
526
527impl VersionedAuthorizationModel {
528 pub(crate) fn new(model: AuthorizationModel, version: AuthorizationModelVersion) -> Self {
529 VersionedAuthorizationModel { model, version }
530 }
531
532 pub(crate) fn version(&self) -> AuthorizationModelVersion {
533 self.version
534 }
535
536 pub(crate) fn model(&self) -> &AuthorizationModel {
537 &self.model
538 }
539}
540
541impl AuthorizationModelVersion {
542 #[must_use]
543 pub fn new(major: u32, minor: u32) -> Self {
544 AuthorizationModelVersion { major, minor }
545 }
546
547 #[must_use]
548 pub fn major(&self) -> u32 {
549 self.major
550 }
551
552 #[must_use]
553 pub fn minor(&self) -> u32 {
554 self.minor
555 }
556}
557
558impl PartialOrd for AuthorizationModelVersion {
560 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
561 Some(self.cmp(other))
562 }
563}
564
565impl Ord for AuthorizationModelVersion {
566 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
567 (self.major, self.minor).cmp(&(other.major, other.minor))
568 }
569}
570
571impl std::fmt::Display for AuthorizationModelVersion {
572 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
573 write!(f, "{}.{}", self.major, self.minor)
574 }
575}
576
577impl FromStr for AuthorizationModelVersion {
578 type Err = Error;
579
580 fn from_str(s: &str) -> Result<Self> {
581 let parts = s.split('.').collect::<Vec<_>>();
582 if parts.len() != 2 {
583 return Err(Error::InvalidModelVersion(s.to_string()));
584 }
585
586 let major = parts[0]
587 .parse()
588 .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
589 let minor = parts[1]
590 .parse()
591 .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
592
593 Ok(AuthorizationModelVersion::new(major, minor))
594 }
595}
596
597impl<T, S> std::fmt::Debug for Migration<T, S> {
598 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
599 f.debug_struct("Migration")
600 .field("model", &self.model)
601 .field("pre_migration_fn", &"...")
602 .field("post_migration_fn", &"...")
603 .finish()
604 }
605}
606
607#[cfg(test)]
608pub(crate) mod test {
609 use std::sync::Mutex;
610
611 use needs_env_var::needs_env_var;
612 use pretty_assertions::assert_eq;
613
614 use super::*;
615
616 type ChannelTupleManager = TupleModelManager<tonic::transport::Channel, ()>;
617
618 #[test]
619 fn test_ordering() {
620 let versioned_1_0 = AuthorizationModelVersion::new(1, 0);
621 let versioned_1_1 = AuthorizationModelVersion::new(1, 1);
622 let versioned_2_0 = AuthorizationModelVersion::new(2, 0);
623 let versioned_2_1 = AuthorizationModelVersion::new(2, 1);
624 let versioned_2_2 = AuthorizationModelVersion::new(2, 2);
625
626 assert!(versioned_1_0 < versioned_1_1);
627 assert!(versioned_1_1 < versioned_2_0);
628 assert!(versioned_2_0 < versioned_2_1);
629 assert!(versioned_2_1 < versioned_2_2);
630 }
631
632 #[test]
633 fn test_auth_model_version_str() {
634 let version = AuthorizationModelVersion::new(1, 0);
635 assert_eq!(version.to_string(), "1.0");
636 assert_eq!("1.0".parse::<AuthorizationModelVersion>().unwrap(), version);
637
638 let version = AuthorizationModelVersion::new(10, 2);
639 assert_eq!(version.to_string(), "10.2");
640 assert_eq!(
641 "10.2".parse::<AuthorizationModelVersion>().unwrap(),
642 version
643 );
644 }
645
646 #[test]
647 fn test_parse_model_version_from_key() {
648 let model_prefix = "test";
649 let model_version = AuthorizationModelVersion::new(1, 0);
650 let key = format!("model_version:{model_prefix}-{model_version}");
651 assert_eq!(
652 ChannelTupleManager::parse_model_version_from_key(&key, model_prefix),
653 Some(model_version)
654 );
655
656 assert!(ChannelTupleManager::parse_model_version_from_key(
658 "model_version:1.0",
659 model_prefix
660 )
661 .is_none());
662
663 assert!(ChannelTupleManager::parse_model_version_from_key(
665 "model_version:foo-1.0",
666 model_prefix
667 )
668 .is_none());
669
670 assert_eq!(
672 ChannelTupleManager::parse_model_version_from_key(
673 "model_version:other-model-10.200",
674 "other-model"
675 ),
676 Some(AuthorizationModelVersion::new(10, 200))
677 );
678 }
679
680 #[test]
681 fn test_format_model_version_key() {
682 let model_prefix = "test";
683 let model_version = AuthorizationModelVersion::new(1, 0);
684 let key = ChannelTupleManager::format_model_version_key(model_prefix, model_version);
685 assert_eq!(key, "model_version:test-1.0");
686 let parsed = ChannelTupleManager::parse_model_version_from_key(&key, model_prefix).unwrap();
687 assert_eq!(parsed, model_version);
688 }
689
690 #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
691 pub(crate) mod openfga {
692 use std::str::FromStr;
693
694 use pretty_assertions::assert_eq;
695
696 use super::*;
697 use crate::client::{OpenFgaServiceClient, ReadAuthorizationModelRequest};
698
699 pub(crate) async fn get_service_client() -> OpenFgaServiceClient<tonic::transport::Channel>
700 {
701 let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
702 let endpoint = tonic::transport::Endpoint::from_str(&endpoint).unwrap();
703 OpenFgaServiceClient::connect(endpoint)
704 .await
705 .expect("Client can be created")
706 }
707
708 pub(crate) async fn service_client_with_store(
709 ) -> (OpenFgaServiceClient<tonic::transport::Channel>, Store) {
710 let mut client = get_service_client().await;
711 let store_name = format!("test-{}", uuid::Uuid::now_v7());
712 let store = client.get_or_create_store(&store_name).await.unwrap();
713 (client, store)
714 }
715
716 #[tokio::test]
717 async fn test_get_existing_versions_nonexistent_store() {
718 let client = get_service_client().await;
719 let mut manager: TupleModelManager<_, ()> =
720 TupleModelManager::new(client, "nonexistent", "test");
721
722 let versions = manager.get_existing_versions().await.unwrap();
723 assert!(versions.is_empty());
724 }
725
726 #[tokio::test]
727 async fn test_get_existing_versions_nonexistent_auth_model() {
728 let mut client = get_service_client().await;
729 let store_name = format!("test-{}", uuid::Uuid::now_v7());
730 let _store = client.get_or_create_store(&store_name).await.unwrap();
731 let mut manager: TupleModelManager<_, ()> =
732 TupleModelManager::new(client, &store_name, "test");
733 let versions = manager.get_existing_versions().await.unwrap();
734 assert!(versions.is_empty());
735 }
736
737 #[tokio::test]
738 async fn test_get_authorization_model_id() {
739 let (mut client, store) = service_client_with_store().await;
740 let model_prefix = "test";
741 let version = AuthorizationModelVersion::new(1, 0);
742
743 let mut manager: TupleModelManager<_, ()> =
744 TupleModelManager::new(client.clone(), &store.name, model_prefix);
745
746 assert_eq!(
748 manager.get_authorization_model_id(version).await.unwrap(),
749 None
750 );
751
752 let model: AuthorizationModel =
754 serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
755 .unwrap();
756 client
757 .write_authorization_model(model.into_write_request(store.id.clone()))
758 .await
759 .unwrap();
760
761 client
763 .write(WriteRequest {
764 store_id: store.id.clone(),
765 writes: Some(WriteRequestWrites {
766 tuple_keys: vec![
767 TupleKey {
768 user: "auth_model_id:111111".to_string(),
769 relation: "openfga_id".to_string(),
770 object: "model_version:test-1.0".to_string(),
771 condition: None,
772 },
773 TupleKey {
774 user: "auth_model_id:*".to_string(),
775 relation: "exists".to_string(),
776 object: "model_version:test-1.0".to_string(),
777 condition: None,
778 },
779 TupleKey {
781 user: "auth_model_id:*".to_string(),
782 relation: "exists".to_string(),
783 object: "model_version:test2-1.0".to_string(),
784 condition: None,
785 },
786 ],
787 }),
788 deletes: None,
789 authorization_model_id: String::new(),
790 })
791 .await
792 .unwrap();
793
794 assert_eq!(
795 manager.get_authorization_model_id(version).await.unwrap(),
796 Some("111111".to_string())
797 );
798 }
799
800 #[tokio::test]
801 async fn test_model_manager() {
802 let store_name = format!("test-{}", uuid::Uuid::now_v7());
803 let mut client = get_service_client().await;
804
805 let model_1_0: AuthorizationModel =
806 serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
807 .unwrap();
808
809 let version_1_0 = AuthorizationModelVersion::new(1, 0);
810
811 let migration_state = MigrationState::default();
812 let mut manager = TupleModelManager::new(client.clone(), &store_name, "test-model")
813 .add_model(
814 model_1_0.clone(),
815 version_1_0,
816 Some(v1_pre_migration_fn),
817 None::<MigrationFn<_, _>>,
818 );
819 manager.migrate(migration_state.clone()).await.unwrap();
820 assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
822 manager.migrate(migration_state.clone()).await.unwrap();
823 assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
825
826 let auth_model_id = manager
828 .get_authorization_model_id(version_1_0)
829 .await
830 .unwrap()
831 .unwrap();
832 let mut auth_model =
833 get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
834 auth_model.id = model_1_0.id.clone();
835 assert_eq!(
836 serde_json::to_value(&model_1_0).unwrap(),
837 serde_json::to_value(auth_model).unwrap()
838 );
839
840 let model_1_1: AuthorizationModel =
842 serde_json::from_str(include_str!("../tests/model-manager/v1.1/schema.json"))
843 .unwrap();
844 let version_1_1 = AuthorizationModelVersion::new(1, 1);
845 let mut manager = manager.add_model(
846 model_1_1.clone(),
847 version_1_1,
848 None::<MigrationFn<_, _>>,
849 Some(v2_post_migration_fn),
850 );
851 manager.migrate(migration_state.clone()).await.unwrap();
852 manager.migrate(migration_state.clone()).await.unwrap();
853 manager.migrate(migration_state.clone()).await.unwrap();
854
855 assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
857 assert_eq!(*migration_state.counter_2.lock().unwrap(), 1);
859
860 let auth_model_id = manager
862 .get_authorization_model_id(version_1_1)
863 .await
864 .unwrap()
865 .unwrap();
866 let mut auth_model =
867 get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
868 auth_model.id = model_1_1.id.clone();
869 assert_eq!(
870 serde_json::to_value(&model_1_1).unwrap(),
871 serde_json::to_value(auth_model).unwrap()
872 );
873 }
874
875 async fn get_auth_model_by_id(
876 client: &mut OpenFgaServiceClient<tonic::transport::Channel>,
877 store_name: &str,
878 auth_model_id: &str,
879 ) -> AuthorizationModel {
880 client
881 .read_authorization_model(ReadAuthorizationModelRequest {
882 store_id: client
883 .clone()
884 .get_store_by_name(store_name)
885 .await
886 .unwrap()
887 .unwrap()
888 .id,
889 id: auth_model_id.to_string(),
890 })
891 .await
892 .unwrap()
893 .into_inner()
894 .authorization_model
895 .unwrap()
896 }
897 }
898
899 #[derive(Default, Clone)]
900 struct MigrationState {
901 counter_1: Arc<Mutex<i32>>,
902 counter_2: Arc<Mutex<i32>>,
903 }
904
905 #[allow(clippy::unused_async)]
906 async fn v1_pre_migration_fn(
907 client: OpenFgaServiceClient<tonic::transport::Channel>,
908 _prev_model: Option<String>,
909 _curr_model: Option<String>,
910 state: MigrationState,
911 ) -> std::result::Result<(), StdError> {
912 let _ = client;
913 let mut counter = state.counter_1.lock().unwrap();
915 *counter += 1;
916 if *counter == 2 {
917 return Err(Box::new(Error::RequestFailed(Box::new(
918 tonic::Status::new(tonic::Code::Internal, "Test"),
919 ))));
920 }
921 Ok(())
922 }
923
924 #[allow(clippy::unused_async)]
925 async fn v2_post_migration_fn(
926 client: OpenFgaServiceClient<tonic::transport::Channel>,
927 _prev_model: Option<String>,
928 _curr_model: Option<String>,
929 state: MigrationState,
930 ) -> std::result::Result<(), StdError> {
931 let _ = client;
932 let mut counter = state.counter_2.lock().unwrap();
934 *counter += 1;
935 if *counter == 2 {
936 return Err(Box::new(Error::RequestFailed(Box::new(
937 tonic::Status::new(tonic::Code::Internal, "Test"),
938 ))));
939 }
940 Ok(())
941 }
942}