use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use crate::{clock::Hlc, codec, device::DeviceId, Error, Result};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SyncCursor {
pub epoch: u64,
pub last_seq_per_device: BTreeMap<DeviceId, u64>,
pub last_hlc: Hlc,
}
impl SyncCursor {
pub fn is_new(&self, epoch: u64, sender: &DeviceId, seq: u64) -> bool {
match epoch.cmp(&self.epoch) {
std::cmp::Ordering::Greater => true,
std::cmp::Ordering::Less => false,
std::cmp::Ordering::Equal => self
.last_seq_per_device
.get(sender)
.is_none_or(|&last| seq > last),
}
}
pub fn advance(&mut self, epoch: u64, sender: DeviceId, seq: u64, hlc: Hlc, now_ms: u64) {
if epoch > self.epoch {
self.epoch = epoch;
self.last_seq_per_device.clear();
}
if epoch == self.epoch {
let entry = self.last_seq_per_device.entry(sender).or_insert(0);
if seq > *entry {
*entry = seq;
}
}
self.last_hlc = self.last_hlc.merge(hlc, now_ms);
}
pub fn encode(&self) -> Result<Vec<u8>> {
codec::encode(self)
}
pub fn decode(bytes: &[u8]) -> Result<Self> {
codec::decode(bytes)
}
pub fn to_token(&self) -> Result<String> {
use base64ish::encode_b64;
Ok(encode_b64(&self.encode()?))
}
pub fn from_token(token: &str) -> Result<Self> {
use base64ish::decode_b64;
let bytes = decode_b64(token).map_err(|e| Error::Codec(e.to_string()))?;
Self::decode(&bytes)
}
}
mod base64ish {
const ALPH: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
pub fn encode_b64(input: &[u8]) -> String {
let mut out = String::with_capacity((input.len() * 4).div_ceil(3));
for chunk in input.chunks(3) {
let b0 = chunk[0];
let b1 = if chunk.len() > 1 { chunk[1] } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] } else { 0 };
out.push(ALPH[(b0 >> 2) as usize] as char);
out.push(ALPH[(((b0 & 0b11) << 4) | (b1 >> 4)) as usize] as char);
if chunk.len() > 1 {
out.push(ALPH[(((b1 & 0b1111) << 2) | (b2 >> 6)) as usize] as char);
}
if chunk.len() > 2 {
out.push(ALPH[(b2 & 0b111111) as usize] as char);
}
}
out
}
pub fn decode_b64(input: &str) -> Result<Vec<u8>, String> {
fn val(c: u8) -> Result<u8, String> {
match c {
b'A'..=b'Z' => Ok(c - b'A'),
b'a'..=b'z' => Ok(c - b'a' + 26),
b'0'..=b'9' => Ok(c - b'0' + 52),
b'-' => Ok(62),
b'_' => Ok(63),
_ => Err(format!("invalid b64 char {c:#x}")),
}
}
let bytes = input.as_bytes();
let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
for chunk in bytes.chunks(4) {
if chunk.len() < 2 {
return Err("truncated b64".into());
}
let v0 = val(chunk[0])?;
let v1 = val(chunk[1])?;
out.push((v0 << 2) | (v1 >> 4));
if chunk.len() > 2 {
let v2 = val(chunk[2])?;
out.push((v1 << 4) | (v2 >> 2));
if chunk.len() > 3 {
let v3 = val(chunk[3])?;
out.push((v2 << 6) | v3);
}
}
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cursor_is_new() {
let mut c = SyncCursor::default();
let d = DeviceId(vec![1; 32]);
assert!(c.is_new(0, &d, 1));
c.advance(0, d.clone(), 1, Hlc::ZERO.tick(100), 100);
assert!(!c.is_new(0, &d, 1));
assert!(c.is_new(0, &d, 2));
assert!(c.is_new(1, &d, 0));
}
#[test]
fn token_roundtrip() {
let mut c = SyncCursor {
epoch: 7,
..Default::default()
};
c.last_seq_per_device.insert(DeviceId(vec![9; 32]), 42);
let t = c.to_token().unwrap();
let back = SyncCursor::from_token(&t).unwrap();
assert_eq!(back.epoch, 7);
assert_eq!(
back.last_seq_per_device.get(&DeviceId(vec![9; 32])),
Some(&42)
);
}
}