tideorm 0.9.14

A developer-friendly ORM for Rust with clean, expressive syntax
Documentation
#![allow(missing_docs)]

use std::any::Any;
#[cfg(not(feature = "runtime-tokio"))]
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[cfg(not(feature = "runtime-tokio"))]
use std::{
    pin::Pin,
    task::{Context, Poll},
};

use parking_lot::Mutex;

use crate::error::Result;
use crate::model::Model;

use super::{
    EntityManager, TideEntityManagerMergePersisted, TideEntityManagerMeta, TideEntityManagerSync,
};

pub(super) type IdentityRollbackLog = HashMap<super::IdentityKey, Box<dyn IdentityMapRollback>>;
pub(super) type ManagedCheckpoints = Vec<Box<dyn super::managed::ManagedCheckpoint>>;

pub(super) struct EntityManagerRollbackState {
    managed_entries: Vec<Arc<dyn super::managed::ManagedOps>>,
    managed_identity_map: HashMap<super::IdentityKey, Arc<dyn Any + Send + Sync>>,
    snapshots: HashMap<super::SnapshotKey, HashSet<String>>,
}

pub(super) trait IdentityMapRollback: Send {
    fn restore(self: Box<Self>, entity_manager: &EntityManager);
}

struct IdentityMapRollbackEntry<T> {
    key: super::IdentityKey,
    original: Option<T>,
}

impl<T> IdentityMapRollback for IdentityMapRollbackEntry<T>
where
    T: Send + Sync + 'static,
{
    fn restore(self: Box<Self>, entity_manager: &EntityManager) {
        let mut map = entity_manager.identity_map.write();
        match self.original {
            Some(entity) => {
                map.insert(self.key, Box::new(entity));
            }
            None => {
                map.remove(&self.key);
            }
        }
    }
}

#[cfg(feature = "runtime-tokio")]
tokio::task_local! {
    static ENTITY_MANAGER_TRANSACTION_SCOPE: bool;
}

#[cfg(feature = "runtime-tokio")]
tokio::task_local! {
    static ENTITY_MANAGER_IDENTITY_ROLLBACK: Arc<Mutex<IdentityRollbackLog>>;
}

#[cfg(not(feature = "runtime-tokio"))]
thread_local! {
    static ENTITY_MANAGER_TRANSACTION_SCOPE: RefCell<bool> = const { RefCell::new(false) };
    static ENTITY_MANAGER_IDENTITY_ROLLBACK: RefCell<Option<Arc<Mutex<IdentityRollbackLog>>>> = const { RefCell::new(None) };
}

pub(super) fn new_identity_rollback_log() -> Arc<Mutex<IdentityRollbackLog>> {
    Arc::new(Mutex::new(HashMap::new()))
}

#[cfg(feature = "runtime-tokio")]
fn current_identity_rollback_log() -> Option<Arc<Mutex<IdentityRollbackLog>>> {
    ENTITY_MANAGER_IDENTITY_ROLLBACK.try_with(Clone::clone).ok()
}

#[cfg(not(feature = "runtime-tokio"))]
fn current_identity_rollback_log() -> Option<Arc<Mutex<IdentityRollbackLog>>> {
    ENTITY_MANAGER_IDENTITY_ROLLBACK.with(|log| log.borrow().clone())
}

#[cfg(feature = "runtime-tokio")]
pub(super) fn in_entity_manager_transaction_scope() -> bool {
    ENTITY_MANAGER_TRANSACTION_SCOPE
        .try_with(|active| *active)
        .unwrap_or(false)
}

#[cfg(not(feature = "runtime-tokio"))]
pub(super) fn in_entity_manager_transaction_scope() -> bool {
    ENTITY_MANAGER_TRANSACTION_SCOPE.with(|active| *active.borrow())
}

#[cfg(feature = "runtime-tokio")]
pub(super) async fn with_entity_manager_transaction_scope<F>(
    rollback_log: Arc<Mutex<IdentityRollbackLog>>,
    future: F,
) -> F::Output
where
    F: std::future::Future,
{
    ENTITY_MANAGER_TRANSACTION_SCOPE
        .scope(
            true,
            ENTITY_MANAGER_IDENTITY_ROLLBACK.scope(rollback_log, future),
        )
        .await
}

