openfga_client/
migration.rs

1use std::{
2    collections::{HashMap, HashSet},
3    future::Future,
4    hash::Hash,
5    pin::Pin,
6    str::FromStr,
7    sync::Arc,
8};
9
10use tonic::codegen::{Body, Bytes, StdError};
11
12use crate::{
13    client::{
14        AuthorizationModel, ConsistencyPreference, OpenFgaServiceClient, ReadRequestTupleKey,
15        Store, Tuple, TupleKey, WriteAuthorizationModelResponse, WriteRequest, WriteRequestWrites,
16    },
17    error::{Error, Result},
18};
19
20const DEFAULT_PAGE_SIZE: i32 = 100;
21const MAX_PAGES: u32 = 1000;
22
23#[derive(Debug, Copy, Clone, PartialEq, Hash, Eq)]
24pub struct AuthorizationModelVersion {
25    major: u32,
26    minor: u32,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30struct VersionedAuthorizationModel {
31    model: AuthorizationModel,
32    version: AuthorizationModelVersion,
33}
34
35/// Manages [`AuthorizationModel`]s in OpenFGA by storing a mapping of their version in the
36/// application to the ID of the model in OpenFGA in a tuple in OpenFGA.
37///
38/// Authorization models in OpenFGA don't receive a unique name. Instead,
39/// they receive a random id on creation. If we don't store this ID, we can't
40/// find the model again and use its ID.
41///
42/// This `ModelManager` stores the mapping of [`AuthorizationModelVersion`]
43/// to the ID of the model in OpenFGA directly inside OpenFGA.
44/// This way can query OpenFGA to determine if a model with a certain version
45/// has already been applied.
46///
47/// The [`TupleModelManager`] will extend provided [`AuthorizationModel`]s with the following
48/// OpenFGA types:
49/// ```text
50/// type auth_model_id
51/// type model_version
52///   relations
53///     define openfga_id: [auth_model_id]
54///     define exists: [auth_model_id:*]
55/// ```
56///
57#[derive(Debug, Clone)]
58pub struct TupleModelManager<T>
59where
60    T: tonic::client::GrpcService<tonic::body::BoxBody>,
61    T::Error: Into<StdError>,
62    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
63    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
64{
65    client: OpenFgaServiceClient<T>,
66    store_name: String,
67    model_prefix: String,
68    migrations: HashMap<AuthorizationModelVersion, Migration<T>>,
69}
70
71#[derive(Clone)]
72struct Migration<T> {
73    model: VersionedAuthorizationModel,
74    pre_migration_fn: Option<BoxedMigrationFn<T>>,
75    post_migration_fn: Option<BoxedMigrationFn<T>>,
76}
77
78// Define a type alias for a boxed future with a specific lifetime
79pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
80
81/// Possible function pointer that implements the migration function signature.
82pub type MigrationFn<T> =
83    fn(OpenFgaServiceClient<T>) -> BoxFuture<'static, std::result::Result<(), StdError>>;
84
85/// Type alias for the migration function signature.
86type DynMigrationFn<T> =
87    dyn Fn(OpenFgaServiceClient<T>) -> BoxFuture<'static, std::result::Result<(), StdError>>;
88
89/// Boxed migration function
90type BoxedMigrationFn<T> = Arc<DynMigrationFn<T>>;
91
92// Function to box the async functions that take an i32 parameter
93fn box_migration_fn<T, F, Fut>(f: F) -> BoxedMigrationFn<T>
94where
95    F: Fn(OpenFgaServiceClient<T>) -> Fut + Send + 'static,
96    Fut: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
97{
98    Arc::new(move |v| Box::pin(f(v)))
99}
100
101// type StaticMigrationFn =
102//     fn(
103//         OpenFgaServiceClient<tonic::transport::Channel>,
104//     ) -> Pin<Box<dyn Future<Output = std::result::Result<(), StdError>> + Send>>;
105
106impl<T> TupleModelManager<T>
107where
108    T: tonic::client::GrpcService<tonic::body::BoxBody>,
109    T: Clone,
110    T::Error: Into<StdError>,
111    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
112    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
113{
114    const AUTH_MODEL_ID_TYPE: &'static str = "auth_model_id";
115    const MODEL_VERSION_EXISTS_RELATION: &'static str = "exists";
116    const MODEL_VERSION_TYPE: &'static str = "model_version";
117    const MODEL_VERSION_OPENFGA_ID_RELATION: &'static str = "openfga_id";
118
119    /// Create a new `TupleModelManager` with the given client and model name.
120    /// The model prefix must not change after the first model has been added or
121    /// the model manager will not be able to find the model again.
122    /// Use different model prefixes if models for different purposes are stored in the
123    /// same OpenFGA store.
124    pub fn new(client: OpenFgaServiceClient<T>, store_name: String, model_prefix: String) -> Self {
125        TupleModelManager {
126            client,
127            model_prefix,
128            store_name,
129            migrations: HashMap::new(),
130        }
131    }
132
133    /// Add a new model to the manager.
134    /// If a model with the same version has already been added, the new model will replace the old one.
135    ///
136    /// Ensure that migration functions are written in an idempotent way.
137    /// If a migration fails, it might be retried.
138    #[must_use]
139    pub fn add_model<FutPre, FutPost>(
140        mut self,
141        model: AuthorizationModel,
142        version: AuthorizationModelVersion,
143        pre_migration_fn: Option<impl Fn(OpenFgaServiceClient<T>) -> FutPre + Send + 'static>,
144        post_migration_fn: Option<impl Fn(OpenFgaServiceClient<T>) -> FutPost + Send + 'static>,
145    ) -> Self
146    where
147        FutPre: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
148        FutPost: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
149    {
150        let migration = Migration {
151            model: VersionedAuthorizationModel::new(model, version),
152            pre_migration_fn: pre_migration_fn.map(box_migration_fn),
153            post_migration_fn: post_migration_fn.map(box_migration_fn),
154        };
155        self.migrations.insert(migration.model.version(), migration);
156        self
157    }
158
159    /// Run migrations.
160    ///
161    /// This will:
162    /// 1. Get all existing models in the OpenFGA store.
163    /// 2. Determine which migrations need to be performed: All migrations with a version higher than the highest existing model.
164    /// 3. In order of the version of the model, perform the migrations:
165    ///    1. Run the pre-migration hook if it exists.
166    ///    2. Write the model to OpenFGA.
167    ///    3. Run the post-migration hook if it exists.
168    /// 4. Mark the model as applied in OpenFGA.
169    ///
170    /// # Errors
171    /// * If OpenFGA cannot be reached or a request fails.
172    /// * If any of the migration hooks fail.
173    pub async fn migrate(&mut self) -> Result<()> {
174        let span = tracing::span!(
175            tracing::Level::INFO,
176            "Running OpenFGA Migrations",
177            store_name = self.store_name,
178            model_prefix = self.model_prefix
179        );
180        let _enter = span.enter();
181
182        if self.migrations.is_empty() {
183            tracing::info!("No Migrations have been added. Nothing to do.");
184            return Ok(());
185        }
186
187        let store = self.client.get_or_create_store(&self.store_name).await?;
188        let existing_models = self.get_existing_versions().await?;
189        let max_existing_model = existing_models.iter().max();
190
191        if let Some(max_existing_model) = max_existing_model {
192            tracing::info!(
193                "Currently the highest existing Model Version is: {}",
194                max_existing_model
195            );
196        } else {
197            tracing::info!("No model found in OpenFGA store");
198        }
199
200        let ordered_migrations = self.migrations_to_perform(max_existing_model.copied());
201
202        let mut client = self.client.clone();
203        for migration in ordered_migrations {
204            tracing::info!("Migrating to model version: {}", migration.model.version());
205
206            // Pre-hook
207            if let Some(pre_migration_fn) = migration.pre_migration_fn.as_ref() {
208                pre_migration_fn(client.clone()).await.map_err(|e| {
209                    tracing::error!("Error in OpenFGA pre-migration hook: {:?}", e);
210                    Error::MigrationHookFailed {
211                        version: migration.model.version().to_string(),
212                        error: Arc::new(e),
213                    }
214                })?;
215            }
216
217            // Write Model
218            let request = migration
219                .model
220                .model()
221                .clone()
222                .into_write_request(store.id.clone());
223            let written_model = client
224                .write_authorization_model(request)
225                .await
226                .map_err(|e| {
227                    tracing::error!("Error writing model: {:?}", e);
228                    Error::RequestFailed(e)
229                })?;
230            tracing::debug!("Model written: {:?}", written_model);
231
232            // Post-hook
233            if let Some(post_migration_fn) = migration.post_migration_fn.as_ref() {
234                post_migration_fn(client.clone()).await.map_err(|e| {
235                    tracing::error!("Error in OpenFGA post-migration hook: {:?}", e);
236                    Error::MigrationHookFailed {
237                        version: migration.model.version().to_string(),
238                        error: Arc::new(e),
239                    }
240                })?;
241            }
242
243            // Mark as applied
244            Self::mark_as_applied(
245                &mut client,
246                &self.model_prefix,
247                &store,
248                migration.model.version(),
249                written_model.into_inner(),
250            )
251            .await?;
252        }
253
254        Ok(())
255    }
256
257    /// Get the OpenFGA Authorization model ID for the specified model version.
258    /// Ensure that migrations have been run before calling this method.
259    ///
260    /// # Errors
261    /// * If the store with the given name does not exist.
262    /// * If a call to OpenFGA fails.
263    pub async fn get_authorization_model_id(
264        &mut self,
265        version: AuthorizationModelVersion,
266    ) -> Result<Option<String>> {
267        let store = self
268            .client
269            .get_store_by_name(&self.store_name)
270            .await?
271            .ok_or_else(|| {
272                tracing::error!("Store with name {} not found", self.store_name);
273                Error::StoreNotFound(self.store_name.clone())
274            })?;
275
276        let applied_models = self
277            .client
278            .read_all_pages(
279                &store.id,
280                ReadRequestTupleKey {
281                    user: String::new(),
282                    relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
283                    object: Self::format_model_version_key(&self.model_prefix, version),
284                },
285                ConsistencyPreference::HigherConsistency,
286                DEFAULT_PAGE_SIZE,
287                MAX_PAGES,
288            )
289            .await?;
290
291        let applied_models = applied_models
292            .into_iter()
293            .filter_map(|t| t.key)
294            .filter_map(|t| {
295                t.user
296                    .strip_prefix(&format!("{}:", Self::AUTH_MODEL_ID_TYPE))
297                    .map(ToString::to_string)
298            })
299            .collect::<Vec<_>>();
300
301        if applied_models.len() > 1 {
302            tracing::error!(
303                "Multiple authorization models with model prefix {} for version {} found.",
304                self.model_prefix,
305                version
306            );
307            return Err(Error::AmbiguousModelVersion {
308                model_prefix: self.model_prefix.clone(),
309                version: version.to_string(),
310            });
311        }
312
313        let model_id = applied_models.into_iter().next().map(|openfga_id| {
314            tracing::info!(
315                "Authorization model for version {version} found in OpenFGA store {}. Model ID: {openfga_id}",
316                self.store_name,
317            );
318            openfga_id
319        });
320
321        Ok(model_id)
322    }
323
324    /// Mark a model version as applied in OpenFGA
325    async fn mark_as_applied(
326        client: &mut OpenFgaServiceClient<T>,
327        model_prefix: &str,
328        store: &Store,
329        version: AuthorizationModelVersion,
330        write_response: WriteAuthorizationModelResponse,
331    ) -> Result<()> {
332        let authorization_model_id = write_response.authorization_model_id;
333        let object = Self::format_model_version_key(model_prefix, version);
334
335        let write_request = WriteRequest {
336            store_id: store.id.clone(),
337            writes: Some(WriteRequestWrites {
338                tuple_keys: vec![
339                    TupleKey {
340                        user: format!("{}:{authorization_model_id}", Self::AUTH_MODEL_ID_TYPE),
341                        relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
342                        object: object.clone(),
343                        condition: None,
344                    },
345                    TupleKey {
346                        user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE),
347                        relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
348                        object,
349                        condition: None,
350                    },
351                ],
352            }),
353            deletes: None,
354            authorization_model_id: authorization_model_id.to_string(),
355        };
356        client.write(write_request.clone()).await.map_err(|e| {
357            tracing::error!("Error marking model as applied: {:?}", e);
358            Error::RequestFailed(e)
359        })?;
360        Ok(())
361    }
362
363    /// Get all migrations that have been added to the manager
364    /// as a `Vec` sorted by the version of the model.
365    fn ordered_migrations(&self) -> Vec<&Migration<T>> {
366        let mut migrations = self.migrations.values().collect::<Vec<_>>();
367        migrations.sort_unstable_by_key(|m| m.model.version());
368        migrations
369    }
370
371    /// Get all migrations that need to be performed, given the maximum existing model version.
372    fn migrations_to_perform(
373        &self,
374        max_existing_model: Option<AuthorizationModelVersion>,
375    ) -> Vec<&Migration<T>> {
376        let ordered_migrations = self.ordered_migrations();
377        let migrations_to_perform = ordered_migrations
378            .into_iter()
379            .filter(|m| {
380                max_existing_model.map_or(true, |max_existing| m.model.version() > max_existing)
381            })
382            .collect::<Vec<_>>();
383
384        tracing::info!(
385            "{} migrations needed in OpenFGA store {} for model-prefix {}",
386            migrations_to_perform.len(),
387            self.store_name,
388            self.model_prefix
389        );
390        migrations_to_perform
391    }
392
393    /// Get versions of all existing models in OpenFGA.
394    /// Returns an empty vector if the store does not exist.
395    ///
396    /// # Errors
397    /// * If the call to determine existing stores fails.
398    /// * If a tuple read call fails.
399    pub async fn get_existing_versions(&mut self) -> Result<Vec<AuthorizationModelVersion>> {
400        let Some(store) = self.client.get_store_by_name(&self.store_name).await? else {
401            return Ok(vec![]);
402        };
403
404        let tuples = self
405            .client
406            .read_all_pages(
407                &store.id,
408                ReadRequestTupleKey {
409                    user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE).to_string(),
410                    relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
411                    object: format!("{}:", Self::MODEL_VERSION_TYPE).to_string(),
412                },
413                crate::client::ConsistencyPreference::HigherConsistency,
414                DEFAULT_PAGE_SIZE,
415                MAX_PAGES,
416            )
417            .await?;
418        let existing_models = Self::parse_existing_models(tuples, &self.model_prefix);
419        Ok(existing_models.into_iter().collect())
420    }
421
422    fn parse_existing_models(
423        exist_tuples: Vec<Tuple>,
424        model_prefix: &str,
425    ) -> HashSet<AuthorizationModelVersion> {
426        exist_tuples
427            .into_iter()
428            .filter_map(|t| t.key)
429            .filter_map(|t| Self::parse_model_version_from_key(&t.object, model_prefix))
430            .collect()
431    }
432
433    fn parse_model_version_from_key(
434        model: &str,
435        model_prefix: &str,
436    ) -> Option<AuthorizationModelVersion> {
437        model
438            // Ignore models with wrong prefix
439            .strip_prefix(&format!("{}:", Self::MODEL_VERSION_TYPE))
440            .and_then(|model| {
441                model
442                    .strip_prefix(&format!("{model_prefix}-"))
443                    .and_then(|version| AuthorizationModelVersion::from_str(version).ok())
444            })
445    }
446
447    fn format_model_version_key(model_prefix: &str, version: AuthorizationModelVersion) -> String {
448        format!("{}:{}-{}", Self::MODEL_VERSION_TYPE, model_prefix, version)
449    }
450}
451
452impl VersionedAuthorizationModel {
453    pub(crate) fn new(model: AuthorizationModel, version: AuthorizationModelVersion) -> Self {
454        VersionedAuthorizationModel { model, version }
455    }
456
457    pub(crate) fn version(&self) -> AuthorizationModelVersion {
458        self.version
459    }
460
461    pub(crate) fn model(&self) -> &AuthorizationModel {
462        &self.model
463    }
464}
465
466impl AuthorizationModelVersion {
467    #[must_use]
468    pub fn new(major: u32, minor: u32) -> Self {
469        AuthorizationModelVersion { major, minor }
470    }
471
472    #[must_use]
473    pub fn major(&self) -> u32 {
474        self.major
475    }
476
477    #[must_use]
478    pub fn minor(&self) -> u32 {
479        self.minor
480    }
481}
482
483// Sort by major version first, then by subversion.
484impl PartialOrd for AuthorizationModelVersion {
485    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
486        Some(self.cmp(other))
487    }
488}
489
490impl Ord for AuthorizationModelVersion {
491    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
492        (self.major, self.minor).cmp(&(other.major, other.minor))
493    }
494}
495
496impl std::fmt::Display for AuthorizationModelVersion {
497    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
498        write!(f, "{}.{}", self.major, self.minor)
499    }
500}
501
502impl FromStr for AuthorizationModelVersion {
503    type Err = Error;
504
505    fn from_str(s: &str) -> Result<Self> {
506        let parts = s.split('.').collect::<Vec<_>>();
507        if parts.len() != 2 {
508            return Err(Error::InvalidModelVersion(s.to_string()));
509        }
510
511        let major = parts[0]
512            .parse()
513            .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
514        let minor = parts[1]
515            .parse()
516            .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
517
518        Ok(AuthorizationModelVersion::new(major, minor))
519    }
520}
521
522impl<T> std::fmt::Debug for Migration<T> {
523    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
524        f.debug_struct("Migration")
525            .field("model", &self.model)
526            .field("pre_migration_fn", &"...")
527            .field("post_migration_fn", &"...")
528            .finish()
529    }
530}
531
532#[cfg(test)]
533mod test {
534    use std::sync::Mutex;
535
536    use needs_env_var::needs_env_var;
537    use pretty_assertions::assert_eq;
538
539    use super::*;
540
541    type ChannelTupleManager = TupleModelManager<tonic::transport::Channel>;
542
543    #[test]
544    fn test_ordering() {
545        let versioned_1_0 = AuthorizationModelVersion::new(1, 0);
546        let versioned_1_1 = AuthorizationModelVersion::new(1, 1);
547        let versioned_2_0 = AuthorizationModelVersion::new(2, 0);
548        let versioned_2_1 = AuthorizationModelVersion::new(2, 1);
549        let versioned_2_2 = AuthorizationModelVersion::new(2, 2);
550
551        assert!(versioned_1_0 < versioned_1_1);
552        assert!(versioned_1_1 < versioned_2_0);
553        assert!(versioned_2_0 < versioned_2_1);
554        assert!(versioned_2_1 < versioned_2_2);
555    }
556
557    #[test]
558    fn test_auth_model_version_str() {
559        let version = AuthorizationModelVersion::new(1, 0);
560        assert_eq!(version.to_string(), "1.0");
561        assert_eq!("1.0".parse::<AuthorizationModelVersion>().unwrap(), version);
562
563        let version = AuthorizationModelVersion::new(10, 2);
564        assert_eq!(version.to_string(), "10.2");
565        assert_eq!(
566            "10.2".parse::<AuthorizationModelVersion>().unwrap(),
567            version
568        );
569    }
570
571    #[test]
572    fn test_parse_model_version_from_key() {
573        let model_prefix = "test";
574        let model_version = AuthorizationModelVersion::new(1, 0);
575        let key = format!("model_version:{model_prefix}-{model_version}");
576        assert_eq!(
577            ChannelTupleManager::parse_model_version_from_key(&key, model_prefix),
578            Some(model_version)
579        );
580
581        // Prefix missing
582        assert!(ChannelTupleManager::parse_model_version_from_key(
583            "model_version:1.0",
584            model_prefix
585        )
586        .is_none());
587
588        // Wrong prefix
589        assert!(ChannelTupleManager::parse_model_version_from_key(
590            "model_version:foo-1.0",
591            model_prefix
592        )
593        .is_none());
594
595        // Higher version
596        assert_eq!(
597            ChannelTupleManager::parse_model_version_from_key(
598                "model_version:other-model-10.200",
599                "other-model"
600            ),
601            Some(AuthorizationModelVersion::new(10, 200))
602        );
603    }
604
605    #[test]
606    fn test_format_model_version_key() {
607        let model_prefix = "test";
608        let model_version = AuthorizationModelVersion::new(1, 0);
609        let key = ChannelTupleManager::format_model_version_key(model_prefix, model_version);
610        assert_eq!(key, "model_version:test-1.0");
611        let parsed = ChannelTupleManager::parse_model_version_from_key(&key, model_prefix).unwrap();
612        assert_eq!(parsed, model_version);
613    }
614
615    #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
616    mod openfga {
617        use std::{str::FromStr, sync::Mutex};
618
619        use pretty_assertions::assert_eq;
620
621        use super::*;
622        use crate::client::{OpenFgaServiceClient, ReadAuthorizationModelRequest};
623
624        async fn get_client() -> OpenFgaServiceClient<tonic::transport::Channel> {
625            let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
626            let endpoint = tonic::transport::Endpoint::from_str(&endpoint).unwrap();
627            OpenFgaServiceClient::connect(endpoint)
628                .await
629                .expect("Client can be created")
630        }
631
632        async fn client_with_store() -> (OpenFgaServiceClient<tonic::transport::Channel>, Store) {
633            let mut client = get_client().await;
634            let store_name = format!("test-{}", uuid::Uuid::now_v7());
635            let store = client.get_or_create_store(&store_name).await.unwrap();
636            (client, store)
637        }
638
639        #[tokio::test]
640        async fn test_get_existing_versions_nonexistent_store() {
641            let client = get_client().await;
642            let mut manager =
643                TupleModelManager::new(client, "nonexistent".to_string(), "test".to_string());
644
645            let versions = manager.get_existing_versions().await.unwrap();
646            assert!(versions.is_empty());
647        }
648
649        #[tokio::test]
650        async fn test_get_existing_versions_nonexistent_auth_model() {
651            let mut client = get_client().await;
652            let store_name = format!("test-{}", uuid::Uuid::now_v7());
653            let _store = client.get_or_create_store(&store_name).await.unwrap();
654            let mut manager = TupleModelManager::new(client, store_name, "test".to_string());
655            let versions = manager.get_existing_versions().await.unwrap();
656            assert!(versions.is_empty());
657        }
658
659        #[tokio::test]
660        async fn test_get_authorization_model_id() {
661            let (mut client, store) = client_with_store().await;
662            let model_prefix = "test";
663            let version = AuthorizationModelVersion::new(1, 0);
664
665            let mut manager = TupleModelManager::new(
666                client.clone(),
667                store.name.clone(),
668                model_prefix.to_string(),
669            );
670
671            // Non-existent model
672            assert_eq!(
673                manager.get_authorization_model_id(version).await.unwrap(),
674                None
675            );
676
677            // Apply auth model
678            let model: AuthorizationModel =
679                serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
680                    .unwrap();
681            client
682                .write_authorization_model(model.into_write_request(store.id.clone()))
683                .await
684                .unwrap();
685
686            // Write model tuples
687            client
688                .write(WriteRequest {
689                    store_id: store.id.clone(),
690                    writes: Some(WriteRequestWrites {
691                        tuple_keys: vec![
692                            TupleKey {
693                                user: "auth_model_id:111111".to_string(),
694                                relation: "openfga_id".to_string(),
695                                object: "model_version:test-1.0".to_string(),
696                                condition: None,
697                            },
698                            TupleKey {
699                                user: "auth_model_id:*".to_string(),
700                                relation: "exists".to_string(),
701                                object: "model_version:test-1.0".to_string(),
702                                condition: None,
703                            },
704                            // Tuple with different model prefix should be ignored
705                            TupleKey {
706                                user: "auth_model_id:*".to_string(),
707                                relation: "exists".to_string(),
708                                object: "model_version:test2-1.0".to_string(),
709                                condition: None,
710                            },
711                        ],
712                    }),
713                    deletes: None,
714                    authorization_model_id: String::new(),
715                })
716                .await
717                .unwrap();
718
719            assert_eq!(
720                manager.get_authorization_model_id(version).await.unwrap(),
721                Some("111111".to_string())
722            );
723        }
724
725        #[tokio::test]
726        async fn test_model_manager() {
727            let store_name = format!("test-{}", uuid::Uuid::now_v7());
728            let mut client = get_client().await;
729
730            let model_1_0: AuthorizationModel =
731                serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
732                    .unwrap();
733
734            let version_1_0 = AuthorizationModelVersion::new(1, 0);
735            let execution_counter_1 = Arc::new(Mutex::new(0));
736
737            let execution_counter_clone = execution_counter_1.clone();
738            let mut manager = TupleModelManager::new(
739                client.clone(),
740                store_name.clone(),
741                "test-model".to_string(),
742            )
743            .add_model(
744                model_1_0.clone(),
745                version_1_0,
746                Some(move |client| {
747                    let counter = execution_counter_clone.clone();
748                    async move { v1_pre_migration_fn(client, counter).await }
749                }),
750                None::<MigrationFn<_>>,
751            );
752            manager.migrate().await.unwrap();
753            // Check hook was called once
754            assert_eq!(*execution_counter_1.lock().unwrap(), 1);
755            manager.migrate().await.unwrap();
756            // Check hook was not called again
757            assert_eq!(*execution_counter_1.lock().unwrap(), 1);
758
759            // Check written model
760            let auth_model_id = manager
761                .get_authorization_model_id(version_1_0)
762                .await
763                .unwrap()
764                .unwrap();
765            let mut auth_model =
766                get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
767            auth_model.id = model_1_0.id.clone();
768            assert_eq!(
769                serde_json::to_value(&model_1_0).unwrap(),
770                serde_json::to_value(auth_model).unwrap()
771            );
772
773            // Add a second model
774            let model_1_1: AuthorizationModel =
775                serde_json::from_str(include_str!("../tests/model-manager/v1.1/schema.json"))
776                    .unwrap();
777            let version_1_1 = AuthorizationModelVersion::new(1, 1);
778            let execution_counter_2 = Arc::new(Mutex::new(0));
779            let execution_counter_clone = execution_counter_2.clone();
780            let mut manager = manager.add_model(
781                model_1_1.clone(),
782                version_1_1,
783                None::<MigrationFn<_>>,
784                Some(move |client| {
785                    let counter = execution_counter_clone.clone();
786                    async move { v2_post_migration_fn(client, counter).await }
787                }),
788            );
789            manager.migrate().await.unwrap();
790            manager.migrate().await.unwrap();
791            manager.migrate().await.unwrap();
792
793            // First migration still only called once
794            assert_eq!(*execution_counter_1.lock().unwrap(), 1);
795            // Second migration called once
796            assert_eq!(*execution_counter_2.lock().unwrap(), 1);
797
798            // Check written model
799            let auth_model_id = manager
800                .get_authorization_model_id(version_1_1)
801                .await
802                .unwrap()
803                .unwrap();
804            let mut auth_model =
805                get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
806            auth_model.id = model_1_1.id.clone();
807            assert_eq!(
808                serde_json::to_value(&model_1_1).unwrap(),
809                serde_json::to_value(auth_model).unwrap()
810            );
811        }
812
813        async fn get_auth_model_by_id(
814            client: &mut OpenFgaServiceClient<tonic::transport::Channel>,
815            store_name: &str,
816            auth_model_id: &str,
817        ) -> AuthorizationModel {
818            client
819                .read_authorization_model(ReadAuthorizationModelRequest {
820                    store_id: client
821                        .clone()
822                        .get_store_by_name(store_name)
823                        .await
824                        .unwrap()
825                        .unwrap()
826                        .id,
827                    id: auth_model_id.to_string(),
828                })
829                .await
830                .unwrap()
831                .into_inner()
832                .authorization_model
833                .unwrap()
834        }
835    }
836
837    async fn v1_pre_migration_fn(
838        client: OpenFgaServiceClient<tonic::transport::Channel>,
839        counter_mutex: Arc<Mutex<i32>>,
840    ) -> std::result::Result<(), StdError> {
841        let _ = client;
842        // Throw an error for the second call
843        let mut counter = counter_mutex.lock().unwrap();
844        *counter += 1;
845        if *counter == 2 {
846            return Err(Box::new(Error::RequestFailed(tonic::Status::new(
847                tonic::Code::Internal,
848                "Test",
849            ))));
850        }
851        Ok(())
852    }
853
854    async fn v2_post_migration_fn(
855        client: OpenFgaServiceClient<tonic::transport::Channel>,
856        counter_mutex: Arc<Mutex<i32>>,
857    ) -> std::result::Result<(), StdError> {
858        let _ = client;
859        // Throw an error for the second call
860        let mut counter = counter_mutex.lock().unwrap();
861        *counter += 1;
862        if *counter == 2 {
863            return Err(Box::new(Error::RequestFailed(tonic::Status::new(
864                tonic::Code::Internal,
865                "Test",
866            ))));
867        }
868        Ok(())
869    }
870}