1use base64::engine::general_purpose::URL_SAFE_NO_PAD;
11use base64::Engine;
12use serde::{Deserialize, Serialize};
13
14use crate::crypto;
15use crate::error::{Error, Result};
16
17const MAX_PAYLOAD: usize = 64 * 1024;
18const ENC_PREFIX: &str = "enc:";
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Envelope {
22 pub v: u8,
23 pub data: serde_json::Value,
24 pub memo: Memo,
25 pub checksum: String,
26}
27
28#[derive(Debug, Clone, Default, Serialize, Deserialize)]
29pub struct Memo {
30 pub id: String,
31 pub class: String,
32 pub view: String,
33 #[serde(default)]
34 pub listeners: Vec<String>,
35 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub errors: Option<serde_json::Value>,
37}
38
39impl Envelope {
40 pub fn build(app_key: &str, data: serde_json::Value, memo: Memo) -> Self {
42 let checksum = compute_checksum(app_key, &data, &memo);
43 Self {
44 v: 1,
45 data,
46 memo,
47 checksum,
48 }
49 }
50
51 pub fn verify(&self, app_key: &str) -> Result<()> {
54 let expected = compute_checksum(app_key, &self.data, &self.memo);
55 if crate::const_eq(self.checksum.as_bytes(), expected.as_bytes()) {
58 Ok(())
59 } else {
60 Err(Error::SnapshotTampered)
61 }
62 }
63}
64
65fn compute_checksum(app_key: &str, data: &serde_json::Value, memo: &Memo) -> String {
66 let body = canonical_pair(data, memo);
67 crypto::sign(app_key, &body)
68}
69
70fn canonical_pair(data: &serde_json::Value, memo: &Memo) -> Vec<u8> {
71 let mut out = serde_json::to_vec(data).unwrap_or_default();
75 out.extend_from_slice(b"||");
76 out.extend_from_slice(serde_json::to_vec(memo).unwrap_or_default().as_slice());
77 out
78}
79
80pub fn encode(envelope: &Envelope, app_key: &str, encrypt: bool) -> Result<String> {
82 let json = serde_json::to_vec(envelope)?;
83 if encrypt {
84 let blob = crypto::encrypt(app_key, &json);
85 let mut out = String::with_capacity(ENC_PREFIX.len() + blob.len() * 2);
86 out.push_str(ENC_PREFIX);
87 out.push_str(&URL_SAFE_NO_PAD.encode(blob));
88 if out.len() > MAX_PAYLOAD {
89 return Err(Error::SnapshotTooLarge {
90 size: out.len(),
91 max: MAX_PAYLOAD,
92 });
93 }
94 Ok(out)
95 } else {
96 let encoded = URL_SAFE_NO_PAD.encode(json);
97 if encoded.len() > MAX_PAYLOAD {
98 return Err(Error::SnapshotTooLarge {
99 size: encoded.len(),
100 max: MAX_PAYLOAD,
101 });
102 }
103 Ok(encoded)
104 }
105}
106
107pub fn decode(wire: &str, app_key: &str) -> Result<Envelope> {
109 if wire.len() > MAX_PAYLOAD {
110 return Err(Error::SnapshotTooLarge {
111 size: wire.len(),
112 max: MAX_PAYLOAD,
113 });
114 }
115 let json_bytes = if let Some(rest) = wire.strip_prefix(ENC_PREFIX) {
116 let blob = URL_SAFE_NO_PAD
117 .decode(rest)
118 .map_err(|e| Error::SnapshotDecode(format!("b64: {e}")))?;
119 crypto::decrypt(app_key, &blob)
120 .ok_or_else(|| Error::SnapshotDecode("aes-gcm decrypt failed".into()))?
121 } else {
122 URL_SAFE_NO_PAD
123 .decode(wire)
124 .map_err(|e| Error::SnapshotDecode(format!("b64: {e}")))?
125 };
126 let envelope: Envelope = serde_json::from_slice(&json_bytes)
127 .map_err(|e| Error::SnapshotDecode(format!("json: {e}")))?;
128 envelope.verify(app_key)?;
129 Ok(envelope)
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use serde_json::json;
136
137 const KEY: &str = "spark-test-app-key-thirty-two-bb";
138
139 fn sample_memo() -> Memo {
140 Memo {
141 id: "01HX-test".into(),
142 class: "tests::Counter".into(),
143 view: "spark/counter".into(),
144 listeners: vec!["posts.created".into()],
145 errors: None,
146 }
147 }
148
149 #[test]
150 fn round_trip_unencrypted() {
151 let envelope = Envelope::build(KEY, json!({"count": 5}), sample_memo());
152 let wire = encode(&envelope, KEY, false).unwrap();
153 let decoded = decode(&wire, KEY).unwrap();
154 assert_eq!(decoded.data, envelope.data);
155 assert_eq!(decoded.memo.class, envelope.memo.class);
156 }
157
158 #[test]
159 fn round_trip_encrypted() {
160 let envelope = Envelope::build(KEY, json!({"count": 5}), sample_memo());
161 let wire = encode(&envelope, KEY, true).unwrap();
162 assert!(wire.starts_with("enc:"));
163 let decoded = decode(&wire, KEY).unwrap();
164 assert_eq!(decoded.data, envelope.data);
165 }
166
167 #[test]
168 fn tampered_unencrypted_fails() {
169 let envelope = Envelope::build(KEY, json!({"count": 5}), sample_memo());
170 let wire = encode(&envelope, KEY, false).unwrap();
171 let mut bytes = wire.into_bytes();
173 let last = bytes.last_mut().unwrap();
174 *last = if *last == b'A' { b'B' } else { b'A' };
175 let tampered = String::from_utf8(bytes).unwrap();
176 assert!(decode(&tampered, KEY).is_err());
177 }
178}