ai_memory/federation/
signing.rs1use base64::Engine;
23use base64::engine::general_purpose::STANDARD as B64;
24use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
25
26pub const SIGNATURE_HEADER: &str = "x-memory-sig";
30
31pub const REQUIRE_SIG_ENV: &str = "AI_MEMORY_FED_REQUIRE_SIG";
34
35pub const NONCE_HEADER: &str = "x-memory-nonce";
37
38const NONCE_DOMAIN_SEP: u8 = 0x00;
40
41pub const REQUIRE_NONCE_ENV: &str = "AI_MEMORY_FED_REQUIRE_NONCE";
43
44pub const ED25519_PREFIX: &str = "ed25519=";
46
47#[must_use]
53pub fn sign_body_header(key: &SigningKey, body: &[u8]) -> String {
54 let sig: Signature = key.sign(body);
55 let b64 = B64.encode(sig.to_bytes());
56 format!("{ED25519_PREFIX}{b64}")
57}
58
59#[must_use]
61pub fn sign_body_with_nonce_header(key: &SigningKey, body: &[u8], nonce: &str) -> String {
62 let mut input = Vec::with_capacity(body.len() + 1 + nonce.len());
63 input.extend_from_slice(body);
64 input.push(NONCE_DOMAIN_SEP);
65 input.extend_from_slice(nonce.as_bytes());
66 let sig: Signature = key.sign(&input);
67 let b64 = B64.encode(sig.to_bytes());
68 format!("{ED25519_PREFIX}{b64}")
69}
70
71#[derive(Debug, Clone)]
73pub enum VerifyError {
74 Missing,
76 UnknownAlgorithm,
78 Malformed,
81 BadSignature,
83 ReplayedNonce,
85 NonceMissing,
87}
88
89impl VerifyError {
90 #[must_use]
92 pub fn tag(&self) -> &'static str {
93 match self {
94 Self::Missing => "x_memory_sig_missing",
95 Self::UnknownAlgorithm => "x_memory_sig_unknown_algorithm",
96 Self::Malformed => "x_memory_sig_malformed",
97 Self::BadSignature => "x_memory_sig_bad_signature",
98 Self::ReplayedNonce => "x_memory_nonce_replay",
99 Self::NonceMissing => "x_memory_nonce_missing",
100 }
101 }
102}
103
104impl std::fmt::Display for VerifyError {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 f.write_str(self.tag())
107 }
108}
109
110impl std::error::Error for VerifyError {}
111
112pub fn verify_header(
121 header: Option<&str>,
122 body: &[u8],
123 pubkey: &VerifyingKey,
124) -> Result<(), VerifyError> {
125 let raw = header.ok_or(VerifyError::Missing)?;
126 let primary = raw.split(';').next().unwrap_or(raw).trim();
127 let b64 = primary
128 .strip_prefix(ED25519_PREFIX)
129 .ok_or(VerifyError::UnknownAlgorithm)?;
130 let bytes = B64
131 .decode(b64.as_bytes())
132 .map_err(|_| VerifyError::Malformed)?;
133 if bytes.len() != 64 {
134 return Err(VerifyError::Malformed);
135 }
136 let mut sig_arr = [0u8; 64];
137 sig_arr.copy_from_slice(&bytes);
138 let sig = Signature::from_bytes(&sig_arr);
139 pubkey
140 .verify(body, &sig)
141 .map_err(|_| VerifyError::BadSignature)
142}
143
144pub fn verify_header_with_nonce(
152 header: Option<&str>,
153 body: &[u8],
154 nonce: &str,
155 pubkey: &VerifyingKey,
156) -> Result<(), VerifyError> {
157 let raw = header.ok_or(VerifyError::Missing)?;
158 let primary = raw.split(';').next().unwrap_or(raw).trim();
159 let b64 = primary
160 .strip_prefix(ED25519_PREFIX)
161 .ok_or(VerifyError::UnknownAlgorithm)?;
162 let bytes = B64
163 .decode(b64.as_bytes())
164 .map_err(|_| VerifyError::Malformed)?;
165 if bytes.len() != 64 {
166 return Err(VerifyError::Malformed);
167 }
168 let mut sig_arr = [0u8; 64];
169 sig_arr.copy_from_slice(&bytes);
170 let sig = Signature::from_bytes(&sig_arr);
171 let mut input = Vec::with_capacity(body.len() + 1 + nonce.len());
172 input.extend_from_slice(body);
173 input.push(NONCE_DOMAIN_SEP);
174 input.extend_from_slice(nonce.as_bytes());
175 pubkey
176 .verify(&input, &sig)
177 .map_err(|_| VerifyError::BadSignature)
178}
179
180#[must_use]
182pub fn require_sig() -> bool {
183 match std::env::var(REQUIRE_SIG_ENV) {
184 Ok(v) => v != "0",
185 Err(_) => true,
186 }
187}
188
189#[must_use]
191pub fn require_nonce() -> bool {
192 match std::env::var(REQUIRE_NONCE_ENV) {
193 Ok(v) => v != "0",
194 Err(_) => true,
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use ed25519_dalek::SigningKey;
202 use rand_core::OsRng;
203
204 fn fresh_key() -> SigningKey {
205 SigningKey::generate(&mut OsRng)
206 }
207
208 #[test]
209 fn sign_then_verify_round_trips() {
210 let key = fresh_key();
211 let pubkey = key.verifying_key();
212 let body = br#"{"memories":[{"id":"a"}]}"#;
213 let header = sign_body_header(&key, body);
214 assert!(header.starts_with(ED25519_PREFIX));
215 assert!(verify_header(Some(&header), body, &pubkey).is_ok());
216 }
217
218 #[test]
219 fn tampered_body_fails_verify() {
220 let key = fresh_key();
221 let pubkey = key.verifying_key();
222 let body = br#"{"memories":[{"id":"a"}]}"#;
223 let header = sign_body_header(&key, body);
224 let tampered = br#"{"memories":[{"id":"EVIL"}]}"#;
225 let err = verify_header(Some(&header), tampered, &pubkey).unwrap_err();
226 assert!(matches!(err, VerifyError::BadSignature));
227 }
228
229 #[test]
230 fn missing_header_returns_missing_variant() {
231 let key = fresh_key();
232 let pubkey = key.verifying_key();
233 let err = verify_header(None, b"body", &pubkey).unwrap_err();
234 assert!(matches!(err, VerifyError::Missing));
235 }
236
237 #[test]
238 fn unknown_algorithm_prefix_rejected() {
239 let key = fresh_key();
240 let pubkey = key.verifying_key();
241 let err = verify_header(Some("rsa=abc"), b"body", &pubkey).unwrap_err();
242 assert!(matches!(err, VerifyError::UnknownAlgorithm));
243 }
244
245 #[test]
246 fn malformed_base64_rejected() {
247 let key = fresh_key();
248 let pubkey = key.verifying_key();
249 let err = verify_header(Some("ed25519=not-base64!!!"), b"body", &pubkey).unwrap_err();
250 assert!(matches!(err, VerifyError::Malformed));
251 }
252
253 #[test]
254 fn wrong_length_signature_rejected() {
255 let key = fresh_key();
256 let pubkey = key.verifying_key();
257 let header = format!("ed25519={}", B64.encode([0u8; 32]));
258 let err = verify_header(Some(&header), b"body", &pubkey).unwrap_err();
259 assert!(matches!(err, VerifyError::Malformed));
260 }
261
262 #[test]
263 fn trailing_suffix_tolerated() {
264 let key = fresh_key();
265 let pubkey = key.verifying_key();
266 let body = b"hello";
267 let header_with_suffix = format!("{}; rsa=other", sign_body_header(&key, body));
268 assert!(verify_header(Some(&header_with_suffix), body, &pubkey).is_ok());
269 }
270
271 fn require_sig_env_lock() -> &'static std::sync::Mutex<()> {
278 static M: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
279 M.get_or_init(|| std::sync::Mutex::new(()))
280 }
281
282 #[test]
283 fn require_sig_defaults_to_true() {
284 let _g = require_sig_env_lock()
285 .lock()
286 .unwrap_or_else(|e| e.into_inner());
287 unsafe {
288 std::env::remove_var(REQUIRE_SIG_ENV);
289 }
290 assert!(require_sig());
291 }
292
293 #[test]
294 fn require_sig_false_when_zero() {
295 let _g = require_sig_env_lock()
296 .lock()
297 .unwrap_or_else(|e| e.into_inner());
298 unsafe {
299 std::env::set_var(REQUIRE_SIG_ENV, "0");
300 }
301 let result = require_sig();
302 unsafe {
303 std::env::remove_var(REQUIRE_SIG_ENV);
304 }
305 assert!(!result);
306 }
307}