openfga_client/
migration.rs

1use std::{
2    collections::{HashMap, HashSet},
3    future::Future,
4    hash::Hash,
5    pin::Pin,
6    str::FromStr,
7    sync::Arc,
8};
9
10use tonic::codegen::{Body, Bytes, StdError};
11
12use crate::{
13    client::{
14        AuthorizationModel, ConsistencyPreference, OpenFgaServiceClient, ReadRequestTupleKey,
15        Store, Tuple, TupleKey, WriteAuthorizationModelResponse, WriteRequest, WriteRequestWrites,
16    },
17    error::{Error, Result},
18};
19
20const DEFAULT_PAGE_SIZE: i32 = 100;
21const MAX_PAGES: u32 = 1000;
22
23#[derive(Debug, Copy, Clone, PartialEq, Hash, Eq)]
24pub struct AuthorizationModelVersion {
25    major: u32,
26    minor: u32,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30struct VersionedAuthorizationModel {
31    model: AuthorizationModel,
32    version: AuthorizationModelVersion,
33}
34
35/// Manages [`AuthorizationModel`]s in OpenFGA.
36///
37/// Authorization models in OpenFGA don't receive a unique name. Instead,
38/// they receive a random id on creation. If we don't store this ID, we can't
39/// find the model again and use its ID.
40///
41/// This `ModelManager` stores the mapping of [`AuthorizationModelVersion`]
42/// to the ID of the model in OpenFGA directly inside OpenFGA.
43/// This way can query OpenFGA to determine if a model with a certain version
44/// has already exists.
45///
46/// When running [`TupleModelManager::migrate()`], the manager only applies models and their migrations
47/// if they don't already exist in OpenFGA.
48///
49/// To store the mapping of model versions to OpenFGA IDs, the following needs to part of your Authorization Model:
50/// ```text
51/// type auth_model_id
52/// type model_version
53///   relations
54///     define openfga_id: [auth_model_id]
55///     define exists: [auth_model_id:*]
56/// ```
57///
58#[derive(Debug, Clone)]
59pub struct TupleModelManager<T, S>
60where
61    T: tonic::client::GrpcService<tonic::body::BoxBody>,
62    T::Error: Into<StdError>,
63    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
64    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
65{
66    client: OpenFgaServiceClient<T>,
67    store_name: String,
68    model_prefix: String,
69    migrations: HashMap<AuthorizationModelVersion, Migration<T, S>>,
70}
71
72#[derive(Clone)]
73struct Migration<T, S> {
74    /// The model being migrated to.
75    ///
76    /// If this has value `vX`, the active model *after* the migration will be `vX`.
77    model: VersionedAuthorizationModel,
78    pre_migration_fn: Option<BoxedMigrationFn<T, S>>,
79    post_migration_fn: Option<BoxedMigrationFn<T, S>>,
80}
81
82// Define a type alias for a boxed future with a specific lifetime
83pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
84
85/// Possible function pointer that implements the migration function signature.
86///
87/// The arguments are:
88///
89/// 1) A client connected to endpoint that's being migrated
90/// 2) The id of the authorization model active prior to the invocation of the function
91/// 3) The id of the authorization model active when the function is invoked
92/// 4) The (user defined) state passed into the migration function
93///
94/// Authorization model ids may be undefined (`None`), e.g. for the pre hook of the first
95/// migration neither a previous nor a current model exist. For the first migration run by
96/// [`TupleModelManager::migrate`], the previous model id is always `None`.
97pub type MigrationFn<T, S> = fn(
98    OpenFgaServiceClient<T>,
99    Option<String>,
100    Option<String>,
101    S,
102) -> BoxFuture<'static, std::result::Result<(), StdError>>;
103
104/// Type alias for the migration function signature.
105type DynMigrationFn<T, S> = dyn Fn(
106    OpenFgaServiceClient<T>,
107    Option<String>,
108    Option<String>,
109    S,
110) -> BoxFuture<'static, std::result::Result<(), StdError>>;
111
112/// Boxed migration function
113type BoxedMigrationFn<T, S> = Arc<DynMigrationFn<T, S>>;
114
115// Function to box the async functions that take an i32 parameter
116fn box_migration_fn<T, S, F, Fut>(f: F) -> BoxedMigrationFn<T, S>
117where
118    F: Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> Fut + Send + 'static,
119    Fut: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
120{
121    Arc::new(
122        move |client, prev_auth_model_id, active_auth_model_id, state| {
123            Box::pin(f(client, prev_auth_model_id, active_auth_model_id, state))
124        },
125    )
126}
127
128impl<T, S> TupleModelManager<T, S>
129where
130    T: tonic::client::GrpcService<tonic::body::BoxBody>,
131    T: Clone,
132    T::Error: Into<StdError>,
133    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
134    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
135    S: Clone,
136{
137    const AUTH_MODEL_ID_TYPE: &'static str = "auth_model_id";
138    const MODEL_VERSION_EXISTS_RELATION: &'static str = "exists";
139    const MODEL_VERSION_TYPE: &'static str = "model_version";
140    const MODEL_VERSION_OPENFGA_ID_RELATION: &'static str = "openfga_id";
141
142    /// Create a new `TupleModelManager` with the given client and model name.
143    /// The model prefix must not change after the first model has been added or
144    /// the model manager will not be able to find the model again.
145    /// Use different model prefixes if models for different purposes are stored in the
146    /// same OpenFGA store.
147    pub fn new(client: OpenFgaServiceClient<T>, store_name: &str, model_prefix: &str) -> Self {
148        TupleModelManager {
149            client,
150            model_prefix: model_prefix.to_string(),
151            store_name: store_name.to_string(),
152            migrations: HashMap::new(),
153        }
154    }
155
156    /// Add a new model to the manager.
157    /// If a model with the same version has already been added, the new model will replace the old one.
158    ///
159    /// Ensure that migration functions are written in an idempotent way.
160    /// If a migration fails, it might be retried.
161    #[must_use]
162    pub fn add_model<FutPre, FutPost>(
163        mut self,
164        model: AuthorizationModel,
165        version: AuthorizationModelVersion,
166        pre_migration_fn: Option<
167            impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPre
168                + Send
169                + 'static,
170        >,
171        post_migration_fn: Option<
172            impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPost
173                + Send
174                + 'static,
175        >,
176    ) -> Self
177    where
178        FutPre: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
179        FutPost: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
180    {
181        let migration = Migration {
182            model: VersionedAuthorizationModel::new(model, version),
183            pre_migration_fn: pre_migration_fn.map(box_migration_fn),
184            post_migration_fn: post_migration_fn.map(box_migration_fn),
185        };
186        self.migrations.insert(migration.model.version(), migration);
187        self
188    }
189
190    /// Run migrations.
191    ///
192    /// This will:
193    /// 1. Get all existing models in the OpenFGA store.
194    /// 2. Determine which migrations need to be performed: All migrations with a version higher than the highest existing model.
195    /// 3. In order of the version of the model, perform the migrations:
196    ///    1. Run the pre-migration hook if it exists.
197    ///    2. Write the model to OpenFGA.
198    ///    3. Run the post-migration hook if it exists.
199    /// 4. Mark the model as applied in OpenFGA.
200    ///
201    /// # Errors
202    /// * If OpenFGA cannot be reached or a request fails.
203    /// * If any of the migration hooks fail.
204    #[allow(clippy::too_many_lines)]
205    pub async fn migrate(&mut self, state: S) -> Result<()> {
206        let span = tracing::span!(
207            tracing::Level::INFO,
208            "Running OpenFGA Migrations",
209            store_name = self.store_name,
210            model_prefix = self.model_prefix
211        );
212        let _enter = span.enter();
213
214        if self.migrations.is_empty() {
215            tracing::info!("No Migrations have been added. Nothing to do.");
216            return Ok(());
217        }
218
219        let store = self.client.get_or_create_store(&self.store_name).await?;
220        let existing_models = self.get_existing_versions().await?;
221        let max_existing_model = existing_models.iter().max().copied();
222        let mut curr_model_id = if let Some(version) = max_existing_model {
223            Some(self.require_authorization_model_id(version).await?)
224        } else {
225            None
226        };
227        // At this point the previous model id cannot be determined reliably, e.g. because
228        // migrations might have been run out of order.
229        let mut prev_model_id = None;
230
231        if let Some(max_existing_model) = max_existing_model {
232            tracing::info!(
233                "Currently the highest existing Model Version is: {}",
234                max_existing_model
235            );
236        } else {
237            tracing::info!("No model found in OpenFGA store");
238        }
239
240        let ordered_migrations = self.migrations_to_perform(max_existing_model);
241
242        let mut client = self.client.clone();
243        for migration in ordered_migrations {
244            tracing::info!("Migrating to model version: {}", migration.model.version());
245
246            // Pre-hook
247            if let Some(pre_migration_fn) = migration.pre_migration_fn.as_ref() {
248                pre_migration_fn(
249                    client.clone(),
250                    prev_model_id.clone(),
251                    curr_model_id.clone(),
252                    state.clone(),
253                )
254                .await
255                .map_err(|e| {
256                    tracing::error!("Error in OpenFGA pre-migration hook: {:?}", e);
257                    Error::MigrationHookFailed {
258                        version: migration.model.version().to_string(),
259                        error: Arc::new(e),
260                    }
261                })?;
262            }
263
264            // Write Model
265            let request = migration
266                .model
267                .model()
268                .clone()
269                .into_write_request(store.id.clone());
270            let written_model = client
271                .write_authorization_model(request)
272                .await
273                .map_err(|e| {
274                    tracing::error!("Error writing model: {:?}", e);
275                    Error::RequestFailed(Box::new(e))
276                })?;
277            tracing::debug!("Model written: {:?}", written_model);
278
279            // Update model versions passed to migration hooks.
280            prev_model_id.clone_from(&curr_model_id);
281            curr_model_id = Some(written_model.get_ref().authorization_model_id.to_string());
282
283            // Post-hook
284            if let Some(post_migration_fn) = migration.post_migration_fn.as_ref() {
285                post_migration_fn(
286                    client.clone(),
287                    prev_model_id.clone(),
288                    curr_model_id.clone(),
289                    state.clone(),
290                )
291                .await
292                .map_err(|e| {
293                    tracing::error!("Error in OpenFGA post-migration hook: {:?}", e);
294                    Error::MigrationHookFailed {
295                        version: migration.model.version().to_string(),
296                        error: Arc::new(e),
297                    }
298                })?;
299            }
300
301            // Mark as applied
302            Self::mark_as_applied(
303                &mut client,
304                &self.model_prefix,
305                &store,
306                migration.model.version(),
307                written_model.into_inner(),
308            )
309            .await?;
310        }
311
312        Ok(())
313    }
314
315    /// Get the OpenFGA Authorization model ID for the specified model version.
316    /// Ensure that migrations have been run before calling this method.
317    ///
318    /// # Errors
319    /// * If the store with the given name does not exist.
320    /// * If a call to OpenFGA fails.
321    pub async fn get_authorization_model_id(
322        &mut self,
323        version: AuthorizationModelVersion,
324    ) -> Result<Option<String>> {
325        let store = self
326            .client
327            .get_store_by_name(&self.store_name)
328            .await?
329            .ok_or_else(|| {
330                tracing::error!("Store with name {} not found", self.store_name);
331                Error::StoreNotFound(self.store_name.clone())
332            })?;
333
334        let applied_models = self
335            .client
336            .read_all_pages(
337                &store.id,
338                Some(ReadRequestTupleKey {
339                    user: String::new(),
340                    relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
341                    object: Self::format_model_version_key(&self.model_prefix, version),
342                }),
343                ConsistencyPreference::HigherConsistency,
344                DEFAULT_PAGE_SIZE,
345                MAX_PAGES,
346            )
347            .await?;
348
349        let applied_models = applied_models
350            .into_iter()
351            .filter_map(|t| t.key)
352            .filter_map(|t| {
353                t.user
354                    .strip_prefix(&format!("{}:", Self::AUTH_MODEL_ID_TYPE))
355                    .map(ToString::to_string)
356            })
357            .collect::<Vec<_>>();
358
359        if applied_models.len() > 1 {
360            tracing::error!(
361                "Multiple authorization models with model prefix {} for version {} found.",
362                self.model_prefix,
363                version
364            );
365            return Err(Error::AmbiguousModelVersion {
366                model_prefix: self.model_prefix.clone(),
367                version: version.to_string(),
368            });
369        }
370
371        let model_id = applied_models.into_iter().next().map(|openfga_id| {
372            tracing::info!(
373                "Authorization model for version {version} found in OpenFGA store {}. Model ID: {openfga_id}",
374                self.store_name,
375            );
376            openfga_id
377        });
378
379        Ok(model_id)
380    }
381
382    /// Helper method that tries to get the authorization model id for `version` and returns
383    /// an error if the id cannot be found.
384    async fn require_authorization_model_id(
385        &mut self,
386        version: AuthorizationModelVersion,
387    ) -> Result<String> {
388        self.get_authorization_model_id(version)
389            .await?
390            .ok_or_else(|| {
391                tracing::error!("Missing authorization model id for model version {version}");
392                Error::MissingAuthorizationModelId {
393                    model_prefix: self.model_prefix.clone(),
394                    version: version.to_string(),
395                }
396            })
397    }
398
399    /// Mark a model version as applied in OpenFGA
400    async fn mark_as_applied(
401        client: &mut OpenFgaServiceClient<T>,
402        model_prefix: &str,
403        store: &Store,
404        version: AuthorizationModelVersion,
405        write_response: WriteAuthorizationModelResponse,
406    ) -> Result<()> {
407        let authorization_model_id = write_response.authorization_model_id;
408        let object = Self::format_model_version_key(model_prefix, version);
409
410        let write_request = WriteRequest {
411            store_id: store.id.clone(),
412            writes: Some(WriteRequestWrites {
413                tuple_keys: vec![
414                    TupleKey {
415                        user: format!("{}:{authorization_model_id}", Self::AUTH_MODEL_ID_TYPE),
416                        relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
417                        object: object.clone(),
418                        condition: None,
419                    },
420                    TupleKey {
421                        user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE),
422                        relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
423                        object,
424                        condition: None,
425                    },
426                ],
427            }),
428            deletes: None,
429            authorization_model_id: authorization_model_id.to_string(),
430        };
431        client.write(write_request.clone()).await.map_err(|e| {
432            tracing::error!("Error marking model as applied: {:?}", e);
433            Error::RequestFailed(Box::new(e))
434        })?;
435        Ok(())
436    }
437
438    /// Get all migrations that have been added to the manager
439    /// as a `Vec` sorted by the version of the model.
440    fn ordered_migrations(&self) -> Vec<&Migration<T, S>> {
441        let mut migrations = self.migrations.values().collect::<Vec<_>>();
442        migrations.sort_unstable_by_key(|m| m.model.version());
443        migrations
444    }
445
446    /// Get all migrations that need to be performed, given the maximum existing model version.
447    fn migrations_to_perform(
448        &self,
449        max_existing_model: Option<AuthorizationModelVersion>,
450    ) -> Vec<&Migration<T, S>> {
451        let ordered_migrations = self.ordered_migrations();
452        let migrations_to_perform = ordered_migrations
453            .into_iter()
454            .filter(|m| {
455                max_existing_model.is_none_or(|max_existing| m.model.version() > max_existing)
456            })
457            .collect::<Vec<_>>();
458
459        tracing::info!(
460            "{} migrations needed in OpenFGA store {} for model-prefix {}",
461            migrations_to_perform.len(),
462            self.store_name,
463            self.model_prefix
464        );
465        migrations_to_perform
466    }
467
468    /// Get versions of all existing models in OpenFGA.
469    /// Returns an empty vector if the store does not exist.
470    ///
471    /// # Errors
472    /// * If the call to determine existing stores fails.
473    /// * If a tuple read call fails.
474    pub async fn get_existing_versions(&mut self) -> Result<Vec<AuthorizationModelVersion>> {
475        let Some(store) = self.client.get_store_by_name(&self.store_name).await? else {
476            return Ok(vec![]);
477        };
478
479        let tuples = self
480            .client
481            .read_all_pages(
482                &store.id,
483                Some(ReadRequestTupleKey {
484                    user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE).to_string(),
485                    relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
486                    object: format!("{}:", Self::MODEL_VERSION_TYPE).to_string(),
487                }),
488                crate::client::ConsistencyPreference::HigherConsistency,
489                DEFAULT_PAGE_SIZE,
490                MAX_PAGES,
491            )
492            .await?;
493        let existing_models = Self::parse_existing_models(tuples, &self.model_prefix);
494        Ok(existing_models.into_iter().collect())
495    }
496
497    fn parse_existing_models(
498        exist_tuples: Vec<Tuple>,
499        model_prefix: &str,
500    ) -> HashSet<AuthorizationModelVersion> {
501        exist_tuples
502            .into_iter()
503            .filter_map(|t| t.key)
504            .filter_map(|t| Self::parse_model_version_from_key(&t.object, model_prefix))
505            .collect()
506    }
507
508    fn parse_model_version_from_key(
509        model: &str,
510        model_prefix: &str,
511    ) -> Option<AuthorizationModelVersion> {
512        model
513            // Ignore models with wrong prefix
514            .strip_prefix(&format!("{}:", Self::MODEL_VERSION_TYPE))
515            .and_then(|model| {
516                model
517                    .strip_prefix(&format!("{model_prefix}-"))
518                    .and_then(|version| AuthorizationModelVersion::from_str(version).ok())
519            })
520    }
521
522    fn format_model_version_key(model_prefix: &str, version: AuthorizationModelVersion) -> String {
523        format!("{}:{}-{}", Self::MODEL_VERSION_TYPE, model_prefix, version)
524    }
525}
526
527impl VersionedAuthorizationModel {
528    pub(crate) fn new(model: AuthorizationModel, version: AuthorizationModelVersion) -> Self {
529        VersionedAuthorizationModel { model, version }
530    }
531
532    pub(crate) fn version(&self) -> AuthorizationModelVersion {
533        self.version
534    }
535
536    pub(crate) fn model(&self) -> &AuthorizationModel {
537        &self.model
538    }
539}
540
541impl AuthorizationModelVersion {
542    #[must_use]
543    pub fn new(major: u32, minor: u32) -> Self {
544        AuthorizationModelVersion { major, minor }
545    }
546
547    #[must_use]
548    pub fn major(&self) -> u32 {
549        self.major
550    }
551
552    #[must_use]
553    pub fn minor(&self) -> u32 {
554        self.minor
555    }
556}
557
558// Sort by major version first, then by subversion.
559impl PartialOrd for AuthorizationModelVersion {
560    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
561        Some(self.cmp(other))
562    }
563}
564
565impl Ord for AuthorizationModelVersion {
566    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
567        (self.major, self.minor).cmp(&(other.major, other.minor))
568    }
569}
570
571impl std::fmt::Display for AuthorizationModelVersion {
572    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
573        write!(f, "{}.{}", self.major, self.minor)
574    }
575}
576
577impl FromStr for AuthorizationModelVersion {
578    type Err = Error;
579
580    fn from_str(s: &str) -> Result<Self> {
581        let parts = s.split('.').collect::<Vec<_>>();
582        if parts.len() != 2 {
583            return Err(Error::InvalidModelVersion(s.to_string()));
584        }
585
586        let major = parts[0]
587            .parse()
588            .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
589        let minor = parts[1]
590            .parse()
591            .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
592
593        Ok(AuthorizationModelVersion::new(major, minor))
594    }
595}
596
597impl<T, S> std::fmt::Debug for Migration<T, S> {
598    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
599        f.debug_struct("Migration")
600            .field("model", &self.model)
601            .field("pre_migration_fn", &"...")
602            .field("post_migration_fn", &"...")
603            .finish()
604    }
605}
606
607#[cfg(test)]
608pub(crate) mod test {
609    use std::sync::Mutex;
610
611    use needs_env_var::needs_env_var;
612    use pretty_assertions::assert_eq;
613
614    use super::*;
615
616    type ChannelTupleManager = TupleModelManager<tonic::transport::Channel, ()>;
617
618    #[test]
619    fn test_ordering() {
620        let versioned_1_0 = AuthorizationModelVersion::new(1, 0);
621        let versioned_1_1 = AuthorizationModelVersion::new(1, 1);
622        let versioned_2_0 = AuthorizationModelVersion::new(2, 0);
623        let versioned_2_1 = AuthorizationModelVersion::new(2, 1);
624        let versioned_2_2 = AuthorizationModelVersion::new(2, 2);
625
626        assert!(versioned_1_0 < versioned_1_1);
627        assert!(versioned_1_1 < versioned_2_0);
628        assert!(versioned_2_0 < versioned_2_1);
629        assert!(versioned_2_1 < versioned_2_2);
630    }
631
632    #[test]
633    fn test_auth_model_version_str() {
634        let version = AuthorizationModelVersion::new(1, 0);
635        assert_eq!(version.to_string(), "1.0");
636        assert_eq!("1.0".parse::<AuthorizationModelVersion>().unwrap(), version);
637
638        let version = AuthorizationModelVersion::new(10, 2);
639        assert_eq!(version.to_string(), "10.2");
640        assert_eq!(
641            "10.2".parse::<AuthorizationModelVersion>().unwrap(),
642            version
643        );
644    }
645
646    #[test]
647    fn test_parse_model_version_from_key() {
648        let model_prefix = "test";
649        let model_version = AuthorizationModelVersion::new(1, 0);
650        let key = format!("model_version:{model_prefix}-{model_version}");
651        assert_eq!(
652            ChannelTupleManager::parse_model_version_from_key(&key, model_prefix),
653            Some(model_version)
654        );
655
656        // Prefix missing
657        assert!(ChannelTupleManager::parse_model_version_from_key(
658            "model_version:1.0",
659            model_prefix
660        )
661        .is_none());
662
663        // Wrong prefix
664        assert!(ChannelTupleManager::parse_model_version_from_key(
665            "model_version:foo-1.0",
666            model_prefix
667        )
668        .is_none());
669
670        // Higher version
671        assert_eq!(
672            ChannelTupleManager::parse_model_version_from_key(
673                "model_version:other-model-10.200",
674                "other-model"
675            ),
676            Some(AuthorizationModelVersion::new(10, 200))
677        );
678    }
679
680    #[test]
681    fn test_format_model_version_key() {
682        let model_prefix = "test";
683        let model_version = AuthorizationModelVersion::new(1, 0);
684        let key = ChannelTupleManager::format_model_version_key(model_prefix, model_version);
685        assert_eq!(key, "model_version:test-1.0");
686        let parsed = ChannelTupleManager::parse_model_version_from_key(&key, model_prefix).unwrap();
687        assert_eq!(parsed, model_version);
688    }
689
690    #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
691    pub(crate) mod openfga {
692        use std::str::FromStr;
693
694        use pretty_assertions::assert_eq;
695
696        use super::*;
697        use crate::client::{OpenFgaServiceClient, ReadAuthorizationModelRequest};
698
699        pub(crate) async fn get_service_client() -> OpenFgaServiceClient<tonic::transport::Channel>
700        {
701            let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
702            let endpoint = tonic::transport::Endpoint::from_str(&endpoint).unwrap();
703            OpenFgaServiceClient::connect(endpoint)
704                .await
705                .expect("Client can be created")
706        }
707
708        pub(crate) async fn service_client_with_store(
709        ) -> (OpenFgaServiceClient<tonic::transport::Channel>, Store) {
710            let mut client = get_service_client().await;
711            let store_name = format!("test-{}", uuid::Uuid::now_v7());
712            let store = client.get_or_create_store(&store_name).await.unwrap();
713            (client, store)
714        }
715
716        #[tokio::test]
717        async fn test_get_existing_versions_nonexistent_store() {
718            let client = get_service_client().await;
719            let mut manager: TupleModelManager<_, ()> =
720                TupleModelManager::new(client, "nonexistent", "test");
721
722            let versions = manager.get_existing_versions().await.unwrap();
723            assert!(versions.is_empty());
724        }
725
726        #[tokio::test]
727        async fn test_get_existing_versions_nonexistent_auth_model() {
728            let mut client = get_service_client().await;
729            let store_name = format!("test-{}", uuid::Uuid::now_v7());
730            let _store = client.get_or_create_store(&store_name).await.unwrap();
731            let mut manager: TupleModelManager<_, ()> =
732                TupleModelManager::new(client, &store_name, "test");
733            let versions = manager.get_existing_versions().await.unwrap();
734            assert!(versions.is_empty());
735        }
736
737        #[tokio::test]
738        async fn test_get_authorization_model_id() {
739            let (mut client, store) = service_client_with_store().await;
740            let model_prefix = "test";
741            let version = AuthorizationModelVersion::new(1, 0);
742
743            let mut manager: TupleModelManager<_, ()> =
744                TupleModelManager::new(client.clone(), &store.name, model_prefix);
745
746            // Non-existent model
747            assert_eq!(
748                manager.get_authorization_model_id(version).await.unwrap(),
749                None
750            );
751
752            // Apply auth model
753            let model: AuthorizationModel =
754                serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
755                    .unwrap();
756            client
757                .write_authorization_model(model.into_write_request(store.id.clone()))
758                .await
759                .unwrap();
760
761            // Write model tuples
762            client
763                .write(WriteRequest {
764                    store_id: store.id.clone(),
765                    writes: Some(WriteRequestWrites {
766                        tuple_keys: vec![
767                            TupleKey {
768                                user: "auth_model_id:111111".to_string(),
769                                relation: "openfga_id".to_string(),
770                                object: "model_version:test-1.0".to_string(),
771                                condition: None,
772                            },
773                            TupleKey {
774                                user: "auth_model_id:*".to_string(),
775                                relation: "exists".to_string(),
776                                object: "model_version:test-1.0".to_string(),
777                                condition: None,
778                            },
779                            // Tuple with different model prefix should be ignored
780                            TupleKey {
781                                user: "auth_model_id:*".to_string(),
782                                relation: "exists".to_string(),
783                                object: "model_version:test2-1.0".to_string(),
784                                condition: None,
785                            },
786                        ],
787                    }),
788                    deletes: None,
789                    authorization_model_id: String::new(),
790                })
791                .await
792                .unwrap();
793
794            assert_eq!(
795                manager.get_authorization_model_id(version).await.unwrap(),
796                Some("111111".to_string())
797            );
798        }
799
800        #[tokio::test]
801        async fn test_model_manager() {
802            let store_name = format!("test-{}", uuid::Uuid::now_v7());
803            let mut client = get_service_client().await;
804
805            let model_1_0: AuthorizationModel =
806                serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
807                    .unwrap();
808
809            let version_1_0 = AuthorizationModelVersion::new(1, 0);
810
811            let migration_state = MigrationState::default();
812            let mut manager = TupleModelManager::new(client.clone(), &store_name, "test-model")
813                .add_model(
814                    model_1_0.clone(),
815                    version_1_0,
816                    Some(v1_pre_migration_fn),
817                    None::<MigrationFn<_, _>>,
818                );
819            manager.migrate(migration_state.clone()).await.unwrap();
820            // Check hook was called once
821            assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
822            manager.migrate(migration_state.clone()).await.unwrap();
823            // Check hook was not called again
824            assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
825
826            // Check written model
827            let auth_model_id = manager
828                .get_authorization_model_id(version_1_0)
829                .await
830                .unwrap()
831                .unwrap();
832            let mut auth_model =
833                get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
834            auth_model.id = model_1_0.id.clone();
835            assert_eq!(
836                serde_json::to_value(&model_1_0).unwrap(),
837                serde_json::to_value(auth_model).unwrap()
838            );
839
840            // Add a second model
841            let model_1_1: AuthorizationModel =
842                serde_json::from_str(include_str!("../tests/model-manager/v1.1/schema.json"))
843                    .unwrap();
844            let version_1_1 = AuthorizationModelVersion::new(1, 1);
845            let mut manager = manager.add_model(
846                model_1_1.clone(),
847                version_1_1,
848                None::<MigrationFn<_, _>>,
849                Some(v2_post_migration_fn),
850            );
851            manager.migrate(migration_state.clone()).await.unwrap();
852            manager.migrate(migration_state.clone()).await.unwrap();
853            manager.migrate(migration_state.clone()).await.unwrap();
854
855            // First migration still only called once
856            assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
857            // Second migration called once
858            assert_eq!(*migration_state.counter_2.lock().unwrap(), 1);
859
860            // Check written model
861            let auth_model_id = manager
862                .get_authorization_model_id(version_1_1)
863                .await
864                .unwrap()
865                .unwrap();
866            let mut auth_model =
867                get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
868            auth_model.id = model_1_1.id.clone();
869            assert_eq!(
870                serde_json::to_value(&model_1_1).unwrap(),
871                serde_json::to_value(auth_model).unwrap()
872            );
873        }
874
875        async fn get_auth_model_by_id(
876            client: &mut OpenFgaServiceClient<tonic::transport::Channel>,
877            store_name: &str,
878            auth_model_id: &str,
879        ) -> AuthorizationModel {
880            client
881                .read_authorization_model(ReadAuthorizationModelRequest {
882                    store_id: client
883                        .clone()
884                        .get_store_by_name(store_name)
885                        .await
886                        .unwrap()
887                        .unwrap()
888                        .id,
889                    id: auth_model_id.to_string(),
890                })
891                .await
892                .unwrap()
893                .into_inner()
894                .authorization_model
895                .unwrap()
896        }
897    }
898
899    #[derive(Default, Clone)]
900    struct MigrationState {
901        counter_1: Arc<Mutex<i32>>,
902        counter_2: Arc<Mutex<i32>>,
903    }
904
905    #[allow(clippy::unused_async)]
906    async fn v1_pre_migration_fn(
907        client: OpenFgaServiceClient<tonic::transport::Channel>,
908        _prev_model: Option<String>,
909        _curr_model: Option<String>,
910        state: MigrationState,
911    ) -> std::result::Result<(), StdError> {
912        let _ = client;
913        // Throw an error for the second call
914        let mut counter = state.counter_1.lock().unwrap();
915        *counter += 1;
916        if *counter == 2 {
917            return Err(Box::new(Error::RequestFailed(Box::new(
918                tonic::Status::new(tonic::Code::Internal, "Test"),
919            ))));
920        }
921        Ok(())
922    }
923
924    #[allow(clippy::unused_async)]
925    async fn v2_post_migration_fn(
926        client: OpenFgaServiceClient<tonic::transport::Channel>,
927        _prev_model: Option<String>,
928        _curr_model: Option<String>,
929        state: MigrationState,
930    ) -> std::result::Result<(), StdError> {
931        let _ = client;
932        // Throw an error for the second call
933        let mut counter = state.counter_2.lock().unwrap();
934        *counter += 1;
935        if *counter == 2 {
936            return Err(Box::new(Error::RequestFailed(Box::new(
937                tonic::Status::new(tonic::Code::Internal, "Test"),
938            ))));
939        }
940        Ok(())
941    }
942}