bevy_persistence_database 0.1.1

A persistence and database integration solution for the Bevy game engine
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
//! Core ECS‐to‐Arango bridge: defines `PersistenceSession`.
//! Handles local cache, change tracking, and commit logic (create/update/delete).

use crate::db::connection::{
    DatabaseConnection, PersistenceError, TransactionOperation,
    BEVY_PERSISTENCE_VERSION_FIELD, Collection,
};
use crate::plugins::TriggerCommit;
use crate::versioning::version_manager::{VersionKey, VersionManager};
use bevy::prelude::{info, App, Component, Entity, Resource, World};
use serde_json::Value;
use std::{
    any::TypeId,
    collections::{HashMap, HashSet},
    sync::{
        atomic::{AtomicU64, Ordering},
        Arc,
    },
};
use crate::persist::Persist;
use crate::plugins::persistence_plugin::{CommitEventListeners, TokioRuntime};
use tokio::sync::oneshot;
use tokio::time::timeout;
use rayon::prelude::*;
use rayon::ThreadPoolBuilder;

/// A unique ID generator for correlating commit requests and responses.
static NEXT_CORRELATION_ID: AtomicU64 = AtomicU64::new(1);

type ComponentSerializer   = Box<dyn Fn(Entity, &World) -> Result<Option<(String, Value)>, PersistenceError> + Send + Sync>;
type ComponentDeserializer = Box<dyn Fn(&mut World, Entity, Value) -> Result<(), PersistenceError> + Send + Sync>;
type ResourceSerializer    = Box<dyn Fn(&World, &PersistenceSession) -> Result<Option<(String, Value)>, PersistenceError> + Send + Sync>;
type ResourceDeserializer  = Box<dyn Fn(&mut World, Value) -> Result<(), PersistenceError> + Send + Sync>;

/// Manages a "unit of work": local World cache + change tracking + async runtime.
#[derive(Resource)]
pub struct PersistenceSession {
    pub db: Arc<dyn DatabaseConnection>,
    pub(crate) dirty_entities: HashSet<Entity>,
    pub despawned_entities: HashSet<Entity>,
    pub entity_keys: HashMap<Entity, String>,
    pub dirty_resources: HashSet<TypeId>,
    pub(crate) version_manager: VersionManager,
    component_serializers: HashMap<TypeId, ComponentSerializer>,
    pub(crate) component_deserializers: HashMap<String, ComponentDeserializer>,
    pub(crate) component_type_id_to_name: HashMap<TypeId, &'static str>,
    // New: reverse lookup and presence checkers
    pub(crate) component_name_to_type_id: HashMap<String, TypeId>,
    pub(crate) component_presence: HashMap<String, Box<dyn Fn(&World, Entity) -> bool + Send + Sync>>,
    resource_serializers: HashMap<TypeId, ResourceSerializer>,
    resource_deserializers: HashMap<String, ResourceDeserializer>,
    resource_name_to_type_id: HashMap<String, TypeId>,
}

pub(crate) struct CommitData {
    pub(crate) operations: Vec<TransactionOperation>,
    pub(crate) new_entities: Vec<Entity>,
}

impl PersistenceSession {
    /// Registers a component type for persistence.
    ///
    /// This method sets up both serialization and deserialization for any
    /// component that implements the `Persist` marker trait.
    pub fn register_component<T: Component + Persist>(&mut self) {
        let ser_key = T::name();
        let type_id = TypeId::of::<T>();
        self.component_type_id_to_name.insert(type_id, ser_key);
        // reverse lookup
        self.component_name_to_type_id.insert(ser_key.to_string(), type_id);
        // presence checker
        self.component_presence.insert(
            ser_key.to_string(),
            Box::new(|world: &World, entity: Entity| world.entity(entity).contains::<T>()),
        );
        self.component_serializers.insert(type_id, Box::new(
            move |entity, world| -> Result<Option<(String, Value)>, PersistenceError> {
                if let Some(c) = world.get::<T>(entity) {
                    let v = serde_json::to_value(c).map_err(|_| PersistenceError::new("Serialization failed"))?;
                    if v.is_null() {
                        return Err(PersistenceError::new("Could not serialize"));
                    }
                    Ok(Some((ser_key.to_string(), v)))
                } else {
                    Ok(None)
                }
            },
        ));

        let de_key = T::name();
        self.component_deserializers.insert(de_key.to_string(), Box::new(
            |world, entity, json_val| {
                let comp: T = serde_json::from_value(json_val).map_err(|e| PersistenceError::new(e.to_string()))?;
                world.entity_mut(entity).insert(comp);
                Ok(())
            }
        ));
    }

