Skip to main content

spark/
snapshot.rs

1//! Spark snapshot — the encoded, signed state envelope that lives on the page.
2//!
3//! Wire form (default mode): `b64url(JSON(envelope))` where the envelope contains
4//! `data`, `memo`, and `checksum` (HMAC-SHA256 over `canonical(data)||canonical(memo)`).
5//!
6//! Encrypted mode (`enc:b64url(AES-256-GCM(envelope))`): the entire envelope is
7//! AES-GCM-sealed under APP_KEY; the recipient (server) is the only one able to
8//! read it.
9
10use base64::engine::general_purpose::URL_SAFE_NO_PAD;
11use base64::Engine;
12use serde::{Deserialize, Serialize};
13
14use crate::crypto;
15use crate::error::{Error, Result};
16
17const MAX_PAYLOAD: usize = 64 * 1024;
18const ENC_PREFIX: &str = "enc:";
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Envelope {
22    pub v: u8,
23    pub data: serde_json::Value,
24    pub memo: Memo,
25    pub checksum: String,
26}
27
28#[derive(Debug, Clone, Default, Serialize, Deserialize)]
29pub struct Memo {
30    pub id: String,
31    pub class: String,
32    pub view: String,
33    #[serde(default)]
34    pub listeners: Vec<String>,
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub errors: Option<serde_json::Value>,
37}
38
39impl Envelope {
40    /// Build a fresh envelope from state + memo, computing the HMAC.
41    pub fn build(app_key: &str, data: serde_json::Value, memo: Memo) -> Self {
42        let checksum = compute_checksum(app_key, &data, &memo);
43        Self {
44            v: 1,
45            data,
46            memo,
47            checksum,
48        }
49    }
50
51    /// Verify that the envelope's checksum matches the body. Returns `Ok(())` if
52    /// the snapshot has not been tampered with.
53    pub fn verify(&self, app_key: &str) -> Result<()> {
54        let expected = compute_checksum(app_key, &self.data, &self.memo);
55        // We compare via crypto::verify for constant-time-ish behavior — but since
56        // we recomputed both sides, a plain == is fine and what most signers do.
57        if crate::const_eq(self.checksum.as_bytes(), expected.as_bytes()) {
58            Ok(())
59        } else {
60            Err(Error::SnapshotTampered)
61        }
62    }
63}
64
65fn compute_checksum(app_key: &str, data: &serde_json::Value, memo: &Memo) -> String {
66    let body = canonical_pair(data, memo);
67    crypto::sign(app_key, &body)
68}
69
70fn canonical_pair(data: &serde_json::Value, memo: &Memo) -> Vec<u8> {
71    // Stable canonical form: serialize both as compact JSON. serde_json by default
72    // preserves insertion order for Maps; with arbitrary nested data this is good
73    // enough for HMAC purposes — the server signs and verifies with the same code.
74    let mut out = serde_json::to_vec(data).unwrap_or_default();
75    out.extend_from_slice(b"||");
76    out.extend_from_slice(serde_json::to_vec(memo).unwrap_or_default().as_slice());
77    out
78}
79
80/// Encode an envelope to the wire form (base64-URL-no-pad of JSON).
81pub fn encode(envelope: &Envelope, app_key: &str, encrypt: bool) -> Result<String> {
82    let json = serde_json::to_vec(envelope)?;
83    if encrypt {
84        let blob = crypto::encrypt(app_key, &json);
85        let mut out = String::with_capacity(ENC_PREFIX.len() + blob.len() * 2);
86        out.push_str(ENC_PREFIX);
87        out.push_str(&URL_SAFE_NO_PAD.encode(blob));
88        if out.len() > MAX_PAYLOAD {
89            return Err(Error::SnapshotTooLarge {
90                size: out.len(),
91                max: MAX_PAYLOAD,
92            });
93        }
94        Ok(out)
95    } else {
96        let encoded = URL_SAFE_NO_PAD.encode(json);
97        if encoded.len() > MAX_PAYLOAD {
98            return Err(Error::SnapshotTooLarge {
99                size: encoded.len(),
100                max: MAX_PAYLOAD,
101            });
102        }
103        Ok(encoded)
104    }
105}
106
107/// Decode + verify a wire-form snapshot.
108pub fn decode(wire: &str, app_key: &str) -> Result<Envelope> {
109    if wire.len() > MAX_PAYLOAD {
110        return Err(Error::SnapshotTooLarge {
111            size: wire.len(),
112            max: MAX_PAYLOAD,
113        });
114    }
115    let json_bytes = if let Some(rest) = wire.strip_prefix(ENC_PREFIX) {
116        let blob = URL_SAFE_NO_PAD
117            .decode(rest)
118            .map_err(|e| Error::SnapshotDecode(format!("b64: {e}")))?;
119        crypto::decrypt(app_key, &blob)
120            .ok_or_else(|| Error::SnapshotDecode("aes-gcm decrypt failed".into()))?
121    } else {
122        URL_SAFE_NO_PAD
123            .decode(wire)
124            .map_err(|e| Error::SnapshotDecode(format!("b64: {e}")))?
125    };
126    let envelope: Envelope = serde_json::from_slice(&json_bytes)
127        .map_err(|e| Error::SnapshotDecode(format!("json: {e}")))?;
128    envelope.verify(app_key)?;
129    Ok(envelope)
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use serde_json::json;
136
137    const KEY: &str = "spark-test-app-key-thirty-two-bb";
138
139    fn sample_memo() -> Memo {
140        Memo {
141            id: "01HX-test".into(),
142            class: "tests::Counter".into(),
143            view: "spark/counter".into(),
144            listeners: vec!["posts.created".into()],
145            errors: None,
146        }
147    }
148
149    #[test]
150    fn round_trip_unencrypted() {
151        let envelope = Envelope::build(KEY, json!({"count": 5}), sample_memo());
152        let wire = encode(&envelope, KEY, false).unwrap();
153        let decoded = decode(&wire, KEY).unwrap();
154        assert_eq!(decoded.data, envelope.data);
155        assert_eq!(decoded.memo.class, envelope.memo.class);
156    }
157
158    #[test]
159    fn round_trip_encrypted() {
160        let envelope = Envelope::build(KEY, json!({"count": 5}), sample_memo());
161        let wire = encode(&envelope, KEY, true).unwrap();
162        assert!(wire.starts_with("enc:"));
163        let decoded = decode(&wire, KEY).unwrap();
164        assert_eq!(decoded.data, envelope.data);
165    }
166
167    #[test]
168    fn tampered_unencrypted_fails() {
169        let envelope = Envelope::build(KEY, json!({"count": 5}), sample_memo());
170        let wire = encode(&envelope, KEY, false).unwrap();
171        // Flip the last char.
172        let mut bytes = wire.into_bytes();
173        let last = bytes.last_mut().unwrap();
174        *last = if *last == b'A' { b'B' } else { b'A' };
175        let tampered = String::from_utf8(bytes).unwrap();
176        assert!(decode(&tampered, KEY).is_err());
177    }
178}