use crate::prelude::*;
use axum::extract::ws::{Message, WebSocket};
use futures::sink::SinkExt;
use futures::stream::SplitSink;
use futures::stream::StreamExt;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Mutex;
use yrs::sync::{Message as YMessage, SyncMessage};
use yrs::updates::decoder::Decode;
use yrs::updates::encoder::Encode;
use yrs::{Map, ReadTxn, Transact, Update};
const TRACKING_THROTTLE_SECS: u64 = 60;
#[allow(clippy::cast_precision_loss)]
fn usize_to_f64(v: usize) -> f64 {
v as f64
}
struct CrdtConnection {
conn_id: String, user_id: String,
doc_id: String,
tn_id: TnId,
awareness_tx: Arc<tokio::sync::broadcast::Sender<(String, Vec<u8>)>>,
sync_tx: Arc<tokio::sync::broadcast::Sender<(String, Vec<u8>)>>,
last_access_update: Mutex<Option<Instant>>,
last_modify_update: Mutex<Option<Instant>>,
has_modified: AtomicBool,
}
type DocChannels = (
Arc<tokio::sync::broadcast::Sender<(String, Vec<u8>)>>, Arc<tokio::sync::broadcast::Sender<(String, Vec<u8>)>>, );
type CrdtDocRegistry = tokio::sync::RwLock<HashMap<String, DocChannels>>;
static CRDT_DOCS: std::sync::LazyLock<CrdtDocRegistry> =
std::sync::LazyLock::new(|| tokio::sync::RwLock::new(HashMap::new()));
pub async fn handle_crdt_connection(
ws: WebSocket,
user_id: String,
doc_id: String,
app: App,
tn_id: TnId,
read_only: bool,
) {
let conn_id =
cloudillo_types::utils::random_id().unwrap_or_else(|_| format!("conn-{}", now_timestamp()));
info!("CRDT connection: {} / {} (tn_id={}, conn_id={})", user_id, doc_id, tn_id.0, conn_id);
let (awareness_tx, sync_tx) = {
let mut docs = CRDT_DOCS.write().await;
docs.entry(doc_id.clone())
.or_insert_with(|| {
let (awareness_tx, _) = tokio::sync::broadcast::channel(256);
let (sync_tx, _) = tokio::sync::broadcast::channel(256);
(Arc::new(awareness_tx), Arc::new(sync_tx))
})
.clone()
};
let conn = Arc::new(CrdtConnection {
conn_id: conn_id.clone(),
user_id: user_id.clone(),
doc_id: doc_id.clone(),
tn_id,
awareness_tx,
sync_tx,
last_access_update: Mutex::new(None),
last_modify_update: Mutex::new(None),
has_modified: AtomicBool::new(false),
});
record_file_access_throttled(&app, &conn).await;
let (ws_tx, ws_rx) = ws.split();
let ws_tx: Arc<tokio::sync::Mutex<_>> = Arc::new(tokio::sync::Mutex::new(ws_tx));
send_initial_sync(&app, tn_id, &doc_id, &user_id, &ws_tx).await;
let heartbeat_task = spawn_heartbeat_task(user_id.clone(), ws_tx.clone());
let ws_recv_task =
spawn_receive_task(conn.clone(), ws_tx.clone(), ws_rx, app.clone(), tn_id, read_only);
let sync_task =
spawn_broadcast_task(conn.clone(), ws_tx.clone(), conn.sync_tx.subscribe(), "SYNC");
let awareness_task = spawn_broadcast_task(
conn.clone(),
ws_tx.clone(),
conn.awareness_tx.subscribe(),
"AWARENESS",
);
let _ = ws_recv_task.await;
debug!("WebSocket receive task ended");
record_final_activity(&app, &conn).await;
info!("CRDT connection closing for {}, aborting tasks...", user_id);
heartbeat_task.abort();
sync_task.abort();
awareness_task.abort();
let _ = tokio::join!(heartbeat_task, sync_task, awareness_task);
info!("CRDT connection closed: {} (all tasks cleaned up)", user_id);
log_doc_statistics(&app, tn_id, &conn.doc_id).await;
let should_optimize = cleanup_registry(&conn.doc_id).await;
if should_optimize {
info!("Last connection closed for doc {}, waiting before optimization...", conn.doc_id);
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
let still_no_connections = {
let docs = CRDT_DOCS.read().await;
docs.get(&conn.doc_id).is_none_or(|(awareness_tx, sync_tx)| {
awareness_tx.receiver_count() == 0 && sync_tx.receiver_count() == 0
})
};
if still_no_connections {
info!(
"Confirmed no active connections for doc {}, proceeding with optimization",
conn.doc_id
);
optimize_document(&app, tn_id, &conn.doc_id).await;
} else {
info!(
"New connection established for doc {} during grace period, skipping optimization",
conn.doc_id
);
}
}
}
async fn send_initial_sync(
app: &App,
tn_id: TnId,
doc_id: &str,
user_id: &str,
ws_tx: &Arc<tokio::sync::Mutex<SplitSink<WebSocket, Message>>>,
) {
match app.crdt_adapter.get_updates(tn_id, doc_id).await {
Ok(updates) => {
if updates.is_empty() {
info!("Document {} not initialized, creating initial structure", doc_id);
let initial_update = tokio::task::spawn_blocking(move || {
let doc = yrs::Doc::new();
let meta = doc.get_or_insert_map("meta");
let mut txn = doc.transact_mut();
meta.insert(&mut txn, "i", true);
drop(txn);
let state_vector = yrs::StateVector::default();
let txn = doc.transact();
txn.encode_state_as_update_v1(&state_vector)
})
.await;
if let Ok(update_data) = initial_update {
if !update_data.is_empty() {
let update = cloudillo_types::crdt_adapter::CrdtUpdate::with_client(
update_data.clone(),
"system".to_string(),
);
if let Err(e) = app.crdt_adapter.store_update(tn_id, doc_id, update).await {
warn!("Failed to store initial CRDT update for doc {}: {}", doc_id, e);
} else {
info!("Document {} initialized", doc_id);
let sync_msg = SyncMessage::Update(update_data);
let y_msg = YMessage::Sync(sync_msg);
let encoded = y_msg.encode_v1();
let ws_msg = Message::Binary(encoded.into());
let mut tx = ws_tx.lock().await;
if let Err(e) = tx.send(ws_msg).await {
warn!("Failed to send initial update to {}: {}", user_id, e);
}
}
}
}
} else {
info!(
"Sending {} initial CRDT updates to {} for doc {} (total: {} bytes)",
updates.len(),
user_id,
doc_id,
updates.iter().map(|u| u.data.len()).sum::<usize>()
);
for (idx, update) in updates.iter().enumerate() {
info!(
" Update #{}: {} bytes, client_id={:?}, first 40 bytes: {:?}",
idx,
update.data.len(),
update.client_id,
&update.data[..40.min(update.data.len())]
);
}
let mut tx = ws_tx.lock().await;
for (idx, update) in updates.iter().enumerate() {
let sync_msg = SyncMessage::Update(update.data.clone());
let y_msg = YMessage::Sync(sync_msg);
let encoded = y_msg.encode_v1();
info!(" Sending update #{}: raw={} bytes, encoded={} bytes, first 20 bytes: {:?}",
idx, update.data.len(), encoded.len(), &encoded[..20.min(encoded.len())]);
let ws_msg = Message::Binary(encoded.into());
if let Err(e) = tx.send(ws_msg).await {
warn!("Failed to send initial update to {}: {}", user_id, e);
break;
}
}
}
}
Err(e) => {
warn!("Failed to get initial CRDT updates for doc {}: {}", doc_id, e);
}
}
}
fn spawn_heartbeat_task(
user_id: String,
ws_tx: Arc<tokio::sync::Mutex<SplitSink<WebSocket, Message>>>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(15));
loop {
interval.tick().await;
debug!("CRDT heartbeat: {}", user_id);
let mut tx = ws_tx.lock().await;
if tx.send(Message::Ping(vec![].into())).await.is_err() {
debug!("Client disconnected during heartbeat");
return;
}
}
})
}
fn spawn_receive_task(
conn: Arc<CrdtConnection>,
ws_tx: Arc<tokio::sync::Mutex<SplitSink<WebSocket, Message>>>,
ws_rx: futures::stream::SplitStream<WebSocket>,
app: App,
tn_id: TnId,
read_only: bool,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut ws_rx = ws_rx;
while let Some(msg) = ws_rx.next().await {
match msg {
Ok(Message::Binary(data)) => {
handle_yrs_message(&conn, &data, &ws_tx, &app, tn_id, read_only).await;
}
Ok(Message::Close(_) | Message::Ping(_) | Message::Pong(_)) => {
}
Ok(_) => {
warn!("Received non-binary WebSocket message");
}
Err(e) => {
warn!("CRDT connection error: {}", e);
break;
}
}
}
})
}
fn spawn_broadcast_task(
conn: Arc<CrdtConnection>,
ws_tx: Arc<tokio::sync::Mutex<SplitSink<WebSocket, Message>>>,
mut rx: tokio::sync::broadcast::Receiver<(String, Vec<u8>)>,
label: &'static str,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
debug!(
"Connection {} (user {}) subscribed to {} broadcasts for doc {}",
conn.conn_id, conn.user_id, label, conn.doc_id
);
loop {
match rx.recv().await {
Ok((sender_conn_id, data)) => {
debug!(
"{} broadcast received by conn {}: from conn {} for doc {} ({} bytes)",
label,
conn.conn_id,
sender_conn_id,
conn.doc_id,
data.len()
);
if sender_conn_id == conn.conn_id {
debug!("Skipping {} echo to self for conn {}", label, conn.conn_id);
continue;
}
let ws_msg = Message::Binary(data.into());
debug!(
"Forwarding {} update from conn {} to conn {} (user {}) for doc {}",
label, sender_conn_id, conn.conn_id, conn.user_id, conn.doc_id
);
let mut tx = ws_tx.lock().await;
if tx.send(ws_msg).await.is_err() {
debug!("Client disconnected while forwarding {} update", label);
return;
}
debug!("{} update successfully forwarded to conn {}", label, conn.conn_id);
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
if label == "SYNC" {
warn!(
"Client {} lagged behind on {} updates for doc {}",
conn.user_id, label, conn.doc_id
);
} else {
debug!("Connection {} lagged on {} updates", conn.conn_id, label);
}
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
debug!("{} broadcast channel closed", label);
return;
}
}
}
})
}
fn broadcast_message(
tx: &Arc<tokio::sync::broadcast::Sender<(String, Vec<u8>)>>,
conn_id: &str,
user_id: &str,
doc_id: &str,
payload: Vec<u8>,
label: &str,
) {
match tx.send((conn_id.to_string(), payload)) {
Ok(receiver_count) => {
if label != "AWARENESS" {
info!(
"CRDT {} broadcast from conn {} (user {}) for doc {}: {} receivers",
label, conn_id, user_id, doc_id, receiver_count
);
}
}
Err(_) => {
debug!("CRDT {} broadcast failed - no receivers for doc {}", label, doc_id);
}
}
}
async fn send_echo_raw(
ws_tx: &Arc<tokio::sync::Mutex<SplitSink<WebSocket, Message>>>,
conn_id: &str,
user_id: &str,
doc_id: &str,
payload: &[u8],
label: &str,
) {
let ws_msg = Message::Binary(payload.to_vec().into());
let mut tx = ws_tx.lock().await;
match tx.send(ws_msg).await {
Ok(()) => {
debug!(
"CRDT {} echo sent back to conn {} (user {}) for doc {} ({} bytes)",
label,
conn_id,
user_id,
doc_id,
payload.len()
);
}
Err(e) => {
warn!("Failed to send CRDT {} echo to conn {}: {}", label, conn_id, e);
}
}
}
async fn handle_yrs_message(
conn: &Arc<CrdtConnection>,
data: &[u8],
ws_tx: &Arc<tokio::sync::Mutex<SplitSink<WebSocket, Message>>>,
app: &App,
tn_id: TnId,
read_only: bool,
) {
if data.is_empty() {
warn!("Empty message from conn {}", conn.conn_id);
return;
}
match YMessage::decode_v1(data) {
Ok(YMessage::Sync(sync_msg)) => {
debug!(
"CRDT SYNC message from conn {} (user {}) for doc {}: {:?}",
conn.conn_id,
conn.user_id,
conn.doc_id,
match &sync_msg {
SyncMessage::SyncStep1(_) => "SyncStep1",
SyncMessage::SyncStep2(_) => "SyncStep2",
SyncMessage::Update(_) => "Update",
}
);
let update_data =
match &sync_msg {
SyncMessage::Update(data) => {
if read_only {
warn!(
"Rejecting CRDT Update from read-only user {} for doc {} ({} bytes)",
conn.user_id, conn.doc_id, data.len()
);
return;
}
if data.is_empty() {
debug!("Received empty Update message from conn {}", conn.conn_id);
None
} else {
Some(data.clone())
}
}
SyncMessage::SyncStep2(data) => {
debug!(
"Received SyncStep2 from conn {} ({} bytes) - not storing",
conn.conn_id,
data.len()
);
None
}
SyncMessage::SyncStep1(_) => None,
};
if let Some(data) = update_data {
if let Err(e) = Update::decode_v1(&data) {
warn!(
"Rejecting malformed update from conn {} - decode failed: {}",
conn.conn_id, e
);
return;
}
let update = cloudillo_types::crdt_adapter::CrdtUpdate::with_client(
data.clone(),
conn.user_id.clone(),
);
match app.crdt_adapter.store_update(tn_id, &conn.doc_id, update).await {
Ok(()) => {
info!(
"✓ CRDT update stored for doc {} from user {} ({} bytes)",
conn.doc_id,
conn.user_id,
data.len()
);
record_file_modification_throttled(app, conn).await;
}
Err(e) => {
warn!("❌ CRDT update FAILED to store for doc {} from user {}: {} - NOT broadcasting to prevent data loss", conn.doc_id, conn.user_id, e);
return;
}
}
}
broadcast_message(
&conn.sync_tx,
&conn.conn_id,
&conn.user_id,
&conn.doc_id,
data.to_vec(),
"SYNC",
);
send_echo_raw(ws_tx, &conn.conn_id, &conn.user_id, &conn.doc_id, data, "SYNC").await;
}
Ok(YMessage::Awareness(_awareness_update)) => {
debug!(
"CRDT AWARENESS from conn {} (user {}) for doc {} ({} bytes)",
conn.conn_id,
conn.user_id,
conn.doc_id,
data.len()
);
broadcast_message(
&conn.awareness_tx,
&conn.conn_id,
&conn.user_id,
&conn.doc_id,
data.to_vec(),
"AWARENESS",
);
send_echo_raw(ws_tx, &conn.conn_id, &conn.user_id, &conn.doc_id, data, "AWARENESS")
.await;
}
Ok(other) => {
debug!("Received non-sync/awareness message: {:?}", other);
}
Err(e) => {
warn!("Failed to decode yrs message from conn {}: {}", conn.conn_id, e);
}
}
}
async fn log_doc_statistics(app: &App, tn_id: TnId, doc_id: &str) {
match app.crdt_adapter.get_updates(tn_id, doc_id).await {
Ok(updates) => {
let update_count = updates.len();
let total_size: usize = updates.iter().map(|u| u.data.len()).sum();
let avg_size = if update_count > 0 { total_size / update_count } else { 0 };
info!(
"CRDT doc stats [{}]: {} updates, {} bytes total, {} bytes avg",
doc_id, update_count, total_size, avg_size
);
}
Err(e) => {
warn!("Failed to get statistics for doc {}: {}", doc_id, e);
}
}
}
async fn optimize_document(app: &App, tn_id: TnId, doc_id: &str) {
let updates = match app.crdt_adapter.get_updates(tn_id, doc_id).await {
Ok(u) => u,
Err(e) => {
warn!("Failed to get updates for optimization of doc {}: {}", doc_id, e);
return;
}
};
if updates.len() <= 1 {
debug!("Skipping optimization for doc {} (only {} updates)", doc_id, updates.len());
return;
}
let updates_before = updates.len();
let size_before: usize = updates.iter().map(|u| u.data.len()).sum();
let doc_id_for_task = doc_id.to_string();
let (merged_update, skipped_count) = match tokio::task::spawn_blocking(move || {
let mut decoded_updates = Vec::new();
let mut skipped = 0;
for (idx, update) in updates.iter().enumerate() {
if update.data.is_empty() {
warn!("Skipping empty update #{} for doc {}", idx, &doc_id_for_task);
skipped += 1;
continue;
}
let decoded = match yrs::Update::decode_v1(&update.data) {
Ok(u) => u,
Err(e) => {
warn!(
"Failed to decode update #{} for doc {} (size: {} bytes, first 20 bytes: {:?}): {}",
idx,
&doc_id_for_task,
update.data.len(),
&update.data[..20.min(update.data.len())],
e
);
skipped += 1;
continue;
}
};
decoded_updates.push(decoded);
}
if decoded_updates.is_empty() {
return Err(format!(
"No valid updates to merge (all {} updates corrupted)",
updates.len()
));
}
let update_count = decoded_updates.len();
info!(
"Merging {} valid updates for doc {} ({} skipped)",
update_count, &doc_id_for_task, skipped
);
info!("Using Doc-based merge for {} updates", update_count);
let doc = yrs::Doc::new();
let mut failed_count = 0;
{
let mut txn = doc.transact_mut();
for (idx, decoded_update) in decoded_updates.into_iter().enumerate() {
match txn.apply_update(decoded_update) {
Ok(()) => {
debug!(
"Applied update #{} successfully during merge for doc {}",
idx, &doc_id_for_task
);
}
Err(e) => {
warn!(
"Failed to apply update #{} during merge for doc {}: {}",
idx, &doc_id_for_task, e
);
failed_count += 1;
}
}
}
}
if failed_count > 0 {
warn!(
"Optimization warning for doc {}: {} out of {} updates failed to apply",
&doc_id_for_task, failed_count, update_count
);
}
let state_vector = yrs::StateVector::default();
let txn = doc.transact();
let encoded = txn.encode_state_as_update_v1(&state_vector);
info!(
"Doc-based merge complete for {}: {} updates merged into {} bytes",
&doc_id_for_task,
update_count,
encoded.len()
);
if encoded.is_empty() {
return Err(format!(
"Merged update for {} is empty (0 bytes)! This would cause data loss. Aborting optimization.",
&doc_id_for_task
));
}
info!("Merged update validation passed, proceeding with optimization");
Ok((encoded, skipped))
})
.await
{
Ok(Ok(result)) => result,
Ok(Err(e)) => {
warn!("Failed to merge updates for doc {}: {}", doc_id, e);
return;
}
Err(e) => {
warn!("Failed to spawn blocking task for doc {}: {}", doc_id, e);
return;
}
};
let size_after = merged_update.len();
info!(
"Optimization size check for doc {}: before={} bytes, after={} bytes, reduction={} bytes",
doc_id,
size_before,
size_after,
size_before.saturating_sub(size_after)
);
if size_after >= size_before {
info!(
"Skipping optimization for doc {} (no size reduction: {} -> {})",
doc_id, size_before, size_after
);
return;
}
info!("Proceeding with optimization for doc {} (delete + store)", doc_id);
if let Err(e) = app.crdt_adapter.delete_doc(tn_id, doc_id).await {
warn!("Failed to delete doc {} during optimization: {}", doc_id, e);
return;
}
let merged_crdt_update = cloudillo_types::crdt_adapter::CrdtUpdate::with_client(
merged_update,
"system".to_string(), );
if let Err(e) = app.crdt_adapter.store_update(tn_id, doc_id, merged_crdt_update).await {
warn!("Failed to store optimized update for doc {}: {}", doc_id, e);
return;
}
let size_reduction = size_before - size_after;
let reduction_percent = (usize_to_f64(size_reduction) / usize_to_f64(size_before)) * 100.0;
let skipped_msg = if skipped_count > 0 {
format!(", {} corrupted updates skipped", skipped_count)
} else {
String::new()
};
info!(
"CRDT doc optimized [{}]: {} -> 1 updates, {} -> {} bytes ({:.1}% reduction){}",
doc_id, updates_before, size_before, size_after, reduction_percent, skipped_msg
);
}
async fn cleanup_registry(doc_id: &str) -> bool {
let docs = CRDT_DOCS.read().await;
if let Some((awareness_tx, sync_tx)) = docs.get(doc_id) {
let awareness_count = awareness_tx.receiver_count();
let sync_count = sync_tx.receiver_count();
info!(
"Checking CRDT registry cleanup for doc {}: {} awareness receivers, {} sync receivers",
doc_id, awareness_count, sync_count
);
if awareness_count == 0 && sync_count == 0 {
drop(docs);
CRDT_DOCS.write().await.remove(doc_id);
info!("✓ Cleaned up CRDT registry for doc {} - triggering optimization", doc_id);
return true; }
info!(
"✗ Not cleaning up doc {} - still has active receivers (awareness: {}, sync: {})",
doc_id, awareness_count, sync_count
);
} else {
info!("✗ Doc {} not found in registry during cleanup", doc_id);
}
false
}
fn now_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
async fn record_file_access_throttled(app: &App, conn: &CrdtConnection) {
let should_update = {
let mut last_update = conn.last_access_update.lock().await;
let now = Instant::now();
let should = match *last_update {
Some(last) => now.duration_since(last).as_secs() >= TRACKING_THROTTLE_SECS,
None => true,
};
if should {
*last_update = Some(now);
}
should
};
if should_update {
if let Err(e) = app
.meta_adapter
.record_file_access(conn.tn_id, &conn.user_id, &conn.doc_id)
.await
{
debug!("Failed to record file access for doc {}: {}", conn.doc_id, e);
}
}
}
async fn record_file_modification_throttled(app: &App, conn: &CrdtConnection) {
conn.has_modified.store(true, Ordering::Relaxed);
let should_update = {
let mut last_update = conn.last_modify_update.lock().await;
let now = Instant::now();
let should = match *last_update {
Some(last) => now.duration_since(last).as_secs() >= TRACKING_THROTTLE_SECS,
None => true,
};
if should {
*last_update = Some(now);
}
should
};
if should_update {
if let Err(e) = app
.meta_adapter
.record_file_modification(conn.tn_id, &conn.user_id, &conn.doc_id)
.await
{
debug!("Failed to record file modification for doc {}: {}", conn.doc_id, e);
}
}
}
async fn record_final_activity(app: &App, conn: &CrdtConnection) {
if let Err(e) = app
.meta_adapter
.record_file_access(conn.tn_id, &conn.user_id, &conn.doc_id)
.await
{
debug!("Failed to record final file access for doc {}: {}", conn.doc_id, e);
}
if conn.has_modified.load(Ordering::Relaxed) {
if let Err(e) = app
.meta_adapter
.record_file_modification(conn.tn_id, &conn.user_id, &conn.doc_id)
.await
{
debug!("Failed to record final file modification for doc {}: {}", conn.doc_id, e);
}
}
}