openfga_client/
migration.rs

1use std::{
2    collections::{HashMap, HashSet},
3    future::Future,
4    hash::Hash,
5    pin::Pin,
6    str::FromStr,
7    sync::Arc,
8};
9
10use tonic::codegen::{Body, Bytes, StdError};
11
12use crate::{
13    client::{
14        AuthorizationModel, ConsistencyPreference, OpenFgaServiceClient, ReadRequestTupleKey,
15        Store, Tuple, TupleKey, WriteAuthorizationModelResponse, WriteRequest, WriteRequestWrites,
16    },
17    error::{Error, Result},
18};
19
20const DEFAULT_PAGE_SIZE: i32 = 100;
21const MAX_PAGES: u32 = 1000;
22
23#[derive(Debug, Copy, Clone, PartialEq, Hash, Eq)]
24pub struct AuthorizationModelVersion {
25    major: u32,
26    minor: u32,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30struct VersionedAuthorizationModel {
31    model: AuthorizationModel,
32    version: AuthorizationModelVersion,
33}
34
35/// Manages [`AuthorizationModel`]s in OpenFGA.
36///
37/// Authorization models in OpenFGA don't receive a unique name. Instead,
38/// they receive a random id on creation. If we don't store this ID, we can't
39/// find the model again and use its ID.
40///
41/// This `ModelManager` stores the mapping of [`AuthorizationModelVersion`]
42/// to the ID of the model in OpenFGA directly inside OpenFGA.
43/// This way can query OpenFGA to determine if a model with a certain version
44/// has already exists.
45///
46/// When running [`TupleModelManager::migrate()`], the manager only applies models and their migrations
47/// if they don't already exist in OpenFGA.
48///
49/// To store the mapping of model versions to OpenFGA IDs, the following needs to part of your Authorization Model:
50/// ```text
51/// type auth_model_id
52/// type model_version
53///   relations
54///     define openfga_id: [auth_model_id]
55///     define exists: [auth_model_id:*]
56/// ```
57///
58#[derive(Debug, Clone)]
59pub struct TupleModelManager<T, S>
60where
61    T: tonic::client::GrpcService<tonic::body::Body>,
62    T::Error: Into<StdError>,
63    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
64    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
65{
66    client: OpenFgaServiceClient<T>,
67    store_name: String,
68    model_prefix: String,
69    migrations: HashMap<AuthorizationModelVersion, Migration<T, S>>,
70}
71
72#[derive(Clone)]
73struct Migration<T, S> {
74    /// The model being migrated to.
75    ///
76    /// If this has value `vX`, the active model *after* the migration will be `vX`.
77    model: VersionedAuthorizationModel,
78    pre_migration_fn: Option<BoxedMigrationFn<T, S>>,
79    post_migration_fn: Option<BoxedMigrationFn<T, S>>,
80}
81
82// Define a type alias for a boxed future with a specific lifetime
83pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
84
85/// Possible function pointer that implements the migration function signature.
86///
87/// The arguments are:
88///
89/// 1) A client connected to endpoint that's being migrated
90/// 2) The id of the authorization model active prior to the invocation of the function
91/// 3) The id of the authorization model active when the function is invoked
92/// 4) The (user defined) state passed into the migration function
93///
94/// Authorization model ids may be undefined (`None`), e.g. for the pre hook of the first
95/// migration neither a previous nor a current model exist. For the first migration run by
96/// [`TupleModelManager::migrate`], the previous model id is always `None`.
97pub type MigrationFn<T, S> = fn(
98    OpenFgaServiceClient<T>,
99    Option<String>,
100    Option<String>,
101    S,
102) -> BoxFuture<'static, std::result::Result<(), StdError>>;
103
104/// Type alias for the migration function signature.
105type DynMigrationFn<T, S> = dyn Fn(
106    OpenFgaServiceClient<T>,
107    Option<String>,
108    Option<String>,
109    S,
110) -> BoxFuture<'static, std::result::Result<(), StdError>>;
111
112/// Boxed migration function
113type BoxedMigrationFn<T, S> = Arc<DynMigrationFn<T, S>>;
114
115// Function to box the async functions that take an i32 parameter
116fn box_migration_fn<T, S, F, Fut>(f: F) -> BoxedMigrationFn<T, S>
117where
118    F: Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> Fut + Send + 'static,
119    Fut: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
120{
121    Arc::new(
122        move |client, prev_auth_model_id, active_auth_model_id, state| {
123            Box::pin(f(client, prev_auth_model_id, active_auth_model_id, state))
124        },
125    )
126}
127
128impl<T, S> TupleModelManager<T, S>
129where
130    T: tonic::client::GrpcService<tonic::body::Body>,
131    T: Clone,
132    T::Error: Into<StdError>,
133    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
134    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
135    S: Clone,
136{
137    const AUTH_MODEL_ID_TYPE: &'static str = "auth_model_id";
138    const MODEL_VERSION_EXISTS_RELATION: &'static str = "exists";
139    const MODEL_VERSION_TYPE: &'static str = "model_version";
140    const MODEL_VERSION_OPENFGA_ID_RELATION: &'static str = "openfga_id";
141
142    /// Create a new `TupleModelManager` with the given client and model name.
143    /// The model prefix must not change after the first model has been added or
144    /// the model manager will not be able to find the model again.
145    /// Use different model prefixes if models for different purposes are stored in the
146    /// same OpenFGA store.
147    pub fn new(client: OpenFgaServiceClient<T>, store_name: &str, model_prefix: &str) -> Self {
148        TupleModelManager {
149            client,
150            model_prefix: model_prefix.to_string(),
151            store_name: store_name.to_string(),
152            migrations: HashMap::new(),
153        }
154    }
155
156    /// Add a new model to the manager.
157    /// If a model with the same version has already been added, the new model will replace the old one.
158    ///
159    /// Ensure that migration functions are written in an idempotent way.
160    /// If a migration fails, it might be retried.
161    #[must_use]
162    pub fn add_model<FutPre, FutPost>(
163        mut self,
164        model: AuthorizationModel,
165        version: AuthorizationModelVersion,
166        pre_migration_fn: Option<
167            impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPre
168            + Send
169            + 'static,
170        >,
171        post_migration_fn: Option<
172            impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPost
173            + Send
174            + 'static,
175        >,
176    ) -> Self
177    where
178        FutPre: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
179        FutPost: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
180    {
181        let migration = Migration {
182            model: VersionedAuthorizationModel::new(model, version),
183            pre_migration_fn: pre_migration_fn.map(box_migration_fn),
184            post_migration_fn: post_migration_fn.map(box_migration_fn),
185        };
186        self.migrations.insert(migration.model.version(), migration);
187        self
188    }
189
190    /// Run migrations.
191    ///
192    /// This will:
193    /// 1. Get all existing models in the OpenFGA store.
194    /// 2. Determine which migrations need to be performed: All migrations with a version higher than the highest existing model.
195    /// 3. In order of the version of the model, perform the migrations:
196    ///    1. Run the pre-migration hook if it exists.
197    ///    2. Write the model to OpenFGA.
198    ///    3. Run the post-migration hook if it exists.
199    /// 4. Mark the model as applied in OpenFGA.
200    ///
201    /// # Errors
202    /// * If OpenFGA cannot be reached or a request fails.
203    /// * If any of the migration hooks fail.
204    #[allow(clippy::too_many_lines)]
205    pub async fn migrate(&mut self, state: S) -> Result<()> {
206        let span = tracing::span!(
207            tracing::Level::INFO,
208            "Running OpenFGA Migrations",
209            store_name = self.store_name,
210            model_prefix = self.model_prefix
211        );
212        let _enter = span.enter();
213
214        if self.migrations.is_empty() {
215            tracing::info!("No Migrations have been added. Nothing to do.");
216            return Ok(());
217        }
218
219        let store = self.client.get_or_create_store(&self.store_name).await?;
220        let existing_models = self.get_existing_versions().await?;
221        let max_existing_model = existing_models.iter().max().copied();
222        let mut curr_model_id = if let Some(version) = max_existing_model {
223            Some(self.require_authorization_model_id(version).await?)
224        } else {
225            None
226        };
227        // At this point the previous model id cannot be determined reliably, e.g. because
228        // migrations might have been run out of order.
229        let mut prev_model_id = None;
230
231        if let Some(max_existing_model) = max_existing_model {
232            tracing::info!(
233                "Currently the highest existing Model Version is: {}",
234                max_existing_model
235            );
236        } else {
237            tracing::info!("No model found in OpenFGA store");
238        }
239
240        let ordered_migrations = self.migrations_to_perform(max_existing_model);
241
242        let mut client = self.client.clone();
243        for migration in ordered_migrations {
244            tracing::info!("Migrating to model version: {}", migration.model.version());
245
246            // Pre-hook
247            if let Some(pre_migration_fn) = migration.pre_migration_fn.as_ref() {
248                pre_migration_fn(
249                    client.clone(),
250                    prev_model_id.clone(),
251                    curr_model_id.clone(),
252                    state.clone(),
253                )
254                .await
255                .map_err(|e| {
256                    tracing::error!("Error in OpenFGA pre-migration hook: {:?}", e);
257                    Error::MigrationHookFailed {
258                        version: migration.model.version().to_string(),
259                        error: Arc::new(e),
260                    }
261                })?;
262            }
263
264            // Write Model
265            let request = migration
266                .model
267                .model()
268                .clone()
269                .into_write_request(store.id.clone());
270            let written_model = client
271                .write_authorization_model(request)
272                .await
273                .map_err(|e| {
274                    tracing::error!("Error writing model: {:?}", e);
275                    Error::RequestFailed(Box::new(e))
276                })?;
277            tracing::debug!("Model written: {:?}", written_model);
278
279            // Update model versions passed to migration hooks.
280            prev_model_id.clone_from(&curr_model_id);
281            curr_model_id = Some(written_model.get_ref().authorization_model_id.clone());
282
283            // Post-hook
284            if let Some(post_migration_fn) = migration.post_migration_fn.as_ref() {
285                post_migration_fn(
286                    client.clone(),
287                    prev_model_id.clone(),
288                    curr_model_id.clone(),
289                    state.clone(),
290                )
291                .await
292                .map_err(|e| {
293                    tracing::error!("Error in OpenFGA post-migration hook: {:?}", e);
294                    Error::MigrationHookFailed {
295                        version: migration.model.version().to_string(),
296                        error: Arc::new(e),
297                    }
298                })?;
299            }
300
301            // Mark as applied
302            Self::mark_as_applied(
303                &mut client,
304                &self.model_prefix,
305                &store,
306                migration.model.version(),
307                written_model.into_inner(),
308            )
309            .await?;
310        }
311
312        Ok(())
313    }
314
315    /// Get the OpenFGA Authorization model ID for the specified model version.
316    /// Ensure that migrations have been run before calling this method.
317    ///
318    /// # Errors
319    /// * If the store with the given name does not exist.
320    /// * If a call to OpenFGA fails.
321    pub async fn get_authorization_model_id(
322        &mut self,
323        version: AuthorizationModelVersion,
324    ) -> Result<Option<String>> {
325        let store = self
326            .client
327            .get_store_by_name(&self.store_name)
328            .await?
329            .ok_or_else(|| {
330                tracing::error!("Store with name {} not found", self.store_name);
331                Error::StoreNotFound(self.store_name.clone())
332            })?;
333
334        let applied_models = self
335            .client
336            .read_all_pages(
337                &store.id,
338                Some(ReadRequestTupleKey {
339                    user: String::new(),
340                    relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
341                    object: Self::format_model_version_key(&self.model_prefix, version),
342                }),
343                ConsistencyPreference::HigherConsistency,
344                DEFAULT_PAGE_SIZE,
345                MAX_PAGES,
346            )
347            .await?;
348
349        let applied_models = applied_models
350            .into_iter()
351            .filter_map(|t| t.key)
352            .filter_map(|t| {
353                t.user
354                    .strip_prefix(&format!("{}:", Self::AUTH_MODEL_ID_TYPE))
355                    .map(ToString::to_string)
356            })
357            .collect::<Vec<_>>();
358
359        if applied_models.len() > 1 {
360            tracing::error!(
361                "Multiple authorization models with model prefix {} for version {} found.",
362                self.model_prefix,
363                version
364            );
365            return Err(Error::AmbiguousModelVersion {
366                model_prefix: self.model_prefix.clone(),
367                version: version.to_string(),
368            });
369        }
370
371        let model_id = applied_models.into_iter().next().map(|openfga_id| {
372            tracing::info!(
373                "Authorization model for version {version} found in OpenFGA store {}. Model ID: {openfga_id}",
374                self.store_name,
375            );
376            openfga_id
377        });
378
379        Ok(model_id)
380    }
381
382    /// Helper method that tries to get the authorization model id for `version` and returns
383    /// an error if the id cannot be found.
384    async fn require_authorization_model_id(
385        &mut self,
386        version: AuthorizationModelVersion,
387    ) -> Result<String> {
388        self.get_authorization_model_id(version)
389            .await?
390            .ok_or_else(|| {
391                tracing::error!("Missing authorization model id for model version {version}");
392                Error::MissingAuthorizationModelId {
393                    model_prefix: self.model_prefix.clone(),
394                    version: version.to_string(),
395                }
396            })
397    }
398
399    /// Mark a model version as applied in OpenFGA
400    async fn mark_as_applied(
401        client: &mut OpenFgaServiceClient<T>,
402        model_prefix: &str,
403        store: &Store,
404        version: AuthorizationModelVersion,
405        write_response: WriteAuthorizationModelResponse,
406    ) -> Result<()> {
407        let authorization_model_id = write_response.authorization_model_id;
408        let object = Self::format_model_version_key(model_prefix, version);
409
410        let write_request = WriteRequest {
411            store_id: store.id.clone(),
412            writes: Some(WriteRequestWrites {
413                on_duplicate: String::new(),
414                tuple_keys: vec![
415                    TupleKey {
416                        user: format!("{}:{authorization_model_id}", Self::AUTH_MODEL_ID_TYPE),
417                        relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
418                        object: object.clone(),
419                        condition: None,
420                    },
421                    TupleKey {
422                        user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE),
423                        relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
424                        object,
425                        condition: None,
426                    },
427                ],
428            }),
429            deletes: None,
430            authorization_model_id: authorization_model_id.clone(),
431        };
432        client.write(write_request.clone()).await.map_err(|e| {
433            tracing::error!("Error marking model as applied: {:?}", e);
434            Error::RequestFailed(Box::new(e))
435        })?;
436        Ok(())
437    }
438
439    /// Get all migrations that have been added to the manager
440    /// as a `Vec` sorted by the version of the model.
441    fn ordered_migrations(&self) -> Vec<&Migration<T, S>> {
442        let mut migrations = self.migrations.values().collect::<Vec<_>>();
443        migrations.sort_unstable_by_key(|m| m.model.version());
444        migrations
445    }
446
447    /// Get all migrations that need to be performed, given the maximum existing model version.
448    fn migrations_to_perform(
449        &self,
450        max_existing_model: Option<AuthorizationModelVersion>,
451    ) -> Vec<&Migration<T, S>> {
452        let ordered_migrations = self.ordered_migrations();
453        let migrations_to_perform = ordered_migrations
454            .into_iter()
455            .filter(|m| {
456                max_existing_model.is_none_or(|max_existing| m.model.version() > max_existing)
457            })
458            .collect::<Vec<_>>();
459
460        tracing::info!(
461            "{} migrations needed in OpenFGA store {} for model-prefix {}",
462            migrations_to_perform.len(),
463            self.store_name,
464            self.model_prefix
465        );
466        migrations_to_perform
467    }
468
469    /// Get versions of all existing models in OpenFGA.
470    /// Returns an empty vector if the store does not exist.
471    ///
472    /// # Errors
473    /// * If the call to determine existing stores fails.
474    /// * If a tuple read call fails.
475    pub async fn get_existing_versions(&mut self) -> Result<Vec<AuthorizationModelVersion>> {
476        let Some(store) = self.client.get_store_by_name(&self.store_name).await? else {
477            return Ok(vec![]);
478        };
479
480        let tuples = self
481            .client
482            .read_all_pages(
483                &store.id,
484                Some(ReadRequestTupleKey {
485                    user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE).to_string(),
486                    relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
487                    object: format!("{}:", Self::MODEL_VERSION_TYPE).to_string(),
488                }),
489                crate::client::ConsistencyPreference::HigherConsistency,
490                DEFAULT_PAGE_SIZE,
491                MAX_PAGES,
492            )
493            .await?;
494        let existing_models = Self::parse_existing_models(tuples, &self.model_prefix);
495        Ok(existing_models.into_iter().collect())
496    }
497
498    fn parse_existing_models(
499        exist_tuples: Vec<Tuple>,
500        model_prefix: &str,
501    ) -> HashSet<AuthorizationModelVersion> {
502        exist_tuples
503            .into_iter()
504            .filter_map(|t| t.key)
505            .filter_map(|t| Self::parse_model_version_from_key(&t.object, model_prefix))
506            .collect()
507    }
508
509    fn parse_model_version_from_key(
510        model: &str,
511        model_prefix: &str,
512    ) -> Option<AuthorizationModelVersion> {
513        model
514            // Ignore models with wrong prefix
515            .strip_prefix(&format!("{}:", Self::MODEL_VERSION_TYPE))
516            .and_then(|model| {
517                model
518                    .strip_prefix(&format!("{model_prefix}-"))
519                    .and_then(|version| AuthorizationModelVersion::from_str(version).ok())
520            })
521    }
522
523    fn format_model_version_key(model_prefix: &str, version: AuthorizationModelVersion) -> String {
524        format!("{}:{}-{}", Self::MODEL_VERSION_TYPE, model_prefix, version)
525    }
526}
527
528impl VersionedAuthorizationModel {
529    pub(crate) fn new(model: AuthorizationModel, version: AuthorizationModelVersion) -> Self {
530        VersionedAuthorizationModel { model, version }
531    }
532
533    pub(crate) fn version(&self) -> AuthorizationModelVersion {
534        self.version
535    }
536
537    pub(crate) fn model(&self) -> &AuthorizationModel {
538        &self.model
539    }
540}
541
542impl AuthorizationModelVersion {
543    #[must_use]
544    pub fn new(major: u32, minor: u32) -> Self {
545        AuthorizationModelVersion { major, minor }
546    }
547
548    #[must_use]
549    pub fn major(&self) -> u32 {
550        self.major
551    }
552
553    #[must_use]
554    pub fn minor(&self) -> u32 {
555        self.minor
556    }
557}
558
559// Sort by major version first, then by subversion.
560impl PartialOrd for AuthorizationModelVersion {
561    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
562        Some(self.cmp(other))
563    }
564}
565
566impl Ord for AuthorizationModelVersion {
567    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
568        (self.major, self.minor).cmp(&(other.major, other.minor))
569    }
570}
571
572impl std::fmt::Display for AuthorizationModelVersion {
573    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
574        write!(f, "{}.{}", self.major, self.minor)
575    }
576}
577
578impl FromStr for AuthorizationModelVersion {
579    type Err = Error;
580
581    fn from_str(s: &str) -> Result<Self> {
582        let parts = s.split('.').collect::<Vec<_>>();
583        if parts.len() != 2 {
584            return Err(Error::InvalidModelVersion(s.to_string()));
585        }
586
587        let major = parts[0]
588            .parse()
589            .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
590        let minor = parts[1]
591            .parse()
592            .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
593
594        Ok(AuthorizationModelVersion::new(major, minor))
595    }
596}
597
598impl<T, S> std::fmt::Debug for Migration<T, S> {
599    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
600        f.debug_struct("Migration")
601            .field("model", &self.model)
602            .field("pre_migration_fn", &"...")
603            .field("post_migration_fn", &"...")
604            .finish()
605    }
606}
607
608#[cfg(test)]
609pub(crate) mod test {
610    use std::sync::Mutex;
611
612    use needs_env_var::needs_env_var;
613    use pretty_assertions::assert_eq;
614
615    use super::*;
616
617    type ChannelTupleManager = TupleModelManager<tonic::transport::Channel, ()>;
618
619    #[test]
620    fn test_ordering() {
621        let versioned_1_0 = AuthorizationModelVersion::new(1, 0);
622        let versioned_1_1 = AuthorizationModelVersion::new(1, 1);
623        let versioned_2_0 = AuthorizationModelVersion::new(2, 0);
624        let versioned_2_1 = AuthorizationModelVersion::new(2, 1);
625        let versioned_2_2 = AuthorizationModelVersion::new(2, 2);
626
627        assert!(versioned_1_0 < versioned_1_1);
628        assert!(versioned_1_1 < versioned_2_0);
629        assert!(versioned_2_0 < versioned_2_1);
630        assert!(versioned_2_1 < versioned_2_2);
631    }
632
633    #[test]
634    fn test_auth_model_version_str() {
635        let version = AuthorizationModelVersion::new(1, 0);
636        assert_eq!(version.to_string(), "1.0");
637        assert_eq!("1.0".parse::<AuthorizationModelVersion>().unwrap(), version);
638
639        let version = AuthorizationModelVersion::new(10, 2);
640        assert_eq!(version.to_string(), "10.2");
641        assert_eq!(
642            "10.2".parse::<AuthorizationModelVersion>().unwrap(),
643            version
644        );
645    }
646
647    #[test]
648    fn test_parse_model_version_from_key() {
649        let model_prefix = "test";
650        let model_version = AuthorizationModelVersion::new(1, 0);
651        let key = format!("model_version:{model_prefix}-{model_version}");
652        assert_eq!(
653            ChannelTupleManager::parse_model_version_from_key(&key, model_prefix),
654            Some(model_version)
655        );
656
657        // Prefix missing
658        assert!(
659            ChannelTupleManager::parse_model_version_from_key("model_version:1.0", model_prefix)
660                .is_none()
661        );
662
663        // Wrong prefix
664        assert!(
665            ChannelTupleManager::parse_model_version_from_key(
666                "model_version:foo-1.0",
667                model_prefix
668            )
669            .is_none()
670        );
671
672        // Higher version
673        assert_eq!(
674            ChannelTupleManager::parse_model_version_from_key(
675                "model_version:other-model-10.200",
676                "other-model"
677            ),
678            Some(AuthorizationModelVersion::new(10, 200))
679        );
680    }
681
682    #[test]
683    fn test_format_model_version_key() {
684        let model_prefix = "test";
685        let model_version = AuthorizationModelVersion::new(1, 0);
686        let key = ChannelTupleManager::format_model_version_key(model_prefix, model_version);
687        assert_eq!(key, "model_version:test-1.0");
688        let parsed = ChannelTupleManager::parse_model_version_from_key(&key, model_prefix).unwrap();
689        assert_eq!(parsed, model_version);
690    }
691
692    #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
693    pub(crate) mod openfga {
694        use std::str::FromStr;
695
696        use pretty_assertions::assert_eq;
697
698        use super::*;
699        use crate::client::{OpenFgaServiceClient, ReadAuthorizationModelRequest};
700
701        pub(crate) async fn get_service_client() -> OpenFgaServiceClient<tonic::transport::Channel>
702        {
703            let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
704            let endpoint = tonic::transport::Endpoint::from_str(&endpoint).unwrap();
705            OpenFgaServiceClient::connect(endpoint)
706                .await
707                .expect("Client can be created")
708        }
709
710        pub(crate) async fn service_client_with_store()
711        -> (OpenFgaServiceClient<tonic::transport::Channel>, Store) {
712            let mut client = get_service_client().await;
713            let store_name = format!("test-{}", uuid::Uuid::now_v7());
714            let store = client.get_or_create_store(&store_name).await.unwrap();
715            (client, store)
716        }
717
718        #[tokio::test]
719        async fn test_get_existing_versions_nonexistent_store() {
720            let client = get_service_client().await;
721            let mut manager: TupleModelManager<_, ()> =
722                TupleModelManager::new(client, "nonexistent", "test");
723
724            let versions = manager.get_existing_versions().await.unwrap();
725            assert!(versions.is_empty());
726        }
727
728        #[tokio::test]
729        async fn test_get_existing_versions_nonexistent_auth_model() {
730            let mut client = get_service_client().await;
731            let store_name = format!("test-{}", uuid::Uuid::now_v7());
732            let _store = client.get_or_create_store(&store_name).await.unwrap();
733            let mut manager: TupleModelManager<_, ()> =
734                TupleModelManager::new(client, &store_name, "test");
735            let versions = manager.get_existing_versions().await.unwrap();
736            assert!(versions.is_empty());
737        }
738
739        #[tokio::test]
740        async fn test_get_authorization_model_id() {
741            let (mut client, store) = service_client_with_store().await;
742            let model_prefix = "test";
743            let version = AuthorizationModelVersion::new(1, 0);
744
745            let mut manager: TupleModelManager<_, ()> =
746                TupleModelManager::new(client.clone(), &store.name, model_prefix);
747
748            // Non-existent model
749            assert_eq!(
750                manager.get_authorization_model_id(version).await.unwrap(),
751                None
752            );
753
754            // Apply auth model
755            let model: AuthorizationModel =
756                serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
757                    .unwrap();
758            client
759                .write_authorization_model(model.into_write_request(store.id.clone()))
760                .await
761                .unwrap();
762
763            // Write model tuples
764            client
765                .write(WriteRequest {
766                    store_id: store.id.clone(),
767                    writes: Some(WriteRequestWrites {
768                        on_duplicate: String::new(),
769                        tuple_keys: vec![
770                            TupleKey {
771                                user: "auth_model_id:111111".to_string(),
772                                relation: "openfga_id".to_string(),
773                                object: "model_version:test-1.0".to_string(),
774                                condition: None,
775                            },
776                            TupleKey {
777                                user: "auth_model_id:*".to_string(),
778                                relation: "exists".to_string(),
779                                object: "model_version:test-1.0".to_string(),
780                                condition: None,
781                            },
782                            // Tuple with different model prefix should be ignored
783                            TupleKey {
784                                user: "auth_model_id:*".to_string(),
785                                relation: "exists".to_string(),
786                                object: "model_version:test2-1.0".to_string(),
787                                condition: None,
788                            },
789                        ],
790                    }),
791                    deletes: None,
792                    authorization_model_id: String::new(),
793                })
794                .await
795                .unwrap();
796
797            assert_eq!(
798                manager.get_authorization_model_id(version).await.unwrap(),
799                Some("111111".to_string())
800            );
801        }
802
803        #[tokio::test]
804        async fn test_model_manager() {
805            let store_name = format!("test-{}", uuid::Uuid::now_v7());
806            let mut client = get_service_client().await;
807
808            let model_1_0: AuthorizationModel =
809                serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
810                    .unwrap();
811
812            let version_1_0 = AuthorizationModelVersion::new(1, 0);
813
814            let migration_state = MigrationState::default();
815            let mut manager = TupleModelManager::new(client.clone(), &store_name, "test-model")
816                .add_model(
817                    model_1_0.clone(),
818                    version_1_0,
819                    Some(v1_pre_migration_fn),
820                    None::<MigrationFn<_, _>>,
821                );
822            manager.migrate(migration_state.clone()).await.unwrap();
823            // Check hook was called once
824            assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
825            manager.migrate(migration_state.clone()).await.unwrap();
826            // Check hook was not called again
827            assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
828
829            // Check written model
830            let auth_model_id = manager
831                .get_authorization_model_id(version_1_0)
832                .await
833                .unwrap()
834                .unwrap();
835            let mut auth_model =
836                get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
837            auth_model.id = model_1_0.id.clone();
838            assert_eq!(
839                serde_json::to_value(&model_1_0).unwrap(),
840                serde_json::to_value(auth_model).unwrap()
841            );
842
843            // Add a second model
844            let model_1_1: AuthorizationModel =
845                serde_json::from_str(include_str!("../tests/model-manager/v1.1/schema.json"))
846                    .unwrap();
847            let version_1_1 = AuthorizationModelVersion::new(1, 1);
848            let mut manager = manager.add_model(
849                model_1_1.clone(),
850                version_1_1,
851                None::<MigrationFn<_, _>>,
852                Some(v2_post_migration_fn),
853            );
854            manager.migrate(migration_state.clone()).await.unwrap();
855            manager.migrate(migration_state.clone()).await.unwrap();
856            manager.migrate(migration_state.clone()).await.unwrap();
857
858            // First migration still only called once
859            assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
860            // Second migration called once
861            assert_eq!(*migration_state.counter_2.lock().unwrap(), 1);
862
863            // Check written model
864            let auth_model_id = manager
865                .get_authorization_model_id(version_1_1)
866                .await
867                .unwrap()
868                .unwrap();
869            let mut auth_model =
870                get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
871            auth_model.id = model_1_1.id.clone();
872            assert_eq!(
873                serde_json::to_value(&model_1_1).unwrap(),
874                serde_json::to_value(auth_model).unwrap()
875            );
876        }
877
878        async fn get_auth_model_by_id(
879            client: &mut OpenFgaServiceClient<tonic::transport::Channel>,
880            store_name: &str,
881            auth_model_id: &str,
882        ) -> AuthorizationModel {
883            client
884                .read_authorization_model(ReadAuthorizationModelRequest {
885                    store_id: client
886                        .clone()
887                        .get_store_by_name(store_name)
888                        .await
889                        .unwrap()
890                        .unwrap()
891                        .id,
892                    id: auth_model_id.to_string(),
893                })
894                .await
895                .unwrap()
896                .into_inner()
897                .authorization_model
898                .unwrap()
899        }
900    }
901
902    #[derive(Default, Clone)]
903    struct MigrationState {
904        counter_1: Arc<Mutex<i32>>,
905        counter_2: Arc<Mutex<i32>>,
906    }
907
908    #[allow(clippy::unused_async)]
909    async fn v1_pre_migration_fn(
910        client: OpenFgaServiceClient<tonic::transport::Channel>,
911        _prev_model: Option<String>,
912        _curr_model: Option<String>,
913        state: MigrationState,
914    ) -> std::result::Result<(), StdError> {
915        let _ = client;
916        // Throw an error for the second call
917        let mut counter = state.counter_1.lock().unwrap();
918        *counter += 1;
919        if *counter == 2 {
920            return Err(Box::new(Error::RequestFailed(Box::new(
921                tonic::Status::new(tonic::Code::Internal, "Test"),
922            ))));
923        }
924        Ok(())
925    }
926
927    #[allow(clippy::unused_async)]
928    async fn v2_post_migration_fn(
929        client: OpenFgaServiceClient<tonic::transport::Channel>,
930        _prev_model: Option<String>,
931        _curr_model: Option<String>,
932        state: MigrationState,
933    ) -> std::result::Result<(), StdError> {
934        let _ = client;
935        // Throw an error for the second call
936        let mut counter = state.counter_2.lock().unwrap();
937        *counter += 1;
938        if *counter == 2 {
939            return Err(Box::new(Error::RequestFailed(Box::new(
940                tonic::Status::new(tonic::Code::Internal, "Test"),
941            ))));
942        }
943        Ok(())
944    }
945}