    /// Registers a resource type for persistence.
    ///
    /// This method sets up both serialization and deserialization for any
    /// resource that implements the `Persist` marker trait.
    pub fn register_resource<R: Resource + Persist>(&mut self) {
        let ser_key = R::name();
        let type_id = std::any::TypeId::of::<R>();
        // Insert serializer into map keyed by TypeId
        self.resource_serializers.insert(type_id, Box::new(move |world, _session| {
            // Fetch and serialize the resource
            if let Some(r) = world.get_resource::<R>() {
                let v = serde_json::to_value(r).map_err(|e| PersistenceError::new(e.to_string()))?;
                if v.is_null() {
                    return Err(PersistenceError::new("Could not serialize"));
                }
                Ok(Some((ser_key.to_string(), v)))
            } else {
                Ok(None)
            }
        }));

        let de_key = R::name();
        self.resource_deserializers.insert(de_key.to_string(), Box::new(
            |world, json_val| {
                let res: R = serde_json::from_value(json_val).map_err(|e| PersistenceError::new(e.to_string()))?;
                world.insert_resource(res);
                Ok(())
            }
        ));
        self.resource_name_to_type_id.insert(de_key.to_string(), type_id);
    }

    /// Manually mark a resource as needing persistence.
    pub fn mark_resource_dirty<R: Resource>(&mut self) {
        self.dirty_resources.insert(TypeId::of::<R>());
    }

    /// Manually mark an entity as having been removed.
    pub fn mark_despawned(&mut self, entity: Entity) {
        self.despawned_entities.insert(entity);
    }

    /// Testing constructor w/ mock DB.
    #[cfg(test)]
    pub fn new_mocked(db: Arc<dyn DatabaseConnection>) -> Self {
        Self {
            db,
            component_serializers: HashMap::new(),
            component_deserializers: HashMap::new(),
            component_type_id_to_name: HashMap::new(),
            component_name_to_type_id: HashMap::new(),
            component_presence: HashMap::new(),
            resource_serializers: HashMap::new(),
            resource_deserializers: HashMap::new(),
            resource_name_to_type_id: HashMap::new(),
            dirty_entities: HashSet::new(),
            despawned_entities: HashSet::new(),
            dirty_resources: HashSet::new(),
            entity_keys: HashMap::new(),
            version_manager: VersionManager::new(),
        }
    }

    /// Create a new session.
    pub fn new(db: Arc<dyn DatabaseConnection>) -> Self {
        Self {
            db,
            component_serializers: HashMap::new(),
            component_deserializers: HashMap::new(),
            component_type_id_to_name: HashMap::new(),
            component_name_to_type_id: HashMap::new(),
            component_presence: HashMap::new(),
            resource_serializers: HashMap::new(),
            resource_deserializers: HashMap::new(),
            resource_name_to_type_id: HashMap::new(),
            dirty_entities: HashSet::new(),
            despawned_entities: HashSet::new(),
            dirty_resources: HashSet::new(),
            entity_keys: HashMap::new(),
            version_manager: VersionManager::new(),
        }
    }

