use crate::db::connection::DatabaseConnectionResource;
use crate::registration::COMPONENT_REGISTRY;
use crate::versioning::version_manager::VersionKey;
use crate::{
DatabaseConnection, Guid, Persist, PersistenceError, PersistenceSession, TransactionOperation,
};
use bevy::app::PluginGroupBuilder;
use bevy::prelude::TaskPoolPlugin;
use bevy::prelude::*;
use once_cell::sync::Lazy;
use std::any::TypeId;
use std::collections::{HashMap, HashSet};
use std::sync::{
Arc, Mutex,
atomic::{AtomicUsize, Ordering},
};
use tokio::runtime::Runtime;
use tokio::sync::oneshot;
use crate::query::PersistenceQueryCache;
use crate::query::deferred_ops::DeferredWorldOperations;
use crate::query::immediate_world_ptr::ImmediateWorldPtr;
fn ensure_task_pools(app: &mut App) {
if !app.is_plugin_added::<TaskPoolPlugin>() {
app.add_plugins(TaskPoolPlugin::default());
}
}
static TOKIO_RUNTIME: Lazy<Arc<Runtime>> = Lazy::new(|| {
Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap(),
)
});
#[derive(Component)]
struct CommitTask {
receiver: Option<
tokio::sync::oneshot::Receiver<Result<(Vec<String>, Vec<Entity>), PersistenceError>>,
>,
}
#[derive(Component)]
struct MultiBatchCommitTracker {
correlation_id: u64,
remaining_batches: Arc<AtomicUsize>,
result_sender: Arc<Mutex<Option<oneshot::Sender<Result<(), PersistenceError>>>>>,
}
#[derive(SystemSet, Debug, Clone, PartialEq, Eq, Hash)]
pub enum PersistenceSystemSet {
ChangeDetection,
PreCommit,
Commit,
}
#[derive(Message)]
pub struct CommitCompleted {
pub result: Result<Vec<String>, PersistenceError>,
pub dirty_entities: Vec<Entity>,
pub correlation_id: Option<u64>,
}
#[derive(Resource, Default)]
pub struct RegisteredPersistTypes {
pub types: HashSet<TypeId>,
}
#[derive(Message, Clone)]
pub struct TriggerCommit {
pub correlation_id: Option<u64>,
pub target_connection: Arc<dyn DatabaseConnection>,
pub store: String,
}
#[derive(Resource, Default, PartialEq, Debug)]
pub enum CommitStatus {
#[default]
Idle,
InProgress,
InProgressAndDirty,
}
#[derive(Clone)]
enum PersistenceBackend {
Static(Arc<dyn DatabaseConnection>),
}
#[derive(Resource)]
pub struct TokioRuntime {
pub runtime: Arc<Runtime>,
}
impl TokioRuntime {
pub fn block_on<F: std::future::Future>(&self, fut: F) -> F::Output {
self.runtime.block_on(fut)
}
}
fn handle_commit_trigger(world: &mut World) {
let mut should_commit = false;
let mut correlation_id = None;
let mut requested_connection: Option<Arc<dyn DatabaseConnection>> = None;
let mut requested_store: Option<String> = None;
world.resource_scope(|world, mut events: Mut<Messages<TriggerCommit>>| {
let mut status = world.resource_mut::<CommitStatus>();
if !events.is_empty() {
let first_trigger = events.drain().next().unwrap();
requested_connection = Some(first_trigger.target_connection.clone());
requested_store = Some(first_trigger.store.clone());
match *status {
CommitStatus::Idle => {
info!("[handle_commit_trigger] TriggerCommit event received. Status is Idle.");
should_commit = true;
correlation_id = first_trigger.correlation_id;
}
CommitStatus::InProgress => {
info!("[handle_commit_trigger] TriggerCommit event received while another is in progress. Queuing.");
*status = CommitStatus::InProgressAndDirty;
}
CommitStatus::InProgressAndDirty => {
}
}
}
});
if !should_commit {
return;
}
let connection = if let Some(conn) = requested_connection {
conn
} else {
let err = PersistenceError::new("TriggerCommit missing target_connection");
world.write_message(CommitCompleted {
result: Err(err.clone()),
dirty_entities: vec![],
correlation_id,
});
bevy::log::error!(%err, "failed to select database connection before commit");
return;
};
let store = if let Some(store) = requested_store {
if store.is_empty() {
let err = PersistenceError::new("TriggerCommit store must be non-empty");
world.write_message(CommitCompleted {
result: Err(err.clone()),
dirty_entities: vec![],
correlation_id,
});
bevy::log::error!(%err, "invalid store for commit");
return;
}
store
} else {
let err = PersistenceError::new("TriggerCommit missing store");
world.write_message(CommitCompleted {
result: Err(err.clone()),
dirty_entities: vec![],
correlation_id,
});
bevy::log::error!(%err, "failed to select store before commit");
return;
};
let plugin_config = world.resource::<PersistencePluginConfig>().clone();
let (dirty_entities, despawned_entities, dirty_resources) = {
let mut session = world.resource_mut::<PersistenceSession>();
(
std::mem::take(&mut session.dirty_entities),
std::mem::take(&mut session.despawned_entities),
std::mem::take(&mut session.dirty_resources),
)
};
let commit_data = match PersistenceSession::_prepare_commit(
world.resource::<PersistenceSession>(),
world,
&dirty_entities,
&despawned_entities,
&dirty_resources,
plugin_config.thread_count,
connection.document_key_field(),
&store,
) {
Ok(data) if data.operations.is_empty() => {
world.write_message(CommitCompleted {
result: Ok(vec![]),
dirty_entities: vec![],
correlation_id,
});
let mut session = world.resource_mut::<PersistenceSession>();
session.dirty_entities.extend(dirty_entities);
session.despawned_entities.extend(despawned_entities);
session.dirty_resources.extend(dirty_resources);
return;
}
Ok(data) => data,
Err(e) => {
world.write_message(CommitCompleted {
result: Err(e.clone()),
dirty_entities: vec![],
correlation_id,
});
let mut session = world.resource_mut::<PersistenceSession>();
session.dirty_entities.extend(dirty_entities);
session.despawned_entities.extend(despawned_entities);
session.dirty_resources.extend(dirty_resources);
return;
}
};
*world.resource_mut::<CommitStatus>() = CommitStatus::InProgress;
let runtime = world.resource::<TokioRuntime>().runtime.clone();
let db = connection.clone();
let all_operations = commit_data.operations;
let new_entities = commit_data.new_entities;
if plugin_config.batching_enabled && all_operations.len() > plugin_config.commit_batch_size {
let batch_size = plugin_config.commit_batch_size;
let session = world.resource::<PersistenceSession>();
let mut entity_ops: HashMap<Entity, Vec<TransactionOperation>> = HashMap::new();
let mut new_entity_ops: Vec<(TransactionOperation, Entity)> = Vec::new();
let mut resource_ops: Vec<TransactionOperation> = Vec::new();
let mut new_entity_idx = 0;
for op in all_operations {
match &op {
TransactionOperation::UpdateDocument {
kind: crate::db::connection::DocumentKind::Entity,
key,
..
} => {
if let Some(entity) = session
.entity_keys
.iter()
.find(|(_, k)| *k == key)
.map(|(e, _)| *e)
{
entity_ops.entry(entity).or_default().push(op);
}
}
TransactionOperation::DeleteDocument {
kind: crate::db::connection::DocumentKind::Entity,
key,
..
} => {
if let Some(entity) = session
.entity_keys
.iter()
.find(|(_, k)| *k == key)
.map(|(e, _)| *e)
{
entity_ops.entry(entity).or_default().push(op);
}
}
TransactionOperation::CreateDocument {
kind: crate::db::connection::DocumentKind::Entity,
..
} => {
if let Some(entity) = new_entities.get(new_entity_idx) {
new_entity_ops.push((op, *entity));
new_entity_idx += 1;
}
}
_ => resource_ops.push(op),
}
}
let mut batches: Vec<Vec<TransactionOperation>> = Vec::new();
let mut batch_entities: Vec<HashSet<Entity>> = Vec::new();
let mut batch_new_entities: Vec<Vec<Entity>> = Vec::new();
let mut current_batch = Vec::new();
let mut current_batch_entities = HashSet::new();
let mut current_batch_new_entities = Vec::new();
for (entity, ops) in entity_ops {
if current_batch.len() + ops.len() > batch_size && !current_batch.is_empty() {
batches.push(std::mem::take(&mut current_batch));
batch_entities.push(std::mem::take(&mut current_batch_entities));
batch_new_entities.push(std::mem::take(&mut current_batch_new_entities));
}
current_batch.extend(ops);
current_batch_entities.insert(entity);
}
for (op, entity) in new_entity_ops {
if current_batch.len() + 1 > batch_size && !current_batch.is_empty() {
batches.push(std::mem::take(&mut current_batch));
batch_entities.push(std::mem::take(&mut current_batch_entities));
batch_new_entities.push(std::mem::take(&mut current_batch_new_entities));
}
current_batch.push(op);
current_batch_new_entities.push(entity);
}
if current_batch.len() + resource_ops.len() > batch_size && !current_batch.is_empty() {
batches.push(std::mem::take(&mut current_batch));
batch_entities.push(std::mem::take(&mut current_batch_entities));
batch_new_entities.push(std::mem::take(&mut current_batch_new_entities));
}
current_batch.extend(resource_ops);
if !current_batch.is_empty() {
batches.push(current_batch);
batch_entities.push(current_batch_entities);
batch_new_entities.push(current_batch_new_entities);
}
let num_batches = batches.len();
info!(
"[handle_commit_trigger] Splitting commit into {} batches of size ~{}.",
num_batches, batch_size
);
if let Some(cid) = correlation_id {
if let Some(listener) = take_commit_listener(world, cid) {
bevy::log::debug!(
"registered multi-batch tracker for correlation_id={cid} batches={}",
num_batches
);
world.spawn(MultiBatchCommitTracker {
correlation_id: cid,
remaining_batches: Arc::new(AtomicUsize::new(num_batches)),
result_sender: Arc::new(Mutex::new(Some(listener))),
});
}
}
let mut resource_sets = Vec::with_capacity(num_batches);
for _ in 0..num_batches {
resource_sets.push(HashSet::new());
}
for (i, res_type) in dirty_resources.iter().enumerate() {
let batch_idx = i % num_batches;
resource_sets[batch_idx].insert(*res_type);
}
for (i, (batch_ops, batch_entities_set)) in batches
.into_iter()
.zip(batch_entities.into_iter())
.enumerate()
{
let batch_db = db.clone();
let batch_runtime = runtime.clone();
let batch_new_entities = batch_new_entities.get(i).cloned().unwrap_or_default();
let db_for_task = batch_db.clone();
let (tx, rx) = tokio::sync::oneshot::channel();
batch_runtime.spawn(async move {
bevy::log::trace!("commit batch task started (batched)");
let res = db_for_task
.execute_transaction(batch_ops)
.await
.map(|keys| (keys, batch_new_entities));
bevy::log::trace!("commit batch runtime task completed send");
let _ = tx.send(res);
});
let meta = CommitMeta {
dirty_entities: batch_entities_set,
despawned_entities: if i == 0 {
despawned_entities.clone()
} else {
HashSet::new()
},
dirty_resources: resource_sets[i].clone(),
connection: batch_db.clone(),
store: store.clone(),
};
world.spawn((
CommitTask { receiver: Some(rx) },
TriggerID { correlation_id },
meta,
));
}
} else {
let db_for_task = db.clone();
let runtime_for_task = runtime.clone();
let (tx, rx) = tokio::sync::oneshot::channel();
runtime_for_task.spawn(async move {
bevy::log::trace!("commit task started (single batch)");
let res = db_for_task
.execute_transaction(all_operations)
.await
.map(|keys| (keys, new_entities));
bevy::log::trace!("commit runtime task completed send");
let _ = tx.send(res);
});
world.spawn((
CommitTask { receiver: Some(rx) },
TriggerID { correlation_id },
CommitMeta {
dirty_entities,
despawned_entities,
dirty_resources,
connection: db.clone(),
store: store.clone(),
},
));
}
}
#[derive(Component)]
struct TriggerID {
correlation_id: Option<u64>,
}
#[derive(Component)]
struct CommitMeta {
dirty_entities: HashSet<Entity>,
despawned_entities: HashSet<Entity>,
dirty_resources: HashSet<TypeId>,
connection: Arc<dyn DatabaseConnection>,
store: String,
}
fn handle_commit_completed(
mut commands: Commands,
mut query: Query<(Entity, &mut CommitTask, &TriggerID, Option<&mut CommitMeta>)>,
mut session: ResMut<PersistenceSession>,
mut status: ResMut<CommitStatus>,
mut completed: MessageWriter<CommitCompleted>,
mut triggers: MessageWriter<TriggerCommit>,
mut trackers: Query<(Entity, &MultiBatchCommitTracker)>,
) {
static PENDING_LOG_COUNT: AtomicUsize = AtomicUsize::new(0);
let mut to_despawn = Vec::new();
let mut had_error = false;
for (ent, mut task, trigger_id, meta_opt) in &mut query {
if let Some(mut receiver) = task.receiver.take() {
let result: Result<(Vec<String>, Vec<Entity>), PersistenceError> =
match receiver.try_recv() {
Ok(res) => res,
Err(tokio::sync::oneshot::error::TryRecvError::Empty) => {
task.receiver = Some(receiver);
continue;
}
Err(tokio::sync::oneshot::error::TryRecvError::Closed) => {
bevy::log::error!("commit task channel closed before result");
Err(PersistenceError::new(
"Commit task cancelled before completion",
))
}
};
let cid = trigger_id.correlation_id;
let mut is_final_batch = true;
let mut should_send_result = true;
let mut commit_connection: Option<Arc<dyn DatabaseConnection>> = None;
let mut commit_store: Option<String> = None;
let mut tracker_found = false;
if result.is_err() {
had_error = true;
}
if let Some(correlation_id) = cid {
if let Some((tracker_entity, tracker)) = trackers
.iter_mut()
.find(|(_, t)| t.correlation_id == correlation_id)
{
tracker_found = true;
let remaining = tracker.remaining_batches.fetch_sub(1, Ordering::SeqCst) - 1;
is_final_batch = remaining == 0;
if result.is_err() || is_final_batch {
if let Some(sender) = tracker.result_sender.lock().unwrap().take() {
if result.is_err() {
let _ = sender.send(Err(result.as_ref().err().unwrap().clone()));
} else if is_final_batch {
let _ = sender.send(Ok(()));
}
}
commands
.entity(tracker_entity)
.remove::<MultiBatchCommitTracker>();
} else {
should_send_result = false;
}
}
}
if let Err(err) = &result {
bevy::log::error!(
"commit batch completed with error (cid={:?} tracker_found={} final_batch={} err={})",
cid,
tracker_found,
is_final_batch,
err
);
} else {
bevy::log::trace!(
"commit batch completed ok (cid={:?} tracker_found={} final_batch={})",
cid,
tracker_found,
is_final_batch
);
}
if let Some(mut meta) = meta_opt {
commit_connection = Some(meta.connection.clone());
commit_store = Some(meta.store.clone());
let event_res = match &result {
Ok((new_keys, new_entities)) => {
for (e, key) in new_entities.iter().zip(new_keys.iter()) {
commands.entity(*e).insert(Guid::new(key.clone()));
session.entity_keys.insert(*e, key.clone());
session
.version_manager
.set_version(VersionKey::Entity(key.clone()), 1);
}
for tid in &meta.dirty_resources {
let vk = VersionKey::Resource(*tid);
let nv = session.version_manager.get_version(&vk).unwrap_or(0) + 1;
session.version_manager.set_version(vk, nv);
}
for &entity in meta.dirty_entities.iter() {
if !new_entities.contains(&entity) {
if let Some(key) = session.entity_keys.get(&entity) {
let vk = VersionKey::Entity(key.clone());
if let Some(v) = session.version_manager.get_version(&vk) {
session.version_manager.set_version(vk, v + 1);
}
}
}
}
for e in &meta.despawned_entities {
if let Some(key) = session.entity_keys.get(e).cloned() {
session
.version_manager
.remove_version(&VersionKey::Entity(key));
}
}
Ok(new_keys.clone())
}
Err(err) => {
session.dirty_entities.extend(meta.dirty_entities.drain());
session
.despawned_entities
.extend(meta.despawned_entities.drain());
session.dirty_resources.extend(meta.dirty_resources.drain());
Err(err.clone())
}
};
if should_send_result && (is_final_batch || result.is_err()) {
bevy::log::debug!(
"emitting CommitCompleted for cid={:?} final_batch={} err={}",
cid,
is_final_batch,
result.is_err()
);
completed.write(CommitCompleted {
result: event_res,
dirty_entities: vec![],
correlation_id: cid,
});
}
} else if let Err(e) = &result {
if should_send_result {
completed.write(CommitCompleted {
result: Err(e.clone()),
dirty_entities: vec![],
correlation_id: cid,
});
}
}
to_despawn.push(ent);
if is_final_batch || had_error {
let should_trigger_next = !had_error && *status == CommitStatus::InProgressAndDirty;
*status = CommitStatus::Idle;
if should_trigger_next {
if let (Some(conn), Some(store)) =
(commit_connection.clone(), commit_store.clone())
{
triggers.write(TriggerCommit {
correlation_id: None,
target_connection: conn,
store,
});
}
}
}
} else if PENDING_LOG_COUNT.fetch_add(1, Ordering::Relaxed) < 5 {
bevy::log::debug!(
"commit task still pending (cid={:?})",
trigger_id.correlation_id
);
}
}
if had_error {
*status = CommitStatus::Idle;
}
for entity in to_despawn {
commands.entity(entity).despawn();
}
}
pub fn auto_dirty_tracking_entity_system<T: Component + Persist>(
mut session: ResMut<PersistenceSession>,
query: Query<Entity, Or<(Added<T>, Changed<T>)>>,
) {
for entity in query.iter() {
debug!(
"Marking entity {:?} as dirty due to component {}",
entity,
std::any::type_name::<T>()
);
session.dirty_entities.insert(entity);
}
}
pub fn auto_dirty_tracking_resource_system<T: Resource + Persist>(
mut session: ResMut<PersistenceSession>,
resource: Option<Res<T>>,
) {
if let Some(resource) = resource {
if resource.is_changed() {
session.mark_resource_dirty::<T>();
}
}
}
fn auto_despawn_tracking_system(
mut session: ResMut<PersistenceSession>,
mut removed: RemovedComponents<Guid>,
) {
for entity in removed.read() {
session.mark_despawned(entity);
}
}
#[derive(Resource, Clone)]
pub struct PersistencePluginConfig {
pub batching_enabled: bool,
pub commit_batch_size: usize,
pub thread_count: usize,
pub default_store: String,
}
impl Default for PersistencePluginConfig {
fn default() -> Self {
Self {
batching_enabled: true,
commit_batch_size: 1000,
thread_count: 4, default_store: "default_store".to_string(),
}
}
}
pub struct PersistencePluginCore {
backend: PersistenceBackend,
config: PersistencePluginConfig,
}
impl PersistencePluginCore {
pub fn new(db: Arc<dyn DatabaseConnection>) -> Self {
Self {
backend: PersistenceBackend::Static(db),
config: PersistencePluginConfig::default(),
}
}
pub fn with_config(mut self, config: PersistencePluginConfig) -> Self {
self.config = config;
self
}
}
#[derive(Resource, Default)]
struct CommitEventListeners {
pub listeners: HashMap<u64, oneshot::Sender<Result<(), PersistenceError>>>,
}
pub fn register_commit_listener(
world: &mut World,
correlation_id: u64,
sender: oneshot::Sender<Result<(), PersistenceError>>,
) {
world
.resource_mut::<CommitEventListeners>()
.listeners
.insert(correlation_id, sender);
}
pub fn take_commit_listener(
world: &mut World,
correlation_id: u64,
) -> Option<oneshot::Sender<Result<(), PersistenceError>>> {
world
.resource_mut::<CommitEventListeners>()
.listeners
.remove(&correlation_id)
}
fn commit_event_listener(
mut events: MessageReader<CommitCompleted>,
mut listeners: ResMut<CommitEventListeners>,
) {
for event in events.read() {
if let Some(id) = event.correlation_id {
if let Some(sender) = listeners.listeners.remove(&id) {
info!("Found listener for commit {}. Sending result.", id);
let result = match &event.result {
Ok(_) => Ok(()),
Err(e) => Err(e.clone()),
};
let _ = sender.send(result);
} else {
info!("Commit listener missing for correlation_id={}", id);
}
} else {
trace!("CommitCompleted event without correlation id consumed");
}
}
}
impl Plugin for PersistencePluginCore {
fn build(&self, app: &mut App) {
ensure_task_pools(app);
let db_conn = match &self.backend {
PersistenceBackend::Static(db) => db.clone(),
};
let session = PersistenceSession::new();
app.insert_resource(session);
app.insert_resource(self.config.clone());
app.insert_resource(DatabaseConnectionResource {
connection: db_conn.clone(),
});
app.init_resource::<RegisteredPersistTypes>();
app.add_message::<TriggerCommit>();
app.add_message::<CommitCompleted>();
app.init_resource::<CommitStatus>();
app.init_resource::<CommitEventListeners>();
app.insert_resource(TokioRuntime {
runtime: TOKIO_RUNTIME.clone(),
});
app.init_resource::<PersistenceQueryCache>();
app.init_resource::<DeferredWorldOperations>();
{
let ptr: *mut World = app.world_mut() as *mut World;
bevy::log::trace!(
"PersistencePluginCore: inserting initial ImmediateWorldPtr {:p}",
ptr
);
if app.world().get_resource::<ImmediateWorldPtr>().is_none() {
app.insert_resource(ImmediateWorldPtr::new(ptr));
} else {
app.world_mut().resource_mut::<ImmediateWorldPtr>().set(ptr);
}
}
fn publish_immediate_world_ptr(world: &mut World) {
let ptr: *mut World = world as *mut World;
if world.get_resource::<ImmediateWorldPtr>().is_none() {
world.insert_resource(ImmediateWorldPtr::new(ptr));
} else {
world.resource_mut::<ImmediateWorldPtr>().set(ptr);
}
}
app.add_systems(Startup, publish_immediate_world_ptr);
app.add_systems(First, publish_immediate_world_ptr);
let registry = COMPONENT_REGISTRY.lock().unwrap();
let registrations = registry.len();
if registrations == 0 {
bevy::log::warn!(
"No #[persist] registrations detected; components/resources will not be persisted"
);
} else {
bevy::log::debug!(registrations, "Applying #[persist] registrations");
}
for reg_fn in registry.iter() {
reg_fn(app);
}
app.configure_sets(
PostUpdate,
(
PersistenceSystemSet::ChangeDetection,
PersistenceSystemSet::PreCommit,
PersistenceSystemSet::Commit,
)
.chain(),
);
fn apply_deferred_world_ops(world: &mut World) {
let mut pending = world.resource::<DeferredWorldOperations>().drain();
for op in pending.drain(..) {
op(world);
}
}
app.add_systems(
PostUpdate,
(
apply_deferred_world_ops,
publish_immediate_world_ptr,
auto_despawn_tracking_system,
)
.in_set(PersistenceSystemSet::ChangeDetection),
);
app.add_systems(
PostUpdate,
(commit_event_listener, handle_commit_trigger).in_set(PersistenceSystemSet::PreCommit),
);
app.add_systems(
PostUpdate,
handle_commit_completed.in_set(PersistenceSystemSet::Commit),
);
}
}
#[derive(Clone)]
pub struct PersistencePlugins {
backend: PersistenceBackend,
config: PersistencePluginConfig,
}
impl PersistencePlugins {
pub fn new(db: Arc<dyn DatabaseConnection>) -> Self {
Self {
backend: PersistenceBackend::Static(db),
config: PersistencePluginConfig::default(),
}
}
pub fn with_config(mut self, config: PersistencePluginConfig) -> Self {
self.config = config;
self
}
}
#[derive(Clone)]
struct PersistenceGuards;
impl Plugin for PersistenceGuards {
fn build(&self, app: &mut App) {
ensure_task_pools(app);
}
}
impl PluginGroup for PersistencePlugins {
fn build(self) -> PluginGroupBuilder {
let core = PersistencePluginCore::new(match self.backend {
PersistenceBackend::Static(db) => db,
})
.with_config(self.config.clone());
PluginGroupBuilder::start::<Self>()
.add(PersistenceGuards)
.add(core)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::connection::MockDatabaseConnection;
use crate::{Persist, PersistenceSession, PersistentRes};
use bevy_persistence_database_derive::persist;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::sync::Arc;
#[derive(Component, Debug, Clone, PartialEq, Serialize, Deserialize)]
struct TestHealth {
value: i32,
}
impl Persist for TestHealth {
fn name() -> &'static str {
"TestHealth"
}
}
#[derive(Clone)]
#[persist(resource)]
struct TestSettings {
difficulty: f32,
map_name: String,
}
#[derive(Resource, Default)]
struct Capture {
loaded: bool,
map_name: Option<String>,
difficulty: Option<f32>,
}
#[test]
fn test_read_only_access_doesnt_mark_dirty() {
let mut app = App::new();
let session = PersistenceSession::new();
app.insert_resource(session);
app.add_systems(Update, auto_dirty_tracking_entity_system::<TestHealth>);
let entity = app.world_mut().spawn(TestHealth { value: 100 }).id();
app.update();
{
let mut session = app.world_mut().resource_mut::<PersistenceSession>();
session.dirty_entities.clear();
}
{
let health = app.world().get::<TestHealth>(entity).unwrap();
assert_eq!(health.value, 100);
}
app.update();
{
let session = app.world().resource::<PersistenceSession>();
assert!(
!session.dirty_entities.contains(&entity),
"Entity was incorrectly marked dirty after read-only access"
);
}
{
let mut health = app.world_mut().get_mut::<TestHealth>(entity).unwrap();
health.value = 200;
}
app.update();
{
let session = app.world().resource::<PersistenceSession>();
assert!(
session.dirty_entities.contains(&entity),
"Entity should be marked dirty after modification"
);
}
}
#[test]
fn refreshes_immediate_world_ptr_before_startup_after_app_move() {
let mut db = MockDatabaseConnection::new();
db.expect_fetch_resource()
.returning(|_, _| Box::pin(async {
Ok(Some((json!({ "difficulty": 0.3, "map_name": "moved" }), 1)))
}));
db.expect_document_key_field().return_const("_key");
let mut app = App::new();
app.add_plugins(MinimalPlugins);
app.add_plugins(PersistencePlugins::new(Arc::new(db)));
{
let mut session = app.world_mut().resource_mut::<PersistenceSession>();
session.register_resource::<TestSettings>();
}
app.insert_resource(Capture::default());
let mut relocated = Vec::new();
relocated.push(app);
let mut app = relocated.pop().expect("relocated app");
app.add_systems(
Update,
|mut res: PersistentRes<TestSettings>, mut cap: ResMut<Capture>| {
if let Some(gs) = res.get() {
cap.loaded = true;
cap.map_name = Some(gs.map_name.clone());
cap.difficulty = Some(gs.difficulty);
}
},
);
app.update();
let cap = app.world().resource::<Capture>();
assert!(cap.loaded, "resource should load even after app move");
assert_eq!(cap.map_name.as_deref(), Some("moved"));
assert_eq!(cap.difficulty, Some(0.3));
}
}