use std::io::BufRead;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use futures::StreamExt;
use ubiquisync_core::{
codec::{
CodecError, DecodedEntry, EntryBufferReader, EntryBufferWriter, IndexableOp, Op,
OpIndexEntry,
},
event::{EventHandler, Publisher},
hlc::Timestamp,
log_entry::LogEntry,
sync::{CursorsEvent, HasCursors, LogProcessor, LogSource, SyncError},
uuid::Uuid,
};
use crate::{
db::{Db, DbBatch, DbError, DbStatementResult, DbType, DbValue, StmtId, ValueBinder},
processor::{Processor, ProcessorError},
reducer::Reducer,
tracker::LogIndexTracker,
util::quote_ident,
};
const TAG_MAX: u8 = 1;
#[derive(Clone)]
struct MaxOp {
key: Vec<u8>,
value: i64,
}
impl Op for MaxOp {
fn decode<R: BufRead>(_tag: u8, r: &mut EntryBufferReader<R>) -> Result<Self, CodecError> {
let key = r.read_blob()?;
let value = r.read_zigzag()?;
Ok(MaxOp { key, value })
}
fn encode(&self, w: &mut EntryBufferWriter) -> Result<(), CodecError> {
w.write_byte(TAG_MAX);
w.write_blob(&self.key);
w.write_zigzag(self.value);
Ok(())
}
}
impl IndexableOp for MaxOp {
fn to_index_entry(&self) -> Result<OpIndexEntry, CodecError> {
Ok(OpIndexEntry {
tag: TAG_MAX,
key: self.key.clone(),
value: self.value.to_le_bytes().to_vec(),
})
}
fn from_index_parts(_tag: u8, key: &[u8], value: &[u8]) -> Result<Self, CodecError> {
let value = i64::from_le_bytes(value.try_into().map_err(|_| CodecError::UnexpectedEof)?);
Ok(MaxOp {
key: key.to_vec(),
value,
})
}
}
struct MaxRegister {
table: String,
}
impl MaxRegister {
fn new(name: &str) -> Self {
Self {
table: quote_ident(name),
}
}
}
#[async_trait]
impl Reducer for MaxRegister {
type Op = MaxOp;
type ReadState = ();
type ApplyState = StmtId;
type Event = i64;
type Error = DbError;
async fn prepare(&mut self, db: &dyn Db, _op: &MaxOp) -> Result<(), DbError> {
let int_type = DbType::Integer.sql_type(db.dialect());
let blob_type = DbType::Blob.sql_type(db.dialect());
let sql = format!(
"CREATE TABLE IF NOT EXISTS {} (k {blob_type} PRIMARY KEY, v {int_type} NOT NULL)",
self.table
);
db.exec(&sql, &[]).await?;
Ok(())
}
fn apply(
&self,
batch: &mut dyn DbBatch,
_timestamp: Timestamp,
op: &MaxOp,
_read: (),
) -> Result<StmtId, DbError> {
let mut binder = ValueBinder::new(batch.dialect());
let k = binder.bind_next(DbValue::Blob(op.key.clone()));
let v = binder.bind_next(DbValue::Integer(op.value));
let max = batch.dialect().scalar_max();
let sql = format!(
"INSERT INTO {tbl} (k, v) VALUES ({k}, {v}) \
ON CONFLICT(k) DO UPDATE SET v = {max}(COALESCE(v, 0), EXCLUDED.v) RETURNING v",
tbl = self.table
);
Ok(batch.add_statement(&sql, &binder.values()))
}
fn post_apply(
&self,
apply_state: StmtId,
batch_result: &[DbStatementResult],
) -> Result<Vec<i64>, DbError> {
Ok(vec![batch_result[apply_state.0].rows[0].get_i64(0)?])
}
}
type MaxProcessor<D, E> = Processor<MaxRegister, D, LogIndexTracker<MaxOp>, E>;
#[derive(Clone, Default)]
struct Captured(Arc<Mutex<Vec<i64>>>);
impl Publisher<i64> for Captured {
fn publish(&self, event: i64) {
self.0.lock().unwrap().push(event);
}
}
impl Captured {
fn last(&self) -> i64 {
*self.0.lock().unwrap().last().expect("an event was published")
}
fn count(&self) -> usize {
self.0.lock().unwrap().len()
}
}
struct CaptureHandler(Captured);
impl CaptureHandler {
fn last(&self) -> i64 {
self.0.last()
}
fn count(&self) -> usize {
self.0.count()
}
}
impl EventHandler<i64> for CaptureHandler {
type Publish = Captured;
fn init() -> (Self::Publish, Self) {
let captured = Captured::default();
(captured.clone(), CaptureHandler(captured))
}
}
const PEER: Uuid = [7u8; 16];
const USER: Uuid = [9u8; 16];
const NODE: Uuid = [1u8; 16];
const PREFIX: &str = "app";
fn entry(key: &[u8], value: i64, millis: u64, server_user_id: Option<Uuid>) -> LogEntry<MaxOp> {
LogEntry {
server_user_id,
timestamp: Timestamp::from_parts(millis, 0),
op: MaxOp {
key: key.to_vec(),
value,
},
}
}
async fn oplog_row_count<D: Db>(db: &D) -> i64 {
let sql = format!(
"SELECT COUNT(*) FROM {}",
quote_ident(&format!("{PREFIX}__oplog"))
);
db.query(&sql, &[]).await.unwrap()[0].get_i64(0).unwrap()
}
async fn oplog_server_user_id<D: Db>(db: &D, entry_idx: u64) -> Option<Uuid> {
let mut binder = ValueBinder::new(db.dialect());
let idx = binder.bind_next(DbValue::from_u64(entry_idx).unwrap());
let sql = format!(
"SELECT server_user_id FROM {} WHERE entry_idx = {idx}",
quote_ident(&format!("{PREFIX}__oplog"))
);
db.query(&sql, &binder.values()).await.unwrap()[0]
.get_optional_uuid(0)
.unwrap()
}
async fn register_value<D: Db>(db: &D, key: &[u8]) -> i64 {
let mut binder = ValueBinder::new(db.dialect());
let k = binder.bind_next(DbValue::Blob(key.to_vec()));
let sql = format!("SELECT v FROM {} WHERE k = {k}", quote_ident("reg"));
db.query(&sql, &binder.values()).await.unwrap()[0]
.get_i64(0)
.unwrap()
}
async fn clock_register<D: Db>(db: &D) -> u64 {
let sql = format!(
"SELECT ts FROM {} WHERE id = 1",
quote_ident(&format!("{PREFIX}__hlc"))
);
db.query(&sql, &[]).await.unwrap()[0].get_u64(0).unwrap()
}
pub async fn run_max_register_suite<D: Db>(db: D) {
let processor: MaxProcessor<D, CaptureHandler> =
Processor::open(MaxRegister::new("reg"), db, PREFIX, NODE)
.await
.unwrap();
processor
.process_one(&PEER, 0, &entry(b"x", 5, 1_700_000_000_000, None))
.await
.unwrap();
assert_eq!(processor.event_handler().last(), 5, "first write sets the value");
processor
.process_one(&PEER, 1, &entry(b"x", 3, 1_700_000_000_001, None))
.await
.unwrap();
assert_eq!(processor.event_handler().last(), 5, "smaller value does not lower the register");
processor
.process_one(&PEER, 2, &entry(b"x", 9, 1_700_000_000_002, Some(USER)))
.await
.unwrap();
assert_eq!(processor.event_handler().last(), 9, "larger value raises the register");
assert_eq!(
oplog_server_user_id(processor.db(), 2).await,
Some(USER),
"attributed entry stores its user id"
);
assert_eq!(
oplog_server_user_id(processor.db(), 0).await,
None,
"unattributed entry stores NULL user id"
);
let committed_rows = oplog_row_count(processor.db()).await;
let committed_clock = clock_register(processor.db()).await;
assert_eq!(committed_rows, 3);
assert_eq!(
committed_clock,
Timestamp::from_parts(1_700_000_000_002, 0).raw()
);
let err = processor
.process_one(&PEER, 0, &entry(b"x", 100, 1_700_000_000_003, None))
.await
.unwrap_err();
assert!(
matches!(err, ProcessorError::Db(DbError::UniqueViolation)),
"duplicate surfaces as a unique violation, got {err:?}"
);
assert_eq!(
oplog_row_count(processor.db()).await,
committed_rows,
"rolled-back duplicate added no op-log row"
);
assert_eq!(
clock_register(processor.db()).await,
committed_clock,
"rolled-back observe did not advance the persisted clock"
);
processor
.process_one(&PEER, 3, &entry(b"x", 1, 1_700_000_000_004, None))
.await
.unwrap();
assert_eq!(processor.event_handler().last(), 9, "rolled-back duplicate left the register at 9");
processor
.process_one(&PEER, 4, &entry(b"y", 7, 1_700_000_000_005, None))
.await
.unwrap();
assert_eq!(processor.event_handler().last(), 7, "distinct key has its own register");
}
pub async fn run_replica_suite<D: Db>(db: D) {
let processor: MaxProcessor<D, CaptureHandler> = Processor::open(MaxRegister::new("reg"), db, PREFIX, NODE)
.await
.unwrap();
assert!(!processor.cursors().await.unwrap().contains_key(&PEER));
let stream = [
DecodedEntry::LogEntry(entry(b"x", 5, 1_700_000_000_000, None)),
DecodedEntry::LogEntry(entry(b"x", 9, 1_700_000_000_001, Some(USER))),
DecodedEntry::Expunged(blake3::hash(b"gone")),
DecodedEntry::LogEntry(entry(b"y", 4, 1_700_000_000_002, None)),
];
for (idx, e) in stream.iter().cloned().enumerate() {
assert!(
processor.apply(PEER, idx as u64, e).await.unwrap().new,
"fresh slot {idx} applies"
);
}
assert_eq!(
processor.event_handler().count(),
3,
"expunged apply emits no event"
);
assert_eq!(
processor.event_handler().last(),
4,
"the last applied LogEntry emitted its value"
);
assert_eq!(processor.cursors().await.unwrap().get(&PEER).copied(), Some(4));
assert_eq!(oplog_row_count(processor.db()).await, 4);
assert_eq!(oplog_server_user_id(processor.db(), 1).await, Some(USER));
assert_eq!(
oplog_server_user_id(processor.db(), 2).await,
None,
"expunged marker carries no attribution"
);
let read = processor.read_since(PEER, 0).await.unwrap();
assert_eq!(
read.iter().map(|(i, _)| *i).collect::<Vec<_>>(),
vec![0, 1, 2, 3]
);
match read[2].1 {
DecodedEntry::Expunged(h) => {
assert_eq!(h, blake3::hash(b"gone"), "expunged hash round-trips");
}
_ => panic!("index 2 should round-trip as the expunged marker"),
}
match &read[1].1 {
DecodedEntry::LogEntry(e) => {
assert_eq!(e.server_user_id, Some(USER));
assert_eq!(e.op.value, 9);
assert_eq!(e.op.key, b"x");
}
_ => panic!("index 1 should reconstruct as a real entry"),
}
let redelivery = processor
.apply(
PEER,
1,
DecodedEntry::LogEntry(entry(b"x", 999, 1_700_000_000_009, None)),
)
.await
.unwrap();
assert!(!redelivery.new, "re-delivered index is dropped");
assert_eq!(
oplog_row_count(processor.db()).await,
4,
"dropped re-delivery added no row"
);
assert_eq!(
register_value(processor.db(), b"x").await,
9,
"dropped re-delivery left the register at 9, not 999"
);
assert_eq!(
processor.event_handler().count(),
3,
"dropped re-delivery emits no event"
);
let mut watch = processor.watch_cursors();
match watch.next().await {
Some(CursorsEvent::Snapshot(c)) => assert_eq!(c.get(&PEER).copied(), Some(4)),
other => panic!("expected a snapshot first, got {other:?}"),
}
assert!(
processor
.apply(
PEER,
4,
DecodedEntry::LogEntry(entry(b"x", 12, 1_700_000_000_003, None)),
)
.await
.unwrap()
.new
);
match watch.next().await {
Some(CursorsEvent::Advanced(c)) => assert_eq!(c.get(&PEER).copied(), Some(5)),
other => panic!("expected an advance, got {other:?}"),
}
assert!(
processor
.apply(
PEER,
5,
DecodedEntry::LogEntry(entry(b"z", 1, 1_700_000_000_004, None)),
)
.await
.unwrap()
.new
);
assert_eq!(processor.cursors().await.unwrap().get(&PEER).copied(), Some(6));
assert_eq!(oplog_row_count(processor.db()).await, 6);
let err = processor
.apply(
PEER,
8,
DecodedEntry::LogEntry(entry(b"x", 1, 1_700_000_000_010, None)),
)
.await
.unwrap_err();
assert!(
matches!(
err,
SyncError::CursorMismatch {
expected_idx: 6,
actual_idx: 8
}
),
"gap rejected, got {err:?}"
);
assert_eq!(
oplog_row_count(processor.db()).await,
6,
"rejected gap wrote nothing"
);
let e1 = DecodedEntry::LogEntry(entry(b"x", 2, 1_700_000_000_011, None));
let e2 = DecodedEntry::LogEntry(entry(b"x", 3, 1_700_000_000_012, None));
let (a, b) = futures::join!(processor.apply(PEER, 6, e1), processor.apply(PEER, 6, e2));
assert!(
a.unwrap().new ^ b.unwrap().new,
"exactly one of two concurrent same-index applies is new"
);
assert_eq!(processor.cursors().await.unwrap().get(&PEER).copied(), Some(7));
assert_eq!(
oplog_row_count(processor.db()).await,
7,
"only one of the two concurrent applies committed"
);
processor
.exec(
None,
MaxOp {
key: b"local".to_vec(),
value: 1,
},
)
.await
.unwrap();
processor
.exec(
None,
MaxOp {
key: b"local".to_vec(),
value: 5,
},
)
.await
.unwrap();
assert_eq!(
processor.cursors().await.unwrap().get(&NODE).copied(),
Some(2),
"two local writes advanced self's cursor"
);
assert_eq!(
processor.read_since(NODE, 0).await.unwrap().len(),
2,
"both local writes reconstruct from the op-log"
);
assert_eq!(
register_value(processor.db(), b"local").await,
5,
"max register merged the two local writes"
);
}