    /// Fetch the document for the given key from `db` and deserialize its components
    /// into `world` for `entity`. Also caches the document version.
    pub async fn fetch_and_insert_document(
        &mut self,
        db: &(dyn DatabaseConnection + 'static),
        world: &mut World,
        key: &str,
        entity: Entity,
        component_names: &[&'static str],
    ) -> Result<(), PersistenceError> {
        if let Some((doc, version)) = db.fetch_document(key).await? {
            // Cache the version
            self.version_manager
                .set_version(VersionKey::Entity(key.to_string()), version);

            // Deserialize requested components
            for &comp_name in component_names {
                if let Some(val) = doc.get(comp_name) {
                    if let Some(deser) = self.component_deserializers.get(comp_name) {
                        deser(world, entity, val.clone())?;
                    }
                }
            }
        }
        Ok(())
    }

    /// Fetch each registered resource's JSON blob from `db`
    /// and run the registered deserializer to insert it into `world`.
    pub async fn fetch_and_insert_resources(
        &mut self,
        db: &(dyn DatabaseConnection + 'static),
        world: &mut World,
    ) -> Result<(), PersistenceError> {
        for (res_name, deser) in self.resource_deserializers.iter() {
            if let Some((val, version)) = db.fetch_resource(res_name).await? {
                // Cache the version based on the resource's TypeId using the map.
                if let Some(type_id) = self.resource_name_to_type_id.get(res_name) {
                    self.version_manager
                        .set_version(VersionKey::Resource(*type_id), version);
                }
                deser(world, val)?;
            }
        }
        Ok(())
    }

    /// Fetch each named component from `db` for the given document `key` and
    /// run the registered deserializer to insert it into `world` for `entity`.
    pub async fn fetch_and_insert_components(
        &self,
        db: &(dyn DatabaseConnection + 'static),
        world: &mut World,
        key: &str,
        entity: Entity,
        component_names: &[&'static str],
    ) -> Result<(), PersistenceError> {
        for &comp_name in component_names {
            if let Some(val) = db.fetch_component(key, comp_name).await? {
                if let Some(deser) = self.component_deserializers.get(comp_name) {
                    deser(world, entity, val)?;
                }
            }
        }
        Ok(())
    }

    /// Prepare all operations (deletions, creations, updates of entities and resources)
    /// using a Rayon pool of `thread_count` threads.
    pub(crate) fn _prepare_commit(
        session: &PersistenceSession,
        world: &World,
        dirty_entities: &HashSet<Entity>,
        despawned_entities: &HashSet<Entity>,
        dirty_resources: &HashSet<TypeId>,
        thread_count: usize,
    ) -> Result<CommitData, PersistenceError> {
        let mut operations = Vec::new();
        let mut newly_created_entities = Vec::new();

        // 1) Deletions (order matters less, do sequentially)
        for &entity in despawned_entities {
            if let Some(key) = session.entity_keys.get(&entity) {
                let version_key = VersionKey::Entity(key.clone());
                let current_version = session
                    .version_manager
                    .get_version(&version_key)
                    .ok_or_else(|| PersistenceError::new("Missing version for deletion"))?;
                operations.push(TransactionOperation::DeleteDocument {
                    collection: Collection::Entities,
                    key: key.clone(),
                    expected_current_version: current_version,
                });
            }
        }

        // Build a Rayon pool
        let pool = ThreadPoolBuilder::new()
            .num_threads(thread_count)
            .build()
            .map_err(|e| PersistenceError::new(format!("ThreadPool error: {}", e)))?;

        // 2) Creations & Updates (entities)
        // Fix: Handle errors properly in the closure
        let entity_ops_result: Result<Vec<_>, PersistenceError> = pool.install(|| {
            dirty_entities
                .par_iter()
                .map(|&entity| {
                    let mut data_map = serde_json::Map::new();
                    // serialize each component
                    for serializer in session.component_serializers.values() {
                        if let Some((field_name, value)) = serializer(entity, world)? {
                            data_map.insert(field_name, value);
                        }
                    }
                    if data_map.is_empty() {
                        return Ok(None);
                    }
                    if let Some(key) = session.entity_keys.get(&entity) {
                        // update existing
                        let version_key = VersionKey::Entity(key.clone());
                        let current_version = session
                            .version_manager
                            .get_version(&version_key)
                            .ok_or_else(|| PersistenceError::new("Missing version for update"))?;
                        let next_version = current_version + 1;
                        data_map.insert(
                            BEVY_PERSISTENCE_VERSION_FIELD.to_string(),
                            serde_json::json!(next_version),
                        );
                        Ok(Some((
                            TransactionOperation::UpdateDocument {
                                collection: Collection::Entities,
                                key: key.clone(),
                                expected_current_version: current_version,
                                patch: Value::Object(data_map),
                            },
                            None,
                        )))
                    } else {
                        // create new document
                        data_map.insert(
                            BEVY_PERSISTENCE_VERSION_FIELD.to_string(),
                            serde_json::json!(1u64),
                        );
                        let document = Value::Object(data_map);
                        Ok(Some((
                            TransactionOperation::CreateDocument {
                                collection: Collection::Entities,
                                data: document,
                            },
                            Some(entity),
                        )))
                    }
                })
                .filter_map(|res| res.transpose())
                .collect::<Result<Vec<(TransactionOperation, Option<Entity>)>, PersistenceError>>()
        });
        
        let (entity_ops, created): (Vec<TransactionOperation>, Vec<Option<Entity>>) = match entity_ops_result {
            Ok(ops_and_entities) => ops_and_entities.into_iter().unzip(),
            Err(e) => return Err(e),
        };
        
        operations.extend(entity_ops);
        newly_created_entities.extend(created.into_iter().flatten());

        // 3) Resources (in parallel)
        let resource_ops_result: Result<Vec<_>, PersistenceError> = pool.install(|| {
            let mut resource_ops = Vec::new();
            
            for &resource_type_id in dirty_resources {
                if let Some(serializer) = session.resource_serializers.get(&resource_type_id) {
                    match serializer(world, session) {
                        Ok(Some((name, mut value))) => {
                            let version_key = VersionKey::Resource(resource_type_id);
                            if let Some(current_version) = session.version_manager.get_version(&version_key) {
                                // update existing resource
                                let next_version = current_version + 1;
                                if let Some(obj) = value.as_object_mut() {
                                    obj.insert(
                                        BEVY_PERSISTENCE_VERSION_FIELD.to_string(),
                                        serde_json::json!(next_version),
                                    );
                                }
                                resource_ops.push(TransactionOperation::UpdateDocument {
                                    collection: Collection::Resources,
                                    key: name,
                                    expected_current_version: current_version,
                                    patch: value,
                                });
                            } else {
                                // create new resource
                                if let Some(obj) = value.as_object_mut() {
                                    obj.insert(
                                        BEVY_PERSISTENCE_VERSION_FIELD.to_string(),
                                        serde_json::json!(1u64),
                                    );
                                    // Use backend-specific key field instead of hardcoding "_key"
                                    let key_field = session.db.document_key_field();
                                    obj.insert(key_field.to_string(), Value::String(name.clone()));
                                }
                                resource_ops.push(TransactionOperation::CreateDocument {
                                    collection: Collection::Resources,
                                    data: value,
                                });
                            }
                        }
                        Ok(None) => {} // Nothing to do if serializer returns None
                        Err(e) => return Err(e),
                    }
                }
            }
            
            Ok(resource_ops)
        });
        
        let resource_ops = resource_ops_result?;
        operations.extend(resource_ops);

        info!("[_prepare_commit] Prepared {} operations.", operations.len());
        Ok(CommitData {
            operations,
            new_entities: newly_created_entities,
        })
    }
}

/// This function provides a clean `await`-able interface for the event-driven
/// commit system. It sends a `TriggerCommit` event, then waits for a
/// `CommitCompleted` event with a matching correlation ID.
pub async fn commit(app: &mut App) -> Result<(), PersistenceError> {
    let correlation_id = NEXT_CORRELATION_ID.fetch_add(1, Ordering::Relaxed);
    let (tx, mut rx) = oneshot::channel();

    // Insert the sender into the world so the listener system can find it.
    app.world_mut()
        .resource_mut::<CommitEventListeners>()
        .0
        .insert(correlation_id, tx);

    // Send the event to trigger the commit.
    app.world_mut().send_event(TriggerCommit {
        correlation_id: Some(correlation_id),
    });

    // The timeout is applied to the entire commit-and-wait process.
    timeout(std::time::Duration::from_secs(60), async {
        // Loop, calling app.update() and checking the receiver.
        // Yield to the executor each time to avoid blocking.
        loop {
            app.update();

            // Check if the receiver has a value without blocking.
            match rx.try_recv() {
                Ok(result) => {
                    info!("Received commit result for correlation ID {}", correlation_id);
                    return result;
                }
                Err(oneshot::error::TryRecvError::Empty) => {
                    // No result yet, yield and try again on the next loop iteration.
                    tokio::task::yield_now().await;
                }
                Err(oneshot::error::TryRecvError::Closed) => {
                    // The sender was dropped, which indicates an error.
                    return Err(PersistenceError::new(
                        "Commit channel closed unexpectedly. The commit listener might have panicked.",
                    ));
                }
            }
        }
    })
    .await
    .map_err(|_| PersistenceError::new("Commit timed out after 60 seconds"))?
}

// Add a synchronous convenience that uses the plugin’s runtime
pub fn commit_sync(app: &mut App) -> Result<(), PersistenceError> {
    // Clone the Arc<Runtime> out of the world without holding a borrow
    let rt = { app.world().resource::<TokioRuntime>().0.clone() };
    rt.block_on(commit(app))
}

#[cfg(test)]
mod arango_session {
    use super::*;
    use crate::db::connection::MockDatabaseConnection;
    use crate::persist::Persist;
    use crate::registration::COMPONENT_REGISTRY;
    use bevy::prelude::World;
    use bevy_persistence_database_derive::persist;
    use serde_json::json;

    fn setup() {
        // Clear the global registry to avoid test pollution from other modules
        let mut registry = COMPONENT_REGISTRY.lock().unwrap();
        registry.clear();
    }

    #[persist(resource)]
    struct MyRes { value: i32 }
    #[persist(component)]
    struct MyComp { value: i32 }

    #[test]
    fn new_session_is_empty() {
        setup();
        let mock_db = MockDatabaseConnection::new();
        let session = PersistenceSession::new_mocked(Arc::new(mock_db));
        assert!(session.dirty_entities.is_empty());
        assert!(session.despawned_entities.is_empty());
    }

    #[test]
    fn deserializer_inserts_component() {
        setup();
        let mut world = World::new();
        let entity = world.spawn_empty().id();

        let mut session = PersistenceSession::new(Arc::new(MockDatabaseConnection::new()));
        session.register_component::<MyComp>();

        let deserializer = session.component_deserializers.get(MyComp::name()).unwrap();
        deserializer(&mut world, entity, json!({"value": 42})).unwrap();

        assert_eq!(world.get::<MyComp>(entity).unwrap().value, 42);
    }

    #[test]
    fn deserializer_inserts_resource() {
        setup();
        let mut world = World::new();

        let mut session = PersistenceSession::new(Arc::new(MockDatabaseConnection::new()));
        session.register_resource::<MyRes>();

        let deserializer = session.resource_deserializers.get(MyRes::name()).unwrap();
        deserializer(&mut world, json!({"value": 5})).unwrap();

        assert_eq!(world.resource::<MyRes>().value, 5);
    }
}