#[cfg(not(feature = "runtime-tokio"))]
pub(super) fn with_entity_manager_transaction_scope<F>(
    rollback_log: Arc<Mutex<IdentityRollbackLog>>,
    future: F,
) -> impl std::future::Future<Output = F::Output>
where
    F: std::future::Future,
{
    struct ScopedEntityManagerTransactionFuture<F> {
        future: Pin<Box<F>>,
        rollback_log: Arc<Mutex<IdentityRollbackLog>>,
    }

    impl<F> std::future::Future for ScopedEntityManagerTransactionFuture<F>
    where
        F: std::future::Future,
    {
        type Output = F::Output;

        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
            let this = self.get_mut();
            let previous = ENTITY_MANAGER_TRANSACTION_SCOPE.with(|active| active.replace(true));
            let previous_log = ENTITY_MANAGER_IDENTITY_ROLLBACK
                .with(|log| log.replace(Some(this.rollback_log.clone())));
            let result = this.future.as_mut().poll(cx);
            ENTITY_MANAGER_IDENTITY_ROLLBACK.with(|log| {
                *log.borrow_mut() = previous_log;
            });
            ENTITY_MANAGER_TRANSACTION_SCOPE.with(|active| {
                *active.borrow_mut() = previous;
            });
            result
        }
    }

    ScopedEntityManagerTransactionFuture {
        future: Box::pin(future),
        rollback_log,
    }
}

pub(super) fn record_identity_map_rollback<T>(
    entity_manager: &EntityManager,
    key: &super::IdentityKey,
) where
    T: Clone + Send + Sync + 'static,
{
    if !in_entity_manager_transaction_scope() {
        return;
    }

    let Some(rollback_log) = current_identity_rollback_log() else {
        return;
    };

    let mut rollback_log = rollback_log.lock();
    if rollback_log.contains_key(key) {
        return;
    }

    let original = entity_manager
        .identity_map
        .read()
        .get(key)
        .and_then(|value| value.downcast_ref::<T>())
        .cloned();
    rollback_log.insert(
        key.clone(),
        Box::new(IdentityMapRollbackEntry {
            key: key.clone(),
            original,
        }),
    );
}

pub(super) fn rollback_identity_map(
    entity_manager: &EntityManager,
    rollback_log: &Arc<Mutex<IdentityRollbackLog>>,
) {
    let rollback_entries = {
        let mut rollback_log = rollback_log.lock();
        std::mem::take(&mut *rollback_log)
    };

    for entry in rollback_entries.into_values() {
        entry.restore(entity_manager);
    }
}

pub(super) fn capture_entity_manager_rollback_state(
    entity_manager: &EntityManager,
) -> EntityManagerRollbackState {
    EntityManagerRollbackState {
        managed_entries: entity_manager.managed_entries.read().clone(),
        managed_identity_map: entity_manager.managed_identity_map.read().clone(),
        snapshots: entity_manager.snapshots.read().clone(),
    }
}

pub(super) fn capture_managed_checkpoints(entity_manager: &EntityManager) -> ManagedCheckpoints {
    entity_manager
        .managed_entries
        .read()
        .iter()
        .cloned()
        .map(|entry| entry.checkpoint())
        .collect()
}

pub(super) fn rollback_entity_manager_state(
    entity_manager: &EntityManager,
    checkpoints: ManagedCheckpoints,
    rollback_state: EntityManagerRollbackState,
    identity_rollback: &Arc<Mutex<IdentityRollbackLog>>,
) {
    for checkpoint in checkpoints {
        checkpoint.rollback(entity_manager);
    }

    *entity_manager.managed_entries.write() = rollback_state.managed_entries;
    *entity_manager.managed_identity_map.write() = rollback_state.managed_identity_map;
    rollback_identity_map(entity_manager, identity_rollback);
    *entity_manager.snapshots.write() = rollback_state.snapshots;
}

