openfga_client/
migration.rs

1use std::{
2    collections::{HashMap, HashSet},
3    future::Future,
4    hash::Hash,
5    pin::Pin,
6    str::FromStr,
7    sync::Arc,
8};
9
10use tonic::codegen::{Body, Bytes, StdError};
11
12use crate::{
13    client::{
14        AuthorizationModel, ConsistencyPreference, OpenFgaServiceClient, ReadRequestTupleKey,
15        Store, Tuple, TupleKey, WriteAuthorizationModelResponse, WriteRequest, WriteRequestWrites,
16    },
17    error::{Error, Result},
18};
19
20const DEFAULT_PAGE_SIZE: i32 = 100;
21const MAX_PAGES: u32 = 1000;
22
23#[derive(Debug, Copy, Clone, PartialEq, Hash, Eq)]
24pub struct AuthorizationModelVersion {
25    major: u32,
26    minor: u32,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30struct VersionedAuthorizationModel {
31    model: AuthorizationModel,
32    version: AuthorizationModelVersion,
33}
34
35/// Manages [`AuthorizationModel`]s in OpenFGA.
36///
37/// Authorization models in OpenFGA don't receive a unique name. Instead,
38/// they receive a random id on creation. If we don't store this ID, we can't
39/// find the model again and use its ID.
40///
41/// This `ModelManager` stores the mapping of [`AuthorizationModelVersion`]
42/// to the ID of the model in OpenFGA directly inside OpenFGA.
43/// This way can query OpenFGA to determine if a model with a certain version
44/// has already exists.
45///
46/// When running [`TupleModelManager::migrate()`], the manager only applies models and their migrations
47/// if they don't already exist in OpenFGA.
48///
49/// To store the mapping of model versions to OpenFGA IDs, the following needs to part of your Authorization Model:
50/// ```text
51/// type auth_model_id
52/// type model_version
53///   relations
54///     define openfga_id: [auth_model_id]
55///     define exists: [auth_model_id:*]
56/// ```
57///
58#[derive(Debug, Clone)]
59pub struct TupleModelManager<T>
60where
61    T: tonic::client::GrpcService<tonic::body::BoxBody>,
62    T::Error: Into<StdError>,
63    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
64    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
65{
66    client: OpenFgaServiceClient<T>,
67    store_name: String,
68    model_prefix: String,
69    migrations: HashMap<AuthorizationModelVersion, Migration<T>>,
70}
71
72#[derive(Clone)]
73struct Migration<T> {
74    model: VersionedAuthorizationModel,
75    pre_migration_fn: Option<BoxedMigrationFn<T>>,
76    post_migration_fn: Option<BoxedMigrationFn<T>>,
77}
78
79// Define a type alias for a boxed future with a specific lifetime
80pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
81
82/// Possible function pointer that implements the migration function signature.
83pub type MigrationFn<T> =
84    fn(OpenFgaServiceClient<T>) -> BoxFuture<'static, std::result::Result<(), StdError>>;
85
86/// Type alias for the migration function signature.
87type DynMigrationFn<T> =
88    dyn Fn(OpenFgaServiceClient<T>) -> BoxFuture<'static, std::result::Result<(), StdError>>;
89
90/// Boxed migration function
91type BoxedMigrationFn<T> = Arc<DynMigrationFn<T>>;
92
93// Function to box the async functions that take an i32 parameter
94fn box_migration_fn<T, F, Fut>(f: F) -> BoxedMigrationFn<T>
95where
96    F: Fn(OpenFgaServiceClient<T>) -> Fut + Send + 'static,
97    Fut: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
98{
99    Arc::new(move |v| Box::pin(f(v)))
100}
101
102impl<T> TupleModelManager<T>
103where
104    T: tonic::client::GrpcService<tonic::body::BoxBody>,
105    T: Clone,
106    T::Error: Into<StdError>,
107    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
108    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
109{
110    const AUTH_MODEL_ID_TYPE: &'static str = "auth_model_id";
111    const MODEL_VERSION_EXISTS_RELATION: &'static str = "exists";
112    const MODEL_VERSION_TYPE: &'static str = "model_version";
113    const MODEL_VERSION_OPENFGA_ID_RELATION: &'static str = "openfga_id";
114
115    /// Create a new `TupleModelManager` with the given client and model name.
116    /// The model prefix must not change after the first model has been added or
117    /// the model manager will not be able to find the model again.
118    /// Use different model prefixes if models for different purposes are stored in the
119    /// same OpenFGA store.
120    pub fn new(client: OpenFgaServiceClient<T>, store_name: &str, model_prefix: &str) -> Self {
121        TupleModelManager {
122            client,
123            model_prefix: model_prefix.to_string(),
124            store_name: store_name.to_string(),
125            migrations: HashMap::new(),
126        }
127    }
128
129    /// Add a new model to the manager.
130    /// If a model with the same version has already been added, the new model will replace the old one.
131    ///
132    /// Ensure that migration functions are written in an idempotent way.
133    /// If a migration fails, it might be retried.
134    #[must_use]
135    pub fn add_model<FutPre, FutPost>(
136        mut self,
137        model: AuthorizationModel,
138        version: AuthorizationModelVersion,
139        pre_migration_fn: Option<impl Fn(OpenFgaServiceClient<T>) -> FutPre + Send + 'static>,
140        post_migration_fn: Option<impl Fn(OpenFgaServiceClient<T>) -> FutPost + Send + 'static>,
141    ) -> Self
142    where
143        FutPre: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
144        FutPost: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
145    {
146        let migration = Migration {
147            model: VersionedAuthorizationModel::new(model, version),
148            pre_migration_fn: pre_migration_fn.map(box_migration_fn),
149            post_migration_fn: post_migration_fn.map(box_migration_fn),
150        };
151        self.migrations.insert(migration.model.version(), migration);
152        self
153    }
154
155    /// Run migrations.
156    ///
157    /// This will:
158    /// 1. Get all existing models in the OpenFGA store.
159    /// 2. Determine which migrations need to be performed: All migrations with a version higher than the highest existing model.
160    /// 3. In order of the version of the model, perform the migrations:
161    ///    1. Run the pre-migration hook if it exists.
162    ///    2. Write the model to OpenFGA.
163    ///    3. Run the post-migration hook if it exists.
164    /// 4. Mark the model as applied in OpenFGA.
165    ///
166    /// # Errors
167    /// * If OpenFGA cannot be reached or a request fails.
168    /// * If any of the migration hooks fail.
169    pub async fn migrate(&mut self) -> Result<()> {
170        let span = tracing::span!(
171            tracing::Level::INFO,
172            "Running OpenFGA Migrations",
173            store_name = self.store_name,
174            model_prefix = self.model_prefix
175        );
176        let _enter = span.enter();
177
178        if self.migrations.is_empty() {
179            tracing::info!("No Migrations have been added. Nothing to do.");
180            return Ok(());
181        }
182
183        let store = self.client.get_or_create_store(&self.store_name).await?;
184        let existing_models = self.get_existing_versions().await?;
185        let max_existing_model = existing_models.iter().max();
186
187        if let Some(max_existing_model) = max_existing_model {
188            tracing::info!(
189                "Currently the highest existing Model Version is: {}",
190                max_existing_model
191            );
192        } else {
193            tracing::info!("No model found in OpenFGA store");
194        }
195
196        let ordered_migrations = self.migrations_to_perform(max_existing_model.copied());
197
198        let mut client = self.client.clone();
199        for migration in ordered_migrations {
200            tracing::info!("Migrating to model version: {}", migration.model.version());
201
202            // Pre-hook
203            if let Some(pre_migration_fn) = migration.pre_migration_fn.as_ref() {
204                pre_migration_fn(client.clone()).await.map_err(|e| {
205                    tracing::error!("Error in OpenFGA pre-migration hook: {:?}", e);
206                    Error::MigrationHookFailed {
207                        version: migration.model.version().to_string(),
208                        error: Arc::new(e),
209                    }
210                })?;
211            }
212
213            // Write Model
214            let request = migration
215                .model
216                .model()
217                .clone()
218                .into_write_request(store.id.clone());
219            let written_model = client
220                .write_authorization_model(request)
221                .await
222                .map_err(|e| {
223                    tracing::error!("Error writing model: {:?}", e);
224                    Error::RequestFailed(e)
225                })?;
226            tracing::debug!("Model written: {:?}", written_model);
227
228            // Post-hook
229            if let Some(post_migration_fn) = migration.post_migration_fn.as_ref() {
230                post_migration_fn(client.clone()).await.map_err(|e| {
231                    tracing::error!("Error in OpenFGA post-migration hook: {:?}", e);
232                    Error::MigrationHookFailed {
233                        version: migration.model.version().to_string(),
234                        error: Arc::new(e),
235                    }
236                })?;
237            }
238
239            // Mark as applied
240            Self::mark_as_applied(
241                &mut client,
242                &self.model_prefix,
243                &store,
244                migration.model.version(),
245                written_model.into_inner(),
246            )
247            .await?;
248        }
249
250        Ok(())
251    }
252
253    /// Get the OpenFGA Authorization model ID for the specified model version.
254    /// Ensure that migrations have been run before calling this method.
255    ///
256    /// # Errors
257    /// * If the store with the given name does not exist.
258    /// * If a call to OpenFGA fails.
259    pub async fn get_authorization_model_id(
260        &mut self,
261        version: AuthorizationModelVersion,
262    ) -> Result<Option<String>> {
263        let store = self
264            .client
265            .get_store_by_name(&self.store_name)
266            .await?
267            .ok_or_else(|| {
268                tracing::error!("Store with name {} not found", self.store_name);
269                Error::StoreNotFound(self.store_name.clone())
270            })?;
271
272        let applied_models = self
273            .client
274            .read_all_pages(
275                &store.id,
276                ReadRequestTupleKey {
277                    user: String::new(),
278                    relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
279                    object: Self::format_model_version_key(&self.model_prefix, version),
280                },
281                ConsistencyPreference::HigherConsistency,
282                DEFAULT_PAGE_SIZE,
283                MAX_PAGES,
284            )
285            .await?;
286
287        let applied_models = applied_models
288            .into_iter()
289            .filter_map(|t| t.key)
290            .filter_map(|t| {
291                t.user
292                    .strip_prefix(&format!("{}:", Self::AUTH_MODEL_ID_TYPE))
293                    .map(ToString::to_string)
294            })
295            .collect::<Vec<_>>();
296
297        if applied_models.len() > 1 {
298            tracing::error!(
299                "Multiple authorization models with model prefix {} for version {} found.",
300                self.model_prefix,
301                version
302            );
303            return Err(Error::AmbiguousModelVersion {
304                model_prefix: self.model_prefix.clone(),
305                version: version.to_string(),
306            });
307        }
308
309        let model_id = applied_models.into_iter().next().map(|openfga_id| {
310            tracing::info!(
311                "Authorization model for version {version} found in OpenFGA store {}. Model ID: {openfga_id}",
312                self.store_name,
313            );
314            openfga_id
315        });
316
317        Ok(model_id)
318    }
319
320    /// Mark a model version as applied in OpenFGA
321    async fn mark_as_applied(
322        client: &mut OpenFgaServiceClient<T>,
323        model_prefix: &str,
324        store: &Store,
325        version: AuthorizationModelVersion,
326        write_response: WriteAuthorizationModelResponse,
327    ) -> Result<()> {
328        let authorization_model_id = write_response.authorization_model_id;
329        let object = Self::format_model_version_key(model_prefix, version);
330
331        let write_request = WriteRequest {
332            store_id: store.id.clone(),
333            writes: Some(WriteRequestWrites {
334                tuple_keys: vec![
335                    TupleKey {
336                        user: format!("{}:{authorization_model_id}", Self::AUTH_MODEL_ID_TYPE),
337                        relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
338                        object: object.clone(),
339                        condition: None,
340                    },
341                    TupleKey {
342                        user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE),
343                        relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
344                        object,
345                        condition: None,
346                    },
347                ],
348            }),
349            deletes: None,
350            authorization_model_id: authorization_model_id.to_string(),
351        };
352        client.write(write_request.clone()).await.map_err(|e| {
353            tracing::error!("Error marking model as applied: {:?}", e);
354            Error::RequestFailed(e)
355        })?;
356        Ok(())
357    }
358
359    /// Get all migrations that have been added to the manager
360    /// as a `Vec` sorted by the version of the model.
361    fn ordered_migrations(&self) -> Vec<&Migration<T>> {
362        let mut migrations = self.migrations.values().collect::<Vec<_>>();
363        migrations.sort_unstable_by_key(|m| m.model.version());
364        migrations
365    }
366
367    /// Get all migrations that need to be performed, given the maximum existing model version.
368    fn migrations_to_perform(
369        &self,
370        max_existing_model: Option<AuthorizationModelVersion>,
371    ) -> Vec<&Migration<T>> {
372        let ordered_migrations = self.ordered_migrations();
373        let migrations_to_perform = ordered_migrations
374            .into_iter()
375            .filter(|m| {
376                max_existing_model.map_or(true, |max_existing| m.model.version() > max_existing)
377            })
378            .collect::<Vec<_>>();
379
380        tracing::info!(
381            "{} migrations needed in OpenFGA store {} for model-prefix {}",
382            migrations_to_perform.len(),
383            self.store_name,
384            self.model_prefix
385        );
386        migrations_to_perform
387    }
388
389    /// Get versions of all existing models in OpenFGA.
390    /// Returns an empty vector if the store does not exist.
391    ///
392    /// # Errors
393    /// * If the call to determine existing stores fails.
394    /// * If a tuple read call fails.
395    pub async fn get_existing_versions(&mut self) -> Result<Vec<AuthorizationModelVersion>> {
396        let Some(store) = self.client.get_store_by_name(&self.store_name).await? else {
397            return Ok(vec![]);
398        };
399
400        let tuples = self
401            .client
402            .read_all_pages(
403                &store.id,
404                ReadRequestTupleKey {
405                    user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE).to_string(),
406                    relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
407                    object: format!("{}:", Self::MODEL_VERSION_TYPE).to_string(),
408                },
409                crate::client::ConsistencyPreference::HigherConsistency,
410                DEFAULT_PAGE_SIZE,
411                MAX_PAGES,
412            )
413            .await?;
414        let existing_models = Self::parse_existing_models(tuples, &self.model_prefix);
415        Ok(existing_models.into_iter().collect())
416    }
417
418    fn parse_existing_models(
419        exist_tuples: Vec<Tuple>,
420        model_prefix: &str,
421    ) -> HashSet<AuthorizationModelVersion> {
422        exist_tuples
423            .into_iter()
424            .filter_map(|t| t.key)
425            .filter_map(|t| Self::parse_model_version_from_key(&t.object, model_prefix))
426            .collect()
427    }
428
429    fn parse_model_version_from_key(
430        model: &str,
431        model_prefix: &str,
432    ) -> Option<AuthorizationModelVersion> {
433        model
434            // Ignore models with wrong prefix
435            .strip_prefix(&format!("{}:", Self::MODEL_VERSION_TYPE))
436            .and_then(|model| {
437                model
438                    .strip_prefix(&format!("{model_prefix}-"))
439                    .and_then(|version| AuthorizationModelVersion::from_str(version).ok())
440            })
441    }
442
443    fn format_model_version_key(model_prefix: &str, version: AuthorizationModelVersion) -> String {
444        format!("{}:{}-{}", Self::MODEL_VERSION_TYPE, model_prefix, version)
445    }
446}
447
448impl VersionedAuthorizationModel {
449    pub(crate) fn new(model: AuthorizationModel, version: AuthorizationModelVersion) -> Self {
450        VersionedAuthorizationModel { model, version }
451    }
452
453    pub(crate) fn version(&self) -> AuthorizationModelVersion {
454        self.version
455    }
456
457    pub(crate) fn model(&self) -> &AuthorizationModel {
458        &self.model
459    }
460}
461
462impl AuthorizationModelVersion {
463    #[must_use]
464    pub fn new(major: u32, minor: u32) -> Self {
465        AuthorizationModelVersion { major, minor }
466    }
467
468    #[must_use]
469    pub fn major(&self) -> u32 {
470        self.major
471    }
472
473    #[must_use]
474    pub fn minor(&self) -> u32 {
475        self.minor
476    }
477}
478
479// Sort by major version first, then by subversion.
480impl PartialOrd for AuthorizationModelVersion {
481    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
482        Some(self.cmp(other))
483    }
484}
485
486impl Ord for AuthorizationModelVersion {
487    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
488        (self.major, self.minor).cmp(&(other.major, other.minor))
489    }
490}
491
492impl std::fmt::Display for AuthorizationModelVersion {
493    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        write!(f, "{}.{}", self.major, self.minor)
495    }
496}
497
498impl FromStr for AuthorizationModelVersion {
499    type Err = Error;
500
501    fn from_str(s: &str) -> Result<Self> {
502        let parts = s.split('.').collect::<Vec<_>>();
503        if parts.len() != 2 {
504            return Err(Error::InvalidModelVersion(s.to_string()));
505        }
506
507        let major = parts[0]
508            .parse()
509            .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
510        let minor = parts[1]
511            .parse()
512            .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
513
514        Ok(AuthorizationModelVersion::new(major, minor))
515    }
516}
517
518impl<T> std::fmt::Debug for Migration<T> {
519    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520        f.debug_struct("Migration")
521            .field("model", &self.model)
522            .field("pre_migration_fn", &"...")
523            .field("post_migration_fn", &"...")
524            .finish()
525    }
526}
527
528#[cfg(test)]
529pub(crate) mod test {
530    use std::sync::Mutex;
531
532    use needs_env_var::needs_env_var;
533    use pretty_assertions::assert_eq;
534
535    use super::*;
536
537    type ChannelTupleManager = TupleModelManager<tonic::transport::Channel>;
538
539    #[test]
540    fn test_ordering() {
541        let versioned_1_0 = AuthorizationModelVersion::new(1, 0);
542        let versioned_1_1 = AuthorizationModelVersion::new(1, 1);
543        let versioned_2_0 = AuthorizationModelVersion::new(2, 0);
544        let versioned_2_1 = AuthorizationModelVersion::new(2, 1);
545        let versioned_2_2 = AuthorizationModelVersion::new(2, 2);
546
547        assert!(versioned_1_0 < versioned_1_1);
548        assert!(versioned_1_1 < versioned_2_0);
549        assert!(versioned_2_0 < versioned_2_1);
550        assert!(versioned_2_1 < versioned_2_2);
551    }
552
553    #[test]
554    fn test_auth_model_version_str() {
555        let version = AuthorizationModelVersion::new(1, 0);
556        assert_eq!(version.to_string(), "1.0");
557        assert_eq!("1.0".parse::<AuthorizationModelVersion>().unwrap(), version);
558
559        let version = AuthorizationModelVersion::new(10, 2);
560        assert_eq!(version.to_string(), "10.2");
561        assert_eq!(
562            "10.2".parse::<AuthorizationModelVersion>().unwrap(),
563            version
564        );
565    }
566
567    #[test]
568    fn test_parse_model_version_from_key() {
569        let model_prefix = "test";
570        let model_version = AuthorizationModelVersion::new(1, 0);
571        let key = format!("model_version:{model_prefix}-{model_version}");
572        assert_eq!(
573            ChannelTupleManager::parse_model_version_from_key(&key, model_prefix),
574            Some(model_version)
575        );
576
577        // Prefix missing
578        assert!(ChannelTupleManager::parse_model_version_from_key(
579            "model_version:1.0",
580            model_prefix
581        )
582        .is_none());
583
584        // Wrong prefix
585        assert!(ChannelTupleManager::parse_model_version_from_key(
586            "model_version:foo-1.0",
587            model_prefix
588        )
589        .is_none());
590
591        // Higher version
592        assert_eq!(
593            ChannelTupleManager::parse_model_version_from_key(
594                "model_version:other-model-10.200",
595                "other-model"
596            ),
597            Some(AuthorizationModelVersion::new(10, 200))
598        );
599    }
600
601    #[test]
602    fn test_format_model_version_key() {
603        let model_prefix = "test";
604        let model_version = AuthorizationModelVersion::new(1, 0);
605        let key = ChannelTupleManager::format_model_version_key(model_prefix, model_version);
606        assert_eq!(key, "model_version:test-1.0");
607        let parsed = ChannelTupleManager::parse_model_version_from_key(&key, model_prefix).unwrap();
608        assert_eq!(parsed, model_version);
609    }
610
611    #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
612    pub(crate) mod openfga {
613        use std::{str::FromStr, sync::Mutex};
614
615        use pretty_assertions::assert_eq;
616
617        use super::*;
618        use crate::client::{OpenFgaServiceClient, ReadAuthorizationModelRequest};
619
620        pub(crate) async fn get_service_client() -> OpenFgaServiceClient<tonic::transport::Channel>
621        {
622            let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
623            let endpoint = tonic::transport::Endpoint::from_str(&endpoint).unwrap();
624            OpenFgaServiceClient::connect(endpoint)
625                .await
626                .expect("Client can be created")
627        }
628
629        pub(crate) async fn service_client_with_store(
630        ) -> (OpenFgaServiceClient<tonic::transport::Channel>, Store) {
631            let mut client = get_service_client().await;
632            let store_name = format!("test-{}", uuid::Uuid::now_v7());
633            let store = client.get_or_create_store(&store_name).await.unwrap();
634            (client, store)
635        }
636
637        #[tokio::test]
638        async fn test_get_existing_versions_nonexistent_store() {
639            let client = get_service_client().await;
640            let mut manager = TupleModelManager::new(client, "nonexistent", "test");
641
642            let versions = manager.get_existing_versions().await.unwrap();
643            assert!(versions.is_empty());
644        }
645
646        #[tokio::test]
647        async fn test_get_existing_versions_nonexistent_auth_model() {
648            let mut client = get_service_client().await;
649            let store_name = format!("test-{}", uuid::Uuid::now_v7());
650            let _store = client.get_or_create_store(&store_name).await.unwrap();
651            let mut manager = TupleModelManager::new(client, &store_name, "test");
652            let versions = manager.get_existing_versions().await.unwrap();
653            assert!(versions.is_empty());
654        }
655
656        #[tokio::test]
657        async fn test_get_authorization_model_id() {
658            let (mut client, store) = service_client_with_store().await;
659            let model_prefix = "test";
660            let version = AuthorizationModelVersion::new(1, 0);
661
662            let mut manager = TupleModelManager::new(client.clone(), &store.name, model_prefix);
663
664            // Non-existent model
665            assert_eq!(
666                manager.get_authorization_model_id(version).await.unwrap(),
667                None
668            );
669
670            // Apply auth model
671            let model: AuthorizationModel =
672                serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
673                    .unwrap();
674            client
675                .write_authorization_model(model.into_write_request(store.id.clone()))
676                .await
677                .unwrap();
678
679            // Write model tuples
680            client
681                .write(WriteRequest {
682                    store_id: store.id.clone(),
683                    writes: Some(WriteRequestWrites {
684                        tuple_keys: vec![
685                            TupleKey {
686                                user: "auth_model_id:111111".to_string(),
687                                relation: "openfga_id".to_string(),
688                                object: "model_version:test-1.0".to_string(),
689                                condition: None,
690                            },
691                            TupleKey {
692                                user: "auth_model_id:*".to_string(),
693                                relation: "exists".to_string(),
694                                object: "model_version:test-1.0".to_string(),
695                                condition: None,
696                            },
697                            // Tuple with different model prefix should be ignored
698                            TupleKey {
699                                user: "auth_model_id:*".to_string(),
700                                relation: "exists".to_string(),
701                                object: "model_version:test2-1.0".to_string(),
702                                condition: None,
703                            },
704                        ],
705                    }),
706                    deletes: None,
707                    authorization_model_id: String::new(),
708                })
709                .await
710                .unwrap();
711
712            assert_eq!(
713                manager.get_authorization_model_id(version).await.unwrap(),
714                Some("111111".to_string())
715            );
716        }
717
718        #[tokio::test]
719        async fn test_model_manager() {
720            let store_name = format!("test-{}", uuid::Uuid::now_v7());
721            let mut client = get_service_client().await;
722
723            let model_1_0: AuthorizationModel =
724                serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
725                    .unwrap();
726
727            let version_1_0 = AuthorizationModelVersion::new(1, 0);
728            let execution_counter_1 = Arc::new(Mutex::new(0));
729
730            let execution_counter_clone = execution_counter_1.clone();
731            let mut manager = TupleModelManager::new(client.clone(), &store_name, "test-model")
732                .add_model(
733                    model_1_0.clone(),
734                    version_1_0,
735                    Some(move |client| {
736                        let counter = execution_counter_clone.clone();
737                        async move { v1_pre_migration_fn(client, counter).await }
738                    }),
739                    None::<MigrationFn<_>>,
740                );
741            manager.migrate().await.unwrap();
742            // Check hook was called once
743            assert_eq!(*execution_counter_1.lock().unwrap(), 1);
744            manager.migrate().await.unwrap();
745            // Check hook was not called again
746            assert_eq!(*execution_counter_1.lock().unwrap(), 1);
747
748            // Check written model
749            let auth_model_id = manager
750                .get_authorization_model_id(version_1_0)
751                .await
752                .unwrap()
753                .unwrap();
754            let mut auth_model =
755                get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
756            auth_model.id = model_1_0.id.clone();
757            assert_eq!(
758                serde_json::to_value(&model_1_0).unwrap(),
759                serde_json::to_value(auth_model).unwrap()
760            );
761
762            // Add a second model
763            let model_1_1: AuthorizationModel =
764                serde_json::from_str(include_str!("../tests/model-manager/v1.1/schema.json"))
765                    .unwrap();
766            let version_1_1 = AuthorizationModelVersion::new(1, 1);
767            let execution_counter_2 = Arc::new(Mutex::new(0));
768            let execution_counter_clone = execution_counter_2.clone();
769            let mut manager = manager.add_model(
770                model_1_1.clone(),
771                version_1_1,
772                None::<MigrationFn<_>>,
773                Some(move |client| {
774                    let counter = execution_counter_clone.clone();
775                    async move { v2_post_migration_fn(client, counter).await }
776                }),
777            );
778            manager.migrate().await.unwrap();
779            manager.migrate().await.unwrap();
780            manager.migrate().await.unwrap();
781
782            // First migration still only called once
783            assert_eq!(*execution_counter_1.lock().unwrap(), 1);
784            // Second migration called once
785            assert_eq!(*execution_counter_2.lock().unwrap(), 1);
786
787            // Check written model
788            let auth_model_id = manager
789                .get_authorization_model_id(version_1_1)
790                .await
791                .unwrap()
792                .unwrap();
793            let mut auth_model =
794                get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
795            auth_model.id = model_1_1.id.clone();
796            assert_eq!(
797                serde_json::to_value(&model_1_1).unwrap(),
798                serde_json::to_value(auth_model).unwrap()
799            );
800        }
801
802        async fn get_auth_model_by_id(
803            client: &mut OpenFgaServiceClient<tonic::transport::Channel>,
804            store_name: &str,
805            auth_model_id: &str,
806        ) -> AuthorizationModel {
807            client
808                .read_authorization_model(ReadAuthorizationModelRequest {
809                    store_id: client
810                        .clone()
811                        .get_store_by_name(store_name)
812                        .await
813                        .unwrap()
814                        .unwrap()
815                        .id,
816                    id: auth_model_id.to_string(),
817                })
818                .await
819                .unwrap()
820                .into_inner()
821                .authorization_model
822                .unwrap()
823        }
824    }
825
826    #[allow(clippy::unused_async)]
827    async fn v1_pre_migration_fn(
828        client: OpenFgaServiceClient<tonic::transport::Channel>,
829        counter_mutex: Arc<Mutex<i32>>,
830    ) -> std::result::Result<(), StdError> {
831        let _ = client;
832        // Throw an error for the second call
833        let mut counter = counter_mutex.lock().unwrap();
834        *counter += 1;
835        if *counter == 2 {
836            return Err(Box::new(Error::RequestFailed(tonic::Status::new(
837                tonic::Code::Internal,
838                "Test",
839            ))));
840        }
841        Ok(())
842    }
843
844    #[allow(clippy::unused_async)]
845    async fn v2_post_migration_fn(
846        client: OpenFgaServiceClient<tonic::transport::Channel>,
847        counter_mutex: Arc<Mutex<i32>>,
848    ) -> std::result::Result<(), StdError> {
849        let _ = client;
850        // Throw an error for the second call
851        let mut counter = counter_mutex.lock().unwrap();
852        *counter += 1;
853        if *counter == 2 {
854            return Err(Box::new(Error::RequestFailed(tonic::Status::new(
855                tonic::Code::Internal,
856                "Test",
857            ))));
858        }
859        Ok(())
860    }
861}