use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock};
use yrs::{
Doc, GetString, Map, Observable, ReadTxn, Text, Transact, Update, updates::decoder::Decode,
updates::encoder::Encode,
};
use super::storage::{CrdtStorage, StorageResult};
use super::types::UpdateOrigin;
use crate::error::DiaryxError;
use crate::fs::FileSystemEvent;
const BODY_TEXT_NAME: &str = "body";
const FRONTMATTER_MAP_NAME: &str = "frontmatter";
pub type SyncCallback = Arc<dyn Fn(&str, &[u8]) + Send + Sync>;
pub struct BodyDoc {
doc: Doc,
body_text: yrs::TextRef,
frontmatter_map: yrs::MapRef,
storage: Arc<dyn CrdtStorage>,
doc_name: Arc<RwLock<String>>,
event_callback: Option<Arc<dyn Fn(&FileSystemEvent) + Send + Sync>>,
applying_remote: Arc<AtomicBool>,
sync_callback: RwLock<Option<SyncCallback>>,
_update_subscription: RwLock<Option<yrs::Subscription>>,
}
impl BodyDoc {
pub fn new(storage: Arc<dyn CrdtStorage>, doc_name: String) -> Self {
let doc = Doc::new();
let body_text = doc.get_or_insert_text(BODY_TEXT_NAME);
let frontmatter_map = doc.get_or_insert_map(FRONTMATTER_MAP_NAME);
Self {
doc,
body_text,
frontmatter_map,
storage,
doc_name: Arc::new(RwLock::new(doc_name)),
event_callback: None,
applying_remote: Arc::new(AtomicBool::new(false)),
sync_callback: RwLock::new(None),
_update_subscription: RwLock::new(None),
}
}
pub fn load(storage: Arc<dyn CrdtStorage>, doc_name: String) -> StorageResult<Self> {
let doc = Doc::new();
let body_text = doc.get_or_insert_text(BODY_TEXT_NAME);
let frontmatter_map = doc.get_or_insert_map(FRONTMATTER_MAP_NAME);
{
let mut txn = doc.transact_mut();
if let Some(state) = storage.load_doc(&doc_name)?
&& let Ok(update) = Update::decode_v1(&state)
&& let Err(e) = txn.apply_update(update)
{
log::warn!(
"Failed to apply stored snapshot for body doc {}: {}",
doc_name,
e
);
}
let updates = storage.get_all_updates(&doc_name)?;
for crdt_update in updates {
if let Ok(update) = Update::decode_v1(&crdt_update.data)
&& let Err(e) = txn.apply_update(update)
{
log::warn!(
"Failed to apply stored update {} for body doc {}: {}",
crdt_update.update_id,
doc_name,
e
);
}
}
}
Ok(Self {
doc,
body_text,
frontmatter_map,
storage,
doc_name: Arc::new(RwLock::new(doc_name)),
event_callback: None,
applying_remote: Arc::new(AtomicBool::new(false)),
sync_callback: RwLock::new(None),
_update_subscription: RwLock::new(None),
})
}
pub fn set_event_callback(&mut self, callback: Arc<dyn Fn(&FileSystemEvent) + Send + Sync>) {
self.event_callback = Some(callback);
}
pub fn set_sync_callback(&self, callback: SyncCallback) {
let doc_name = self.doc_name.read().unwrap().clone();
if self._update_subscription.read().unwrap().is_some() {
log::warn!(
"[BodyDoc] DEBUG set_sync_callback: observer ALREADY registered for '{}', skipping",
doc_name
);
return;
}
log::warn!(
"[BodyDoc] DEBUG set_sync_callback: REGISTERING observer for '{}'",
doc_name
);
{
let mut cb = self.sync_callback.write().unwrap();
*cb = Some(callback.clone());
}
let applying_remote = Arc::clone(&self.applying_remote);
let doc_name_ref = Arc::clone(&self.doc_name);
let subscription = self
.doc
.observe_update_v1(move |_, event| {
let current_doc_name = doc_name_ref.read().unwrap().clone();
if applying_remote.load(Ordering::SeqCst) {
log::warn!(
"[BodyDoc] DEBUG Observer: SKIPPING remote update for '{}', update_len={}",
current_doc_name,
event.update.len()
);
return;
}
log::warn!(
"[BodyDoc] DEBUG Observer: FIRING for '{}', update_len={}",
current_doc_name,
event.update.len()
);
callback(¤t_doc_name, &event.update);
})
.expect("Failed to observe document updates");
let mut sub = self._update_subscription.write().unwrap();
*sub = Some(subscription);
log::trace!(
"[BodyDoc] set_sync_callback: observer registered for '{}'",
doc_name
);
}
#[allow(dead_code)]
fn emit_event(&self, event: FileSystemEvent) {
if let Some(ref cb) = self.event_callback {
cb(&event);
}
}
pub fn doc_name(&self) -> String {
self.doc_name.read().unwrap().clone()
}
pub fn set_doc_name(&self, new_name: String) {
let mut name = self.doc_name.write().unwrap();
*name = new_name;
}
pub fn get_body(&self) -> String {
let txn = self.doc.transact();
self.body_text.get_string(&txn)
}
pub fn set_body(&self, content: &str) -> StorageResult<()> {
let doc_name = self.doc_name.read().unwrap().clone();
let has_observer = self._update_subscription.read().unwrap().is_some();
log::warn!(
"[BodyDoc] DEBUG set_body: doc='{}', content_len={}, has_observer={}",
doc_name,
content.len(),
has_observer
);
let (current, sv_before) = {
let txn = self.doc.transact();
(self.body_text.get_string(&txn), txn.state_vector())
};
if current == content {
log::warn!(
"[BodyDoc] DEBUG set_body: UNCHANGED doc='{}', both_len={}",
doc_name,
content.len()
);
return Ok(());
}
log::warn!(
"[BodyDoc] DEBUG set_body: CHANGED doc='{}', current_len={}, new_len={}, current_preview='{}', new_preview='{}'",
doc_name,
current.len(),
content.len(),
current.chars().take(80).collect::<String>(),
content.chars().take(80).collect::<String>()
);
let current_bytes = current.as_bytes();
let new_bytes = content.as_bytes();
let common_prefix = current_bytes
.iter()
.zip(new_bytes.iter())
.take_while(|(a, b)| a == b)
.count();
let remaining_current = current_bytes.len() - common_prefix;
let remaining_new = new_bytes.len() - common_prefix;
let common_suffix = current_bytes[common_prefix..]
.iter()
.rev()
.zip(new_bytes[common_prefix..].iter().rev())
.take_while(|(a, b)| a == b)
.take(remaining_current.min(remaining_new))
.count();
let delete_start = common_prefix;
let delete_end = current_bytes.len() - common_suffix;
let insert_start = common_prefix;
let insert_end = new_bytes.len() - common_suffix;
{
let mut txn = self.doc.transact_mut();
if delete_end > delete_start {
let delete_len = (delete_end - delete_start) as u32;
self.body_text
.remove_range(&mut txn, delete_start as u32, delete_len);
}
if insert_end > insert_start {
let insert_text = &content[insert_start..insert_end];
self.body_text
.insert(&mut txn, delete_start as u32, insert_text);
}
}
self.record_update(&sv_before)
}
pub fn insert_at(&self, index: u32, text: &str) -> StorageResult<()> {
let sv_before = {
let txn = self.doc.transact();
txn.state_vector()
};
{
let mut txn = self.doc.transact_mut();
self.body_text.insert(&mut txn, index, text);
}
self.record_update(&sv_before)
}
pub fn delete_range(&self, index: u32, length: u32) -> StorageResult<()> {
let sv_before = {
let txn = self.doc.transact();
txn.state_vector()
};
{
let mut txn = self.doc.transact_mut();
self.body_text.remove_range(&mut txn, index, length);
}
self.record_update(&sv_before)
}
fn record_update(&self, sv_before: &yrs::StateVector) -> StorageResult<()> {
let update = {
let txn = self.doc.transact();
txn.encode_state_as_update_v1(sv_before)
};
if !update.is_empty() {
let doc_name = self.doc_name.read().unwrap();
self.storage
.append_update(&doc_name, &update, UpdateOrigin::Local)?;
}
Ok(())
}
pub fn body_len(&self) -> u32 {
let txn = self.doc.transact();
self.body_text.len(&txn)
}
pub fn get_frontmatter(&self, key: &str) -> Option<String> {
let txn = self.doc.transact();
self.frontmatter_map
.get(&txn, key)
.and_then(|v| v.cast::<String>().ok())
}
pub fn set_frontmatter(&self, key: &str, value: &str) -> StorageResult<()> {
let sv_before = {
let txn = self.doc.transact();
txn.state_vector()
};
{
let mut txn = self.doc.transact_mut();
self.frontmatter_map.insert(&mut txn, key, value);
}
self.record_update(&sv_before)
}
pub fn remove_frontmatter(&self, key: &str) -> StorageResult<()> {
let sv_before = {
let txn = self.doc.transact();
txn.state_vector()
};
{
let mut txn = self.doc.transact_mut();
self.frontmatter_map.remove(&mut txn, key);
}
self.record_update(&sv_before)
}
pub fn frontmatter_keys(&self) -> Vec<String> {
let txn = self.doc.transact();
self.frontmatter_map.keys(&txn).map(String::from).collect()
}
pub fn encode_state_vector(&self) -> Vec<u8> {
let txn = self.doc.transact();
txn.state_vector().encode_v1()
}
pub fn encode_state_as_update(&self) -> Vec<u8> {
let txn = self.doc.transact();
txn.encode_state_as_update_v1(&Default::default())
}
pub fn encode_diff(&self, remote_state_vector: &[u8]) -> StorageResult<Vec<u8>> {
let sv = yrs::StateVector::decode_v1(remote_state_vector)
.map_err(|e| DiaryxError::Crdt(format!("Failed to decode state vector: {}", e)))?;
let txn = self.doc.transact();
Ok(txn.encode_state_as_update_v1(&sv))
}
pub fn apply_update(&self, update: &[u8], origin: UpdateOrigin) -> StorageResult<Option<i64>> {
let is_remote = origin != UpdateOrigin::Local;
if is_remote {
self.applying_remote.store(true, Ordering::SeqCst);
}
let decoded = Update::decode_v1(update)
.map_err(|e| DiaryxError::Crdt(format!("Failed to decode update: {}", e)))?;
let result = {
let mut txn = self.doc.transact_mut();
txn.apply_update(decoded)
.map_err(|e| DiaryxError::Crdt(format!("Failed to apply update: {}", e)))
};
if is_remote {
self.applying_remote.store(false, Ordering::SeqCst);
}
result?;
let doc_name = self.doc_name.read().unwrap();
let update_id = self.storage.append_update(&doc_name, update, origin)?;
Ok(Some(update_id))
}
pub fn save(&self) -> StorageResult<()> {
let state = self.encode_state_as_update();
let doc_name = self.doc_name.read().unwrap();
self.storage.save_doc(&doc_name, &state)
}
pub fn reload(&mut self) -> StorageResult<()> {
let doc_name = self.doc_name.read().unwrap().clone();
if let Some(state) = self.storage.load_doc(&doc_name)?
&& let Ok(update) = Update::decode_v1(&state)
{
let mut txn = self.doc.transact_mut();
if let Err(e) = txn.apply_update(update) {
log::warn!("Failed to reload body doc {}: {}", doc_name, e);
}
}
Ok(())
}
pub fn get_history(&self) -> StorageResult<Vec<super::types::CrdtUpdate>> {
let doc_name = self.doc_name.read().unwrap();
self.storage.get_all_updates(&doc_name)
}
pub fn get_updates_since(&self, since_id: i64) -> StorageResult<Vec<super::types::CrdtUpdate>> {
let doc_name = self.doc_name.read().unwrap();
self.storage.get_updates_since(&doc_name, since_id)
}
pub fn observe_body<F>(&self, callback: F) -> yrs::Subscription
where
F: Fn() + Send + Sync + 'static,
{
self.body_text.observe(move |_txn, _event| {
callback();
})
}
pub fn observe_updates<F>(&self, callback: F) -> yrs::Subscription
where
F: Fn(&[u8]) + Send + Sync + 'static,
{
self.doc
.observe_update_v1(move |_, event| {
callback(&event.update);
})
.expect("Failed to observe document updates")
}
}
impl std::fmt::Debug for BodyDoc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let doc_name = self.doc_name.read().unwrap();
f.debug_struct("BodyDoc")
.field("doc_name", &*doc_name)
.field("body_len", &self.body_len())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crdt::MemoryStorage;
fn create_body_doc(name: &str) -> BodyDoc {
let storage = Arc::new(MemoryStorage::new());
BodyDoc::new(storage, name.to_string())
}
#[test]
fn test_new_body_doc_is_empty() {
let doc = create_body_doc("test.md");
assert_eq!(doc.get_body(), "");
assert_eq!(doc.body_len(), 0);
}
#[test]
fn test_set_and_get_body() {
let doc = create_body_doc("test.md");
let content = "# Hello World\n\nThis is content.";
doc.set_body(content).unwrap();
assert_eq!(doc.get_body(), content);
assert_eq!(doc.body_len(), content.len() as u32);
}
#[test]
fn test_replace_body() {
let doc = create_body_doc("test.md");
doc.set_body("Original content").unwrap();
doc.set_body("New content").unwrap();
assert_eq!(doc.get_body(), "New content");
}
#[test]
fn test_insert_at() {
let doc = create_body_doc("test.md");
doc.set_body("Hello World").unwrap();
doc.insert_at(6, "Beautiful ").unwrap();
assert_eq!(doc.get_body(), "Hello Beautiful World");
}
#[test]
fn test_delete_range() {
let doc = create_body_doc("test.md");
doc.set_body("Hello Beautiful World").unwrap();
doc.delete_range(6, 10).unwrap();
assert_eq!(doc.get_body(), "Hello World");
}
#[test]
fn test_frontmatter_operations() {
let doc = create_body_doc("test.md");
doc.set_frontmatter("title", "My Title").unwrap();
doc.set_frontmatter("author", "John Doe").unwrap();
assert_eq!(doc.get_frontmatter("title"), Some("My Title".to_string()));
assert_eq!(doc.get_frontmatter("author"), Some("John Doe".to_string()));
assert_eq!(doc.get_frontmatter("nonexistent"), None);
let keys = doc.frontmatter_keys();
assert!(keys.contains(&"title".to_string()));
assert!(keys.contains(&"author".to_string()));
doc.remove_frontmatter("author").unwrap();
assert_eq!(doc.get_frontmatter("author"), None);
}
#[test]
fn test_save_and_load() {
let storage = Arc::new(MemoryStorage::new());
let doc_name = "test.md".to_string();
{
let doc = BodyDoc::new(storage.clone(), doc_name.clone());
doc.set_body("# Persistent Content").unwrap();
doc.set_frontmatter("title", "Saved Title").unwrap();
doc.save().unwrap();
}
{
let doc = BodyDoc::load(storage, doc_name).unwrap();
assert_eq!(doc.get_body(), "# Persistent Content");
assert_eq!(
doc.get_frontmatter("title"),
Some("Saved Title".to_string())
);
}
}
#[test]
fn test_sync_between_docs() {
let storage1 = Arc::new(MemoryStorage::new());
let storage2 = Arc::new(MemoryStorage::new());
let doc1 = BodyDoc::new(storage1, "test.md".to_string());
let doc2 = BodyDoc::new(storage2, "test.md".to_string());
doc1.set_body("Content from doc1").unwrap();
doc1.set_frontmatter("source", "doc1").unwrap();
let update = doc1.encode_state_as_update();
doc2.apply_update(&update, UpdateOrigin::Remote).unwrap();
assert_eq!(doc2.get_body(), "Content from doc1");
assert_eq!(doc2.get_frontmatter("source"), Some("doc1".to_string()));
}
#[test]
fn test_concurrent_edits() {
let storage1 = Arc::new(MemoryStorage::new());
let storage2 = Arc::new(MemoryStorage::new());
let doc1 = BodyDoc::new(storage1, "test.md".to_string());
let doc2 = BodyDoc::new(storage2, "test.md".to_string());
doc1.set_body("Hello World").unwrap();
let initial = doc1.encode_state_as_update();
doc2.apply_update(&initial, UpdateOrigin::Remote).unwrap();
doc1.insert_at(0, "A: ").unwrap(); doc2.insert_at(11, "!").unwrap();
let update1 = doc1.encode_state_as_update();
let update2 = doc2.encode_state_as_update();
doc1.apply_update(&update2, UpdateOrigin::Remote).unwrap();
doc2.apply_update(&update1, UpdateOrigin::Remote).unwrap();
assert_eq!(doc1.get_body(), doc2.get_body());
let body = doc1.get_body();
assert!(body.contains("A: "));
assert!(body.contains("!"));
}
#[test]
fn test_encode_diff() {
let storage1 = Arc::new(MemoryStorage::new());
let storage2 = Arc::new(MemoryStorage::new());
let doc1 = BodyDoc::new(storage1, "test.md".to_string());
let doc2 = BodyDoc::new(storage2, "test.md".to_string());
doc1.set_body("Initial content").unwrap();
let initial = doc1.encode_state_as_update();
doc2.apply_update(&initial, UpdateOrigin::Remote).unwrap();
let sv2 = doc2.encode_state_vector();
doc1.insert_at(0, "NEW: ").unwrap();
let diff = doc1.encode_diff(&sv2).unwrap();
doc2.apply_update(&diff, UpdateOrigin::Remote).unwrap();
assert_eq!(doc2.get_body(), "NEW: Initial content");
}
#[test]
fn test_observer_fires_on_change() {
use std::sync::atomic::{AtomicBool, Ordering};
let doc = create_body_doc("test.md");
let changed = Arc::new(AtomicBool::new(false));
let changed_clone = changed.clone();
let _sub = doc.observe_updates(move |_update| {
changed_clone.store(true, Ordering::SeqCst);
});
doc.set_body("Trigger change").unwrap();
assert!(changed.load(Ordering::SeqCst));
}
#[test]
fn test_doc_name() {
let doc = create_body_doc("workspace/notes/hello.md");
assert_eq!(doc.doc_name(), "workspace/notes/hello.md");
}
#[test]
fn test_set_body_with_links_after_sync() {
let storage1 = Arc::new(MemoryStorage::new());
let storage2 = Arc::new(MemoryStorage::new());
let doc1 = BodyDoc::new(storage1, "family.md".to_string());
let doc2 = BodyDoc::new(storage2, "family.md".to_string());
let v1 =
"# Hooray, you made it!\n\nThat's all folks!\n\n[← Go back](/Message for my family.md)";
doc1.set_body(v1).unwrap();
assert_eq!(doc1.get_body(), v1);
let update = doc1.encode_state_as_update();
doc2.apply_update(&update, UpdateOrigin::Remote).unwrap();
assert_eq!(doc2.get_body(), v1);
doc2.set_body(v1).unwrap();
assert_eq!(doc2.get_body(), v1, "After re-saving same content on doc2");
let update2 = doc2.encode_state_as_update();
doc1.apply_update(&update2, UpdateOrigin::Remote).unwrap();
assert_eq!(doc1.get_body(), v1, "After sync back to doc1");
let v2 = "# Hooray, you made it!\n\nThat's all folks!\n\n[← Go back](</Message for my family.md>)";
doc1.set_body(v2).unwrap();
assert_eq!(doc1.get_body(), v2, "After adding angle brackets to link");
let update3 = doc1.encode_state_as_update();
doc2.apply_update(&update3, UpdateOrigin::Remote).unwrap();
assert_eq!(
doc2.get_body(),
v2,
"doc2 after sync of angle bracket change"
);
}
#[test]
fn test_set_body_incremental_edits_near_links() {
let doc = create_body_doc("family.md");
let v1 = "# Hooray, you made it!\n\nThat's all folks!";
doc.set_body(v1).unwrap();
assert_eq!(doc.get_body(), v1);
let v2 =
"# Hooray, you made it!\n\nThat's all folks!\n\n[← Go back](/Message for my family.md)";
doc.set_body(v2).unwrap();
assert_eq!(doc.get_body(), v2, "After adding link");
let v3 = "# Hooray, you made it!\n\nThat's all folks!\n\n[← Go back](</Message for my family.md>)";
doc.set_body(v3).unwrap();
assert_eq!(doc.get_body(), v3, "After adding angle brackets");
let v4 =
"# Hooray, you made it!\n\nThat's all folks!\n\n[← Go back](/Message for my family.md)";
doc.set_body(v4).unwrap();
assert_eq!(doc.get_body(), v4, "After removing angle brackets");
doc.set_body(v3).unwrap();
assert_eq!(doc.get_body(), v3, "After re-adding angle brackets");
}
#[test]
fn test_concurrent_edits_near_links() {
let storage1 = Arc::new(MemoryStorage::new());
let storage2 = Arc::new(MemoryStorage::new());
let doc1 = BodyDoc::new(storage1, "msg.md".to_string());
let doc2 = BodyDoc::new(storage2, "msg.md".to_string());
let initial = "Welcome!\n\n[Click me!](/family.md)";
doc1.set_body(initial).unwrap();
let update = doc1.encode_state_as_update();
doc2.apply_update(&update, UpdateOrigin::Remote).unwrap();
let v1 = "Welcome!\n\nNow, click this link:\n\n[Click me!](/family.md)";
doc1.set_body(v1).unwrap();
let v2 =
"Welcome! If you are reading this, that means it works.\n\n[Click me!](/family.md)";
doc2.set_body(v2).unwrap();
let update1 = doc1.encode_state_as_update();
let update2 = doc2.encode_state_as_update();
doc1.apply_update(&update2, UpdateOrigin::Remote).unwrap();
doc2.apply_update(&update1, UpdateOrigin::Remote).unwrap();
let body1 = doc1.get_body();
let body2 = doc2.get_body();
assert_eq!(body1, body2, "Clients should converge");
doc1.set_body(&body1).unwrap();
assert_eq!(
doc1.get_body(),
body1,
"Re-saving merged content should be no-op"
);
let update3 = doc1.encode_state_as_update();
doc2.apply_update(&update3, UpdateOrigin::Remote).unwrap();
assert_eq!(
doc2.get_body(),
body1,
"doc2 should still match after re-save sync"
);
}
#[test]
fn test_set_body_message_file_corruption_scenario() {
let storage1 = Arc::new(MemoryStorage::new());
let storage2 = Arc::new(MemoryStorage::new());
let doc1 = BodyDoc::new(storage1, "Message for my family.md".to_string());
let doc2 = BodyDoc::new(storage2, "Message for my family.md".to_string());
let correct_content = concat!(
"Welcome! If you are reading this, that means site generation is officially working. ",
"Welcome to the first-ever Diaryx site!\n\n",
"Now, click this other link below: it will only work if you have been added to the ",
"\"family\" group:\n\n",
"[Click me!](/family.md)"
);
doc1.set_body(correct_content).unwrap();
assert_eq!(doc1.get_body(), correct_content);
let update = doc1.encode_state_as_update();
doc2.apply_update(&update, UpdateOrigin::Remote).unwrap();
assert_eq!(doc2.get_body(), correct_content);
doc2.set_body(correct_content).unwrap();
assert_eq!(
doc2.get_body(),
correct_content,
"Re-save should not corrupt"
);
let update2 = doc2.encode_state_as_update();
doc1.apply_update(&update2, UpdateOrigin::Remote).unwrap();
assert_eq!(
doc1.get_body(),
correct_content,
"After sync back, content should be intact"
);
}
#[test]
fn test_sync_callback_uses_updated_doc_name_after_rename() {
use std::sync::{Arc, Mutex};
let doc = create_body_doc("old-name.md");
let emitted_names = Arc::new(Mutex::new(Vec::<String>::new()));
let emitted_names_clone = Arc::clone(&emitted_names);
doc.set_sync_callback(Arc::new(move |doc_name: &str, _update: &[u8]| {
emitted_names_clone
.lock()
.unwrap()
.push(doc_name.to_string());
}));
doc.set_body("v1").unwrap();
doc.set_doc_name("new-name.md".to_string());
doc.set_body("v2").unwrap();
let names = emitted_names.lock().unwrap();
assert!(!names.is_empty(), "expected sync callbacks to fire");
assert!(names.iter().any(|name| name == "old-name.md"));
assert_eq!(names.last().map(String::as_str), Some("new-name.md"));
}
}