Skip to main content

ping_core/
sync.rs

1//! Per-conversation sync state. See `docs/SYNC_PROTOCOL.md`.
2
3use serde::{Deserialize, Serialize};
4use std::collections::BTreeMap;
5
6use crate::{clock::Hlc, codec, device::DeviceId, Error, Result};
7
8#[derive(Debug, Clone, Default, Serialize, Deserialize)]
9pub struct SyncCursor {
10    pub epoch: u64,
11    pub last_seq_per_device: BTreeMap<DeviceId, u64>,
12    pub last_hlc: Hlc,
13}
14
15impl SyncCursor {
16    /// True if `(epoch, sender, seq)` is strictly newer than the cursor.
17    pub fn is_new(&self, epoch: u64, sender: &DeviceId, seq: u64) -> bool {
18        match epoch.cmp(&self.epoch) {
19            std::cmp::Ordering::Greater => true,
20            std::cmp::Ordering::Less => false,
21            std::cmp::Ordering::Equal => self
22                .last_seq_per_device
23                .get(sender)
24                .is_none_or(|&last| seq > last),
25        }
26    }
27
28    pub fn advance(&mut self, epoch: u64, sender: DeviceId, seq: u64, hlc: Hlc, now_ms: u64) {
29        if epoch > self.epoch {
30            self.epoch = epoch;
31            self.last_seq_per_device.clear();
32        }
33        if epoch == self.epoch {
34            let entry = self.last_seq_per_device.entry(sender).or_insert(0);
35            if seq > *entry {
36                *entry = seq;
37            }
38        }
39        self.last_hlc = self.last_hlc.merge(hlc, now_ms);
40    }
41
42    pub fn encode(&self) -> Result<Vec<u8>> {
43        codec::encode(self)
44    }
45    pub fn decode(bytes: &[u8]) -> Result<Self> {
46        codec::decode(bytes)
47    }
48
49    /// Base64 form for opaque transmission to/from servers.
50    pub fn to_token(&self) -> Result<String> {
51        use base64ish::encode_b64;
52        Ok(encode_b64(&self.encode()?))
53    }
54    pub fn from_token(token: &str) -> Result<Self> {
55        use base64ish::decode_b64;
56        let bytes = decode_b64(token).map_err(|e| Error::Codec(e.to_string()))?;
57        Self::decode(&bytes)
58    }
59}
60
61mod base64ish {
62    // Tiny URL-safe base64 without padding, to avoid pulling another dep.
63    const ALPH: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
64
65    pub fn encode_b64(input: &[u8]) -> String {
66        let mut out = String::with_capacity((input.len() * 4).div_ceil(3));
67        for chunk in input.chunks(3) {
68            let b0 = chunk[0];
69            let b1 = if chunk.len() > 1 { chunk[1] } else { 0 };
70            let b2 = if chunk.len() > 2 { chunk[2] } else { 0 };
71            out.push(ALPH[(b0 >> 2) as usize] as char);
72            out.push(ALPH[(((b0 & 0b11) << 4) | (b1 >> 4)) as usize] as char);
73            if chunk.len() > 1 {
74                out.push(ALPH[(((b1 & 0b1111) << 2) | (b2 >> 6)) as usize] as char);
75            }
76            if chunk.len() > 2 {
77                out.push(ALPH[(b2 & 0b111111) as usize] as char);
78            }
79        }
80        out
81    }
82
83    pub fn decode_b64(input: &str) -> Result<Vec<u8>, String> {
84        fn val(c: u8) -> Result<u8, String> {
85            match c {
86                b'A'..=b'Z' => Ok(c - b'A'),
87                b'a'..=b'z' => Ok(c - b'a' + 26),
88                b'0'..=b'9' => Ok(c - b'0' + 52),
89                b'-' => Ok(62),
90                b'_' => Ok(63),
91                _ => Err(format!("invalid b64 char {c:#x}")),
92            }
93        }
94        let bytes = input.as_bytes();
95        let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
96        for chunk in bytes.chunks(4) {
97            if chunk.len() < 2 {
98                return Err("truncated b64".into());
99            }
100            let v0 = val(chunk[0])?;
101            let v1 = val(chunk[1])?;
102            out.push((v0 << 2) | (v1 >> 4));
103            if chunk.len() > 2 {
104                let v2 = val(chunk[2])?;
105                out.push((v1 << 4) | (v2 >> 2));
106                if chunk.len() > 3 {
107                    let v3 = val(chunk[3])?;
108                    out.push((v2 << 6) | v3);
109                }
110            }
111        }
112        Ok(out)
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn cursor_is_new() {
122        let mut c = SyncCursor::default();
123        let d = DeviceId(vec![1; 32]);
124        assert!(c.is_new(0, &d, 1));
125        c.advance(0, d.clone(), 1, Hlc::ZERO.tick(100), 100);
126        assert!(!c.is_new(0, &d, 1));
127        assert!(c.is_new(0, &d, 2));
128        assert!(c.is_new(1, &d, 0));
129    }
130
131    #[test]
132    fn token_roundtrip() {
133        let mut c = SyncCursor {
134            epoch: 7,
135            ..Default::default()
136        };
137        c.last_seq_per_device.insert(DeviceId(vec![9; 32]), 42);
138        let t = c.to_token().unwrap();
139        let back = SyncCursor::from_token(&t).unwrap();
140        assert_eq!(back.epoch, 7);
141        assert_eq!(
142            back.last_seq_per_device.get(&DeviceId(vec![9; 32])),
143            Some(&42)
144        );
145    }
146}