Skip to main content

assay_core/mcp/
signing.rs

1//! Tool signing and verification per SPEC-Tool-Signing-v1.
2//!
3//! Provides ed25519 signing/verification with DSSE-compatible PAE encoding.
4
5use anyhow::{bail, Context, Result};
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use chrono::{DateTime, Utc};
8use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use sha2::{Digest, Sha256};
12
13use super::jcs;
14
15/// Payload type for tool definitions (DSSE-style binding).
16pub const PAYLOAD_TYPE_TOOL_V1: &str = "application/vnd.assay.tool+json;v=1";
17
18/// The x-assay-sig field name.
19pub const SIG_FIELD: &str = "x-assay-sig";
20
21/// Signature algorithm.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "lowercase")]
24pub enum SignatureAlgorithm {
25    Ed25519,
26}
27
28/// The x-assay-sig structure.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ToolSignature {
31    pub version: u8,
32    pub algorithm: SignatureAlgorithm,
33    pub payload_type: String,
34    pub payload_digest: String,
35    pub key_id: String,
36    pub signature: String,
37    pub signed_at: DateTime<Utc>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub public_key: Option<String>,
40}
41
42/// Result of successful verification.
43#[derive(Debug, Clone)]
44pub struct VerifyResult {
45    pub key_id: String,
46    pub signed_at: DateTime<Utc>,
47}
48
49/// Verification errors with exit codes.
50#[derive(Debug, Clone, thiserror::Error)]
51pub enum VerifyError {
52    #[error("tool is not signed")]
53    NoSignature,
54
55    #[error("payload type mismatch: expected {expected}, got {got}")]
56    PayloadTypeMismatch { expected: String, got: String },
57
58    #[error("signature invalid: {reason}")]
59    SignatureInvalid { reason: String },
60
61    #[error("key not trusted: {key_id}")]
62    KeyNotTrusted { key_id: String },
63
64    #[error("malformed signature: {reason}")]
65    MalformedSignature { reason: String },
66
67    #[error("payload digest mismatch")]
68    DigestMismatch,
69
70    #[error("key_id mismatch: signature claims {claimed}, actual {actual}")]
71    KeyIdMismatch { claimed: String, actual: String },
72}
73
74impl VerifyError {
75    /// Exit code for CLI.
76    pub fn exit_code(&self) -> i32 {
77        match self {
78            Self::NoSignature => 2,
79            Self::KeyNotTrusted { .. } => 3,
80            Self::SignatureInvalid { .. }
81            | Self::PayloadTypeMismatch { .. }
82            | Self::DigestMismatch
83            | Self::KeyIdMismatch { .. } => 4,
84            Self::MalformedSignature { .. } => 1,
85        }
86    }
87}
88
89/// Compute key_id from SPKI-encoded public key bytes.
90///
91/// Returns `sha256:<lowercase-hex>`.
92pub fn compute_key_id(spki_bytes: &[u8]) -> String {
93    let hash = Sha256::digest(spki_bytes);
94    format!("sha256:{:x}", hash)
95}
96
97/// Compute key_id from a VerifyingKey.
98pub fn compute_key_id_from_verifying_key(key: &VerifyingKey) -> Result<String> {
99    let spki_bytes = key_to_spki_der(key)?;
100    Ok(compute_key_id(&spki_bytes))
101}
102
103/// Convert VerifyingKey to SPKI DER bytes.
104fn key_to_spki_der(key: &VerifyingKey) -> Result<Vec<u8>> {
105    use pkcs8::EncodePublicKey;
106    let doc = key
107        .to_public_key_der()
108        .context("failed to encode public key as SPKI DER")?;
109    Ok(doc.as_bytes().to_vec())
110}
111
112/// Build DSSE Pre-Authentication Encoding (PAE).
113///
114/// ```text
115/// PAE(type, payload) = "DSSEv1" SP LEN(type) SP type SP LEN(payload) SP payload
116/// ```
117fn build_pae(payload_type: &str, payload: &[u8]) -> Vec<u8> {
118    let type_len = payload_type.len().to_string();
119    let payload_len = payload.len().to_string();
120
121    let mut pae = Vec::new();
122    pae.extend_from_slice(b"DSSEv1 ");
123    pae.extend_from_slice(type_len.as_bytes());
124    pae.push(b' ');
125    pae.extend_from_slice(payload_type.as_bytes());
126    pae.push(b' ');
127    pae.extend_from_slice(payload_len.as_bytes());
128    pae.push(b' ');
129    pae.extend_from_slice(payload);
130    pae
131}
132
133/// Remove x-assay-sig field from tool JSON.
134fn strip_signature(tool: &Value) -> Result<Value> {
135    let mut tool = tool.clone();
136    if let Some(obj) = tool.as_object_mut() {
137        obj.remove(SIG_FIELD);
138    }
139    Ok(tool)
140}
141
142/// Compute payload digest.
143fn compute_payload_digest(canonical: &[u8]) -> String {
144    let hash = Sha256::digest(canonical);
145    format!("sha256:{:x}", hash)
146}
147
148/// Sign a tool definition.
149///
150/// # Arguments
151///
152/// * `tool` - Tool definition JSON (may or may not have existing signature)
153/// * `signing_key` - Ed25519 private key
154/// * `embed_pubkey` - If true, include public_key in signature
155///
156/// # Returns
157///
158/// Tool definition with x-assay-sig field added.
159pub fn sign_tool(tool: &Value, signing_key: &SigningKey, embed_pubkey: bool) -> Result<Value> {
160    // 1. Remove existing signature
161    let tool_without_sig = strip_signature(tool)?;
162
163    // 2. Canonicalize
164    let canonical = jcs::to_vec(&tool_without_sig)?;
165
166    // 3. Build PAE
167    let pae = build_pae(PAYLOAD_TYPE_TOOL_V1, &canonical);
168
169    // 4. Sign
170    let signature: Signature = signing_key.sign(&pae);
171
172    // 5. Compute digests
173    let payload_digest = compute_payload_digest(&canonical);
174    let verifying_key = signing_key.verifying_key();
175    let key_id = compute_key_id_from_verifying_key(&verifying_key)?;
176
177    // 6. Build x-assay-sig
178    let sig = ToolSignature {
179        version: 1,
180        algorithm: SignatureAlgorithm::Ed25519,
181        payload_type: PAYLOAD_TYPE_TOOL_V1.to_string(),
182        payload_digest,
183        key_id,
184        signature: BASE64.encode(signature.to_bytes()),
185        signed_at: Utc::now(),
186        public_key: if embed_pubkey {
187            let spki = key_to_spki_der(&verifying_key)?;
188            Some(BASE64.encode(&spki))
189        } else {
190            None
191        },
192    };
193
194    // 7. Add to tool
195    let mut result = tool_without_sig;
196    if let Some(obj) = result.as_object_mut() {
197        obj.insert(SIG_FIELD.to_string(), serde_json::to_value(&sig)?);
198    } else {
199        bail!("tool must be a JSON object");
200    }
201
202    Ok(result)
203}
204
205/// Verify a signed tool definition.
206///
207/// # Arguments
208///
209/// * `tool` - Signed tool definition JSON
210/// * `trusted_key` - Public key to verify against
211///
212/// # Returns
213///
214/// `VerifyResult` on success, `VerifyError` on failure.
215pub fn verify_tool(tool: &Value, trusted_key: &VerifyingKey) -> Result<VerifyResult, VerifyError> {
216    // 1. Extract signature
217    let sig_value = tool.get(SIG_FIELD).ok_or(VerifyError::NoSignature)?;
218
219    let sig: ToolSignature =
220        serde_json::from_value(sig_value.clone()).map_err(|e| VerifyError::MalformedSignature {
221            reason: e.to_string(),
222        })?;
223
224    // 2. Validate version and algorithm
225    if sig.version != 1 {
226        return Err(VerifyError::MalformedSignature {
227            reason: format!("unsupported version: {}", sig.version),
228        });
229    }
230    if sig.algorithm != SignatureAlgorithm::Ed25519 {
231        return Err(VerifyError::MalformedSignature {
232            reason: format!("unsupported algorithm: {:?}", sig.algorithm),
233        });
234    }
235
236    // 3. Validate payload_type
237    if sig.payload_type != PAYLOAD_TYPE_TOOL_V1 {
238        return Err(VerifyError::PayloadTypeMismatch {
239            expected: PAYLOAD_TYPE_TOOL_V1.to_string(),
240            got: sig.payload_type,
241        });
242    }
243
244    // 4. Strip signature and canonicalize
245    let tool_without_sig = strip_signature(tool).map_err(|e| VerifyError::MalformedSignature {
246        reason: e.to_string(),
247    })?;
248    let canonical =
249        jcs::to_vec(&tool_without_sig).map_err(|e| VerifyError::MalformedSignature {
250            reason: e.to_string(),
251        })?;
252
253    // 5. Verify payload digest
254    let computed_digest = compute_payload_digest(&canonical);
255    if sig.payload_digest != computed_digest {
256        return Err(VerifyError::DigestMismatch);
257    }
258
259    // 6. Build PAE and verify signature
260    let pae = build_pae(&sig.payload_type, &canonical);
261    let signature_bytes =
262        BASE64
263            .decode(&sig.signature)
264            .map_err(|e| VerifyError::MalformedSignature {
265                reason: format!("invalid base64 signature: {}", e),
266            })?;
267    let signature =
268        Signature::from_slice(&signature_bytes).map_err(|e| VerifyError::MalformedSignature {
269            reason: format!("invalid signature bytes: {}", e),
270        })?;
271
272    trusted_key
273        .verify(&pae, &signature)
274        .map_err(|_| VerifyError::SignatureInvalid {
275            reason: "ed25519 verification failed".to_string(),
276        })?;
277
278    // 7. Verify key_id matches
279    let actual_key_id = compute_key_id_from_verifying_key(trusted_key).map_err(|e| {
280        VerifyError::MalformedSignature {
281            reason: e.to_string(),
282        }
283    })?;
284    if sig.key_id != actual_key_id {
285        return Err(VerifyError::KeyIdMismatch {
286            claimed: sig.key_id,
287            actual: actual_key_id,
288        });
289    }
290
291    Ok(VerifyResult {
292        key_id: sig.key_id,
293        signed_at: sig.signed_at,
294    })
295}
296
297/// Extract signature from a tool (if present).
298pub fn extract_signature(tool: &Value) -> Option<ToolSignature> {
299    tool.get(SIG_FIELD)
300        .and_then(|v| serde_json::from_value(v.clone()).ok())
301}
302
303/// Check if a tool is signed.
304pub fn is_signed(tool: &Value) -> bool {
305    tool.get(SIG_FIELD).is_some()
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use serde_json::json;
312
313    fn generate_keypair() -> SigningKey {
314        SigningKey::generate(&mut rand::thread_rng())
315    }
316
317    #[test]
318    fn test_sign_and_verify_roundtrip() {
319        let key = generate_keypair();
320        let tool = json!({
321            "name": "read_file",
322            "description": "Read a file",
323            "inputSchema": {"type": "object"}
324        });
325
326        let signed = sign_tool(&tool, &key, false).unwrap();
327        assert!(is_signed(&signed));
328
329        let result = verify_tool(&signed, &key.verifying_key()).unwrap();
330        assert!(result.key_id.starts_with("sha256:"));
331    }
332
333    #[test]
334    fn test_tamper_detection() {
335        let key = generate_keypair();
336        let tool = json!({
337            "name": "read_file",
338            "description": "Read a file",
339            "inputSchema": {"type": "object"}
340        });
341
342        let mut signed = sign_tool(&tool, &key, false).unwrap();
343
344        // Tamper with the tool
345        signed["description"] = json!("Malicious description");
346
347        let result = verify_tool(&signed, &key.verifying_key());
348        assert!(matches!(result, Err(VerifyError::DigestMismatch)));
349    }
350
351    #[test]
352    fn test_wrong_key_fails() {
353        let key1 = generate_keypair();
354        let key2 = generate_keypair();
355        let tool = json!({
356            "name": "test_tool",
357            "description": "Test",
358            "inputSchema": {}
359        });
360
361        let signed = sign_tool(&tool, &key1, false).unwrap();
362        let result = verify_tool(&signed, &key2.verifying_key());
363
364        // Should fail with either SignatureInvalid or KeyIdMismatch
365        assert!(matches!(
366            result,
367            Err(VerifyError::SignatureInvalid { .. }) | Err(VerifyError::KeyIdMismatch { .. })
368        ));
369    }
370
371    #[test]
372    fn test_unsigned_tool() {
373        let key = generate_keypair();
374        let tool = json!({"name": "unsigned"});
375
376        let result = verify_tool(&tool, &key.verifying_key());
377        assert!(matches!(result, Err(VerifyError::NoSignature)));
378    }
379
380    #[test]
381    fn test_embed_pubkey() {
382        let key = generate_keypair();
383        let tool = json!({"name": "test", "description": "test", "inputSchema": {}});
384
385        let signed = sign_tool(&tool, &key, true).unwrap();
386        let sig = extract_signature(&signed).unwrap();
387
388        assert!(sig.public_key.is_some());
389    }
390
391    #[test]
392    fn test_key_id_computation() {
393        let key = generate_keypair();
394        let key_id = compute_key_id_from_verifying_key(&key.verifying_key()).unwrap();
395
396        assert!(key_id.starts_with("sha256:"));
397        assert_eq!(key_id.len(), 7 + 64); // "sha256:" + 64 hex chars
398    }
399
400    #[test]
401    fn test_pae_format() {
402        let pae = build_pae("application/json", b"test");
403
404        // "DSSEv1 16 application/json 4 test"
405        let expected = b"DSSEv1 16 application/json 4 test";
406        assert_eq!(pae, expected);
407    }
408
409    #[test]
410    fn test_canonicalization_stability() {
411        let key = generate_keypair();
412
413        // Same tool, different JSON formatting
414        let tool1 =
415            json!({"name": "test", "description": "desc", "inputSchema": {"type": "object"}});
416        let tool2 =
417            json!({"inputSchema": {"type": "object"}, "name": "test", "description": "desc"});
418
419        let signed1 = sign_tool(&tool1, &key, false).unwrap();
420        let signed2 = sign_tool(&tool2, &key, false).unwrap();
421
422        // Both should have the same payload_digest
423        let sig1 = extract_signature(&signed1).unwrap();
424        let sig2 = extract_signature(&signed2).unwrap();
425
426        assert_eq!(sig1.payload_digest, sig2.payload_digest);
427    }
428
429    #[test]
430    fn test_exit_codes() {
431        assert_eq!(VerifyError::NoSignature.exit_code(), 2);
432        assert_eq!(
433            VerifyError::KeyNotTrusted { key_id: "x".into() }.exit_code(),
434            3
435        );
436        assert_eq!(
437            VerifyError::SignatureInvalid { reason: "x".into() }.exit_code(),
438            4
439        );
440        assert_eq!(
441            VerifyError::MalformedSignature { reason: "x".into() }.exit_code(),
442            1
443        );
444    }
445}