openfga_client/
migration.rs

1use std::{
2    collections::{HashMap, HashSet},
3    future::Future,
4    hash::Hash,
5    pin::Pin,
6    str::FromStr,
7    sync::Arc,
8};
9
10use futures_timer::Delay;
11use tonic::codegen::{Body, Bytes, StdError};
12
13use crate::{
14    client::{
15        AuthorizationModel, ConsistencyPreference, OpenFgaServiceClient, ReadRequestTupleKey,
16        Store, Tuple, TupleKey, WriteAuthorizationModelResponse, WriteRequest, WriteRequestWrites,
17    },
18    error::{Error, Result},
19};
20
21const DEFAULT_PAGE_SIZE: i32 = 100;
22const MAX_PAGES: u32 = 1000;
23
24#[derive(Debug, Copy, Clone, PartialEq, Hash, Eq)]
25pub struct AuthorizationModelVersion {
26    major: u32,
27    minor: u32,
28}
29
30#[derive(Debug, Clone, PartialEq)]
31struct VersionedAuthorizationModel {
32    model: AuthorizationModel,
33    version: AuthorizationModelVersion,
34}
35
36/// Manages [`AuthorizationModel`]s in OpenFGA.
37///
38/// Authorization models in OpenFGA don't receive a unique name. Instead,
39/// they receive a random id on creation. If we don't store this ID, we can't
40/// find the model again and use its ID.
41///
42/// This `ModelManager` stores the mapping of [`AuthorizationModelVersion`]
43/// to the ID of the model in OpenFGA directly inside OpenFGA.
44/// This way can query OpenFGA to determine if a model with a certain version
45/// has already exists.
46///
47/// When running [`TupleModelManager::migrate()`], the manager only applies models and their migrations
48/// if they don't already exist in OpenFGA.
49///
50/// To store the mapping of model versions to OpenFGA IDs, the following needs to part of your Authorization Model:
51/// ```text
52/// type auth_model_id
53/// type model_version
54///   relations
55///     define openfga_id: [auth_model_id]
56///     define exists: [auth_model_id:*]
57/// ```
58///
59#[derive(Debug, Clone)]
60pub struct TupleModelManager<T, S>
61where
62    T: tonic::client::GrpcService<tonic::body::Body>,
63    T::Error: Into<StdError>,
64    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
65    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
66{
67    client: OpenFgaServiceClient<T>,
68    store_name: String,
69    model_prefix: String,
70    migrations: HashMap<AuthorizationModelVersion, Migration<T, S>>,
71}
72
73#[derive(Clone)]
74struct Migration<T, S> {
75    /// The model being migrated to.
76    ///
77    /// If this has value `vX`, the active model *after* the migration will be `vX`.
78    model: VersionedAuthorizationModel,
79    pre_migration_fn: Option<BoxedMigrationFn<T, S>>,
80    post_migration_fn: Option<BoxedMigrationFn<T, S>>,
81}
82
83// Define a type alias for a boxed future with a specific lifetime
84pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
85
86/// Possible function pointer that implements the migration function signature.
87///
88/// The arguments are:
89///
90/// 1) A client connected to endpoint that's being migrated
91/// 2) The id of the authorization model active prior to the invocation of the function
92/// 3) The id of the authorization model active when the function is invoked
93/// 4) The (user defined) state passed into the migration function
94///
95/// Authorization model ids may be undefined (`None`), e.g. for the pre hook of the first
96/// migration neither a previous nor a current model exist. For the first migration run by
97/// [`TupleModelManager::migrate`], the previous model id is always `None`.
98pub type MigrationFn<T, S> = fn(
99    OpenFgaServiceClient<T>,
100    Option<String>,
101    Option<String>,
102    S,
103) -> BoxFuture<'static, std::result::Result<(), StdError>>;
104
105/// Type alias for the migration function signature.
106type DynMigrationFn<T, S> = dyn Fn(
107    OpenFgaServiceClient<T>,
108    Option<String>,
109    Option<String>,
110    S,
111) -> BoxFuture<'static, std::result::Result<(), StdError>>;
112
113/// Boxed migration function
114type BoxedMigrationFn<T, S> = Arc<DynMigrationFn<T, S>>;
115
116// Function to box the async functions that take an i32 parameter
117fn box_migration_fn<T, S, F, Fut>(f: F) -> BoxedMigrationFn<T, S>
118where
119    F: Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> Fut + Send + 'static,
120    Fut: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
121{
122    Arc::new(
123        move |client, prev_auth_model_id, active_auth_model_id, state| {
124            Box::pin(f(client, prev_auth_model_id, active_auth_model_id, state))
125        },
126    )
127}
128
129impl<T, S> TupleModelManager<T, S>
130where
131    T: tonic::client::GrpcService<tonic::body::Body>,
132    T: Clone,
133    T::Error: Into<StdError>,
134    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
135    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
136    S: Clone,
137{
138    const AUTH_MODEL_ID_TYPE: &'static str = "auth_model_id";
139    const MODEL_VERSION_EXISTS_RELATION: &'static str = "exists";
140    const MODEL_VERSION_TYPE: &'static str = "model_version";
141    const MODEL_VERSION_OPENFGA_ID_RELATION: &'static str = "openfga_id";
142
143    /// Create a new `TupleModelManager` with the given client and model name.
144    /// The model prefix must not change after the first model has been added or
145    /// the model manager will not be able to find the model again.
146    /// Use different model prefixes if models for different purposes are stored in the
147    /// same OpenFGA store.
148    pub fn new(client: OpenFgaServiceClient<T>, store_name: &str, model_prefix: &str) -> Self {
149        TupleModelManager {
150            client,
151            model_prefix: model_prefix.to_string(),
152            store_name: store_name.to_string(),
153            migrations: HashMap::new(),
154        }
155    }
156
157    /// Add a new model to the manager.
158    /// If a model with the same version has already been added, the new model will replace the old one.
159    ///
160    /// Ensure that migration functions are written in an idempotent way.
161    /// If a migration fails, it might be retried.
162    #[must_use]
163    pub fn add_model<FutPre, FutPost>(
164        mut self,
165        model: AuthorizationModel,
166        version: AuthorizationModelVersion,
167        pre_migration_fn: Option<
168            impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPre
169            + Send
170            + 'static,
171        >,
172        post_migration_fn: Option<
173            impl Fn(OpenFgaServiceClient<T>, Option<String>, Option<String>, S) -> FutPost
174            + Send
175            + 'static,
176        >,
177    ) -> Self
178    where
179        FutPre: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
180        FutPost: Future<Output = std::result::Result<(), StdError>> + Send + 'static,
181    {
182        let migration = Migration {
183            model: VersionedAuthorizationModel::new(model, version),
184            pre_migration_fn: pre_migration_fn.map(box_migration_fn),
185            post_migration_fn: post_migration_fn.map(box_migration_fn),
186        };
187        self.migrations.insert(migration.model.version(), migration);
188        self
189    }
190
191    /// Run migrations.
192    ///
193    /// This will:
194    /// 1. Get all existing models in the OpenFGA store.
195    /// 2. Determine which migrations need to be performed: All migrations with a version higher than the highest existing model.
196    /// 3. In order of the version of the model, perform the migrations:
197    ///    1. Run the pre-migration hook if it exists.
198    ///    2. Write the model to OpenFGA.
199    ///    3. Run the post-migration hook if it exists.
200    /// 4. Mark the model as applied in OpenFGA.
201    ///
202    /// # Errors
203    /// * If OpenFGA cannot be reached or a request fails.
204    /// * If any of the migration hooks fail.
205    #[allow(clippy::too_many_lines)]
206    pub async fn migrate(&mut self, state: S) -> Result<()> {
207        let span = tracing::span!(
208            tracing::Level::INFO,
209            "Running OpenFGA Migrations",
210            store_name = self.store_name,
211            model_prefix = self.model_prefix
212        );
213        let _enter = span.enter();
214
215        if self.migrations.is_empty() {
216            tracing::info!("No Migrations have been added. Nothing to do.");
217            return Ok(());
218        }
219
220        let store = self.client.get_or_create_store(&self.store_name).await?;
221        let existing_models = self.get_existing_versions().await?;
222        let max_existing_model = existing_models.iter().max().copied();
223        let mut curr_model_id = if let Some(version) = max_existing_model {
224            Some(self.require_authorization_model_id(version).await?)
225        } else {
226            None
227        };
228        // At this point the previous model id cannot be determined reliably, e.g. because
229        // migrations might have been run out of order.
230        let mut prev_model_id = None;
231
232        if let Some(max_existing_model) = max_existing_model {
233            tracing::info!(
234                "Currently the highest existing Model Version is: {}",
235                max_existing_model
236            );
237        } else {
238            tracing::info!("No model found in OpenFGA store");
239        }
240
241        let ordered_migrations = self.migrations_to_perform(max_existing_model);
242
243        let mut client = self.client.clone();
244        for migration in ordered_migrations {
245            tracing::info!("Migrating to model version: {}", migration.model.version());
246
247            // Pre-hook
248            if let Some(pre_migration_fn) = migration.pre_migration_fn.as_ref() {
249                pre_migration_fn(
250                    client.clone(),
251                    prev_model_id.clone(),
252                    curr_model_id.clone(),
253                    state.clone(),
254                )
255                .await
256                .map_err(|e| {
257                    tracing::error!("Error in OpenFGA pre-migration hook: {:?}", e);
258                    Error::MigrationHookFailed {
259                        version: migration.model.version().to_string(),
260                        error: Arc::new(e),
261                    }
262                })?;
263            }
264
265            // Write Model
266            let request = migration
267                .model
268                .model()
269                .clone()
270                .into_write_request(store.id.clone());
271            let written_model = client
272                .write_authorization_model(request)
273                .await
274                .map_err(|e| {
275                    tracing::error!("Error writing model: {:?}", e);
276                    Error::RequestFailed(Box::new(e))
277                })?;
278            tracing::info!(
279                "Model version {} written to OpenFGA store {} with model id {}",
280                migration.model.version(),
281                self.store_name,
282                written_model.get_ref().authorization_model_id,
283            );
284            tracing::debug!("Model written: {:?}", written_model);
285
286            // Update model versions passed to migration hooks.
287            prev_model_id.clone_from(&curr_model_id);
288            curr_model_id = Some(written_model.get_ref().authorization_model_id.clone());
289
290            // Post-hook
291            if let Some(post_migration_fn) = migration.post_migration_fn.as_ref() {
292                post_migration_fn(
293                    client.clone(),
294                    prev_model_id.clone(),
295                    curr_model_id.clone(),
296                    state.clone(),
297                )
298                .await
299                .map_err(|e| {
300                    tracing::error!("Error in OpenFGA post-migration hook: {:?}", e);
301                    Error::MigrationHookFailed {
302                        version: migration.model.version().to_string(),
303                        error: Arc::new(e),
304                    }
305                })?;
306            }
307
308            // Mark as applied
309            Self::mark_as_applied(
310                &mut client,
311                &self.model_prefix,
312                &store,
313                migration.model.version(),
314                written_model.into_inner(),
315            )
316            .await?;
317        }
318
319        Ok(())
320    }
321
322    /// Get the OpenFGA Authorization model ID for the specified model version.
323    /// Ensure that migrations have been run before calling this method.
324    ///
325    /// # Errors
326    /// * If the store with the given name does not exist.
327    /// * If a call to OpenFGA fails.
328    pub async fn get_authorization_model_id(
329        &mut self,
330        version: AuthorizationModelVersion,
331    ) -> Result<Option<String>> {
332        let store = self
333            .client
334            .get_store_by_name(&self.store_name)
335            .await?
336            .ok_or_else(|| {
337                tracing::error!("Store with name {} not found", self.store_name);
338                Error::StoreNotFound(self.store_name.clone())
339            })?;
340
341        let applied_models = self
342            .client
343            .read_all_pages(
344                &store.id,
345                Some(ReadRequestTupleKey {
346                    user: String::new(),
347                    relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
348                    object: Self::format_model_version_key(&self.model_prefix, version),
349                }),
350                ConsistencyPreference::HigherConsistency,
351                DEFAULT_PAGE_SIZE,
352                MAX_PAGES,
353            )
354            .await?;
355
356        let applied_models = applied_models
357            .into_iter()
358            .filter_map(|t| t.key)
359            .filter_map(|t| {
360                t.user
361                    .strip_prefix(&format!("{}:", Self::AUTH_MODEL_ID_TYPE))
362                    .map(ToString::to_string)
363            })
364            .collect::<Vec<_>>();
365
366        if applied_models.len() > 1 {
367            tracing::error!(
368                "Multiple authorization models with model prefix {} for version {} found.",
369                self.model_prefix,
370                version
371            );
372            return Err(Error::AmbiguousModelVersion {
373                model_prefix: self.model_prefix.clone(),
374                version: version.to_string(),
375            });
376        }
377
378        let model_id = applied_models.into_iter().next().map(|openfga_id| {
379            tracing::info!(
380                "Authorization model for version {version} found in OpenFGA store {}. Model ID: {openfga_id}",
381                self.store_name,
382            );
383            openfga_id
384        });
385
386        Ok(model_id)
387    }
388
389    /// Helper method that tries to get the authorization model id for `version` and returns
390    /// an error if the id cannot be found.
391    async fn require_authorization_model_id(
392        &mut self,
393        version: AuthorizationModelVersion,
394    ) -> Result<String> {
395        self.get_authorization_model_id(version)
396            .await?
397            .ok_or_else(|| {
398                tracing::error!("Missing authorization model id for model version {version}");
399                Error::MissingAuthorizationModelId {
400                    model_prefix: self.model_prefix.clone(),
401                    version: version.to_string(),
402                }
403            })
404    }
405
406    /// Mark a model version as applied in OpenFGA
407    async fn mark_as_applied(
408        client: &mut OpenFgaServiceClient<T>,
409        model_prefix: &str,
410        store: &Store,
411        version: AuthorizationModelVersion,
412        write_response: WriteAuthorizationModelResponse,
413    ) -> Result<()> {
414        let authorization_model_id = write_response.authorization_model_id;
415        let object = Self::format_model_version_key(model_prefix, version);
416
417        let write_request = WriteRequest {
418            store_id: store.id.clone(),
419            writes: Some(WriteRequestWrites {
420                on_duplicate: String::new(),
421                tuple_keys: vec![
422                    TupleKey {
423                        user: format!("{}:{authorization_model_id}", Self::AUTH_MODEL_ID_TYPE),
424                        relation: Self::MODEL_VERSION_OPENFGA_ID_RELATION.to_string(),
425                        object: object.clone(),
426                        condition: None,
427                    },
428                    TupleKey {
429                        user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE),
430                        relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
431                        object,
432                        condition: None,
433                    },
434                ],
435            }),
436            deletes: None,
437            authorization_model_id: authorization_model_id.clone(),
438        };
439
440        // Retry once per second for up to 5 attempts (4 seconds max wait) to handle replication lag
441        let max_retries = 5;
442        let retry_delay = std::time::Duration::from_secs(1);
443
444        for attempt in 0..max_retries {
445            match client.write(write_request.clone()).await {
446                Ok(_) => {
447                    if attempt > 0 {
448                        tracing::info!(
449                            "Successfully marked model {version} as applied after {attempt} retries",
450                        );
451                    }
452                    return Ok(());
453                }
454                Err(e) => {
455                    if attempt == max_retries {
456                        tracing::error!(
457                            "Error marking model as applied after {} retries: {:?}",
458                            max_retries,
459                            e
460                        );
461                        return Err(Error::RequestFailed(Box::new(e)));
462                    }
463
464                    tracing::warn!(
465                        "Failed to mark model as applied (attempt {}/{max_retries}), retrying in {retry_delay:?}: {e:?}",
466                        attempt + 1,
467                    );
468
469                    Delay::new(retry_delay).await;
470                }
471            }
472        }
473
474        unreachable!();
475    }
476
477    /// Get all migrations that have been added to the manager
478    /// as a `Vec` sorted by the version of the model.
479    fn ordered_migrations(&self) -> Vec<&Migration<T, S>> {
480        let mut migrations = self.migrations.values().collect::<Vec<_>>();
481        migrations.sort_unstable_by_key(|m| m.model.version());
482        migrations
483    }
484
485    /// Get all migrations that need to be performed, given the maximum existing model version.
486    fn migrations_to_perform(
487        &self,
488        max_existing_model: Option<AuthorizationModelVersion>,
489    ) -> Vec<&Migration<T, S>> {
490        let ordered_migrations = self.ordered_migrations();
491        let migrations_to_perform = ordered_migrations
492            .into_iter()
493            .filter(|m| {
494                max_existing_model.is_none_or(|max_existing| m.model.version() > max_existing)
495            })
496            .collect::<Vec<_>>();
497
498        tracing::info!(
499            "{} migrations needed in OpenFGA store {} for model-prefix {}",
500            migrations_to_perform.len(),
501            self.store_name,
502            self.model_prefix
503        );
504        migrations_to_perform
505    }
506
507    /// Get versions of all existing models in OpenFGA.
508    /// Returns an empty vector if the store does not exist.
509    ///
510    /// # Errors
511    /// * If the call to determine existing stores fails.
512    /// * If a tuple read call fails.
513    pub async fn get_existing_versions(&mut self) -> Result<Vec<AuthorizationModelVersion>> {
514        let Some(store) = self.client.get_store_by_name(&self.store_name).await? else {
515            return Ok(vec![]);
516        };
517
518        let tuples = self
519            .client
520            .read_all_pages(
521                &store.id,
522                Some(ReadRequestTupleKey {
523                    user: format!("{}:*", Self::AUTH_MODEL_ID_TYPE).to_string(),
524                    relation: Self::MODEL_VERSION_EXISTS_RELATION.to_string(),
525                    object: format!("{}:", Self::MODEL_VERSION_TYPE).to_string(),
526                }),
527                crate::client::ConsistencyPreference::HigherConsistency,
528                DEFAULT_PAGE_SIZE,
529                MAX_PAGES,
530            )
531            .await?;
532        let existing_models = Self::parse_existing_models(tuples, &self.model_prefix);
533        Ok(existing_models.into_iter().collect())
534    }
535
536    fn parse_existing_models(
537        exist_tuples: Vec<Tuple>,
538        model_prefix: &str,
539    ) -> HashSet<AuthorizationModelVersion> {
540        exist_tuples
541            .into_iter()
542            .filter_map(|t| t.key)
543            .filter_map(|t| Self::parse_model_version_from_key(&t.object, model_prefix))
544            .collect()
545    }
546
547    fn parse_model_version_from_key(
548        model: &str,
549        model_prefix: &str,
550    ) -> Option<AuthorizationModelVersion> {
551        model
552            // Ignore models with wrong prefix
553            .strip_prefix(&format!("{}:", Self::MODEL_VERSION_TYPE))
554            .and_then(|model| {
555                model
556                    .strip_prefix(&format!("{model_prefix}-"))
557                    .and_then(|version| AuthorizationModelVersion::from_str(version).ok())
558            })
559    }
560
561    fn format_model_version_key(model_prefix: &str, version: AuthorizationModelVersion) -> String {
562        format!("{}:{}-{}", Self::MODEL_VERSION_TYPE, model_prefix, version)
563    }
564}
565
566impl VersionedAuthorizationModel {
567    pub(crate) fn new(model: AuthorizationModel, version: AuthorizationModelVersion) -> Self {
568        VersionedAuthorizationModel { model, version }
569    }
570
571    pub(crate) fn version(&self) -> AuthorizationModelVersion {
572        self.version
573    }
574
575    pub(crate) fn model(&self) -> &AuthorizationModel {
576        &self.model
577    }
578}
579
580impl AuthorizationModelVersion {
581    #[must_use]
582    pub fn new(major: u32, minor: u32) -> Self {
583        AuthorizationModelVersion { major, minor }
584    }
585
586    #[must_use]
587    pub fn major(&self) -> u32 {
588        self.major
589    }
590
591    #[must_use]
592    pub fn minor(&self) -> u32 {
593        self.minor
594    }
595}
596
597// Sort by major version first, then by subversion.
598impl PartialOrd for AuthorizationModelVersion {
599    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
600        Some(self.cmp(other))
601    }
602}
603
604impl Ord for AuthorizationModelVersion {
605    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
606        (self.major, self.minor).cmp(&(other.major, other.minor))
607    }
608}
609
610impl std::fmt::Display for AuthorizationModelVersion {
611    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
612        write!(f, "{}.{}", self.major, self.minor)
613    }
614}
615
616impl FromStr for AuthorizationModelVersion {
617    type Err = Error;
618
619    fn from_str(s: &str) -> Result<Self> {
620        let parts = s.split('.').collect::<Vec<_>>();
621        if parts.len() != 2 {
622            return Err(Error::InvalidModelVersion(s.to_string()));
623        }
624
625        let major = parts[0]
626            .parse()
627            .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
628        let minor = parts[1]
629            .parse()
630            .map_err(|_| Error::InvalidModelVersion(s.to_string()))?;
631
632        Ok(AuthorizationModelVersion::new(major, minor))
633    }
634}
635
636impl<T, S> std::fmt::Debug for Migration<T, S> {
637    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638        f.debug_struct("Migration")
639            .field("model", &self.model)
640            .field("pre_migration_fn", &"...")
641            .field("post_migration_fn", &"...")
642            .finish()
643    }
644}
645
646#[cfg(test)]
647pub(crate) mod test {
648    use std::sync::Mutex;
649
650    use needs_env_var::needs_env_var;
651    use pretty_assertions::assert_eq;
652
653    use super::*;
654
655    type ChannelTupleManager = TupleModelManager<tonic::transport::Channel, ()>;
656
657    #[test]
658    fn test_ordering() {
659        let versioned_1_0 = AuthorizationModelVersion::new(1, 0);
660        let versioned_1_1 = AuthorizationModelVersion::new(1, 1);
661        let versioned_2_0 = AuthorizationModelVersion::new(2, 0);
662        let versioned_2_1 = AuthorizationModelVersion::new(2, 1);
663        let versioned_2_2 = AuthorizationModelVersion::new(2, 2);
664
665        assert!(versioned_1_0 < versioned_1_1);
666        assert!(versioned_1_1 < versioned_2_0);
667        assert!(versioned_2_0 < versioned_2_1);
668        assert!(versioned_2_1 < versioned_2_2);
669    }
670
671    #[test]
672    fn test_auth_model_version_str() {
673        let version = AuthorizationModelVersion::new(1, 0);
674        assert_eq!(version.to_string(), "1.0");
675        assert_eq!("1.0".parse::<AuthorizationModelVersion>().unwrap(), version);
676
677        let version = AuthorizationModelVersion::new(10, 2);
678        assert_eq!(version.to_string(), "10.2");
679        assert_eq!(
680            "10.2".parse::<AuthorizationModelVersion>().unwrap(),
681            version
682        );
683    }
684
685    #[test]
686    fn test_parse_model_version_from_key() {
687        let model_prefix = "test";
688        let model_version = AuthorizationModelVersion::new(1, 0);
689        let key = format!("model_version:{model_prefix}-{model_version}");
690        assert_eq!(
691            ChannelTupleManager::parse_model_version_from_key(&key, model_prefix),
692            Some(model_version)
693        );
694
695        // Prefix missing
696        assert!(
697            ChannelTupleManager::parse_model_version_from_key("model_version:1.0", model_prefix)
698                .is_none()
699        );
700
701        // Wrong prefix
702        assert!(
703            ChannelTupleManager::parse_model_version_from_key(
704                "model_version:foo-1.0",
705                model_prefix
706            )
707            .is_none()
708        );
709
710        // Higher version
711        assert_eq!(
712            ChannelTupleManager::parse_model_version_from_key(
713                "model_version:other-model-10.200",
714                "other-model"
715            ),
716            Some(AuthorizationModelVersion::new(10, 200))
717        );
718    }
719
720    #[test]
721    fn test_format_model_version_key() {
722        let model_prefix = "test";
723        let model_version = AuthorizationModelVersion::new(1, 0);
724        let key = ChannelTupleManager::format_model_version_key(model_prefix, model_version);
725        assert_eq!(key, "model_version:test-1.0");
726        let parsed = ChannelTupleManager::parse_model_version_from_key(&key, model_prefix).unwrap();
727        assert_eq!(parsed, model_version);
728    }
729
730    #[needs_env_var(TEST_OPENFGA_CLIENT_GRPC_URL)]
731    pub(crate) mod openfga {
732        use std::str::FromStr;
733
734        use pretty_assertions::assert_eq;
735
736        use super::*;
737        use crate::client::{OpenFgaServiceClient, ReadAuthorizationModelRequest};
738
739        pub(crate) async fn get_service_client() -> OpenFgaServiceClient<tonic::transport::Channel>
740        {
741            let endpoint = std::env::var("TEST_OPENFGA_CLIENT_GRPC_URL").unwrap();
742            let endpoint = tonic::transport::Endpoint::from_str(&endpoint).unwrap();
743            OpenFgaServiceClient::connect(endpoint)
744                .await
745                .expect("Client can be created")
746        }
747
748        pub(crate) async fn service_client_with_store()
749        -> (OpenFgaServiceClient<tonic::transport::Channel>, Store) {
750            let mut client = get_service_client().await;
751            let store_name = format!("test-{}", uuid::Uuid::now_v7());
752            let store = client.get_or_create_store(&store_name).await.unwrap();
753            (client, store)
754        }
755
756        #[tokio::test]
757        async fn test_get_existing_versions_nonexistent_store() {
758            let client = get_service_client().await;
759            let mut manager: TupleModelManager<_, ()> =
760                TupleModelManager::new(client, "nonexistent", "test");
761
762            let versions = manager.get_existing_versions().await.unwrap();
763            assert!(versions.is_empty());
764        }
765
766        #[tokio::test]
767        async fn test_get_existing_versions_nonexistent_auth_model() {
768            let mut client = get_service_client().await;
769            let store_name = format!("test-{}", uuid::Uuid::now_v7());
770            let _store = client.get_or_create_store(&store_name).await.unwrap();
771            let mut manager: TupleModelManager<_, ()> =
772                TupleModelManager::new(client, &store_name, "test");
773            let versions = manager.get_existing_versions().await.unwrap();
774            assert!(versions.is_empty());
775        }
776
777        #[tokio::test]
778        async fn test_get_authorization_model_id() {
779            let (mut client, store) = service_client_with_store().await;
780            let model_prefix = "test";
781            let version = AuthorizationModelVersion::new(1, 0);
782
783            let mut manager: TupleModelManager<_, ()> =
784                TupleModelManager::new(client.clone(), &store.name, model_prefix);
785
786            // Non-existent model
787            assert_eq!(
788                manager.get_authorization_model_id(version).await.unwrap(),
789                None
790            );
791
792            // Apply auth model
793            let model: AuthorizationModel =
794                serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
795                    .unwrap();
796            client
797                .write_authorization_model(model.into_write_request(store.id.clone()))
798                .await
799                .unwrap();
800
801            // Write model tuples
802            client
803                .write(WriteRequest {
804                    store_id: store.id.clone(),
805                    writes: Some(WriteRequestWrites {
806                        on_duplicate: String::new(),
807                        tuple_keys: vec![
808                            TupleKey {
809                                user: "auth_model_id:111111".to_string(),
810                                relation: "openfga_id".to_string(),
811                                object: "model_version:test-1.0".to_string(),
812                                condition: None,
813                            },
814                            TupleKey {
815                                user: "auth_model_id:*".to_string(),
816                                relation: "exists".to_string(),
817                                object: "model_version:test-1.0".to_string(),
818                                condition: None,
819                            },
820                            // Tuple with different model prefix should be ignored
821                            TupleKey {
822                                user: "auth_model_id:*".to_string(),
823                                relation: "exists".to_string(),
824                                object: "model_version:test2-1.0".to_string(),
825                                condition: None,
826                            },
827                        ],
828                    }),
829                    deletes: None,
830                    authorization_model_id: String::new(),
831                })
832                .await
833                .unwrap();
834
835            assert_eq!(
836                manager.get_authorization_model_id(version).await.unwrap(),
837                Some("111111".to_string())
838            );
839        }
840
841        #[tokio::test]
842        async fn test_model_manager() {
843            let store_name = format!("test-{}", uuid::Uuid::now_v7());
844            let mut client = get_service_client().await;
845
846            let model_1_0: AuthorizationModel =
847                serde_json::from_str(include_str!("../tests/model-manager/v1.0/schema.json"))
848                    .unwrap();
849
850            let version_1_0 = AuthorizationModelVersion::new(1, 0);
851
852            let migration_state = MigrationState::default();
853            let mut manager = TupleModelManager::new(client.clone(), &store_name, "test-model")
854                .add_model(
855                    model_1_0.clone(),
856                    version_1_0,
857                    Some(v1_pre_migration_fn),
858                    None::<MigrationFn<_, _>>,
859                );
860            manager.migrate(migration_state.clone()).await.unwrap();
861            // Check hook was called once
862            assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
863            manager.migrate(migration_state.clone()).await.unwrap();
864            // Check hook was not called again
865            assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
866
867            // Check written model
868            let auth_model_id = manager
869                .get_authorization_model_id(version_1_0)
870                .await
871                .unwrap()
872                .unwrap();
873            let mut auth_model =
874                get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
875            auth_model.id = model_1_0.id.clone();
876            assert_eq!(
877                serde_json::to_value(&model_1_0).unwrap(),
878                serde_json::to_value(auth_model).unwrap()
879            );
880
881            // Add a second model
882            let model_1_1: AuthorizationModel =
883                serde_json::from_str(include_str!("../tests/model-manager/v1.1/schema.json"))
884                    .unwrap();
885            let version_1_1 = AuthorizationModelVersion::new(1, 1);
886            let mut manager = manager.add_model(
887                model_1_1.clone(),
888                version_1_1,
889                None::<MigrationFn<_, _>>,
890                Some(v2_post_migration_fn),
891            );
892            manager.migrate(migration_state.clone()).await.unwrap();
893            manager.migrate(migration_state.clone()).await.unwrap();
894            manager.migrate(migration_state.clone()).await.unwrap();
895
896            // First migration still only called once
897            assert_eq!(*migration_state.counter_1.lock().unwrap(), 1);
898            // Second migration called once
899            assert_eq!(*migration_state.counter_2.lock().unwrap(), 1);
900
901            // Check written model
902            let auth_model_id = manager
903                .get_authorization_model_id(version_1_1)
904                .await
905                .unwrap()
906                .unwrap();
907            let mut auth_model =
908                get_auth_model_by_id(&mut client, &store_name, &auth_model_id).await;
909            auth_model.id = model_1_1.id.clone();
910            assert_eq!(
911                serde_json::to_value(&model_1_1).unwrap(),
912                serde_json::to_value(auth_model).unwrap()
913            );
914        }
915
916        async fn get_auth_model_by_id(
917            client: &mut OpenFgaServiceClient<tonic::transport::Channel>,
918            store_name: &str,
919            auth_model_id: &str,
920        ) -> AuthorizationModel {
921            client
922                .read_authorization_model(ReadAuthorizationModelRequest {
923                    store_id: client
924                        .clone()
925                        .get_store_by_name(store_name)
926                        .await
927                        .unwrap()
928                        .unwrap()
929                        .id,
930                    id: auth_model_id.to_string(),
931                })
932                .await
933                .unwrap()
934                .into_inner()
935                .authorization_model
936                .unwrap()
937        }
938    }
939
940    #[derive(Default, Clone)]
941    struct MigrationState {
942        counter_1: Arc<Mutex<i32>>,
943        counter_2: Arc<Mutex<i32>>,
944    }
945
946    #[allow(clippy::unused_async)]
947    async fn v1_pre_migration_fn(
948        client: OpenFgaServiceClient<tonic::transport::Channel>,
949        _prev_model: Option<String>,
950        _curr_model: Option<String>,
951        state: MigrationState,
952    ) -> std::result::Result<(), StdError> {
953        let _ = client;
954        // Throw an error for the second call
955        let mut counter = state.counter_1.lock().unwrap();
956        *counter += 1;
957        if *counter == 2 {
958            return Err(Box::new(Error::RequestFailed(Box::new(
959                tonic::Status::new(tonic::Code::Internal, "Test"),
960            ))));
961        }
962        Ok(())
963    }
964
965    #[allow(clippy::unused_async)]
966    async fn v2_post_migration_fn(
967        client: OpenFgaServiceClient<tonic::transport::Channel>,
968        _prev_model: Option<String>,
969        _curr_model: Option<String>,
970        state: MigrationState,
971    ) -> std::result::Result<(), StdError> {
972        let _ = client;
973        // Throw an error for the second call
974        let mut counter = state.counter_2.lock().unwrap();
975        *counter += 1;
976        if *counter == 2 {
977            return Err(Box::new(Error::RequestFailed(Box::new(
978                tonic::Status::new(tonic::Code::Internal, "Test"),
979            ))));
980        }
981        Ok(())
982    }
983}