Skip to main content

assay_core/mcp/
signing.rs

1//! Tool signing and verification per SPEC-Tool-Signing-v1.
2//!
3//! Provides ed25519 signing/verification with DSSE-compatible PAE encoding.
4
5use anyhow::{bail, Context, Result};
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use chrono::{DateTime, Utc};
8use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use sha2::{Digest, Sha256};
12
13use super::jcs;
14
15/// Payload type for tool definitions (DSSE-style binding).
16pub const PAYLOAD_TYPE_TOOL_V1: &str = "application/vnd.assay.tool+json;v=1";
17
18/// The x-assay-sig field name.
19pub const SIG_FIELD: &str = "x-assay-sig";
20
21/// Signature algorithm.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "lowercase")]
24pub enum SignatureAlgorithm {
25    Ed25519,
26}
27
28/// The x-assay-sig structure.
29///
30/// # Field Serialization
31///
32/// Producers SHOULD omit `public_key` when not embedding the key.
33/// This is enforced via `skip_serializing_if = "Option::is_none"`.
34/// Verifiers MUST treat `null` as equivalent to absent.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ToolSignature {
37    pub version: u8,
38    pub algorithm: SignatureAlgorithm,
39    pub payload_type: String,
40    pub payload_digest: String,
41    pub key_id: String,
42    pub signature: String,
43    pub signed_at: DateTime<Utc>,
44    /// Embedded public key (SPKI DER, base64).
45    /// Producers SHOULD omit (not set to null) when not embedding.
46    #[serde(default, skip_serializing_if = "Option::is_none")]
47    pub public_key: Option<String>,
48}
49
50/// Result of successful verification.
51#[derive(Debug, Clone)]
52pub struct VerifyResult {
53    pub key_id: String,
54    pub signed_at: DateTime<Utc>,
55}
56
57/// Verification errors with exit codes.
58#[derive(Debug, Clone, thiserror::Error)]
59pub enum VerifyError {
60    #[error("tool is not signed")]
61    NoSignature,
62
63    #[error("payload type mismatch: expected {expected}, got {got}")]
64    PayloadTypeMismatch { expected: String, got: String },
65
66    #[error("signature invalid: {reason}")]
67    SignatureInvalid { reason: String },
68
69    #[error("key not trusted: {key_id}")]
70    KeyNotTrusted { key_id: String },
71
72    #[error("malformed signature: {reason}")]
73    MalformedSignature { reason: String },
74
75    #[error("payload digest mismatch")]
76    DigestMismatch,
77
78    #[error("key_id mismatch: signature claims {claimed}, actual {actual}")]
79    KeyIdMismatch { claimed: String, actual: String },
80}
81
82impl VerifyError {
83    /// Exit code for CLI.
84    pub fn exit_code(&self) -> i32 {
85        match self {
86            Self::NoSignature => 2,
87            Self::KeyNotTrusted { .. } => 3,
88            Self::SignatureInvalid { .. }
89            | Self::PayloadTypeMismatch { .. }
90            | Self::DigestMismatch
91            | Self::KeyIdMismatch { .. } => 4,
92            Self::MalformedSignature { .. } => 1,
93        }
94    }
95}
96
97/// Compute key_id from SPKI-encoded public key bytes.
98///
99/// Returns `sha256:<lowercase-hex>`.
100pub fn compute_key_id(spki_bytes: &[u8]) -> String {
101    let hash = Sha256::digest(spki_bytes);
102    format!("sha256:{:x}", hash)
103}
104
105/// Compute key_id from a VerifyingKey.
106pub fn compute_key_id_from_verifying_key(key: &VerifyingKey) -> Result<String> {
107    let spki_bytes = key_to_spki_der(key)?;
108    Ok(compute_key_id(&spki_bytes))
109}
110
111/// Convert VerifyingKey to SPKI DER bytes.
112fn key_to_spki_der(key: &VerifyingKey) -> Result<Vec<u8>> {
113    use pkcs8::EncodePublicKey;
114    let doc = key
115        .to_public_key_der()
116        .context("failed to encode public key as SPKI DER")?;
117    Ok(doc.as_bytes().to_vec())
118}
119
120/// Build DSSE Pre-Authentication Encoding (PAE).
121///
122/// ```text
123/// PAE(type, payload) = "DSSEv1" SP LEN(type) SP type SP LEN(payload) SP payload
124/// ```
125fn build_pae(payload_type: &str, payload: &[u8]) -> Vec<u8> {
126    let type_len = payload_type.len().to_string();
127    let payload_len = payload.len().to_string();
128
129    let mut pae = Vec::new();
130    pae.extend_from_slice(b"DSSEv1 ");
131    pae.extend_from_slice(type_len.as_bytes());
132    pae.push(b' ');
133    pae.extend_from_slice(payload_type.as_bytes());
134    pae.push(b' ');
135    pae.extend_from_slice(payload_len.as_bytes());
136    pae.push(b' ');
137    pae.extend_from_slice(payload);
138    pae
139}
140
141/// Remove x-assay-sig field from tool JSON.
142fn strip_signature(tool: &Value) -> Result<Value> {
143    let mut tool = tool.clone();
144    if let Some(obj) = tool.as_object_mut() {
145        obj.remove(SIG_FIELD);
146    }
147    Ok(tool)
148}
149
150/// Compute payload digest.
151fn compute_payload_digest(canonical: &[u8]) -> String {
152    let hash = Sha256::digest(canonical);
153    format!("sha256:{:x}", hash)
154}
155
156/// Sign a tool definition.
157///
158/// # Arguments
159///
160/// * `tool` - Tool definition JSON (may or may not have existing signature)
161/// * `signing_key` - Ed25519 private key
162/// * `embed_pubkey` - If true, include public_key in signature
163///
164/// # Returns
165///
166/// Tool definition with x-assay-sig field added.
167pub fn sign_tool(tool: &Value, signing_key: &SigningKey, embed_pubkey: bool) -> Result<Value> {
168    // 1. Remove existing signature
169    let tool_without_sig = strip_signature(tool)?;
170
171    // 2. Canonicalize
172    let canonical = jcs::to_vec(&tool_without_sig)?;
173
174    // 3. Build PAE
175    let pae = build_pae(PAYLOAD_TYPE_TOOL_V1, &canonical);
176
177    // 4. Sign
178    let signature: Signature = signing_key.sign(&pae);
179
180    // 5. Compute digests
181    let payload_digest = compute_payload_digest(&canonical);
182    let verifying_key = signing_key.verifying_key();
183    let key_id = compute_key_id_from_verifying_key(&verifying_key)?;
184
185    // 6. Build x-assay-sig
186    let sig = ToolSignature {
187        version: 1,
188        algorithm: SignatureAlgorithm::Ed25519,
189        payload_type: PAYLOAD_TYPE_TOOL_V1.to_string(),
190        payload_digest,
191        key_id,
192        signature: BASE64.encode(signature.to_bytes()),
193        signed_at: Utc::now(),
194        public_key: if embed_pubkey {
195            let spki = key_to_spki_der(&verifying_key)?;
196            Some(BASE64.encode(&spki))
197        } else {
198            None
199        },
200    };
201
202    // 7. Add to tool
203    let mut result = tool_without_sig;
204    if let Some(obj) = result.as_object_mut() {
205        obj.insert(SIG_FIELD.to_string(), serde_json::to_value(&sig)?);
206    } else {
207        bail!("tool must be a JSON object");
208    }
209
210    Ok(result)
211}
212
213/// Verify a signed tool definition.
214///
215/// # Arguments
216///
217/// * `tool` - Signed tool definition JSON
218/// * `trusted_key` - Public key to verify against
219///
220/// # Returns
221///
222/// `VerifyResult` on success, `VerifyError` on failure.
223pub fn verify_tool(tool: &Value, trusted_key: &VerifyingKey) -> Result<VerifyResult, VerifyError> {
224    // 1. Extract signature
225    let sig_value = tool.get(SIG_FIELD).ok_or(VerifyError::NoSignature)?;
226
227    let sig: ToolSignature =
228        serde_json::from_value(sig_value.clone()).map_err(|e| VerifyError::MalformedSignature {
229            reason: e.to_string(),
230        })?;
231
232    // 2. Validate version and algorithm
233    if sig.version != 1 {
234        return Err(VerifyError::MalformedSignature {
235            reason: format!("unsupported version: {}", sig.version),
236        });
237    }
238    if sig.algorithm != SignatureAlgorithm::Ed25519 {
239        return Err(VerifyError::MalformedSignature {
240            reason: format!("unsupported algorithm: {:?}", sig.algorithm),
241        });
242    }
243
244    // 3. Validate payload_type
245    if sig.payload_type != PAYLOAD_TYPE_TOOL_V1 {
246        return Err(VerifyError::PayloadTypeMismatch {
247            expected: PAYLOAD_TYPE_TOOL_V1.to_string(),
248            got: sig.payload_type,
249        });
250    }
251
252    // 4. Strip signature and canonicalize
253    let tool_without_sig = strip_signature(tool).map_err(|e| VerifyError::MalformedSignature {
254        reason: e.to_string(),
255    })?;
256    let canonical =
257        jcs::to_vec(&tool_without_sig).map_err(|e| VerifyError::MalformedSignature {
258            reason: e.to_string(),
259        })?;
260
261    // 5. Verify payload digest
262    let computed_digest = compute_payload_digest(&canonical);
263    if sig.payload_digest != computed_digest {
264        return Err(VerifyError::DigestMismatch);
265    }
266
267    // 6. Build PAE and verify signature
268    let pae = build_pae(&sig.payload_type, &canonical);
269    let signature_bytes =
270        BASE64
271            .decode(&sig.signature)
272            .map_err(|e| VerifyError::MalformedSignature {
273                reason: format!("invalid base64 signature: {}", e),
274            })?;
275    let signature =
276        Signature::from_slice(&signature_bytes).map_err(|e| VerifyError::MalformedSignature {
277            reason: format!("invalid signature bytes: {}", e),
278        })?;
279
280    trusted_key
281        .verify(&pae, &signature)
282        .map_err(|_| VerifyError::SignatureInvalid {
283            reason: "ed25519 verification failed".to_string(),
284        })?;
285
286    // 7. Verify key_id matches
287    let actual_key_id = compute_key_id_from_verifying_key(trusted_key).map_err(|e| {
288        VerifyError::MalformedSignature {
289            reason: e.to_string(),
290        }
291    })?;
292    if sig.key_id != actual_key_id {
293        return Err(VerifyError::KeyIdMismatch {
294            claimed: sig.key_id,
295            actual: actual_key_id,
296        });
297    }
298
299    Ok(VerifyResult {
300        key_id: sig.key_id,
301        signed_at: sig.signed_at,
302    })
303}
304
305/// Extract signature from a tool (if present).
306pub fn extract_signature(tool: &Value) -> Option<ToolSignature> {
307    tool.get(SIG_FIELD)
308        .and_then(|v| serde_json::from_value(v.clone()).ok())
309}
310
311/// Check if a tool is signed.
312pub fn is_signed(tool: &Value) -> bool {
313    tool.get(SIG_FIELD).is_some()
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use serde_json::json;
320
321    fn generate_keypair() -> SigningKey {
322        SigningKey::generate(&mut rand::thread_rng())
323    }
324
325    #[test]
326    fn test_sign_and_verify_roundtrip() {
327        let key = generate_keypair();
328        let tool = json!({
329            "name": "read_file",
330            "description": "Read a file",
331            "inputSchema": {"type": "object"}
332        });
333
334        let signed = sign_tool(&tool, &key, false).unwrap();
335        assert!(is_signed(&signed));
336
337        let result = verify_tool(&signed, &key.verifying_key()).unwrap();
338        assert!(result.key_id.starts_with("sha256:"));
339    }
340
341    #[test]
342    fn test_tamper_detection() {
343        let key = generate_keypair();
344        let tool = json!({
345            "name": "read_file",
346            "description": "Read a file",
347            "inputSchema": {"type": "object"}
348        });
349
350        let mut signed = sign_tool(&tool, &key, false).unwrap();
351
352        // Tamper with the tool
353        signed["description"] = json!("Malicious description");
354
355        let result = verify_tool(&signed, &key.verifying_key());
356        assert!(matches!(result, Err(VerifyError::DigestMismatch)));
357    }
358
359    #[test]
360    fn test_wrong_key_fails() {
361        let key1 = generate_keypair();
362        let key2 = generate_keypair();
363        let tool = json!({
364            "name": "test_tool",
365            "description": "Test",
366            "inputSchema": {}
367        });
368
369        let signed = sign_tool(&tool, &key1, false).unwrap();
370        let result = verify_tool(&signed, &key2.verifying_key());
371
372        // Should fail with either SignatureInvalid or KeyIdMismatch
373        assert!(matches!(
374            result,
375            Err(VerifyError::SignatureInvalid { .. }) | Err(VerifyError::KeyIdMismatch { .. })
376        ));
377    }
378
379    #[test]
380    fn test_unsigned_tool() {
381        let key = generate_keypair();
382        let tool = json!({"name": "unsigned"});
383
384        let result = verify_tool(&tool, &key.verifying_key());
385        assert!(matches!(result, Err(VerifyError::NoSignature)));
386    }
387
388    #[test]
389    fn test_embed_pubkey() {
390        let key = generate_keypair();
391        let tool = json!({"name": "test", "description": "test", "inputSchema": {}});
392
393        let signed = sign_tool(&tool, &key, true).unwrap();
394        let sig = extract_signature(&signed).unwrap();
395
396        assert!(sig.public_key.is_some());
397    }
398
399    #[test]
400    fn test_key_id_computation() {
401        let key = generate_keypair();
402        let key_id = compute_key_id_from_verifying_key(&key.verifying_key()).unwrap();
403
404        assert!(key_id.starts_with("sha256:"));
405        assert_eq!(key_id.len(), 7 + 64); // "sha256:" + 64 hex chars
406    }
407
408    #[test]
409    fn test_pae_format() {
410        let pae = build_pae("application/json", b"test");
411
412        // "DSSEv1 16 application/json 4 test"
413        let expected = b"DSSEv1 16 application/json 4 test";
414        assert_eq!(pae, expected);
415    }
416
417    /// Normative test vector for PAYLOAD_TYPE_TOOL_V1 length.
418    ///
419    /// This test ensures the exact byte length of the payload type is
420    /// consistent across implementations. PAE uses decimal length encoding,
421    /// so any mismatch causes cross-impl verification failures.
422    #[test]
423    fn test_payload_type_length_normative() {
424        // "application/vnd.assay.tool+json;v=1" is exactly 35 bytes UTF-8
425        let payload_type = PAYLOAD_TYPE_TOOL_V1;
426        assert_eq!(
427            payload_type.len(),
428            35,
429            "PAYLOAD_TYPE_TOOL_V1 must be 35 bytes"
430        );
431        // Verify it's pure ASCII (each char = 1 byte)
432        assert!(payload_type.is_ascii());
433
434        // Verify PAE encoding uses correct length
435        let pae = build_pae(payload_type, b"{}");
436        let pae_str = String::from_utf8_lossy(&pae);
437        assert!(
438            pae_str.starts_with("DSSEv1 35 application/vnd.assay.tool+json;v=1 2 {}"),
439            "PAE must start with 'DSSEv1 35 ...' for tool signing"
440        );
441    }
442
443    /// Test that key_id uses lowercase hex (normative).
444    #[test]
445    fn test_key_id_lowercase_hex() {
446        let key = generate_keypair();
447        let key_id = compute_key_id_from_verifying_key(&key.verifying_key()).unwrap();
448
449        // Must be lowercase hex
450        assert!(key_id.starts_with("sha256:"));
451        let hex_part = &key_id[7..];
452        assert!(
453            hex_part
454                .chars()
455                .all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()),
456            "key_id hex must be lowercase: {}",
457            key_id
458        );
459    }
460
461    #[test]
462    fn test_canonicalization_stability() {
463        let key = generate_keypair();
464
465        // Same tool, different JSON formatting
466        let tool1 =
467            json!({"name": "test", "description": "desc", "inputSchema": {"type": "object"}});
468        let tool2 =
469            json!({"inputSchema": {"type": "object"}, "name": "test", "description": "desc"});
470
471        let signed1 = sign_tool(&tool1, &key, false).unwrap();
472        let signed2 = sign_tool(&tool2, &key, false).unwrap();
473
474        // Both should have the same payload_digest
475        let sig1 = extract_signature(&signed1).unwrap();
476        let sig2 = extract_signature(&signed2).unwrap();
477
478        assert_eq!(sig1.payload_digest, sig2.payload_digest);
479    }
480
481    #[test]
482    fn test_exit_codes() {
483        assert_eq!(VerifyError::NoSignature.exit_code(), 2);
484        assert_eq!(
485            VerifyError::KeyNotTrusted { key_id: "x".into() }.exit_code(),
486            3
487        );
488        assert_eq!(
489            VerifyError::SignatureInvalid { reason: "x".into() }.exit_code(),
490            4
491        );
492        assert_eq!(
493            VerifyError::MalformedSignature { reason: "x".into() }.exit_code(),
494            1
495        );
496    }
497}