use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use postgres::Client;
use pylon_crdt::{
apply_patch, apply_update as crdt_apply_update, encode_snapshot, encode_update_since,
loro::{LoroDoc, VersionVector},
project_doc_to_json, CrdtField,
};
use pylon_storage::pg_exec::PgConn;
use serde_json::Value;
use crate::loro_store::LoroStoreError;
pub const CREATE_PG_SIDECAR_SQL: &str = "\
CREATE TABLE IF NOT EXISTS _pylon_crdt_snapshots (\
entity text NOT NULL,\
row_id text NOT NULL,\
snapshot bytea NOT NULL,\
updated_at timestamptz NOT NULL DEFAULT now(),\
PRIMARY KEY (entity, row_id)\
)";
pub fn ensure_sidecar(client: &mut Client) -> Result<(), LoroStoreError> {
client
.execute(CREATE_PG_SIDECAR_SQL, &[])
.map(|_| ())
.map_err(|e| LoroStoreError::Storage(format!("create pg sidecar: {e}")))
}
#[derive(Default)]
pub struct PgLoroStore {
docs: Mutex<HashMap<(String, String), Arc<Mutex<LoroDoc>>>>,
}
impl PgLoroStore {
pub fn new() -> Self {
Self::default()
}
fn hydrate_for_write<C: PgConn>(
conn: &mut C,
entity: &str,
row_id: &str,
) -> Result<LoroDoc, LoroStoreError> {
let entity_key = pg_advisory_key(entity);
let row_key = pg_advisory_key(row_id);
conn.execute(
"SELECT pg_advisory_xact_lock($1::int, $2::int)",
&[&entity_key, &row_key],
)
.map_err(|e| LoroStoreError::Storage(format!("crdt advisory lock: {e}")))?;
let snapshot: Option<Vec<u8>> = conn
.query_opt(
"SELECT snapshot FROM _pylon_crdt_snapshots \
WHERE entity = $1 AND row_id = $2",
&[&entity, &row_id],
)
.map_err(|e| LoroStoreError::Storage(format!("read pg snapshot: {e}")))?
.map(|r| r.get::<_, Vec<u8>>(0));
let doc = LoroDoc::new();
if let Some(bytes) = snapshot {
crdt_apply_update(&doc, &bytes).map_err(LoroStoreError::Decode)?;
}
Ok(doc)
}
}
fn pg_advisory_key(s: &str) -> i32 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
s.hash(&mut hasher);
let h = hasher.finish();
(h as u32) as i32
}
impl PgLoroStore {
fn get_or_hydrate_read<C: PgConn>(
&self,
conn: &mut C,
entity: &str,
row_id: &str,
) -> Result<Arc<Mutex<LoroDoc>>, LoroStoreError> {
let key = (entity.to_string(), row_id.to_string());
{
let guard = self.docs.lock().unwrap();
if let Some(doc) = guard.get(&key) {
return Ok(Arc::clone(doc));
}
}
let snapshot: Option<Vec<u8>> = conn
.query_opt(
"SELECT snapshot FROM _pylon_crdt_snapshots WHERE entity = $1 AND row_id = $2",
&[&entity, &row_id],
)
.map_err(|e| LoroStoreError::Storage(format!("read pg snapshot: {e}")))?
.map(|r| r.get::<_, Vec<u8>>(0));
let doc = LoroDoc::new();
if let Some(bytes) = snapshot {
crdt_apply_update(&doc, &bytes).map_err(LoroStoreError::Decode)?;
}
let handle = Arc::new(Mutex::new(doc));
let mut guard = self.docs.lock().unwrap();
let entry = guard.entry(key).or_insert_with(|| Arc::clone(&handle));
Ok(Arc::clone(entry))
}
fn persist_snapshot<C: PgConn>(
conn: &mut C,
entity: &str,
row_id: &str,
doc: &LoroDoc,
) -> Result<(), LoroStoreError> {
let snap = encode_snapshot(doc);
conn.execute(
"INSERT INTO _pylon_crdt_snapshots (entity, row_id, snapshot, updated_at) \
VALUES ($1, $2, $3, now()) \
ON CONFLICT (entity, row_id) DO UPDATE \
SET snapshot = EXCLUDED.snapshot, updated_at = EXCLUDED.updated_at",
&[&entity, &row_id, &snap],
)
.map(|_| ())
.map_err(|e| LoroStoreError::Storage(format!("persist pg snapshot: {e}")))
}
pub fn apply_patch<C: PgConn>(
&self,
conn: &mut C,
entity: &str,
row_id: &str,
fields: &[CrdtField],
patch: &Value,
) -> Result<Value, LoroStoreError> {
let doc = Self::hydrate_for_write(conn, entity, row_id)?;
apply_patch(&doc, fields, patch).map_err(LoroStoreError::Apply)?;
Self::persist_snapshot(conn, entity, row_id, &doc)?;
let projected = project_doc_to_json(&doc, fields);
Ok(projected)
}
pub fn apply_remote_update<C: PgConn>(
&self,
conn: &mut C,
entity: &str,
row_id: &str,
fields: &[CrdtField],
update: &[u8],
) -> Result<Value, LoroStoreError> {
let doc = Self::hydrate_for_write(conn, entity, row_id)?;
crdt_apply_update(&doc, update).map_err(LoroStoreError::Decode)?;
Self::persist_snapshot(conn, entity, row_id, &doc)?;
let projected = project_doc_to_json(&doc, fields);
Ok(projected)
}
pub fn snapshot<C: PgConn>(
&self,
conn: &mut C,
entity: &str,
row_id: &str,
) -> Result<Vec<u8>, LoroStoreError> {
let handle = self.get_or_hydrate_read(conn, entity, row_id)?;
let doc = handle.lock().unwrap();
Ok(encode_snapshot(&doc))
}
pub fn update_since<C: PgConn>(
&self,
conn: &mut C,
entity: &str,
row_id: &str,
since: &VersionVector,
) -> Result<Vec<u8>, LoroStoreError> {
let handle = self.get_or_hydrate_read(conn, entity, row_id)?;
let doc = handle.lock().unwrap();
Ok(encode_update_since(&doc, since))
}
pub fn read_snapshot_via_conn<C: PgConn>(
conn: &mut C,
entity: &str,
row_id: &str,
) -> Result<Vec<u8>, LoroStoreError> {
let snap: Option<Vec<u8>> = conn
.query_opt(
"SELECT snapshot FROM _pylon_crdt_snapshots WHERE entity = $1 AND row_id = $2",
&[&entity, &row_id],
)
.map_err(|e| LoroStoreError::Storage(format!("read pg snapshot: {e}")))?
.map(|r| r.get::<_, Vec<u8>>(0));
let bytes = snap.unwrap_or_default();
if bytes.is_empty() {
let doc = LoroDoc::new();
Ok(encode_snapshot(&doc))
} else {
Ok(bytes)
}
}
pub fn cache_after_commit<C: PgConn>(&self, conn: &mut C, entity: &str, row_id: &str) {
let snap_result = conn.query_opt(
"SELECT snapshot FROM _pylon_crdt_snapshots WHERE entity = $1 AND row_id = $2",
&[&entity, &row_id],
);
let bytes = match snap_result {
Ok(Some(row)) => row.get::<_, Vec<u8>>(0),
_ => {
self.evict(entity, row_id);
return;
}
};
let doc = LoroDoc::new();
if crdt_apply_update(&doc, &bytes).is_err() {
self.evict(entity, row_id);
return;
}
let handle = Arc::new(Mutex::new(doc));
let mut guard = self.docs.lock().unwrap();
guard.insert((entity.to_string(), row_id.to_string()), handle);
}
pub fn evict(&self, entity: &str, row_id: &str) {
self.docs
.lock()
.unwrap()
.remove(&(entity.to_string(), row_id.to_string()));
}
pub fn cached_rows(&self) -> usize {
self.docs.lock().unwrap().len()
}
}
use pylon_kernel::AppManifest;
use pylon_storage::pg_tx_store::PgCrdtHook;
pub struct PgCrdtHookImpl {
pub crdt: std::sync::Arc<PgLoroStore>,
pub manifest: std::sync::Arc<AppManifest>,
}
impl PgCrdtHook for PgCrdtHookImpl {
fn before_insert(
&self,
tx: &mut postgres::Transaction<'_>,
entity: &str,
data: &serde_json::Value,
) -> Result<Option<serde_json::Value>, pylon_http::DataError> {
let ent = self
.manifest
.entities
.iter()
.find(|e| e.name == entity)
.ok_or_else(|| pylon_http::DataError {
code: "ENTITY_NOT_FOUND".into(),
message: format!("Unknown entity: {entity}"),
})?;
let crdt_fields = crdt_fields_for(ent)?;
let id = data
.get("id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.unwrap_or_else(crate::generate_id);
self.crdt
.apply_patch(tx, entity, &id, &crdt_fields, data)
.map_err(|e| pylon_http::DataError {
code: "CRDT_APPLY_FAILED".into(),
message: format!("crdt write {entity}/{id}: {e}"),
})?;
let mut row = data.clone();
if let Some(obj) = row.as_object_mut() {
obj.insert("id".into(), serde_json::Value::String(id.clone()));
}
Ok(Some(row))
}
fn before_update(
&self,
tx: &mut postgres::Transaction<'_>,
entity: &str,
id: &str,
data: &serde_json::Value,
) -> Result<(), pylon_http::DataError> {
let ent = self
.manifest
.entities
.iter()
.find(|e| e.name == entity)
.ok_or_else(|| pylon_http::DataError {
code: "ENTITY_NOT_FOUND".into(),
message: format!("Unknown entity: {entity}"),
})?;
let crdt_fields = crdt_fields_for(ent)?;
self.crdt
.apply_patch(tx, entity, id, &crdt_fields, data)
.map(|_| ())
.map_err(|e| pylon_http::DataError {
code: "CRDT_APPLY_FAILED".into(),
message: format!("crdt update {entity}/{id}: {e}"),
})
}
fn before_delete(
&self,
tx: &mut postgres::Transaction<'_>,
entity: &str,
id: &str,
) -> Result<(), pylon_http::DataError> {
tx.execute(
"DELETE FROM _pylon_crdt_snapshots WHERE entity = $1 AND row_id = $2",
&[&entity, &id],
)
.map(|_| ())
.map_err(|e| pylon_http::DataError {
code: "CRDT_SIDECAR_DELETE_FAILED".into(),
message: format!("delete pg crdt snapshot {entity}/{id}: {e}"),
})
}
fn after_commit(&self, entity: &str, id: &str) {
self.crdt.evict(entity, id);
}
fn on_rollback(&self, entity: &str, id: &str) {
self.crdt.evict(entity, id);
}
}
fn crdt_fields_for(
ent: &pylon_kernel::ManifestEntity,
) -> Result<Vec<pylon_crdt::CrdtField>, pylon_http::DataError> {
let mut out = Vec::with_capacity(ent.fields.len());
for f in &ent.fields {
if f.name == "id" {
continue;
}
let kind =
pylon_crdt::field_kind(&f.field_type, f.crdt).map_err(|e| pylon_http::DataError {
code: "INVALID_CRDT_FIELD".into(),
message: format!(
"{}.{}: {e} (declared type={}, crdt={:?})",
ent.name, f.name, f.field_type, f.crdt
),
})?;
out.push(pylon_crdt::CrdtField {
name: f.name.clone(),
kind,
});
}
Ok(out)
}