use crate::id::ReplicaId;
use crate::version::VersionVector;
use std::collections::BTreeMap;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
pub mod update;
pub use update::{snapshot_string_to_yjs_update, snapshot_text_to_yjs_update};
pub mod lib0 {
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Error {
Truncated,
Overflow,
InvalidUtf8,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Truncated => f.write_str("lib0: unexpected end of buffer"),
Error::Overflow => f.write_str("lib0: var-int/var-uint overflow"),
Error::InvalidUtf8 => f.write_str("lib0: invalid UTF-8 in var-string"),
}
}
}
impl std::error::Error for Error {}
pub fn write_var_uint(out: &mut Vec<u8>, mut n: u64) {
while n >= 0x80 {
out.push(((n & 0x7F) | 0x80) as u8);
n >>= 7;
}
out.push(n as u8);
}
pub fn read_var_uint(bytes: &[u8]) -> Result<(u64, usize), Error> {
let mut n: u64 = 0;
let mut shift: u32 = 0;
for (i, &byte) in bytes.iter().enumerate() {
if shift >= 64 {
return Err(Error::Overflow);
}
let chunk = u64::from(byte & 0x7F);
n |= chunk.checked_shl(shift).ok_or(Error::Overflow)?;
if byte & 0x80 == 0 {
return Ok((n, i + 1));
}
shift += 7;
}
Err(Error::Truncated)
}
pub fn write_var_int(out: &mut Vec<u8>, n: i64) {
let mut value = n.unsigned_abs();
let sign = if n < 0 { 0x40 } else { 0 };
let cont = if value > 0x3F { 0x80 } else { 0 };
out.push((value & 0x3F) as u8 | sign | cont);
value >>= 6;
while value > 0 {
let cont = if value > 0x7F { 0x80 } else { 0 };
out.push((value & 0x7F) as u8 | cont);
value >>= 7;
}
}
pub fn read_var_int(bytes: &[u8]) -> Result<(i64, usize), Error> {
let first = *bytes.first().ok_or(Error::Truncated)?;
let sign = first & 0x40 != 0;
let mut value = u64::from(first & 0x3F);
let mut shift = 6u32;
let mut consumed = 1usize;
if first & 0x80 != 0 {
for &byte in &bytes[1..] {
consumed += 1;
if shift >= 64 {
return Err(Error::Overflow);
}
let chunk = u64::from(byte & 0x7F);
value |= chunk.checked_shl(shift).ok_or(Error::Overflow)?;
shift += 7;
if byte & 0x80 == 0 {
break;
}
if consumed >= bytes.len() {
return Err(Error::Truncated);
}
}
}
if value > i64::MAX as u64 {
return Err(Error::Overflow);
}
#[allow(clippy::cast_possible_wrap)]
let unsigned_signed = value as i64;
let signed = if sign {
-unsigned_signed
} else {
unsigned_signed
};
Ok((signed, consumed))
}
pub fn write_var_string(out: &mut Vec<u8>, s: &str) {
let bytes = s.as_bytes();
write_var_uint(out, bytes.len() as u64);
out.extend_from_slice(bytes);
}
pub fn read_var_string(bytes: &[u8]) -> Result<(String, usize), Error> {
let (len, hdr) = read_var_uint(bytes)?;
let len = len as usize;
let payload = bytes.get(hdr..hdr + len).ok_or(Error::Truncated)?;
let s = std::str::from_utf8(payload)
.map_err(|_| Error::InvalidUtf8)?
.to_string();
Ok((s, hdr + len))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn var_uint_round_trip() {
for v in [0u64, 1, 127, 128, 255, 16384, u64::MAX] {
let mut buf = Vec::new();
write_var_uint(&mut buf, v);
let (decoded, n) = read_var_uint(&buf).unwrap();
assert_eq!(decoded, v);
assert_eq!(n, buf.len());
}
}
#[test]
fn var_int_round_trip() {
for v in [0i64, 1, -1, 63, -63, 64, -64, 1_000_000, -1_000_000] {
let mut buf = Vec::new();
write_var_int(&mut buf, v);
let (decoded, n) = read_var_int(&buf).unwrap();
assert_eq!(decoded, v, "round-trip failed for {v}");
assert_eq!(n, buf.len());
}
}
#[test]
fn var_string_round_trip() {
for s in ["", "hello", "👨👩👧 emoji"] {
let mut buf = Vec::new();
write_var_string(&mut buf, s);
let (decoded, n) = read_var_string(&buf).unwrap();
assert_eq!(decoded, s);
assert_eq!(n, buf.len());
}
}
#[test]
fn var_uint_truncated() {
let buf = [0x80u8, 0x80]; assert_eq!(read_var_uint(&buf).err(), Some(Error::Truncated));
}
#[test]
fn var_uint_known_encoding() {
let mut buf = Vec::new();
write_var_uint(&mut buf, 127);
assert_eq!(buf, vec![0x7F]);
let mut buf = Vec::new();
write_var_uint(&mut buf, 128);
assert_eq!(buf, vec![0x80, 0x01]);
}
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct StateVector {
pub clocks: BTreeMap<ReplicaId, u64>,
}
impl StateVector {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn len(&self) -> usize {
self.clocks.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.clocks.is_empty()
}
#[must_use]
pub fn get(&self, replica: ReplicaId) -> u64 {
self.clocks.get(&replica).copied().unwrap_or(0)
}
#[must_use]
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(self.clocks.len() * 4);
lib0::write_var_uint(&mut buf, self.clocks.len() as u64);
for (&replica, &clock) in &self.clocks {
lib0::write_var_uint(&mut buf, replica);
lib0::write_var_uint(&mut buf, clock);
}
buf
}
pub fn decode(bytes: &[u8]) -> Result<Self, lib0::Error> {
let mut cursor = 0usize;
let (count, n) = lib0::read_var_uint(&bytes[cursor..])?;
cursor += n;
let mut clocks = BTreeMap::new();
for _ in 0..count {
let (replica, n) = lib0::read_var_uint(&bytes[cursor..])?;
cursor += n;
let (clock, n) = lib0::read_var_uint(&bytes[cursor..])?;
cursor += n;
clocks.insert(replica, clock);
}
Ok(Self { clocks })
}
#[must_use]
pub fn from_version(v: &VersionVector) -> Self {
let mut clocks = BTreeMap::new();
for (replica, counter) in v.iter_clocks() {
clocks.insert(replica, counter);
}
Self { clocks }
}
#[must_use]
pub fn to_version(&self) -> VersionVector {
let mut v = VersionVector::new();
for (&replica, &counter) in &self.clocks {
v.observe(crate::id::OpId::new(counter, replica));
}
v
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_state_vector_round_trip() {
let sv = StateVector::new();
let bytes = sv.encode();
let restored = StateVector::decode(&bytes).unwrap();
assert_eq!(sv, restored);
}
#[test]
fn state_vector_round_trip() {
let mut sv = StateVector::new();
sv.clocks.insert(1, 100);
sv.clocks.insert(2, 50);
sv.clocks.insert(99, 1_000_000);
let bytes = sv.encode();
let restored = StateVector::decode(&bytes).unwrap();
assert_eq!(sv, restored);
}
#[test]
fn state_vector_yjs_byte_format() {
let mut sv = StateVector::new();
sv.clocks.insert(1, 5);
sv.clocks.insert(2, 7);
assert_eq!(sv.encode(), vec![0x02, 0x01, 0x05, 0x02, 0x07]);
}
#[test]
fn convert_to_from_version_vector() {
let mut sv = StateVector::new();
sv.clocks.insert(7, 42);
sv.clocks.insert(8, 11);
let v = sv.to_version();
let sv2 = StateVector::from_version(&v);
assert_eq!(sv, sv2);
}
}