pub async fn save_with_entity_manager<T>(
    entity: &T,
    entity_manager: &Arc<EntityManager>,
) -> Result<T>
where
    T: TideEntityManagerMeta
        + TideEntityManagerMergePersisted
        + TideEntityManagerSync
        + Model
        + Clone
        + Send
        + Sync
        + 'static,
{
    if in_entity_manager_transaction_scope() {
        return save_with_entity_manager_impl(entity, entity_manager).await;
    }

    let rollback_state = capture_entity_manager_rollback_state(entity_manager.as_ref());
    let checkpoints = capture_managed_checkpoints(entity_manager.as_ref());
    let identity_rollback = new_identity_rollback_log();
    let db = entity_manager.db.clone();
    let entity_manager = entity_manager.clone();
    let transaction_entity_manager = entity_manager.clone();
    let transaction_identity_rollback = identity_rollback.clone();
    let entity = entity.clone();
    let result = db
        .transaction(move |_| {
            let entity_manager = transaction_entity_manager.clone();
            let identity_rollback = transaction_identity_rollback.clone();
            Box::pin(async move {
                with_entity_manager_transaction_scope(
                    identity_rollback,
                    save_with_entity_manager_impl(&entity, &entity_manager),
                )
                .await
            })
        })
        .await;

    match result {
        Ok(saved) => Ok(saved),
        Err(error) => {
            rollback_entity_manager_state(
                entity_manager.as_ref(),
                checkpoints,
                rollback_state,
                &identity_rollback,
            );
            Err(error)
        }
    }
}

pub(crate) async fn save_with_entity_manager_impl<T>(
    entity: &T,
    entity_manager: &Arc<EntityManager>,
) -> Result<T>
where
    T: TideEntityManagerMeta
        + TideEntityManagerMergePersisted
        + TideEntityManagerSync
        + Model
        + Clone
        + Send
        + Sync
        + 'static,
{
    let persisted = __with_entity_manager_db(
        entity_manager,
        <T as crate::model::Model>::save(entity.clone()),
    )
    .await?;
    let mut aggregate = entity.clone();
    let previous = aggregate.clone();
    aggregate.tide_merge_persisted(persisted);
    <T as crate::internal::InternalModel>::refresh_runtime_relations_from(
        &mut aggregate,
        &previous,
    );
    aggregate
        .tide_sync_entity_manager_relations(entity_manager)
        .await?;
    entity_manager.put(aggregate.clone());
    Ok(aggregate)
}

pub(crate) async fn sync_entity_manager_relations_only_impl<T>(
    entity: &T,
    entity_manager: &Arc<EntityManager>,
) -> Result<T>
where
    T: TideEntityManagerMeta
        + TideEntityManagerMergePersisted
        + TideEntityManagerSync
        + Model
        + Clone
        + Send
        + Sync
        + 'static,
{
    let mut aggregate = entity.clone();
    let previous = aggregate.clone();
    <T as crate::internal::InternalModel>::refresh_runtime_relations_from(
        &mut aggregate,
        &previous,
    );
    aggregate
        .tide_sync_entity_manager_relations(entity_manager)
        .await?;
    entity_manager.put(aggregate.clone());
    Ok(aggregate)
}

#[doc(hidden)]
pub async fn __save_with_entity_manager_in_scope<T>(
    entity: &T,
    entity_manager: &Arc<EntityManager>,
) -> Result<T>
where
    T: TideEntityManagerMeta
        + TideEntityManagerMergePersisted
        + TideEntityManagerSync
        + Model
        + Clone
        + Send
        + Sync
        + 'static,
{
    save_with_entity_manager_impl(entity, entity_manager).await
}

#[doc(hidden)]
pub async fn __with_entity_manager_db<F, T>(
    entity_manager: &Arc<EntityManager>,
    future: F,
) -> Result<T>
where
    F: std::future::Future<Output = Result<T>>,
{
    if in_entity_manager_transaction_scope() {
        return future.await;
    }

    crate::database::__in_db_scope(entity_manager.db.as_ref(), future).await
}