1use base64::engine::general_purpose::URL_SAFE_NO_PAD;
11use base64::Engine;
12use serde::{Deserialize, Serialize};
13
14use crate::crypto;
15use crate::error::{Error, Result};
16
17const MAX_PAYLOAD: usize = 64 * 1024;
18const ENC_PREFIX: &str = "enc:";
19const GZ_PREFIX: &str = "gz:";
20const GZ_MIN_SIZE: usize = 4 * 1024;
24
25pub const CURRENT_VERSION: u8 = 1;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Envelope {
34 pub v: u8,
35 pub data: serde_json::Value,
36 pub memo: Memo,
37 pub checksum: String,
38 #[serde(default, skip_serializing_if = "Option::is_none")]
47 pub kid: Option<u8>,
48}
49
50#[derive(Debug, Clone, Default, Serialize, Deserialize)]
51pub struct Memo {
52 pub id: String,
53 pub class: String,
54 pub view: String,
55 #[serde(default)]
56 pub listeners: Vec<String>,
57 #[serde(default, skip_serializing_if = "Option::is_none")]
58 pub errors: Option<serde_json::Value>,
59 #[serde(default)]
68 pub rev: u64,
69}
70
71impl Envelope {
72 pub fn build(app_key: &str, data: serde_json::Value, memo: Memo) -> Self {
75 let checksum = compute_checksum(app_key, &data, &memo);
76 Self {
77 v: 1,
78 data,
79 memo,
80 checksum,
81 kid: None,
82 }
83 }
84
85 pub fn build_with_kid(kid: u8, app_key: &str, data: serde_json::Value, memo: Memo) -> Self {
89 let checksum = compute_checksum(app_key, &data, &memo);
90 Self {
91 v: 1,
92 data,
93 memo,
94 checksum,
95 kid: Some(kid),
96 }
97 }
98
99 pub fn verify(&self, app_key: &str) -> Result<()> {
102 let expected = compute_checksum(app_key, &self.data, &self.memo);
103 if crate::const_eq(self.checksum.as_bytes(), expected.as_bytes()) {
104 Ok(())
105 } else {
106 Err(Error::SnapshotTampered)
107 }
108 }
109
110 pub fn verify_with_keys(&self, keys: &[(u8, &str)]) -> Result<()> {
121 if keys.is_empty() {
122 return Err(Error::SnapshotTampered);
123 }
124 let key = match self.kid {
125 Some(k) => keys
126 .iter()
127 .find_map(|(kid, key)| (*kid == k).then_some(*key))
128 .ok_or(Error::SnapshotTampered)?,
129 None => keys[0].1,
130 };
131 self.verify(key)
132 }
133}
134
135pub fn parse_keyring(raw: &str) -> Vec<(u8, String)> {
143 raw.split(',')
144 .filter_map(|entry| {
145 let entry = entry.trim();
146 if entry.is_empty() {
147 return None;
148 }
149 let (kid_s, key) = entry.split_once(':')?;
150 let kid: u8 = kid_s.trim().parse().ok()?;
151 Some((kid, key.trim().to_string()))
152 })
153 .collect()
154}
155
156fn compute_checksum(app_key: &str, data: &serde_json::Value, memo: &Memo) -> String {
157 let body = canonical_pair(data, memo);
158 crypto::sign(app_key, &body)
159}
160
161fn canonical_pair(data: &serde_json::Value, memo: &Memo) -> Vec<u8> {
162 let mut out = serde_json::to_vec(data).unwrap_or_default();
166 out.extend_from_slice(b"||");
167 out.extend_from_slice(serde_json::to_vec(memo).unwrap_or_default().as_slice());
168 out
169}
170
171pub fn encode(envelope: &Envelope, app_key: &str, encrypt: bool) -> Result<String> {
180 let json = serde_json::to_vec(envelope)?;
181 if encrypt {
182 let blob = crypto::encrypt(app_key, &json);
183 let mut out = String::with_capacity(ENC_PREFIX.len() + blob.len() * 2);
184 out.push_str(ENC_PREFIX);
185 out.push_str(&URL_SAFE_NO_PAD.encode(blob));
186 if out.len() > MAX_PAYLOAD {
187 return Err(Error::SnapshotTooLarge {
188 size: out.len(),
189 max: MAX_PAYLOAD,
190 });
191 }
192 Ok(out)
193 } else if json.len() >= GZ_MIN_SIZE {
194 let compressed = gzip_encode(&json);
195 if compressed.len() < json.len() {
199 let mut out = String::with_capacity(GZ_PREFIX.len() + compressed.len() * 2);
200 out.push_str(GZ_PREFIX);
201 out.push_str(&URL_SAFE_NO_PAD.encode(&compressed));
202 if out.len() > MAX_PAYLOAD {
203 return Err(Error::SnapshotTooLarge {
204 size: out.len(),
205 max: MAX_PAYLOAD,
206 });
207 }
208 return Ok(out);
209 }
210 let encoded = URL_SAFE_NO_PAD.encode(&json);
211 if encoded.len() > MAX_PAYLOAD {
212 return Err(Error::SnapshotTooLarge {
213 size: encoded.len(),
214 max: MAX_PAYLOAD,
215 });
216 }
217 Ok(encoded)
218 } else {
219 let encoded = URL_SAFE_NO_PAD.encode(json);
220 if encoded.len() > MAX_PAYLOAD {
221 return Err(Error::SnapshotTooLarge {
222 size: encoded.len(),
223 max: MAX_PAYLOAD,
224 });
225 }
226 Ok(encoded)
227 }
228}
229
230fn gzip_encode(input: &[u8]) -> Vec<u8> {
231 use flate2::write::GzEncoder;
232 use flate2::Compression;
233 use std::io::Write;
234 let mut enc = GzEncoder::new(Vec::with_capacity(input.len() / 4), Compression::default());
235 let _ = enc.write_all(input);
236 enc.finish().unwrap_or_default()
237}
238
239fn gzip_decode(input: &[u8]) -> Result<Vec<u8>> {
240 use flate2::read::GzDecoder;
241 use std::io::Read;
242 let mut decoder = GzDecoder::new(input);
243 let mut out = Vec::with_capacity(input.len() * 2);
244 decoder
245 .read_to_end(&mut out)
246 .map_err(|e| Error::SnapshotDecode(format!("gzip: {e}")))?;
247 Ok(out)
248}
249
250pub fn decode(wire: &str, app_key: &str) -> Result<Envelope> {
254 decode_with_keys(wire, &[(0, app_key)])
255}
256
257pub fn decode_with_keys(wire: &str, keys: &[(u8, &str)]) -> Result<Envelope> {
267 if wire.len() > MAX_PAYLOAD {
268 return Err(Error::SnapshotTooLarge {
269 size: wire.len(),
270 max: MAX_PAYLOAD,
271 });
272 }
273 let primary_key = keys
274 .first()
275 .map(|(_, k)| *k)
276 .ok_or_else(|| Error::SnapshotDecode("empty keyring".into()))?;
277 let json_bytes = if let Some(rest) = wire.strip_prefix(ENC_PREFIX) {
278 let blob = URL_SAFE_NO_PAD
279 .decode(rest)
280 .map_err(|e| Error::SnapshotDecode(format!("b64: {e}")))?;
281 crypto::decrypt(primary_key, &blob)
282 .ok_or_else(|| Error::SnapshotDecode("aes-gcm decrypt failed".into()))?
283 } else if let Some(rest) = wire.strip_prefix(GZ_PREFIX) {
284 let compressed = URL_SAFE_NO_PAD
285 .decode(rest)
286 .map_err(|e| Error::SnapshotDecode(format!("b64: {e}")))?;
287 gzip_decode(&compressed)?
288 } else {
289 URL_SAFE_NO_PAD
290 .decode(wire)
291 .map_err(|e| Error::SnapshotDecode(format!("b64: {e}")))?
292 };
293 let envelope: Envelope = serde_json::from_slice(&json_bytes)
294 .map_err(|e| Error::SnapshotDecode(format!("json: {e}")))?;
295 if envelope.v > CURRENT_VERSION {
300 return Err(Error::SnapshotVersionMismatch {
301 client_v: envelope.v,
302 server_v: CURRENT_VERSION,
303 });
304 }
305 envelope.verify_with_keys(keys)?;
306 Ok(envelope)
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use serde_json::json;
313
314 const KEY: &str = "spark-test-app-key-thirty-two-bb";
315
316 fn sample_memo() -> Memo {
317 Memo {
318 id: "01HX-test".into(),
319 class: "tests::Counter".into(),
320 view: "spark/counter".into(),
321 listeners: vec!["posts.created".into()],
322 errors: None,
323 rev: 0,
324 }
325 }
326
327 #[test]
328 fn round_trip_unencrypted() {
329 let envelope = Envelope::build(KEY, json!({"count": 5}), sample_memo());
330 let wire = encode(&envelope, KEY, false).unwrap();
331 let decoded = decode(&wire, KEY).unwrap();
332 assert_eq!(decoded.data, envelope.data);
333 assert_eq!(decoded.memo.class, envelope.memo.class);
334 }
335
336 #[test]
337 fn round_trip_encrypted() {
338 let envelope = Envelope::build(KEY, json!({"count": 5}), sample_memo());
339 let wire = encode(&envelope, KEY, true).unwrap();
340 assert!(wire.starts_with("enc:"));
341 let decoded = decode(&wire, KEY).unwrap();
342 assert_eq!(decoded.data, envelope.data);
343 }
344
345 #[test]
346 fn tampered_unencrypted_fails() {
347 let envelope = Envelope::build(KEY, json!({"count": 5}), sample_memo());
348 let wire = encode(&envelope, KEY, false).unwrap();
349 let mut bytes = wire.into_bytes();
351 let last = bytes.last_mut().unwrap();
352 *last = if *last == b'A' { b'B' } else { b'A' };
353 let tampered = String::from_utf8(bytes).unwrap();
354 assert!(decode(&tampered, KEY).is_err());
355 }
356
357 #[test]
358 fn parse_keyring_handles_whitespace_and_skips_garbage() {
359 let parsed = parse_keyring(" 1:keyA , bad , 2:keyB,");
360 assert_eq!(
361 parsed,
362 vec![(1, "keyA".to_string()), (2, "keyB".to_string())]
363 );
364 }
365
366 #[test]
367 fn keyring_verifies_under_either_active_key() {
368 let env = Envelope::build_with_kid(
372 2,
373 "old-key-thirty-two-bytes-padding",
374 json!({"x": 1}),
375 sample_memo(),
376 );
377 let wire = encode(&env, "old-key-thirty-two-bytes-padding", false).unwrap();
378
379 let keys: &[(u8, &str)] = &[
380 (3, "new-key-thirty-two-bytes-padding"),
381 (2, "old-key-thirty-two-bytes-padding"),
382 ];
383 let decoded = decode_with_keys(&wire, keys).expect("rotation should accept old kid");
384 assert_eq!(decoded.kid, Some(2));
385 }
386
387 #[test]
388 fn keyring_rejects_unknown_kid() {
389 let env = Envelope::build_with_kid(99, KEY, json!({"x": 1}), sample_memo());
390 let wire = encode(&env, KEY, false).unwrap();
391 let keys: &[(u8, &str)] = &[(1, KEY)];
392 assert!(decode_with_keys(&wire, keys).is_err());
393 }
394
395 #[test]
396 fn large_payload_round_trips_through_gzip_form() {
397 let big_string = "a".repeat(8 * 1024);
400 let data = json!({ "blob": big_string });
401 let envelope = Envelope::build(KEY, data.clone(), sample_memo());
402 let wire = encode(&envelope, KEY, false).unwrap();
403
404 assert!(
405 wire.starts_with("gz:"),
406 "wire should be gzip-framed; got `{}`...",
407 &wire[..20.min(wire.len())]
408 );
409 assert!(
410 wire.len() < 8 * 1024,
411 "gzipped payload must be smaller than raw"
412 );
413
414 let decoded = decode(&wire, KEY).unwrap();
415 assert_eq!(decoded.data, data);
416 }
417
418 #[test]
419 fn small_payload_does_not_use_gzip() {
420 let envelope = Envelope::build(KEY, json!({"x": 1}), sample_memo());
421 let wire = encode(&envelope, KEY, false).unwrap();
422 assert!(!wire.starts_with("gz:"));
423 assert!(!wire.starts_with("enc:"));
424 }
425
426 #[test]
427 fn missing_kid_falls_back_to_first_key() {
428 let env = Envelope::build(KEY, json!({"x": 1}), sample_memo());
432 assert!(env.kid.is_none());
433 let wire = encode(&env, KEY, false).unwrap();
434 let keys: &[(u8, &str)] = &[(0, KEY), (1, "other-key-thirty-two-bytes-pad")];
435 decode_with_keys(&wire, keys).expect("no-kid envelope should verify under first key");
436 }
437}