tap_agent/
message_packing.rs

1//! Message Packing and Unpacking Utilities
2//!
3//! This module provides traits and implementations for standardizing
4//! how messages are prepared for transmission (packed) and processed
5//! upon receipt (unpacked).
6
7use crate::agent_key::VerificationKey;
8use crate::error::{Error, Result};
9use crate::message::{Jwe, Jws, SecurityMode};
10use async_trait::async_trait;
11use base64::Engine;
12use serde::de::DeserializeOwned;
13use serde::Serialize;
14use serde_json::Value;
15use std::any::Any;
16use std::fmt::Debug;
17use std::sync::Arc;
18use tap_msg::didcomm::PlainMessage;
19use uuid::Uuid;
20
21/// Error type specific to message packing and unpacking
22#[derive(Debug, thiserror::Error)]
23pub enum MessageError {
24    #[error("Serialization error: {0}")]
25    Serialization(#[from] serde_json::Error),
26
27    #[error("Key manager error: {0}")]
28    KeyManager(String),
29
30    #[error("Crypto operation failed: {0}")]
31    Crypto(String),
32
33    #[error("Invalid message format: {0}")]
34    InvalidFormat(String),
35
36    #[error("Unsupported security mode: {0:?}")]
37    UnsupportedSecurityMode(SecurityMode),
38
39    #[error("Missing required parameter: {0}")]
40    MissingParameter(String),
41
42    #[error("Key not found: {0}")]
43    KeyNotFound(String),
44
45    #[error("Verification failed")]
46    VerificationFailed,
47
48    #[error("Decryption failed")]
49    DecryptionFailed,
50}
51
52impl From<MessageError> for Error {
53    fn from(err: MessageError) -> Self {
54        match err {
55            MessageError::Serialization(e) => Error::Serialization(e.to_string()),
56            MessageError::KeyManager(e) => Error::Cryptography(e),
57            MessageError::Crypto(e) => Error::Cryptography(e),
58            MessageError::InvalidFormat(e) => Error::Validation(e),
59            MessageError::UnsupportedSecurityMode(mode) => {
60                Error::Validation(format!("Unsupported security mode: {:?}", mode))
61            }
62            MessageError::MissingParameter(e) => {
63                Error::Validation(format!("Missing parameter: {}", e))
64            }
65            MessageError::KeyNotFound(e) => Error::Cryptography(format!("Key not found: {}", e)),
66            MessageError::VerificationFailed => {
67                Error::Cryptography("Verification failed".to_string())
68            }
69            MessageError::DecryptionFailed => Error::Cryptography("Decryption failed".to_string()),
70        }
71    }
72}
73
74/// Options for packing a message
75#[derive(Debug, Clone)]
76pub struct PackOptions {
77    /// Security mode to use
78    pub security_mode: SecurityMode,
79    /// Key ID of the recipient (for JWE)
80    pub recipient_kid: Option<String>,
81    /// Key ID of the sender (for JWS and JWE)
82    pub sender_kid: Option<String>,
83}
84
85impl Default for PackOptions {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl PackOptions {
92    /// Create new default packing options
93    pub fn new() -> Self {
94        Self {
95            security_mode: SecurityMode::Plain,
96            recipient_kid: None,
97            sender_kid: None,
98        }
99    }
100
101    /// Set to use plain mode (no security)
102    pub fn with_plain(mut self) -> Self {
103        self.security_mode = SecurityMode::Plain;
104        self
105    }
106
107    /// Set to use signed mode with the given sender key ID
108    pub fn with_sign(mut self, sender_kid: &str) -> Self {
109        self.security_mode = SecurityMode::Signed;
110        self.sender_kid = Some(sender_kid.to_string());
111        self
112    }
113
114    /// Set to use auth-crypt mode with the given sender and recipient key IDs
115    pub fn with_auth_crypt(mut self, sender_kid: &str, recipient_jwk: &serde_json::Value) -> Self {
116        self.security_mode = SecurityMode::AuthCrypt;
117        self.sender_kid = Some(sender_kid.to_string());
118
119        // Extract kid from JWK if available
120        if let Some(kid) = recipient_jwk.get("kid").and_then(|k| k.as_str()) {
121            self.recipient_kid = Some(kid.to_string());
122        }
123
124        self
125    }
126
127    /// Get the security mode
128    pub fn security_mode(&self) -> SecurityMode {
129        self.security_mode
130    }
131}
132
133/// Options for unpacking a message
134#[derive(Debug, Clone)]
135pub struct UnpackOptions {
136    /// Expected security mode, or Any to try all modes
137    pub expected_security_mode: SecurityMode,
138    /// Expected recipient key ID
139    pub expected_recipient_kid: Option<String>,
140    /// Whether to require a valid signature
141    pub require_signature: bool,
142}
143
144impl Default for UnpackOptions {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150impl UnpackOptions {
151    /// Create new default unpacking options
152    pub fn new() -> Self {
153        Self {
154            expected_security_mode: SecurityMode::Any,
155            expected_recipient_kid: None,
156            require_signature: false,
157        }
158    }
159
160    /// Set whether to require a valid signature
161    pub fn with_require_signature(mut self, require: bool) -> Self {
162        self.require_signature = require;
163        self
164    }
165}
166
167/// Trait for objects that can be packed for secure transmission
168#[async_trait]
169pub trait Packable<Output = String>: Sized {
170    /// Pack the object for secure transmission
171    async fn pack(
172        &self,
173        key_manager: &(impl KeyManagerPacking + ?Sized),
174        options: PackOptions,
175    ) -> Result<Output>;
176}
177
178/// Trait for objects that can be unpacked from a secure format
179#[async_trait]
180pub trait Unpackable<Input, Output = PlainMessage>: Sized {
181    /// Unpack the object from its secure format
182    async fn unpack(
183        packed_message: &Input,
184        key_manager: &(impl KeyManagerPacking + ?Sized),
185        options: UnpackOptions,
186    ) -> Result<Output>;
187}
188
189/// Interface required for key managers to support packing/unpacking
190#[async_trait]
191pub trait KeyManagerPacking: Send + Sync + Debug {
192    /// Get a signing key by ID
193    async fn get_signing_key(
194        &self,
195        kid: &str,
196    ) -> Result<Arc<dyn crate::agent_key::SigningKey + Send + Sync>>;
197
198    /// Get an encryption key by ID
199    async fn get_encryption_key(
200        &self,
201        kid: &str,
202    ) -> Result<Arc<dyn crate::agent_key::EncryptionKey + Send + Sync>>;
203
204    /// Get a decryption key by ID
205    async fn get_decryption_key(
206        &self,
207        kid: &str,
208    ) -> Result<Arc<dyn crate::agent_key::DecryptionKey + Send + Sync>>;
209
210    /// Resolve a verification key
211    async fn resolve_verification_key(
212        &self,
213        kid: &str,
214    ) -> Result<Arc<dyn VerificationKey + Send + Sync>>;
215}
216
217/// Implement Packable for PlainMessage
218#[async_trait]
219impl Packable for PlainMessage {
220    async fn pack(
221        &self,
222        key_manager: &(impl KeyManagerPacking + ?Sized),
223        options: PackOptions,
224    ) -> Result<String> {
225        match options.security_mode {
226            SecurityMode::Plain => {
227                // For plain mode, just serialize the PlainMessage
228                serde_json::to_string(self).map_err(|e| Error::Serialization(e.to_string()))
229            }
230            SecurityMode::Signed => {
231                // Signed mode requires a sender KID
232                let sender_kid = options.sender_kid.clone().ok_or_else(|| {
233                    Error::Validation("Signed mode requires sender_kid".to_string())
234                })?;
235
236                // Get the signing key
237                let signing_key = key_manager.get_signing_key(&sender_kid).await?;
238
239                // Prepare the message payload to sign
240                let payload =
241                    serde_json::to_string(self).map_err(|e| Error::Serialization(e.to_string()))?;
242
243                // Create a JWS
244                let jws = signing_key
245                    .create_jws(payload.as_bytes(), None)
246                    .await
247                    .map_err(|e| Error::Cryptography(format!("Failed to create JWS: {}", e)))?;
248
249                // Serialize the JWS
250                serde_json::to_string(&jws).map_err(|e| Error::Serialization(e.to_string()))
251            }
252            SecurityMode::AuthCrypt => {
253                // AuthCrypt mode requires both sender and recipient KIDs
254                let sender_kid = options.sender_kid.clone().ok_or_else(|| {
255                    Error::Validation("AuthCrypt mode requires sender_kid".to_string())
256                })?;
257
258                let recipient_kid = options.recipient_kid.clone().ok_or_else(|| {
259                    Error::Validation("AuthCrypt mode requires recipient_kid".to_string())
260                })?;
261
262                // Get the encryption key
263                let encryption_key = key_manager.get_encryption_key(&sender_kid).await?;
264
265                // Get the recipient's verification key
266                let recipient_key = key_manager.resolve_verification_key(&recipient_kid).await?;
267
268                // Serialize the message
269                let plaintext =
270                    serde_json::to_string(self).map_err(|e| Error::Serialization(e.to_string()))?;
271
272                // Create a JWE for the recipient
273                let jwe = encryption_key
274                    .create_jwe(plaintext.as_bytes(), &[recipient_key], None)
275                    .await
276                    .map_err(|e| Error::Cryptography(format!("Failed to create JWE: {}", e)))?;
277
278                // Serialize the JWE
279                serde_json::to_string(&jwe).map_err(|e| Error::Serialization(e.to_string()))
280            }
281            SecurityMode::Any => {
282                // Any mode is not valid for packing, only for unpacking
283                Err(Error::Validation(
284                    "SecurityMode::Any is not valid for packing".to_string(),
285                ))
286            }
287        }
288    }
289}
290
291/// We can't implement Packable for all types due to the conflict with PlainMessage
292/// Instead, let's create a helper function:
293pub async fn pack_any<T>(
294    obj: &T,
295    key_manager: &(impl KeyManagerPacking + ?Sized),
296    options: PackOptions,
297) -> Result<String>
298where
299    T: Serialize + Send + Sync + std::fmt::Debug + 'static + Sized,
300{
301    // Skip attempt to implement Packable for generic types and use a helper function instead
302
303    // If the object is a PlainMessage, use PlainMessage's implementation
304    if obj.type_id() == std::any::TypeId::of::<PlainMessage>() {
305        // In this case, we can't easily downcast, so we'll serialize and deserialize
306        let value = serde_json::to_value(obj).map_err(|e| Error::Serialization(e.to_string()))?;
307        let plain_msg: PlainMessage =
308            serde_json::from_value(value).map_err(|e| Error::Serialization(e.to_string()))?;
309        return plain_msg.pack(key_manager, options).await;
310    }
311
312    // Otherwise, implement the same logic here as in the PlainMessage implementation
313    match options.security_mode {
314        SecurityMode::Plain => {
315            // For plain mode, just serialize the object to JSON
316            serde_json::to_string(obj).map_err(|e| Error::Serialization(e.to_string()))
317        }
318        SecurityMode::Signed => {
319            // Signed mode requires a sender KID
320            let sender_kid = options
321                .sender_kid
322                .clone()
323                .ok_or_else(|| Error::Validation("Signed mode requires sender_kid".to_string()))?;
324
325            // Get the signing key
326            let signing_key = key_manager.get_signing_key(&sender_kid).await?;
327
328            // Convert to a Value first
329            let value =
330                serde_json::to_value(obj).map_err(|e| Error::Serialization(e.to_string()))?;
331
332            // Ensure it's an object
333            let obj = value
334                .as_object()
335                .ok_or_else(|| Error::Validation("Message is not a JSON object".to_string()))?;
336
337            // Extract ID, or generate one if missing
338            let id_string = obj
339                .get("id")
340                .map(|v| v.as_str().unwrap_or_default().to_string())
341                .unwrap_or_else(|| Uuid::new_v4().to_string());
342            let id = id_string.as_str();
343
344            // Extract type, or use default
345            let msg_type = obj
346                .get("type")
347                .and_then(|v| v.as_str())
348                .unwrap_or("https://tap.rsvp/schema/1.0/message");
349
350            // Create sender/recipient lists
351            let from = options.sender_kid.as_ref().map(|kid| {
352                // Extract DID part from kid (assuming format is did#key-1)
353                kid.split('#').next().unwrap_or(kid).to_string()
354            });
355
356            let to = if let Some(kid) = &options.recipient_kid {
357                // Extract DID part from kid
358                let did = kid.split('#').next().unwrap_or(kid).to_string();
359                vec![did]
360            } else {
361                vec![]
362            };
363
364            // Create a PlainMessage
365            let plain_message = PlainMessage {
366                id: id.to_string(),
367                typ: "application/didcomm-plain+json".to_string(),
368                type_: msg_type.to_string(),
369                body: value,
370                from: from.unwrap_or_default(),
371                to,
372                thid: None,
373                pthid: None,
374                created_time: Some(chrono::Utc::now().timestamp() as u64),
375                expires_time: None,
376                from_prior: None,
377                attachments: None,
378                extra_headers: std::collections::HashMap::new(),
379            };
380
381            // Prepare the message payload to sign
382            let payload = serde_json::to_string(&plain_message)
383                .map_err(|e| Error::Serialization(e.to_string()))?;
384
385            // Create a JWS
386            let jws = signing_key
387                .create_jws(payload.as_bytes(), None)
388                .await
389                .map_err(|e| Error::Cryptography(format!("Failed to create JWS: {}", e)))?;
390
391            // Serialize the JWS
392            serde_json::to_string(&jws).map_err(|e| Error::Serialization(e.to_string()))
393        }
394        SecurityMode::AuthCrypt => {
395            // AuthCrypt mode requires both sender and recipient KIDs
396            let sender_kid = options.sender_kid.clone().ok_or_else(|| {
397                Error::Validation("AuthCrypt mode requires sender_kid".to_string())
398            })?;
399
400            let recipient_kid = options.recipient_kid.clone().ok_or_else(|| {
401                Error::Validation("AuthCrypt mode requires recipient_kid".to_string())
402            })?;
403
404            // Get the encryption key
405            let encryption_key = key_manager.get_encryption_key(&sender_kid).await?;
406
407            // Get the recipient's verification key
408            let recipient_key = key_manager.resolve_verification_key(&recipient_kid).await?;
409
410            // Convert to a Value first
411            let value =
412                serde_json::to_value(obj).map_err(|e| Error::Serialization(e.to_string()))?;
413
414            // Ensure it's an object
415            let obj = value
416                .as_object()
417                .ok_or_else(|| Error::Validation("Message is not a JSON object".to_string()))?;
418
419            // Extract ID, or generate one if missing
420            let id_string = obj
421                .get("id")
422                .map(|v| v.as_str().unwrap_or_default().to_string())
423                .unwrap_or_else(|| Uuid::new_v4().to_string());
424            let id = id_string.as_str();
425
426            // Extract type, or use default
427            let msg_type = obj
428                .get("type")
429                .and_then(|v| v.as_str())
430                .unwrap_or("https://tap.rsvp/schema/1.0/message");
431
432            // Create sender/recipient lists
433            let from = options.sender_kid.as_ref().map(|kid| {
434                // Extract DID part from kid (assuming format is did#key-1)
435                kid.split('#').next().unwrap_or(kid).to_string()
436            });
437
438            let to = if let Some(kid) = &options.recipient_kid {
439                // Extract DID part from kid
440                let did = kid.split('#').next().unwrap_or(kid).to_string();
441                vec![did]
442            } else {
443                vec![]
444            };
445
446            // Create a PlainMessage
447            let plain_message = PlainMessage {
448                id: id.to_string(),
449                typ: "application/didcomm-plain+json".to_string(),
450                type_: msg_type.to_string(),
451                body: value,
452                from: from.unwrap_or_default(),
453                to,
454                thid: None,
455                pthid: None,
456                created_time: Some(chrono::Utc::now().timestamp() as u64),
457                expires_time: None,
458                from_prior: None,
459                attachments: None,
460                extra_headers: std::collections::HashMap::new(),
461            };
462
463            // Serialize the message
464            let plaintext = serde_json::to_string(&plain_message)
465                .map_err(|e| Error::Serialization(e.to_string()))?;
466
467            // Create a JWE for the recipient
468            let jwe = encryption_key
469                .create_jwe(plaintext.as_bytes(), &[recipient_key], None)
470                .await
471                .map_err(|e| Error::Cryptography(format!("Failed to create JWE: {}", e)))?;
472
473            // Serialize the JWE
474            serde_json::to_string(&jwe).map_err(|e| Error::Serialization(e.to_string()))
475        }
476        SecurityMode::Any => {
477            // Any mode is not valid for packing, only for unpacking
478            Err(Error::Validation(
479                "SecurityMode::Any is not valid for packing".to_string(),
480            ))
481        }
482    }
483}
484
485/// Implement Unpackable for JWS
486#[async_trait]
487impl<T: DeserializeOwned + Send + 'static> Unpackable<Jws, T> for Jws {
488    async fn unpack(
489        packed_message: &Jws,
490        key_manager: &(impl KeyManagerPacking + ?Sized),
491        _options: UnpackOptions,
492    ) -> Result<T> {
493        // Decode the payload
494        let payload_bytes = base64::engine::general_purpose::STANDARD
495            .decode(&packed_message.payload)
496            .map_err(|e| Error::Cryptography(format!("Failed to decode JWS payload: {}", e)))?;
497
498        // Convert to string
499        let payload_str = String::from_utf8(payload_bytes)
500            .map_err(|e| Error::Validation(format!("Invalid UTF-8 in payload: {}", e)))?;
501
502        // Parse as PlainMessage first
503        let plain_message: PlainMessage =
504            serde_json::from_str(&payload_str).map_err(|e| Error::Serialization(e.to_string()))?;
505
506        // Verify signatures
507        let mut verified = false;
508
509        for signature in &packed_message.signatures {
510            // Decode the protected header
511            let protected_bytes = base64::engine::general_purpose::STANDARD
512                .decode(&signature.protected)
513                .map_err(|e| {
514                    Error::Cryptography(format!("Failed to decode protected header: {}", e))
515                })?;
516
517            // Parse the protected header
518            let protected: crate::message::JwsProtected = serde_json::from_slice(&protected_bytes)
519                .map_err(|e| {
520                    Error::Serialization(format!("Failed to parse protected header: {}", e))
521                })?;
522
523            // Get the key ID
524            let kid = &signature.header.kid;
525
526            // Resolve the verification key
527            let verification_key = match key_manager.resolve_verification_key(kid).await {
528                Ok(key) => key,
529                Err(_) => continue, // Skip key if we can't resolve it
530            };
531
532            // Decode the signature
533            let signature_bytes = base64::engine::general_purpose::STANDARD
534                .decode(&signature.signature)
535                .map_err(|e| Error::Cryptography(format!("Failed to decode signature: {}", e)))?;
536
537            // Create the signing input (protected.payload)
538            let signing_input = format!("{}.{}", signature.protected, packed_message.payload);
539
540            // Verify the signature
541            match verification_key
542                .verify_signature(signing_input.as_bytes(), &signature_bytes, &protected)
543                .await
544            {
545                Ok(true) => {
546                    verified = true;
547                    break;
548                }
549                _ => continue,
550            }
551        }
552
553        if !verified {
554            return Err(Error::Cryptography(
555                "Signature verification failed".to_string(),
556            ));
557        }
558
559        // If we want the PlainMessage itself, return it
560        if std::any::TypeId::of::<T>() == std::any::TypeId::of::<PlainMessage>() {
561            // This is safe because we've verified that T is PlainMessage
562            let result = serde_json::to_value(plain_message).unwrap();
563            return serde_json::from_value(result).map_err(|e| Error::Serialization(e.to_string()));
564        }
565
566        // Otherwise deserialize the body to the requested type
567        serde_json::from_value(plain_message.body).map_err(|e| Error::Serialization(e.to_string()))
568    }
569}
570
571/// Implement Unpackable for JWE
572#[async_trait]
573impl<T: DeserializeOwned + Send + 'static> Unpackable<Jwe, T> for Jwe {
574    async fn unpack(
575        packed_message: &Jwe,
576        key_manager: &(impl KeyManagerPacking + ?Sized),
577        options: UnpackOptions,
578    ) -> Result<T> {
579        // Find a recipient that matches our expected key, if any
580        let recipients = if let Some(kid) = &options.expected_recipient_kid {
581            // Filter to just the matching recipient
582            packed_message
583                .recipients
584                .iter()
585                .filter(|r| r.header.kid == *kid)
586                .collect::<Vec<_>>()
587        } else {
588            // Try all recipients
589            packed_message.recipients.iter().collect::<Vec<_>>()
590        };
591
592        // Try each recipient until we find one we can decrypt
593        for recipient in recipients {
594            // Get the recipient's key ID
595            let kid = &recipient.header.kid;
596
597            // Get the decryption key
598            let decryption_key = match key_manager.get_decryption_key(kid).await {
599                Ok(key) => key,
600                Err(_) => continue, // Skip if we don't have the key
601            };
602
603            // Try to decrypt
604            match decryption_key.unwrap_jwe(packed_message).await {
605                Ok(plaintext) => {
606                    // Convert to string
607                    let plaintext_str = String::from_utf8(plaintext).map_err(|e| {
608                        Error::Validation(format!("Invalid UTF-8 in plaintext: {}", e))
609                    })?;
610
611                    // Parse as PlainMessage
612                    let plain_message: PlainMessage = match serde_json::from_str(&plaintext_str) {
613                        Ok(msg) => msg,
614                        Err(e) => {
615                            return Err(Error::Serialization(e.to_string()));
616                        }
617                    };
618
619                    // If we want the PlainMessage itself, return it
620                    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<PlainMessage>() {
621                        // This is safe because we've verified that T is PlainMessage
622                        let result = serde_json::to_value(plain_message).unwrap();
623                        return serde_json::from_value(result)
624                            .map_err(|e| Error::Serialization(e.to_string()));
625                    }
626
627                    // Otherwise deserialize the body to the requested type
628                    return serde_json::from_value(plain_message.body)
629                        .map_err(|e| Error::Serialization(e.to_string()));
630                }
631                Err(_) => continue, // Try next recipient
632            }
633        }
634
635        // If we get here, we couldn't decrypt for any recipient
636        Err(Error::Cryptography("Failed to decrypt message".to_string()))
637    }
638}
639
640/// Implement Unpackable for String (to handle any packed format)
641#[async_trait]
642impl<T: DeserializeOwned + Send + 'static> Unpackable<String, T> for String {
643    async fn unpack(
644        packed_message: &String,
645        key_manager: &(impl KeyManagerPacking + ?Sized),
646        options: UnpackOptions,
647    ) -> Result<T> {
648        // Try to parse as JSON first
649        if let Ok(value) = serde_json::from_str::<Value>(packed_message) {
650            // Check if it's a JWS (has payload and signatures fields)
651            if value.get("payload").is_some() && value.get("signatures").is_some() {
652                // Parse as JWS
653                let jws: Jws = serde_json::from_str(packed_message)
654                    .map_err(|e| Error::Serialization(e.to_string()))?;
655
656                return Jws::unpack(&jws, key_manager, options).await;
657            }
658
659            // Check if it's a JWE (has ciphertext, protected, and recipients fields)
660            if value.get("ciphertext").is_some()
661                && value.get("protected").is_some()
662                && value.get("recipients").is_some()
663            {
664                // Parse as JWE
665                let jwe: Jwe = serde_json::from_str(packed_message)
666                    .map_err(|e| Error::Serialization(e.to_string()))?;
667
668                return Jwe::unpack(&jwe, key_manager, options).await;
669            }
670
671            // Check if it's a PlainMessage (has body and type fields)
672            if value.get("body").is_some() && value.get("type").is_some() {
673                // Parse as PlainMessage
674                let plain: PlainMessage = serde_json::from_str(packed_message)
675                    .map_err(|e| Error::Serialization(e.to_string()))?;
676
677                // If we want the PlainMessage itself, return it
678                if std::any::TypeId::of::<T>() == std::any::TypeId::of::<PlainMessage>() {
679                    // This is safe because we've verified that T is PlainMessage
680                    let result = serde_json::to_value(plain).unwrap();
681                    return serde_json::from_value(result)
682                        .map_err(|e| Error::Serialization(e.to_string()));
683                }
684
685                // Otherwise get the body
686                return serde_json::from_value(plain.body)
687                    .map_err(|e| Error::Serialization(e.to_string()));
688            }
689
690            // If it doesn't match any known format but is a valid JSON, try to parse directly
691            return serde_json::from_value(value).map_err(|e| Error::Serialization(e.to_string()));
692        }
693
694        // If not valid JSON, return an error
695        Err(Error::Validation("Message is not valid JSON".to_string()))
696    }
697}