ping-openmls-sdk-core 0.0.1

Platform-agnostic OpenMLS-based messaging engine
//! Per-conversation sync state. See `docs/SYNC_PROTOCOL.md`.

use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;

use crate::{clock::Hlc, codec, device::DeviceId, Error, Result};

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SyncCursor {
    pub epoch: u64,
    pub last_seq_per_device: BTreeMap<DeviceId, u64>,
    pub last_hlc: Hlc,
}

impl SyncCursor {
    /// True if `(epoch, sender, seq)` is strictly newer than the cursor.
    pub fn is_new(&self, epoch: u64, sender: &DeviceId, seq: u64) -> bool {
        match epoch.cmp(&self.epoch) {
            std::cmp::Ordering::Greater => true,
            std::cmp::Ordering::Less => false,
            std::cmp::Ordering::Equal => self
                .last_seq_per_device
                .get(sender)
                .is_none_or(|&last| seq > last),
        }
    }

    pub fn advance(&mut self, epoch: u64, sender: DeviceId, seq: u64, hlc: Hlc, now_ms: u64) {
        if epoch > self.epoch {
            self.epoch = epoch;
            self.last_seq_per_device.clear();
        }
        if epoch == self.epoch {
            let entry = self.last_seq_per_device.entry(sender).or_insert(0);
            if seq > *entry {
                *entry = seq;
            }
        }
        self.last_hlc = self.last_hlc.merge(hlc, now_ms);
    }

    pub fn encode(&self) -> Result<Vec<u8>> {
        codec::encode(self)
    }
    pub fn decode(bytes: &[u8]) -> Result<Self> {
        codec::decode(bytes)
    }

    /// Base64 form for opaque transmission to/from servers.
    pub fn to_token(&self) -> Result<String> {
        use base64ish::encode_b64;
        Ok(encode_b64(&self.encode()?))
    }
    pub fn from_token(token: &str) -> Result<Self> {
        use base64ish::decode_b64;
        let bytes = decode_b64(token).map_err(|e| Error::Codec(e.to_string()))?;
        Self::decode(&bytes)
    }
}

mod base64ish {
    // Tiny URL-safe base64 without padding, to avoid pulling another dep.
    const ALPH: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";

    pub fn encode_b64(input: &[u8]) -> String {
        let mut out = String::with_capacity((input.len() * 4).div_ceil(3));
        for chunk in input.chunks(3) {
            let b0 = chunk[0];
            let b1 = if chunk.len() > 1 { chunk[1] } else { 0 };
            let b2 = if chunk.len() > 2 { chunk[2] } else { 0 };
            out.push(ALPH[(b0 >> 2) as usize] as char);
            out.push(ALPH[(((b0 & 0b11) << 4) | (b1 >> 4)) as usize] as char);
            if chunk.len() > 1 {
                out.push(ALPH[(((b1 & 0b1111) << 2) | (b2 >> 6)) as usize] as char);
            }
            if chunk.len() > 2 {
                out.push(ALPH[(b2 & 0b111111) as usize] as char);
            }
        }
        out
    }

    pub fn decode_b64(input: &str) -> Result<Vec<u8>, String> {
        fn val(c: u8) -> Result<u8, String> {
            match c {
                b'A'..=b'Z' => Ok(c - b'A'),
                b'a'..=b'z' => Ok(c - b'a' + 26),
                b'0'..=b'9' => Ok(c - b'0' + 52),
                b'-' => Ok(62),
                b'_' => Ok(63),
                _ => Err(format!("invalid b64 char {c:#x}")),
            }
        }
        let bytes = input.as_bytes();
        let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
        for chunk in bytes.chunks(4) {
            if chunk.len() < 2 {
                return Err("truncated b64".into());
            }
            let v0 = val(chunk[0])?;
            let v1 = val(chunk[1])?;
            out.push((v0 << 2) | (v1 >> 4));
            if chunk.len() > 2 {
                let v2 = val(chunk[2])?;
                out.push((v1 << 4) | (v2 >> 2));
                if chunk.len() > 3 {
                    let v3 = val(chunk[3])?;
                    out.push((v2 << 6) | v3);
                }
            }
        }
        Ok(out)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn cursor_is_new() {
        let mut c = SyncCursor::default();
        let d = DeviceId(vec![1; 32]);
        assert!(c.is_new(0, &d, 1));
        c.advance(0, d.clone(), 1, Hlc::ZERO.tick(100), 100);
        assert!(!c.is_new(0, &d, 1));
        assert!(c.is_new(0, &d, 2));
        assert!(c.is_new(1, &d, 0));
    }

    #[test]
    fn token_roundtrip() {
        let mut c = SyncCursor {
            epoch: 7,
            ..Default::default()
        };
        c.last_seq_per_device.insert(DeviceId(vec![9; 32]), 42);
        let t = c.to_token().unwrap();
        let back = SyncCursor::from_token(&t).unwrap();
        assert_eq!(back.epoch, 7);
        assert_eq!(
            back.last_seq_per_device.get(&DeviceId(vec![9; 32])),
            Some(&42)
        );
    }
}