1use anyhow::{bail, Context, Result};
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use chrono::{DateTime, Utc};
8use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use sha2::{Digest, Sha256};
12
13use super::jcs;
14
15pub const PAYLOAD_TYPE_TOOL_V1: &str = "application/vnd.assay.tool+json;v=1";
17
18pub const SIG_FIELD: &str = "x-assay-sig";
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "lowercase")]
24pub enum SignatureAlgorithm {
25 Ed25519,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ToolSignature {
31 pub version: u8,
32 pub algorithm: SignatureAlgorithm,
33 pub payload_type: String,
34 pub payload_digest: String,
35 pub key_id: String,
36 pub signature: String,
37 pub signed_at: DateTime<Utc>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub public_key: Option<String>,
40}
41
42#[derive(Debug, Clone)]
44pub struct VerifyResult {
45 pub key_id: String,
46 pub signed_at: DateTime<Utc>,
47}
48
49#[derive(Debug, Clone, thiserror::Error)]
51pub enum VerifyError {
52 #[error("tool is not signed")]
53 NoSignature,
54
55 #[error("payload type mismatch: expected {expected}, got {got}")]
56 PayloadTypeMismatch { expected: String, got: String },
57
58 #[error("signature invalid: {reason}")]
59 SignatureInvalid { reason: String },
60
61 #[error("key not trusted: {key_id}")]
62 KeyNotTrusted { key_id: String },
63
64 #[error("malformed signature: {reason}")]
65 MalformedSignature { reason: String },
66
67 #[error("payload digest mismatch")]
68 DigestMismatch,
69
70 #[error("key_id mismatch: signature claims {claimed}, actual {actual}")]
71 KeyIdMismatch { claimed: String, actual: String },
72}
73
74impl VerifyError {
75 pub fn exit_code(&self) -> i32 {
77 match self {
78 Self::NoSignature => 2,
79 Self::KeyNotTrusted { .. } => 3,
80 Self::SignatureInvalid { .. }
81 | Self::PayloadTypeMismatch { .. }
82 | Self::DigestMismatch
83 | Self::KeyIdMismatch { .. } => 4,
84 Self::MalformedSignature { .. } => 1,
85 }
86 }
87}
88
89pub fn compute_key_id(spki_bytes: &[u8]) -> String {
93 let hash = Sha256::digest(spki_bytes);
94 format!("sha256:{:x}", hash)
95}
96
97pub fn compute_key_id_from_verifying_key(key: &VerifyingKey) -> Result<String> {
99 let spki_bytes = key_to_spki_der(key)?;
100 Ok(compute_key_id(&spki_bytes))
101}
102
103fn key_to_spki_der(key: &VerifyingKey) -> Result<Vec<u8>> {
105 use pkcs8::EncodePublicKey;
106 let doc = key
107 .to_public_key_der()
108 .context("failed to encode public key as SPKI DER")?;
109 Ok(doc.as_bytes().to_vec())
110}
111
112fn build_pae(payload_type: &str, payload: &[u8]) -> Vec<u8> {
118 let type_len = payload_type.len().to_string();
119 let payload_len = payload.len().to_string();
120
121 let mut pae = Vec::new();
122 pae.extend_from_slice(b"DSSEv1 ");
123 pae.extend_from_slice(type_len.as_bytes());
124 pae.push(b' ');
125 pae.extend_from_slice(payload_type.as_bytes());
126 pae.push(b' ');
127 pae.extend_from_slice(payload_len.as_bytes());
128 pae.push(b' ');
129 pae.extend_from_slice(payload);
130 pae
131}
132
133fn strip_signature(tool: &Value) -> Result<Value> {
135 let mut tool = tool.clone();
136 if let Some(obj) = tool.as_object_mut() {
137 obj.remove(SIG_FIELD);
138 }
139 Ok(tool)
140}
141
142fn compute_payload_digest(canonical: &[u8]) -> String {
144 let hash = Sha256::digest(canonical);
145 format!("sha256:{:x}", hash)
146}
147
148pub fn sign_tool(tool: &Value, signing_key: &SigningKey, embed_pubkey: bool) -> Result<Value> {
160 let tool_without_sig = strip_signature(tool)?;
162
163 let canonical = jcs::to_vec(&tool_without_sig)?;
165
166 let pae = build_pae(PAYLOAD_TYPE_TOOL_V1, &canonical);
168
169 let signature: Signature = signing_key.sign(&pae);
171
172 let payload_digest = compute_payload_digest(&canonical);
174 let verifying_key = signing_key.verifying_key();
175 let key_id = compute_key_id_from_verifying_key(&verifying_key)?;
176
177 let sig = ToolSignature {
179 version: 1,
180 algorithm: SignatureAlgorithm::Ed25519,
181 payload_type: PAYLOAD_TYPE_TOOL_V1.to_string(),
182 payload_digest,
183 key_id,
184 signature: BASE64.encode(signature.to_bytes()),
185 signed_at: Utc::now(),
186 public_key: if embed_pubkey {
187 let spki = key_to_spki_der(&verifying_key)?;
188 Some(BASE64.encode(&spki))
189 } else {
190 None
191 },
192 };
193
194 let mut result = tool_without_sig;
196 if let Some(obj) = result.as_object_mut() {
197 obj.insert(SIG_FIELD.to_string(), serde_json::to_value(&sig)?);
198 } else {
199 bail!("tool must be a JSON object");
200 }
201
202 Ok(result)
203}
204
205pub fn verify_tool(tool: &Value, trusted_key: &VerifyingKey) -> Result<VerifyResult, VerifyError> {
216 let sig_value = tool.get(SIG_FIELD).ok_or(VerifyError::NoSignature)?;
218
219 let sig: ToolSignature =
220 serde_json::from_value(sig_value.clone()).map_err(|e| VerifyError::MalformedSignature {
221 reason: e.to_string(),
222 })?;
223
224 if sig.version != 1 {
226 return Err(VerifyError::MalformedSignature {
227 reason: format!("unsupported version: {}", sig.version),
228 });
229 }
230 if sig.algorithm != SignatureAlgorithm::Ed25519 {
231 return Err(VerifyError::MalformedSignature {
232 reason: format!("unsupported algorithm: {:?}", sig.algorithm),
233 });
234 }
235
236 if sig.payload_type != PAYLOAD_TYPE_TOOL_V1 {
238 return Err(VerifyError::PayloadTypeMismatch {
239 expected: PAYLOAD_TYPE_TOOL_V1.to_string(),
240 got: sig.payload_type,
241 });
242 }
243
244 let tool_without_sig = strip_signature(tool).map_err(|e| VerifyError::MalformedSignature {
246 reason: e.to_string(),
247 })?;
248 let canonical =
249 jcs::to_vec(&tool_without_sig).map_err(|e| VerifyError::MalformedSignature {
250 reason: e.to_string(),
251 })?;
252
253 let computed_digest = compute_payload_digest(&canonical);
255 if sig.payload_digest != computed_digest {
256 return Err(VerifyError::DigestMismatch);
257 }
258
259 let pae = build_pae(&sig.payload_type, &canonical);
261 let signature_bytes =
262 BASE64
263 .decode(&sig.signature)
264 .map_err(|e| VerifyError::MalformedSignature {
265 reason: format!("invalid base64 signature: {}", e),
266 })?;
267 let signature =
268 Signature::from_slice(&signature_bytes).map_err(|e| VerifyError::MalformedSignature {
269 reason: format!("invalid signature bytes: {}", e),
270 })?;
271
272 trusted_key
273 .verify(&pae, &signature)
274 .map_err(|_| VerifyError::SignatureInvalid {
275 reason: "ed25519 verification failed".to_string(),
276 })?;
277
278 let actual_key_id = compute_key_id_from_verifying_key(trusted_key).map_err(|e| {
280 VerifyError::MalformedSignature {
281 reason: e.to_string(),
282 }
283 })?;
284 if sig.key_id != actual_key_id {
285 return Err(VerifyError::KeyIdMismatch {
286 claimed: sig.key_id,
287 actual: actual_key_id,
288 });
289 }
290
291 Ok(VerifyResult {
292 key_id: sig.key_id,
293 signed_at: sig.signed_at,
294 })
295}
296
297pub fn extract_signature(tool: &Value) -> Option<ToolSignature> {
299 tool.get(SIG_FIELD)
300 .and_then(|v| serde_json::from_value(v.clone()).ok())
301}
302
303pub fn is_signed(tool: &Value) -> bool {
305 tool.get(SIG_FIELD).is_some()
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use serde_json::json;
312
313 fn generate_keypair() -> SigningKey {
314 SigningKey::generate(&mut rand::thread_rng())
315 }
316
317 #[test]
318 fn test_sign_and_verify_roundtrip() {
319 let key = generate_keypair();
320 let tool = json!({
321 "name": "read_file",
322 "description": "Read a file",
323 "inputSchema": {"type": "object"}
324 });
325
326 let signed = sign_tool(&tool, &key, false).unwrap();
327 assert!(is_signed(&signed));
328
329 let result = verify_tool(&signed, &key.verifying_key()).unwrap();
330 assert!(result.key_id.starts_with("sha256:"));
331 }
332
333 #[test]
334 fn test_tamper_detection() {
335 let key = generate_keypair();
336 let tool = json!({
337 "name": "read_file",
338 "description": "Read a file",
339 "inputSchema": {"type": "object"}
340 });
341
342 let mut signed = sign_tool(&tool, &key, false).unwrap();
343
344 signed["description"] = json!("Malicious description");
346
347 let result = verify_tool(&signed, &key.verifying_key());
348 assert!(matches!(result, Err(VerifyError::DigestMismatch)));
349 }
350
351 #[test]
352 fn test_wrong_key_fails() {
353 let key1 = generate_keypair();
354 let key2 = generate_keypair();
355 let tool = json!({
356 "name": "test_tool",
357 "description": "Test",
358 "inputSchema": {}
359 });
360
361 let signed = sign_tool(&tool, &key1, false).unwrap();
362 let result = verify_tool(&signed, &key2.verifying_key());
363
364 assert!(matches!(
366 result,
367 Err(VerifyError::SignatureInvalid { .. }) | Err(VerifyError::KeyIdMismatch { .. })
368 ));
369 }
370
371 #[test]
372 fn test_unsigned_tool() {
373 let key = generate_keypair();
374 let tool = json!({"name": "unsigned"});
375
376 let result = verify_tool(&tool, &key.verifying_key());
377 assert!(matches!(result, Err(VerifyError::NoSignature)));
378 }
379
380 #[test]
381 fn test_embed_pubkey() {
382 let key = generate_keypair();
383 let tool = json!({"name": "test", "description": "test", "inputSchema": {}});
384
385 let signed = sign_tool(&tool, &key, true).unwrap();
386 let sig = extract_signature(&signed).unwrap();
387
388 assert!(sig.public_key.is_some());
389 }
390
391 #[test]
392 fn test_key_id_computation() {
393 let key = generate_keypair();
394 let key_id = compute_key_id_from_verifying_key(&key.verifying_key()).unwrap();
395
396 assert!(key_id.starts_with("sha256:"));
397 assert_eq!(key_id.len(), 7 + 64); }
399
400 #[test]
401 fn test_pae_format() {
402 let pae = build_pae("application/json", b"test");
403
404 let expected = b"DSSEv1 16 application/json 4 test";
406 assert_eq!(pae, expected);
407 }
408
409 #[test]
410 fn test_canonicalization_stability() {
411 let key = generate_keypair();
412
413 let tool1 =
415 json!({"name": "test", "description": "desc", "inputSchema": {"type": "object"}});
416 let tool2 =
417 json!({"inputSchema": {"type": "object"}, "name": "test", "description": "desc"});
418
419 let signed1 = sign_tool(&tool1, &key, false).unwrap();
420 let signed2 = sign_tool(&tool2, &key, false).unwrap();
421
422 let sig1 = extract_signature(&signed1).unwrap();
424 let sig2 = extract_signature(&signed2).unwrap();
425
426 assert_eq!(sig1.payload_digest, sig2.payload_digest);
427 }
428
429 #[test]
430 fn test_exit_codes() {
431 assert_eq!(VerifyError::NoSignature.exit_code(), 2);
432 assert_eq!(
433 VerifyError::KeyNotTrusted { key_id: "x".into() }.exit_code(),
434 3
435 );
436 assert_eq!(
437 VerifyError::SignatureInvalid { reason: "x".into() }.exit_code(),
438 4
439 );
440 assert_eq!(
441 VerifyError::MalformedSignature { reason: "x".into() }.exit_code(),
442 1
443 );
444 }
445}