use crate::registration::COMPONENT_REGISTRY;
use crate::{PersistenceError, PersistenceSession, DatabaseConnection, Guid, Persist, TransactionOperation, Collection};
use crate::db::connection::DatabaseConnectionResource;
use crate::versioning::version_manager::VersionKey;
use bevy::app::PluginGroupBuilder;
use bevy::prelude::*;
use bevy::tasks::{IoTaskPool, TaskPool, Task};
use futures_lite::future;
use once_cell::sync::Lazy;
use std::any::TypeId;
use std::collections::{HashMap, HashSet};
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
};
use tokio::runtime::Runtime;
use tokio::sync::oneshot;
use crate::query::deferred_ops::DeferredWorldOperations;
use crate::query::immediate_world_ptr::ImmediateWorldPtr;
use crate::query::PersistenceQueryCache;
static TOKIO_RUNTIME: Lazy<Arc<Runtime>> = Lazy::new(|| {
Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap(),
)
});
#[derive(Component)]
struct CommitTask(Task<Result<(Vec<String>, Vec<Entity>), PersistenceError>>);
#[derive(Component)]
struct MultiBatchCommitTracker {
correlation_id: u64,
remaining_batches: Arc<AtomicUsize>,
result_sender: Arc<Mutex<Option<oneshot::Sender<Result<(), PersistenceError>>>>>,
}
#[derive(SystemSet, Debug, Clone, PartialEq, Eq, Hash)]
pub enum PersistenceSystemSet {
PreCommit,
Commit,
}
#[derive(Message)]
pub struct CommitCompleted(pub Result<Vec<String>, PersistenceError>, pub Vec<Entity>, pub Option<u64>);
#[derive(Resource, Default)]
pub struct RegisteredPersistTypes(pub HashSet<TypeId>);
#[derive(Message, Clone)]
pub struct TriggerCommit {
pub correlation_id: Option<u64>,
pub target_connection: Arc<dyn DatabaseConnection>,
}
#[derive(Resource, Default, PartialEq, Debug)]
pub enum CommitStatus {
#[default]
Idle,
InProgress,
InProgressAndDirty,
}
#[derive(Clone)]
enum PersistenceBackend {
Static(Arc<dyn DatabaseConnection>),
}
#[derive(Resource)]
pub struct TokioRuntime(pub Arc<Runtime>);
impl TokioRuntime {
pub fn block_on<F: std::future::Future>(&self, fut: F) -> F::Output {
self.0.block_on(fut)
}
}
fn handle_commit_trigger(world: &mut World) {
let mut should_commit = false;
let mut correlation_id = None;
let mut requested_connection: Option<Arc<dyn DatabaseConnection>> = None;
world.resource_scope(|world, mut events: Mut<Messages<TriggerCommit>>| {
let mut status = world.resource_mut::<CommitStatus>();
if !events.is_empty() {
let first_trigger = events.drain().next().unwrap();
requested_connection = Some(first_trigger.target_connection.clone());
match *status {
CommitStatus::Idle => {
info!("[handle_commit_trigger] TriggerCommit event received. Status is Idle.");
should_commit = true;
correlation_id = first_trigger.correlation_id;
}
CommitStatus::InProgress => {
info!("[handle_commit_trigger] TriggerCommit event received while another is in progress. Queuing.");
*status = CommitStatus::InProgressAndDirty;
}
CommitStatus::InProgressAndDirty => {
}
}
}
});
if !should_commit {
return;
}
let connection = if let Some(conn) = requested_connection {
conn
} else {
let err = PersistenceError::new("TriggerCommit missing target_connection");
world.write_message(CommitCompleted(Err(err.clone()), vec![], correlation_id));
bevy::log::error!(%err, "failed to select database connection before commit");
return;
};
let plugin_config = world.resource::<PersistencePluginConfig>().clone();
let (dirty_entities, despawned_entities, dirty_resources) = {
let mut session = world.resource_mut::<PersistenceSession>();
(
std::mem::take(&mut session.dirty_entities),
std::mem::take(&mut session.despawned_entities),
std::mem::take(&mut session.dirty_resources),
)
};
let commit_data = match PersistenceSession::_prepare_commit(
world.resource::<PersistenceSession>(),
world,
&dirty_entities,
&despawned_entities,
&dirty_resources,
plugin_config.thread_count,
connection.document_key_field(),
) {
Ok(data) if data.operations.is_empty() => {
world.write_message(CommitCompleted(Ok(vec![]), vec![], correlation_id));
let mut session = world.resource_mut::<PersistenceSession>();
session.dirty_entities.extend(dirty_entities);
session.despawned_entities.extend(despawned_entities);
session.dirty_resources.extend(dirty_resources);
return;
}
Ok(data) => data,
Err(e) => {
world.write_message(CommitCompleted(Err(e.clone()), vec![], correlation_id));
let mut session = world.resource_mut::<PersistenceSession>();
session.dirty_entities.extend(dirty_entities);
session.despawned_entities.extend(despawned_entities);
session.dirty_resources.extend(dirty_resources);
return;
}
};
*world.resource_mut::<CommitStatus>() = CommitStatus::InProgress;
let runtime = world.resource::<TokioRuntime>().0.clone();
let thread_pool = IoTaskPool::get();
let db = connection.clone();
let all_operations = commit_data.operations;
let new_entities = commit_data.new_entities;
if plugin_config.batching_enabled && all_operations.len() > plugin_config.commit_batch_size {
let batch_size = plugin_config.commit_batch_size;
let session = world.resource::<PersistenceSession>();
let mut entity_ops: HashMap<Entity, Vec<TransactionOperation>> = HashMap::new();
let mut new_entity_ops: Vec<(TransactionOperation, Entity)> = Vec::new();
let mut resource_ops: Vec<TransactionOperation> = Vec::new();
let mut new_entity_idx = 0;
for op in all_operations {
match &op {
TransactionOperation::UpdateDocument { collection: Collection::Entities, key, .. } => {
if let Some(entity) = session.entity_keys.iter().find(|(_, k)| *k == key).map(|(e, _)| *e) {
entity_ops.entry(entity).or_default().push(op);
}
},
TransactionOperation::DeleteDocument { collection: Collection::Entities, key, .. } => {
if let Some(entity) = session.entity_keys.iter().find(|(_, k)| *k == key).map(|(e, _)| *e) {
entity_ops.entry(entity).or_default().push(op);
}
},
TransactionOperation::CreateDocument { collection: Collection::Entities, .. } => {
if let Some(entity) = new_entities.get(new_entity_idx) {
new_entity_ops.push((op, *entity));
new_entity_idx += 1;
}
},
_ => resource_ops.push(op),
}
}
let mut batches: Vec<Vec<TransactionOperation>> = Vec::new();
let mut batch_entities: Vec<HashSet<Entity>> = Vec::new();
let mut batch_new_entities: Vec<Vec<Entity>> = Vec::new();
let mut current_batch = Vec::new();
let mut current_batch_entities = HashSet::new();
let mut current_batch_new_entities = Vec::new();
for (entity, ops) in entity_ops {
if current_batch.len() + ops.len() > batch_size && !current_batch.is_empty() {
batches.push(std::mem::take(&mut current_batch));
batch_entities.push(std::mem::take(&mut current_batch_entities));
batch_new_entities.push(std::mem::take(&mut current_batch_new_entities));
}
current_batch.extend(ops);
current_batch_entities.insert(entity);
}
for (op, entity) in new_entity_ops {
if current_batch.len() + 1 > batch_size && !current_batch.is_empty() {
batches.push(std::mem::take(&mut current_batch));
batch_entities.push(std::mem::take(&mut current_batch_entities));
batch_new_entities.push(std::mem::take(&mut current_batch_new_entities));
}
current_batch.push(op);
current_batch_new_entities.push(entity);
}
if current_batch.len() + resource_ops.len() > batch_size && !current_batch.is_empty() {
batches.push(std::mem::take(&mut current_batch));
batch_entities.push(std::mem::take(&mut current_batch_entities));
batch_new_entities.push(std::mem::take(&mut current_batch_new_entities));
}
current_batch.extend(resource_ops);
if !current_batch.is_empty() {
batches.push(current_batch);
batch_entities.push(current_batch_entities);
batch_new_entities.push(current_batch_new_entities);
}
let num_batches = batches.len();
info!(
"[handle_commit_trigger] Splitting commit into {} batches of size ~{}.",
num_batches, batch_size
);
if let Some(cid) = correlation_id {
if let Some(listener) = world.resource_mut::<CommitEventListeners>().0.remove(&cid) {
world.spawn(MultiBatchCommitTracker {
correlation_id: cid,
remaining_batches: Arc::new(AtomicUsize::new(num_batches)),
result_sender: Arc::new(Mutex::new(Some(listener))),
});
}
}
let mut resource_sets = Vec::with_capacity(num_batches);
for _ in 0..num_batches {
resource_sets.push(HashSet::new());
}
for (i, res_type) in dirty_resources.iter().enumerate() {
let batch_idx = i % num_batches;
resource_sets[batch_idx].insert(*res_type);
}
for (i, (batch_ops, batch_entities_set)) in batches.into_iter().zip(batch_entities.into_iter()).enumerate() {
let batch_db = db.clone();
let batch_runtime = runtime.clone();
let batch_new_entities = batch_new_entities.get(i).cloned().unwrap_or_default();
let db_for_task = batch_db.clone();
let task = thread_pool.spawn(async move {
batch_runtime
.block_on(async {
db_for_task
.execute_transaction(batch_ops)
.await
.map(|keys| (keys, batch_new_entities))
})
});
let meta = CommitMeta {
dirty_entities: batch_entities_set,
despawned_entities: if i == 0 { despawned_entities.clone() } else { HashSet::new() },
dirty_resources: resource_sets[i].clone(),
connection: batch_db.clone(),
};
world.spawn((CommitTask(task), TriggerID(correlation_id), meta));
}
} else {
let db_for_task = db.clone();
let task = thread_pool.spawn(async move {
runtime.block_on(async {
db_for_task
.execute_transaction(all_operations)
.await
.map(|keys| (keys, new_entities))
})
});
world.spawn((
CommitTask(task),
TriggerID(correlation_id),
CommitMeta {
dirty_entities,
despawned_entities,
dirty_resources,
connection: db.clone(),
},
));
}
}
#[derive(Component)]
struct TriggerID(Option<u64>);
#[derive(Component)]
struct CommitMeta {
dirty_entities: HashSet<Entity>,
despawned_entities: HashSet<Entity>,
dirty_resources: HashSet<TypeId>,
connection: Arc<dyn DatabaseConnection>,
}
fn handle_commit_completed(
mut commands: Commands,
mut query: Query<(Entity, &mut CommitTask, &TriggerID, Option<&mut CommitMeta>)>,
mut session: ResMut<PersistenceSession>,
mut status: ResMut<CommitStatus>,
mut completed: MessageWriter<CommitCompleted>,
mut triggers: MessageWriter<TriggerCommit>,
mut trackers: Query<(Entity, &MultiBatchCommitTracker)>,
) {
let mut to_despawn = Vec::new();
let mut had_error = false;
for (ent, mut task, trigger_id, meta_opt) in &mut query {
if let Some(result) = future::block_on(future::poll_once(&mut task.0)) {
let cid = trigger_id.0;
let mut is_final_batch = true;
let mut should_send_result = true;
let mut commit_connection: Option<Arc<dyn DatabaseConnection>> = None;
if result.is_err() {
had_error = true;
}
if let Some(correlation_id) = cid {
if let Some((tracker_entity, tracker)) = trackers
.iter_mut()
.find(|(_, t)| t.correlation_id == correlation_id)
{
let remaining = tracker
.remaining_batches
.fetch_sub(1, Ordering::SeqCst)
- 1;
is_final_batch = remaining == 0;
if result.is_err() || is_final_batch {
if let Some(sender) = tracker.result_sender.lock().unwrap().take() {
if result.is_err() {
let _ = sender.send(Err(result.as_ref().err().unwrap().clone()));
} else if is_final_batch {
let _ = sender.send(Ok(()));
}
}
commands.entity(tracker_entity).remove::<MultiBatchCommitTracker>();
} else {
should_send_result = false;
}
}
}
if let Some(mut meta) = meta_opt {
commit_connection = Some(meta.connection.clone());
let event_res = match &result {
Ok((new_keys, new_entities)) => {
for (e, key) in new_entities.iter().zip(new_keys.iter()) {
commands.entity(*e).insert(Guid::new(key.clone()));
session.entity_keys.insert(*e, key.clone());
session.version_manager.set_version(VersionKey::Entity(key.clone()), 1);
}
for tid in &meta.dirty_resources {
let vk = VersionKey::Resource(*tid);
let nv = session.version_manager.get_version(&vk).unwrap_or(0) + 1;
session.version_manager.set_version(vk, nv);
}
for &entity in meta.dirty_entities.iter() {
if !new_entities.contains(&entity) {
if let Some(key) = session.entity_keys.get(&entity) {
let vk = VersionKey::Entity(key.clone());
if let Some(v) = session.version_manager.get_version(&vk) {
session.version_manager.set_version(vk, v + 1);
}
}
}
}
for e in &meta.despawned_entities {
if let Some(key) = session.entity_keys.get(e).cloned() {
session.version_manager.remove_version(&VersionKey::Entity(key));
}
}
Ok(new_keys.clone())
}
Err(err) => {
session.dirty_entities.extend(meta.dirty_entities.drain());
session.despawned_entities.extend(meta.despawned_entities.drain());
session.dirty_resources.extend(meta.dirty_resources.drain());
Err(err.clone())
}
};
if should_send_result && (is_final_batch || result.is_err()) {
completed.write(CommitCompleted(event_res, vec![], cid));
}
} else if let Err(e) = &result {
if should_send_result {
completed.write(CommitCompleted(Err(e.clone()), vec![], cid));
}
}
to_despawn.push(ent);
if is_final_batch || had_error {
let should_trigger_next = !had_error && *status == CommitStatus::InProgressAndDirty;
*status = CommitStatus::Idle;
if should_trigger_next {
if let Some(conn) = commit_connection.clone() {
triggers.write(TriggerCommit {
correlation_id: None,
target_connection: conn,
});
}
}
}
}
}
if had_error {
*status = CommitStatus::Idle;
}
for entity in to_despawn {
commands.entity(entity).despawn();
}
}
pub fn auto_dirty_tracking_entity_system<T: Component + Persist>(
mut session: ResMut<PersistenceSession>,
query: Query<Entity, Or<(Added<T>, Changed<T>)>>,
) {
for entity in query.iter() {
debug!(
"Marking entity {:?} as dirty due to component {}",
entity,
std::any::type_name::<T>()
);
session.dirty_entities.insert(entity);
}
}
pub fn auto_dirty_tracking_resource_system<T: Resource + Persist>(
mut session: ResMut<PersistenceSession>,
resource: Option<Res<T>>,
) {
if let Some(resource) = resource {
if resource.is_changed() {
session.mark_resource_dirty::<T>();
}
}
}
fn auto_despawn_tracking_system(
mut session: ResMut<PersistenceSession>,
mut removed: RemovedComponents<Guid>,
) {
for entity in removed.read() {
session.mark_despawned(entity);
}
}
#[derive(Resource, Clone)]
pub struct PersistencePluginConfig {
pub batching_enabled: bool,
pub commit_batch_size: usize,
pub thread_count: usize,
}
impl Default for PersistencePluginConfig {
fn default() -> Self {
Self {
batching_enabled: true,
commit_batch_size: 1000,
thread_count: 4, }
}
}
pub struct PersistencePluginCore {
backend: PersistenceBackend,
config: PersistencePluginConfig,
}
impl PersistencePluginCore {
pub fn new(db: Arc<dyn DatabaseConnection>) -> Self {
Self {
backend: PersistenceBackend::Static(db),
config: PersistencePluginConfig::default(),
}
}
pub fn with_config(mut self, config: PersistencePluginConfig) -> Self {
self.config = config;
self
}
}
#[derive(Resource, Default)]
pub(crate) struct CommitEventListeners(pub(crate) HashMap<u64, oneshot::Sender<Result<(), PersistenceError>>>);
fn commit_event_listener(
mut events: MessageReader<CommitCompleted>,
mut listeners: ResMut<CommitEventListeners>,
) {
for event in events.read() {
if let Some(id) = event.2 {
if let Some(sender) = listeners.0.remove(&id) {
info!("Found listener for commit {}. Sending result.", id);
let result = match &event.0 {
Ok(_) => Ok(()),
Err(e) => Err(e.clone()),
};
let _ = sender.send(result);
}
}
}
}
impl Plugin for PersistencePluginCore {
fn build(&self, app: &mut App) {
IoTaskPool::get_or_init(|| TaskPool::new());
let db_conn = match &self.backend {
PersistenceBackend::Static(db) => db.clone(),
};
let session = PersistenceSession::new();
app.insert_resource(session);
app.insert_resource(self.config.clone());
app.insert_resource(DatabaseConnectionResource(db_conn.clone()));
app.init_resource::<RegisteredPersistTypes>();
app.add_message::<TriggerCommit>();
app.add_message::<CommitCompleted>();
app.init_resource::<CommitStatus>();
app.init_resource::<CommitEventListeners>();
app.insert_resource(TokioRuntime(TOKIO_RUNTIME.clone()));
app.init_resource::<PersistenceQueryCache>();
app.init_resource::<DeferredWorldOperations>();
{
let ptr: *mut World = app.world_mut() as *mut World;
bevy::log::trace!("PersistencePluginCore: inserting initial ImmediateWorldPtr {:p}", ptr);
if app.world().get_resource::<ImmediateWorldPtr>().is_none() {
app.insert_resource(ImmediateWorldPtr::new(ptr));
} else {
app.world_mut().resource_mut::<ImmediateWorldPtr>().set(ptr);
}
}
fn publish_immediate_world_ptr(world: &mut World) {
let ptr: *mut World = world as *mut World;
if world.get_resource::<ImmediateWorldPtr>().is_none() {
world.insert_resource(ImmediateWorldPtr::new(ptr));
} else {
world.resource_mut::<ImmediateWorldPtr>().set(ptr);
}
}
app.add_systems(First, publish_immediate_world_ptr);
let registry = COMPONENT_REGISTRY.lock().unwrap();
for reg_fn in registry.iter() {
reg_fn(app);
}
app.configure_sets(
PostUpdate,
(PersistenceSystemSet::PreCommit, PersistenceSystemSet::Commit).chain(),
);
fn apply_deferred_world_ops(world: &mut World) {
let mut pending = world.resource::<DeferredWorldOperations>().drain();
for op in pending.drain(..) {
op(world);
}
}
app.add_systems(
PostUpdate,
(
apply_deferred_world_ops,
publish_immediate_world_ptr,
auto_despawn_tracking_system,
handle_commit_trigger,
commit_event_listener,
)
.in_set(PersistenceSystemSet::PreCommit),
);
app.add_systems(
PostUpdate,
handle_commit_completed.in_set(PersistenceSystemSet::Commit),
);
}
}
#[derive(Clone)]
pub struct PersistencePlugins {
backend: PersistenceBackend,
config: PersistencePluginConfig,
}
impl PersistencePlugins {
pub fn new(db: Arc<dyn DatabaseConnection>) -> Self {
Self {
backend: PersistenceBackend::Static(db),
config: PersistencePluginConfig::default(),
}
}
pub fn with_config(mut self, config: PersistencePluginConfig) -> Self {
self.config = config;
self
}
}
#[derive(Clone)]
struct MaybeAddLogPlugin;
impl Plugin for MaybeAddLogPlugin {
fn build(&self, app: &mut App) {
let already_has_subscriber = bevy::log::tracing::dispatcher::has_been_set();
let already_added = app.is_plugin_added::<bevy::log::LogPlugin>();
if !already_has_subscriber && !already_added {
app.add_plugins(bevy::log::LogPlugin::default());
}
}
}
impl PluginGroup for PersistencePlugins {
fn build(self) -> PluginGroupBuilder {
let core = PersistencePluginCore::new(match self.backend {
PersistenceBackend::Static(db) => db,
})
.with_config(self.config.clone());
MinimalPlugins
.build()
.add(MaybeAddLogPlugin)
.add(core)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Persist, PersistenceSession};
use serde::{Serialize, Deserialize};
#[derive(Component, Debug, Clone, PartialEq, Serialize, Deserialize)]
struct TestHealth {
value: i32
}
impl Persist for TestHealth {
fn name() -> &'static str {
"TestHealth"
}
}
#[test]
fn test_read_only_access_doesnt_mark_dirty() {
let mut app = App::new();
let session = PersistenceSession::new();
app.insert_resource(session);
app.add_systems(Update, auto_dirty_tracking_entity_system::<TestHealth>);
let entity = app.world_mut().spawn(TestHealth { value: 100 }).id();
app.update();
{
let mut session = app.world_mut().resource_mut::<PersistenceSession>();
session.dirty_entities.clear();
}
{
let health = app.world().get::<TestHealth>(entity).unwrap();
assert_eq!(health.value, 100);
}
app.update();
{
let session = app.world().resource::<PersistenceSession>();
assert!(
!session.dirty_entities.contains(&entity),
"Entity was incorrectly marked dirty after read-only access"
);
}
{
let mut health = app.world_mut().get_mut::<TestHealth>(entity).unwrap();
health.value = 200;
}
app.update();
{
let session = app.world().resource::<PersistenceSession>();
assert!(
session.dirty_entities.contains(&entity),
"Entity should be marked dirty after modification"
);
}
}
}