1use anyhow::{bail, Context, Result};
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use chrono::{DateTime, Utc};
8use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use sha2::{Digest, Sha256};
12
13use super::jcs;
14
15pub const PAYLOAD_TYPE_TOOL_V1: &str = "application/vnd.assay.tool+json;v=1";
17
18pub const SIG_FIELD: &str = "x-assay-sig";
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "lowercase")]
24pub enum SignatureAlgorithm {
25 Ed25519,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ToolSignature {
37 pub version: u8,
38 pub algorithm: SignatureAlgorithm,
39 pub payload_type: String,
40 pub payload_digest: String,
41 pub key_id: String,
42 pub signature: String,
43 pub signed_at: DateTime<Utc>,
44 #[serde(default, skip_serializing_if = "Option::is_none")]
47 pub public_key: Option<String>,
48}
49
50#[derive(Debug, Clone)]
52pub struct VerifyResult {
53 pub key_id: String,
54 pub signed_at: DateTime<Utc>,
55}
56
57#[derive(Debug, Clone, thiserror::Error)]
59pub enum VerifyError {
60 #[error("tool is not signed")]
61 NoSignature,
62
63 #[error("payload type mismatch: expected {expected}, got {got}")]
64 PayloadTypeMismatch { expected: String, got: String },
65
66 #[error("signature invalid: {reason}")]
67 SignatureInvalid { reason: String },
68
69 #[error("key not trusted: {key_id}")]
70 KeyNotTrusted { key_id: String },
71
72 #[error("malformed signature: {reason}")]
73 MalformedSignature { reason: String },
74
75 #[error("payload digest mismatch")]
76 DigestMismatch,
77
78 #[error("key_id mismatch: signature claims {claimed}, actual {actual}")]
79 KeyIdMismatch { claimed: String, actual: String },
80}
81
82impl VerifyError {
83 pub fn exit_code(&self) -> i32 {
85 match self {
86 Self::NoSignature => 2,
87 Self::KeyNotTrusted { .. } => 3,
88 Self::SignatureInvalid { .. }
89 | Self::PayloadTypeMismatch { .. }
90 | Self::DigestMismatch
91 | Self::KeyIdMismatch { .. } => 4,
92 Self::MalformedSignature { .. } => 1,
93 }
94 }
95}
96
97pub fn compute_key_id(spki_bytes: &[u8]) -> String {
101 let hash = Sha256::digest(spki_bytes);
102 format!("sha256:{:x}", hash)
103}
104
105pub fn compute_key_id_from_verifying_key(key: &VerifyingKey) -> Result<String> {
107 let spki_bytes = key_to_spki_der(key)?;
108 Ok(compute_key_id(&spki_bytes))
109}
110
111fn key_to_spki_der(key: &VerifyingKey) -> Result<Vec<u8>> {
113 use pkcs8::EncodePublicKey;
114 let doc = key
115 .to_public_key_der()
116 .context("failed to encode public key as SPKI DER")?;
117 Ok(doc.as_bytes().to_vec())
118}
119
120fn build_pae(payload_type: &str, payload: &[u8]) -> Vec<u8> {
126 let type_len = payload_type.len().to_string();
127 let payload_len = payload.len().to_string();
128
129 let mut pae = Vec::new();
130 pae.extend_from_slice(b"DSSEv1 ");
131 pae.extend_from_slice(type_len.as_bytes());
132 pae.push(b' ');
133 pae.extend_from_slice(payload_type.as_bytes());
134 pae.push(b' ');
135 pae.extend_from_slice(payload_len.as_bytes());
136 pae.push(b' ');
137 pae.extend_from_slice(payload);
138 pae
139}
140
141fn strip_signature(tool: &Value) -> Result<Value> {
143 let mut tool = tool.clone();
144 if let Some(obj) = tool.as_object_mut() {
145 obj.remove(SIG_FIELD);
146 }
147 Ok(tool)
148}
149
150fn compute_payload_digest(canonical: &[u8]) -> String {
152 let hash = Sha256::digest(canonical);
153 format!("sha256:{:x}", hash)
154}
155
156pub fn sign_tool(tool: &Value, signing_key: &SigningKey, embed_pubkey: bool) -> Result<Value> {
168 let tool_without_sig = strip_signature(tool)?;
170
171 let canonical = jcs::to_vec(&tool_without_sig)?;
173
174 let pae = build_pae(PAYLOAD_TYPE_TOOL_V1, &canonical);
176
177 let signature: Signature = signing_key.sign(&pae);
179
180 let payload_digest = compute_payload_digest(&canonical);
182 let verifying_key = signing_key.verifying_key();
183 let key_id = compute_key_id_from_verifying_key(&verifying_key)?;
184
185 let sig = ToolSignature {
187 version: 1,
188 algorithm: SignatureAlgorithm::Ed25519,
189 payload_type: PAYLOAD_TYPE_TOOL_V1.to_string(),
190 payload_digest,
191 key_id,
192 signature: BASE64.encode(signature.to_bytes()),
193 signed_at: Utc::now(),
194 public_key: if embed_pubkey {
195 let spki = key_to_spki_der(&verifying_key)?;
196 Some(BASE64.encode(&spki))
197 } else {
198 None
199 },
200 };
201
202 let mut result = tool_without_sig;
204 if let Some(obj) = result.as_object_mut() {
205 obj.insert(SIG_FIELD.to_string(), serde_json::to_value(&sig)?);
206 } else {
207 bail!("tool must be a JSON object");
208 }
209
210 Ok(result)
211}
212
213pub fn verify_tool(tool: &Value, trusted_key: &VerifyingKey) -> Result<VerifyResult, VerifyError> {
224 let sig_value = tool.get(SIG_FIELD).ok_or(VerifyError::NoSignature)?;
226
227 let sig: ToolSignature =
228 serde_json::from_value(sig_value.clone()).map_err(|e| VerifyError::MalformedSignature {
229 reason: e.to_string(),
230 })?;
231
232 if sig.version != 1 {
234 return Err(VerifyError::MalformedSignature {
235 reason: format!("unsupported version: {}", sig.version),
236 });
237 }
238 if sig.algorithm != SignatureAlgorithm::Ed25519 {
239 return Err(VerifyError::MalformedSignature {
240 reason: format!("unsupported algorithm: {:?}", sig.algorithm),
241 });
242 }
243
244 if sig.payload_type != PAYLOAD_TYPE_TOOL_V1 {
246 return Err(VerifyError::PayloadTypeMismatch {
247 expected: PAYLOAD_TYPE_TOOL_V1.to_string(),
248 got: sig.payload_type,
249 });
250 }
251
252 let tool_without_sig = strip_signature(tool).map_err(|e| VerifyError::MalformedSignature {
254 reason: e.to_string(),
255 })?;
256 let canonical =
257 jcs::to_vec(&tool_without_sig).map_err(|e| VerifyError::MalformedSignature {
258 reason: e.to_string(),
259 })?;
260
261 let computed_digest = compute_payload_digest(&canonical);
263 if sig.payload_digest != computed_digest {
264 return Err(VerifyError::DigestMismatch);
265 }
266
267 let pae = build_pae(&sig.payload_type, &canonical);
269 let signature_bytes =
270 BASE64
271 .decode(&sig.signature)
272 .map_err(|e| VerifyError::MalformedSignature {
273 reason: format!("invalid base64 signature: {}", e),
274 })?;
275 let signature =
276 Signature::from_slice(&signature_bytes).map_err(|e| VerifyError::MalformedSignature {
277 reason: format!("invalid signature bytes: {}", e),
278 })?;
279
280 trusted_key
281 .verify(&pae, &signature)
282 .map_err(|_| VerifyError::SignatureInvalid {
283 reason: "ed25519 verification failed".to_string(),
284 })?;
285
286 let actual_key_id = compute_key_id_from_verifying_key(trusted_key).map_err(|e| {
288 VerifyError::MalformedSignature {
289 reason: e.to_string(),
290 }
291 })?;
292 if sig.key_id != actual_key_id {
293 return Err(VerifyError::KeyIdMismatch {
294 claimed: sig.key_id,
295 actual: actual_key_id,
296 });
297 }
298
299 Ok(VerifyResult {
300 key_id: sig.key_id,
301 signed_at: sig.signed_at,
302 })
303}
304
305pub fn extract_signature(tool: &Value) -> Option<ToolSignature> {
307 tool.get(SIG_FIELD)
308 .and_then(|v| serde_json::from_value(v.clone()).ok())
309}
310
311pub fn is_signed(tool: &Value) -> bool {
313 tool.get(SIG_FIELD).is_some()
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use serde_json::json;
320
321 fn generate_keypair() -> SigningKey {
322 SigningKey::generate(&mut rand::thread_rng())
323 }
324
325 #[test]
326 fn test_sign_and_verify_roundtrip() {
327 let key = generate_keypair();
328 let tool = json!({
329 "name": "read_file",
330 "description": "Read a file",
331 "inputSchema": {"type": "object"}
332 });
333
334 let signed = sign_tool(&tool, &key, false).unwrap();
335 assert!(is_signed(&signed));
336
337 let result = verify_tool(&signed, &key.verifying_key()).unwrap();
338 assert!(result.key_id.starts_with("sha256:"));
339 }
340
341 #[test]
342 fn test_tamper_detection() {
343 let key = generate_keypair();
344 let tool = json!({
345 "name": "read_file",
346 "description": "Read a file",
347 "inputSchema": {"type": "object"}
348 });
349
350 let mut signed = sign_tool(&tool, &key, false).unwrap();
351
352 signed["description"] = json!("Malicious description");
354
355 let result = verify_tool(&signed, &key.verifying_key());
356 assert!(matches!(result, Err(VerifyError::DigestMismatch)));
357 }
358
359 #[test]
360 fn test_wrong_key_fails() {
361 let key1 = generate_keypair();
362 let key2 = generate_keypair();
363 let tool = json!({
364 "name": "test_tool",
365 "description": "Test",
366 "inputSchema": {}
367 });
368
369 let signed = sign_tool(&tool, &key1, false).unwrap();
370 let result = verify_tool(&signed, &key2.verifying_key());
371
372 assert!(matches!(
374 result,
375 Err(VerifyError::SignatureInvalid { .. }) | Err(VerifyError::KeyIdMismatch { .. })
376 ));
377 }
378
379 #[test]
380 fn test_unsigned_tool() {
381 let key = generate_keypair();
382 let tool = json!({"name": "unsigned"});
383
384 let result = verify_tool(&tool, &key.verifying_key());
385 assert!(matches!(result, Err(VerifyError::NoSignature)));
386 }
387
388 #[test]
389 fn test_embed_pubkey() {
390 let key = generate_keypair();
391 let tool = json!({"name": "test", "description": "test", "inputSchema": {}});
392
393 let signed = sign_tool(&tool, &key, true).unwrap();
394 let sig = extract_signature(&signed).unwrap();
395
396 assert!(sig.public_key.is_some());
397 }
398
399 #[test]
400 fn test_key_id_computation() {
401 let key = generate_keypair();
402 let key_id = compute_key_id_from_verifying_key(&key.verifying_key()).unwrap();
403
404 assert!(key_id.starts_with("sha256:"));
405 assert_eq!(key_id.len(), 7 + 64); }
407
408 #[test]
409 fn test_pae_format() {
410 let pae = build_pae("application/json", b"test");
411
412 let expected = b"DSSEv1 16 application/json 4 test";
414 assert_eq!(pae, expected);
415 }
416
417 #[test]
423 fn test_payload_type_length_normative() {
424 let payload_type = PAYLOAD_TYPE_TOOL_V1;
426 assert_eq!(
427 payload_type.len(),
428 35,
429 "PAYLOAD_TYPE_TOOL_V1 must be 35 bytes"
430 );
431 assert!(payload_type.is_ascii());
433
434 let pae = build_pae(payload_type, b"{}");
436 let pae_str = String::from_utf8_lossy(&pae);
437 assert!(
438 pae_str.starts_with("DSSEv1 35 application/vnd.assay.tool+json;v=1 2 {}"),
439 "PAE must start with 'DSSEv1 35 ...' for tool signing"
440 );
441 }
442
443 #[test]
445 fn test_key_id_lowercase_hex() {
446 let key = generate_keypair();
447 let key_id = compute_key_id_from_verifying_key(&key.verifying_key()).unwrap();
448
449 assert!(key_id.starts_with("sha256:"));
451 let hex_part = &key_id[7..];
452 assert!(
453 hex_part
454 .chars()
455 .all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()),
456 "key_id hex must be lowercase: {}",
457 key_id
458 );
459 }
460
461 #[test]
462 fn test_canonicalization_stability() {
463 let key = generate_keypair();
464
465 let tool1 =
467 json!({"name": "test", "description": "desc", "inputSchema": {"type": "object"}});
468 let tool2 =
469 json!({"inputSchema": {"type": "object"}, "name": "test", "description": "desc"});
470
471 let signed1 = sign_tool(&tool1, &key, false).unwrap();
472 let signed2 = sign_tool(&tool2, &key, false).unwrap();
473
474 let sig1 = extract_signature(&signed1).unwrap();
476 let sig2 = extract_signature(&signed2).unwrap();
477
478 assert_eq!(sig1.payload_digest, sig2.payload_digest);
479 }
480
481 #[test]
482 fn test_exit_codes() {
483 assert_eq!(VerifyError::NoSignature.exit_code(), 2);
484 assert_eq!(
485 VerifyError::KeyNotTrusted { key_id: "x".into() }.exit_code(),
486 3
487 );
488 assert_eq!(
489 VerifyError::SignatureInvalid { reason: "x".into() }.exit_code(),
490 4
491 );
492 assert_eq!(
493 VerifyError::MalformedSignature { reason: "x".into() }.exit_code(),
494 1
495 );
496 }
497}