Skip to main content

tap_agent/
message_packing.rs

1//! Message Packing and Unpacking Utilities
2//!
3//! This module provides traits and implementations for standardizing
4//! how messages are prepared for transmission (packed) and processed
5//! upon receipt (unpacked).
6
7use crate::agent_key::VerificationKey;
8use crate::error::{Error, Result};
9use crate::message::{Jwe, Jws, SecurityMode};
10use async_trait::async_trait;
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use serde_json::Value;
14use std::any::Any;
15use std::fmt::Debug;
16use std::sync::Arc;
17use tap_msg::didcomm::{PlainMessage, PlainMessageExt};
18use tap_msg::message::TapMessage;
19use uuid::Uuid;
20
21/// Result of unpacking a message containing both the PlainMessage
22/// and the parsed TAP message
23#[derive(Debug, Clone)]
24pub struct UnpackedMessage {
25    /// The unpacked PlainMessage
26    pub plain_message: PlainMessage,
27    /// The parsed TAP message (if it could be parsed)
28    pub tap_message: Option<TapMessage>,
29}
30
31impl UnpackedMessage {
32    /// Create a new UnpackedMessage
33    pub fn new(plain_message: PlainMessage) -> Self {
34        let tap_message = TapMessage::from_plain_message(&plain_message).ok();
35        Self {
36            plain_message,
37            tap_message,
38        }
39    }
40
41    /// Try to get the message as a specific typed message
42    pub fn as_typed<T: tap_msg::TapMessageBody>(&self) -> Result<PlainMessage<T>> {
43        self.plain_message
44            .clone()
45            .parse_as()
46            .map_err(|e| Error::Serialization(e.to_string()))
47    }
48
49    /// Convert to a typed message with untyped body
50    pub fn into_typed(self) -> PlainMessage<Value> {
51        self.plain_message.into_typed()
52    }
53}
54
55/// Error type specific to message packing and unpacking
56#[derive(Debug, thiserror::Error)]
57pub enum MessageError {
58    #[error("Serialization error: {0}")]
59    Serialization(#[from] serde_json::Error),
60
61    #[error("Key manager error: {0}")]
62    KeyManager(String),
63
64    #[error("Crypto operation failed: {0}")]
65    Crypto(String),
66
67    #[error("Invalid message format: {0}")]
68    InvalidFormat(String),
69
70    #[error("Unsupported security mode: {0:?}")]
71    UnsupportedSecurityMode(SecurityMode),
72
73    #[error("Missing required parameter: {0}")]
74    MissingParameter(String),
75
76    #[error("Key not found: {0}")]
77    KeyNotFound(String),
78
79    #[error("Verification failed")]
80    VerificationFailed,
81
82    #[error("Decryption failed")]
83    DecryptionFailed,
84}
85
86impl From<MessageError> for Error {
87    fn from(err: MessageError) -> Self {
88        match err {
89            MessageError::Serialization(e) => Error::Serialization(e.to_string()),
90            MessageError::KeyManager(e) => Error::Cryptography(e),
91            MessageError::Crypto(e) => Error::Cryptography(e),
92            MessageError::InvalidFormat(e) => Error::Validation(e),
93            MessageError::UnsupportedSecurityMode(mode) => {
94                Error::Validation(format!("Unsupported security mode: {:?}", mode))
95            }
96            MessageError::MissingParameter(e) => {
97                Error::Validation(format!("Missing parameter: {}", e))
98            }
99            MessageError::KeyNotFound(e) => Error::Cryptography(format!("Key not found: {}", e)),
100            MessageError::VerificationFailed => {
101                Error::Cryptography("Verification failed".to_string())
102            }
103            MessageError::DecryptionFailed => Error::Cryptography("Decryption failed".to_string()),
104        }
105    }
106}
107
108/// Options for packing a message
109#[derive(Debug, Clone)]
110pub struct PackOptions {
111    /// Security mode to use
112    pub security_mode: SecurityMode,
113    /// Key ID of the recipient (for JWE)
114    pub recipient_kid: Option<String>,
115    /// Key ID of the sender (for JWS and JWE)
116    pub sender_kid: Option<String>,
117}
118
119impl Default for PackOptions {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl PackOptions {
126    /// Create new default packing options
127    pub fn new() -> Self {
128        Self {
129            security_mode: SecurityMode::Plain,
130            recipient_kid: None,
131            sender_kid: None,
132        }
133    }
134
135    /// Set to use plain mode (no security)
136    pub fn with_plain(mut self) -> Self {
137        self.security_mode = SecurityMode::Plain;
138        self
139    }
140
141    /// Set to use signed mode with the given sender key ID
142    pub fn with_sign(mut self, sender_kid: &str) -> Self {
143        self.security_mode = SecurityMode::Signed;
144        self.sender_kid = Some(sender_kid.to_string());
145        self
146    }
147
148    /// Set to use auth-crypt mode with the given sender and recipient key IDs
149    pub fn with_auth_crypt(mut self, sender_kid: &str, recipient_jwk: &serde_json::Value) -> Self {
150        self.security_mode = SecurityMode::AuthCrypt;
151        self.sender_kid = Some(sender_kid.to_string());
152
153        // Extract kid from JWK if available
154        if let Some(kid) = recipient_jwk.get("kid").and_then(|k| k.as_str()) {
155            self.recipient_kid = Some(kid.to_string());
156        }
157
158        self
159    }
160
161    /// Get the security mode
162    pub fn security_mode(&self) -> SecurityMode {
163        self.security_mode
164    }
165}
166
167/// Options for unpacking a message
168#[derive(Debug, Clone)]
169pub struct UnpackOptions {
170    /// Expected security mode, or Any to try all modes
171    pub expected_security_mode: SecurityMode,
172    /// Expected recipient key ID
173    pub expected_recipient_kid: Option<String>,
174    /// Whether to require a valid signature
175    pub require_signature: bool,
176}
177
178impl Default for UnpackOptions {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184impl UnpackOptions {
185    /// Create new default unpacking options
186    pub fn new() -> Self {
187        Self {
188            expected_security_mode: SecurityMode::Any,
189            expected_recipient_kid: None,
190            require_signature: false,
191        }
192    }
193
194    /// Set whether to require a valid signature
195    pub fn with_require_signature(mut self, require: bool) -> Self {
196        self.require_signature = require;
197        self
198    }
199}
200
201/// Trait for objects that can be packed for secure transmission
202#[async_trait]
203pub trait Packable<Output = String>: Sized {
204    /// Pack the object for secure transmission
205    async fn pack(
206        &self,
207        key_manager: &(impl KeyManagerPacking + ?Sized),
208        options: PackOptions,
209    ) -> Result<Output>;
210}
211
212/// Trait for objects that can be unpacked from a secure format
213#[async_trait]
214pub trait Unpackable<Input, Output = PlainMessage>: Sized {
215    /// Unpack the object from its secure format
216    async fn unpack(
217        packed_message: &Input,
218        key_manager: &(impl KeyManagerPacking + ?Sized),
219        options: UnpackOptions,
220    ) -> Result<Output>;
221}
222
223/// Interface required for key managers to support packing/unpacking
224#[async_trait]
225pub trait KeyManagerPacking: Send + Sync + Debug {
226    /// Get a signing key by ID
227    async fn get_signing_key(
228        &self,
229        kid: &str,
230    ) -> Result<Arc<dyn crate::agent_key::SigningKey + Send + Sync>>;
231
232    /// Get an encryption key by ID
233    async fn get_encryption_key(
234        &self,
235        kid: &str,
236    ) -> Result<Arc<dyn crate::agent_key::EncryptionKey + Send + Sync>>;
237
238    /// Get a decryption key by ID
239    async fn get_decryption_key(
240        &self,
241        kid: &str,
242    ) -> Result<Arc<dyn crate::agent_key::DecryptionKey + Send + Sync>>;
243
244    /// Resolve a verification key
245    async fn resolve_verification_key(
246        &self,
247        kid: &str,
248    ) -> Result<Arc<dyn VerificationKey + Send + Sync>>;
249}
250
251/// Implement Packable for PlainMessage
252#[async_trait]
253impl Packable for PlainMessage {
254    async fn pack(
255        &self,
256        key_manager: &(impl KeyManagerPacking + ?Sized),
257        options: PackOptions,
258    ) -> Result<String> {
259        match options.security_mode {
260            SecurityMode::Plain => {
261                // For plain mode, just serialize the PlainMessage
262                serde_json::to_string(self).map_err(|e| Error::Serialization(e.to_string()))
263            }
264            SecurityMode::Signed => {
265                // Signed mode requires a sender KID
266                let sender_kid = options.sender_kid.clone().ok_or_else(|| {
267                    Error::Validation("Signed mode requires sender_kid".to_string())
268                })?;
269
270                // Get the signing key
271                let signing_key = key_manager.get_signing_key(&sender_kid).await?;
272
273                // Prepare the message payload to sign
274                let payload =
275                    serde_json::to_string(self).map_err(|e| Error::Serialization(e.to_string()))?;
276
277                // Create protected header with the sender_kid
278                let protected_header = crate::message::JwsProtected {
279                    typ: crate::message::DIDCOMM_SIGNED.to_string(),
280                    alg: String::new(), // Will be set by create_jws based on key type
281                    kid: sender_kid.clone(),
282                };
283
284                // Create a JWS
285                let jws = signing_key
286                    .create_jws(payload.as_bytes(), Some(protected_header))
287                    .await
288                    .map_err(|e| Error::Cryptography(format!("Failed to create JWS: {}", e)))?;
289
290                // Serialize the JWS
291                serde_json::to_string(&jws).map_err(|e| Error::Serialization(e.to_string()))
292            }
293            SecurityMode::AuthCrypt => {
294                // AuthCrypt mode requires both sender and recipient KIDs
295                let sender_kid = options.sender_kid.clone().ok_or_else(|| {
296                    Error::Validation("AuthCrypt mode requires sender_kid".to_string())
297                })?;
298
299                let recipient_kid = options.recipient_kid.clone().ok_or_else(|| {
300                    Error::Validation("AuthCrypt mode requires recipient_kid".to_string())
301                })?;
302
303                // Get the encryption key
304                let encryption_key = key_manager.get_encryption_key(&sender_kid).await?;
305
306                // Get the recipient's verification key
307                let recipient_key = key_manager.resolve_verification_key(&recipient_kid).await?;
308
309                // Serialize the message
310                let plaintext =
311                    serde_json::to_string(self).map_err(|e| Error::Serialization(e.to_string()))?;
312
313                // Create a JWE for the recipient
314                let jwe = encryption_key
315                    .create_jwe(plaintext.as_bytes(), &[recipient_key], None)
316                    .await
317                    .map_err(|e| Error::Cryptography(format!("Failed to create JWE: {}", e)))?;
318
319                // Serialize the JWE
320                serde_json::to_string(&jwe).map_err(|e| Error::Serialization(e.to_string()))
321            }
322            SecurityMode::AnonCrypt => {
323                // AnonCrypt mode requires only recipient KID (sender is anonymous)
324                let recipient_kid = options.recipient_kid.clone().ok_or_else(|| {
325                    Error::Validation("AnonCrypt mode requires recipient_kid".to_string())
326                })?;
327
328                // We need some key for encryption - use the first available key if no sender specified
329                let encryption_key = if let Some(sender_kid) = &options.sender_kid {
330                    key_manager.get_encryption_key(sender_kid).await?
331                } else {
332                    // For anonymous encryption, we can use any available encryption key
333                    // In practice, this might need to be handled differently depending on requirements
334                    return Err(Error::Validation(
335                        "AnonCrypt mode requires a temporary encryption key".to_string(),
336                    ));
337                };
338
339                // Get the recipient's verification key
340                let recipient_key = key_manager.resolve_verification_key(&recipient_kid).await?;
341
342                // Serialize the message
343                let plaintext =
344                    serde_json::to_string(self).map_err(|e| Error::Serialization(e.to_string()))?;
345
346                // Create a JWE for the recipient without sender information
347                let jwe = encryption_key
348                    .create_jwe(plaintext.as_bytes(), &[recipient_key], None)
349                    .await
350                    .map_err(|e| Error::Cryptography(format!("Failed to create JWE: {}", e)))?;
351
352                // Serialize the JWE
353                serde_json::to_string(&jwe).map_err(|e| Error::Serialization(e.to_string()))
354            }
355            SecurityMode::Any => {
356                // Any mode is not valid for packing, only for unpacking
357                Err(Error::Validation(
358                    "SecurityMode::Any is not valid for packing".to_string(),
359                ))
360            }
361        }
362    }
363}
364
365/// We can't implement Packable for all types due to the conflict with PlainMessage
366/// Instead, let's create a helper function:
367pub async fn pack_any<T>(
368    obj: &T,
369    key_manager: &(impl KeyManagerPacking + ?Sized),
370    options: PackOptions,
371) -> Result<String>
372where
373    T: Serialize + Send + Sync + std::fmt::Debug + 'static + Sized,
374{
375    // Skip attempt to implement Packable for generic types and use a helper function instead
376
377    // If the object is a PlainMessage, use PlainMessage's implementation
378    if obj.type_id() == std::any::TypeId::of::<PlainMessage>() {
379        // In this case, we can't easily downcast, so we'll serialize and deserialize
380        let value = serde_json::to_value(obj).map_err(|e| Error::Serialization(e.to_string()))?;
381        let plain_msg: PlainMessage =
382            serde_json::from_value(value).map_err(|e| Error::Serialization(e.to_string()))?;
383        return plain_msg.pack(key_manager, options).await;
384    }
385
386    // Otherwise, implement the same logic here as in the PlainMessage implementation
387    match options.security_mode {
388        SecurityMode::Plain => {
389            // For plain mode, just serialize the object to JSON
390            serde_json::to_string(obj).map_err(|e| Error::Serialization(e.to_string()))
391        }
392        SecurityMode::Signed => {
393            // Signed mode requires a sender KID
394            let sender_kid = options
395                .sender_kid
396                .clone()
397                .ok_or_else(|| Error::Validation("Signed mode requires sender_kid".to_string()))?;
398
399            // Get the signing key
400            let signing_key = key_manager.get_signing_key(&sender_kid).await?;
401
402            // Convert to a Value first
403            let value =
404                serde_json::to_value(obj).map_err(|e| Error::Serialization(e.to_string()))?;
405
406            // Ensure it's an object
407            let obj = value
408                .as_object()
409                .ok_or_else(|| Error::Validation("Message is not a JSON object".to_string()))?;
410
411            // Extract ID, or generate one if missing
412            let id_string = obj
413                .get("id")
414                .map(|v| v.as_str().unwrap_or_default().to_string())
415                .unwrap_or_else(|| Uuid::new_v4().to_string());
416            let id = id_string.as_str();
417
418            // Extract type, or use default
419            let msg_type = obj
420                .get("type")
421                .and_then(|v| v.as_str())
422                .unwrap_or("https://tap.rsvp/schema/1.0/message");
423
424            // Create sender/recipient lists
425            let from = options.sender_kid.as_ref().map(|kid| {
426                // Extract DID part from kid (assuming format is did#key-1)
427                kid.split('#').next().unwrap_or(kid).to_string()
428            });
429
430            let to = if let Some(kid) = &options.recipient_kid {
431                // Extract DID part from kid
432                let did = kid.split('#').next().unwrap_or(kid).to_string();
433                vec![did]
434            } else {
435                vec![]
436            };
437
438            // Create a PlainMessage
439            let plain_message = PlainMessage {
440                id: id.to_string(),
441                typ: "application/didcomm-plain+json".to_string(),
442                type_: msg_type.to_string(),
443                body: value,
444                from: from.unwrap_or_default(),
445                to,
446                thid: None,
447                pthid: None,
448                created_time: Some(chrono::Utc::now().timestamp() as u64),
449                expires_time: None,
450                from_prior: None,
451                attachments: None,
452                extra_headers: std::collections::HashMap::new(),
453            };
454
455            // Prepare the message payload to sign
456            let payload = serde_json::to_string(&plain_message)
457                .map_err(|e| Error::Serialization(e.to_string()))?;
458
459            // Create protected header with the sender_kid
460            let protected_header = crate::message::JwsProtected {
461                typ: crate::message::DIDCOMM_SIGNED.to_string(),
462                alg: String::new(), // Will be set by create_jws based on key type
463                kid: sender_kid.clone(),
464            };
465
466            // Create a JWS
467            let jws = signing_key
468                .create_jws(payload.as_bytes(), Some(protected_header))
469                .await
470                .map_err(|e| Error::Cryptography(format!("Failed to create JWS: {}", e)))?;
471
472            // Serialize the JWS
473            serde_json::to_string(&jws).map_err(|e| Error::Serialization(e.to_string()))
474        }
475        SecurityMode::AuthCrypt => {
476            // AuthCrypt mode requires both sender and recipient KIDs
477            let sender_kid = options.sender_kid.clone().ok_or_else(|| {
478                Error::Validation("AuthCrypt mode requires sender_kid".to_string())
479            })?;
480
481            let recipient_kid = options.recipient_kid.clone().ok_or_else(|| {
482                Error::Validation("AuthCrypt mode requires recipient_kid".to_string())
483            })?;
484
485            // Get the encryption key
486            let encryption_key = key_manager.get_encryption_key(&sender_kid).await?;
487
488            // Get the recipient's verification key
489            let recipient_key = key_manager.resolve_verification_key(&recipient_kid).await?;
490
491            // Convert to a Value first
492            let value =
493                serde_json::to_value(obj).map_err(|e| Error::Serialization(e.to_string()))?;
494
495            // Ensure it's an object
496            let obj = value
497                .as_object()
498                .ok_or_else(|| Error::Validation("Message is not a JSON object".to_string()))?;
499
500            // Extract ID, or generate one if missing
501            let id_string = obj
502                .get("id")
503                .map(|v| v.as_str().unwrap_or_default().to_string())
504                .unwrap_or_else(|| Uuid::new_v4().to_string());
505            let id = id_string.as_str();
506
507            // Extract type, or use default
508            let msg_type = obj
509                .get("type")
510                .and_then(|v| v.as_str())
511                .unwrap_or("https://tap.rsvp/schema/1.0/message");
512
513            // Create sender/recipient lists
514            let from = options.sender_kid.as_ref().map(|kid| {
515                // Extract DID part from kid (assuming format is did#key-1)
516                kid.split('#').next().unwrap_or(kid).to_string()
517            });
518
519            let to = if let Some(kid) = &options.recipient_kid {
520                // Extract DID part from kid
521                let did = kid.split('#').next().unwrap_or(kid).to_string();
522                vec![did]
523            } else {
524                vec![]
525            };
526
527            // Create a PlainMessage
528            let plain_message = PlainMessage {
529                id: id.to_string(),
530                typ: "application/didcomm-plain+json".to_string(),
531                type_: msg_type.to_string(),
532                body: value,
533                from: from.unwrap_or_default(),
534                to,
535                thid: None,
536                pthid: None,
537                created_time: Some(chrono::Utc::now().timestamp() as u64),
538                expires_time: None,
539                from_prior: None,
540                attachments: None,
541                extra_headers: std::collections::HashMap::new(),
542            };
543
544            // Serialize the message
545            let plaintext = serde_json::to_string(&plain_message)
546                .map_err(|e| Error::Serialization(e.to_string()))?;
547
548            // Create a JWE for the recipient
549            let jwe = encryption_key
550                .create_jwe(plaintext.as_bytes(), &[recipient_key], None)
551                .await
552                .map_err(|e| Error::Cryptography(format!("Failed to create JWE: {}", e)))?;
553
554            // Serialize the JWE
555            serde_json::to_string(&jwe).map_err(|e| Error::Serialization(e.to_string()))
556        }
557        SecurityMode::AnonCrypt => {
558            // AnonCrypt mode requires only recipient KID (sender is anonymous)
559            let recipient_kid = options.recipient_kid.clone().ok_or_else(|| {
560                Error::Validation("AnonCrypt mode requires recipient_kid".to_string())
561            })?;
562
563            // We need some key for encryption - use the first available key if no sender specified
564            let encryption_key = if let Some(sender_kid) = &options.sender_kid {
565                key_manager.get_encryption_key(sender_kid).await?
566            } else {
567                // For anonymous encryption, we can use any available encryption key
568                return Err(Error::Validation(
569                    "AnonCrypt mode requires a temporary encryption key".to_string(),
570                ));
571            };
572
573            // Get the recipient's verification key
574            let recipient_key = key_manager.resolve_verification_key(&recipient_kid).await?;
575
576            // Convert to a Value first and create a PlainMessage (similar to AuthCrypt)
577            let value =
578                serde_json::to_value(obj).map_err(|e| Error::Serialization(e.to_string()))?;
579
580            let obj = value
581                .as_object()
582                .ok_or_else(|| Error::Validation("Message is not a JSON object".to_string()))?;
583
584            let id_string = obj
585                .get("id")
586                .map(|v| v.as_str().unwrap_or_default().to_string())
587                .unwrap_or_else(|| Uuid::new_v4().to_string());
588
589            let msg_type = obj
590                .get("type")
591                .and_then(|v| v.as_str())
592                .unwrap_or("https://tap.rsvp/schema/1.0/message");
593
594            let to = if let Some(kid) = &options.recipient_kid {
595                let did = kid.split('#').next().unwrap_or(kid).to_string();
596                vec![did]
597            } else {
598                vec![]
599            };
600
601            // Create a PlainMessage (no sender info for anonymous)
602            let plain_message = PlainMessage {
603                id: id_string,
604                typ: "application/didcomm-plain+json".to_string(),
605                type_: msg_type.to_string(),
606                body: value,
607                from: String::new(), // Anonymous - no sender
608                to,
609                thid: None,
610                pthid: None,
611                created_time: Some(chrono::Utc::now().timestamp() as u64),
612                expires_time: None,
613                from_prior: None,
614                attachments: None,
615                extra_headers: std::collections::HashMap::new(),
616            };
617
618            // Serialize the message
619            let plaintext = serde_json::to_string(&plain_message)
620                .map_err(|e| Error::Serialization(e.to_string()))?;
621
622            // Create a JWE for the recipient without sender information
623            let jwe = encryption_key
624                .create_jwe(plaintext.as_bytes(), &[recipient_key], None)
625                .await
626                .map_err(|e| Error::Cryptography(format!("Failed to create JWE: {}", e)))?;
627
628            // Serialize the JWE
629            serde_json::to_string(&jwe).map_err(|e| Error::Serialization(e.to_string()))
630        }
631        SecurityMode::Any => {
632            // Any mode is not valid for packing, only for unpacking
633            Err(Error::Validation(
634                "SecurityMode::Any is not valid for packing".to_string(),
635            ))
636        }
637    }
638}
639
640/// Implement Unpackable for JWS
641#[async_trait]
642impl<T: DeserializeOwned + Send + 'static> Unpackable<Jws, T> for Jws {
643    async fn unpack(
644        packed_message: &Jws,
645        key_manager: &(impl KeyManagerPacking + ?Sized),
646        _options: UnpackOptions,
647    ) -> Result<T> {
648        // Decode the payload (accept both base64 and base64url)
649        let payload_bytes = crate::message::base64_decode_flexible(&packed_message.payload)
650            .map_err(|e| Error::Cryptography(format!("Failed to decode JWS payload: {}", e)))?;
651
652        // Convert to string
653        let payload_str = String::from_utf8(payload_bytes)
654            .map_err(|e| Error::Validation(format!("Invalid UTF-8 in payload: {}", e)))?;
655
656        // Parse as PlainMessage first
657        let plain_message: PlainMessage =
658            serde_json::from_str(&payload_str).map_err(|e| Error::Serialization(e.to_string()))?;
659
660        // Verify signatures
661        let mut verified = false;
662
663        for signature in &packed_message.signatures {
664            // Decode the protected header (accept both base64 and base64url)
665            let protected_bytes = crate::message::base64_decode_flexible(&signature.protected)
666                .map_err(|e| {
667                    Error::Cryptography(format!("Failed to decode protected header: {}", e))
668                })?;
669
670            // Parse the protected header
671            let protected: crate::message::JwsProtected = serde_json::from_slice(&protected_bytes)
672                .map_err(|e| {
673                    Error::Serialization(format!("Failed to parse protected header: {}", e))
674                })?;
675
676            // Get the key ID from protected header
677            let kid = match signature.get_kid() {
678                Some(kid) => kid,
679                None => continue, // Skip if no kid found
680            };
681
682            // Resolve the verification key
683            let verification_key = match key_manager.resolve_verification_key(&kid).await {
684                Ok(key) => key,
685                Err(_) => continue, // Skip key if we can't resolve it
686            };
687
688            // Decode the signature (accept both base64 and base64url)
689            let signature_bytes = crate::message::base64_decode_flexible(&signature.signature)
690                .map_err(|e| Error::Cryptography(format!("Failed to decode signature: {}", e)))?;
691
692            // Create the signing input (protected.payload)
693            let signing_input = format!("{}.{}", signature.protected, packed_message.payload);
694
695            // Verify the signature
696            match verification_key
697                .verify_signature(signing_input.as_bytes(), &signature_bytes, &protected)
698                .await
699            {
700                Ok(true) => {
701                    verified = true;
702                    break;
703                }
704                _ => continue,
705            }
706        }
707
708        if !verified {
709            return Err(Error::Cryptography(
710                "Signature verification failed".to_string(),
711            ));
712        }
713
714        // If we want the PlainMessage itself, return it
715        if std::any::TypeId::of::<T>() == std::any::TypeId::of::<PlainMessage>() {
716            // This is safe because we've verified that T is PlainMessage
717            let result = serde_json::to_value(plain_message).unwrap();
718            return serde_json::from_value(result).map_err(|e| Error::Serialization(e.to_string()));
719        }
720
721        // Otherwise deserialize the body to the requested type
722        serde_json::from_value(plain_message.body).map_err(|e| Error::Serialization(e.to_string()))
723    }
724}
725
726/// Implement Unpackable for JWE
727#[async_trait]
728impl<T: DeserializeOwned + Send + 'static> Unpackable<Jwe, T> for Jwe {
729    async fn unpack(
730        packed_message: &Jwe,
731        key_manager: &(impl KeyManagerPacking + ?Sized),
732        options: UnpackOptions,
733    ) -> Result<T> {
734        // Find a recipient that matches our expected key, if any
735        let recipients = if let Some(kid) = &options.expected_recipient_kid {
736            // Filter to just the matching recipient
737            packed_message
738                .recipients
739                .iter()
740                .filter(|r| r.header.kid == *kid)
741                .collect::<Vec<_>>()
742        } else {
743            // Try all recipients
744            packed_message.recipients.iter().collect::<Vec<_>>()
745        };
746
747        // Try each recipient until we find one we can decrypt
748        let mut last_error = None;
749        for recipient in recipients {
750            // Get the recipient's key ID
751            let kid = &recipient.header.kid;
752
753            // Get the decryption key
754            let decryption_key = match key_manager.get_decryption_key(kid).await {
755                Ok(key) => key,
756                Err(e) => {
757                    last_error = Some(format!("Key lookup failed for {}: {}", kid, e));
758                    continue;
759                }
760            };
761
762            // Try to decrypt
763            match decryption_key.unwrap_jwe(packed_message).await {
764                Ok(plaintext) => {
765                    // Convert to string
766                    let plaintext_str = String::from_utf8(plaintext).map_err(|e| {
767                        Error::Validation(format!("Invalid UTF-8 in plaintext: {}", e))
768                    })?;
769
770                    // Parse as PlainMessage
771                    let plain_message: PlainMessage = match serde_json::from_str(&plaintext_str) {
772                        Ok(msg) => msg,
773                        Err(e) => {
774                            return Err(Error::Serialization(e.to_string()));
775                        }
776                    };
777
778                    // If we want the PlainMessage itself, return it
779                    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<PlainMessage>() {
780                        // This is safe because we've verified that T is PlainMessage
781                        let result = serde_json::to_value(plain_message).unwrap();
782                        return serde_json::from_value(result)
783                            .map_err(|e| Error::Serialization(e.to_string()));
784                    }
785
786                    // Otherwise deserialize the body to the requested type
787                    return serde_json::from_value(plain_message.body)
788                        .map_err(|e| Error::Serialization(e.to_string()));
789                }
790                Err(e) => {
791                    last_error = Some(format!("Decryption failed for {}: {}", kid, e));
792                    continue;
793                }
794            }
795        }
796
797        // If we get here, we couldn't decrypt for any recipient
798        Err(Error::Cryptography(format!(
799            "Failed to decrypt JWE for any of {} recipients{}",
800            packed_message.recipients.len(),
801            last_error.map(|e| format!(": {}", e)).unwrap_or_default()
802        )))
803    }
804}
805
806/// Implement Unpackable for String (to handle any packed format)
807#[async_trait]
808impl<T: DeserializeOwned + Send + 'static> Unpackable<String, T> for String {
809    async fn unpack(
810        packed_message: &String,
811        key_manager: &(impl KeyManagerPacking + ?Sized),
812        options: UnpackOptions,
813    ) -> Result<T> {
814        // Try to parse as JSON first
815        if let Ok(value) = serde_json::from_str::<Value>(packed_message) {
816            // Check if it's a JWS (General or Flattened serialization)
817            // General: has "payload" + "signatures" array
818            // Flattened: has "payload" + "signature" + "protected"
819            if value.get("payload").is_some()
820                && (value.get("signatures").is_some() || value.get("signature").is_some())
821            {
822                // Jws custom Deserialize handles both General and Flattened formats
823                let jws: Jws = serde_json::from_str(packed_message)
824                    .map_err(|e| Error::Serialization(e.to_string()))?;
825
826                return Jws::unpack(&jws, key_manager, options).await;
827            }
828
829            // Check if it's a JWE (has ciphertext, protected, and recipients fields)
830            if value.get("ciphertext").is_some()
831                && value.get("protected").is_some()
832                && value.get("recipients").is_some()
833            {
834                // Parse as JWE
835                let jwe: Jwe = serde_json::from_str(packed_message)
836                    .map_err(|e| Error::Serialization(e.to_string()))?;
837
838                return Jwe::unpack(&jwe, key_manager, options).await;
839            }
840
841            // Check if it's a PlainMessage (has body and type fields)
842            if value.get("body").is_some() && value.get("type").is_some() {
843                // Parse as PlainMessage
844                let plain: PlainMessage = serde_json::from_str(packed_message)
845                    .map_err(|e| Error::Serialization(e.to_string()))?;
846
847                // If we want the PlainMessage itself, return it
848                if std::any::TypeId::of::<T>() == std::any::TypeId::of::<PlainMessage>() {
849                    // This is safe because we've verified that T is PlainMessage
850                    let result = serde_json::to_value(plain).unwrap();
851                    return serde_json::from_value(result)
852                        .map_err(|e| Error::Serialization(e.to_string()));
853                }
854
855                // Otherwise get the body
856                return serde_json::from_value(plain.body)
857                    .map_err(|e| Error::Serialization(e.to_string()));
858            }
859
860            // If it doesn't match any known format but is a valid JSON, try to parse directly
861            return serde_json::from_value(value).map_err(|e| Error::Serialization(e.to_string()));
862        }
863
864        // If not valid JSON, return an error
865        Err(Error::Validation("Message is not valid JSON".to_string()))
866    }
867}
868
869/// Implement Unpackable for String to UnpackedMessage
870#[async_trait]
871impl Unpackable<String, UnpackedMessage> for String {
872    async fn unpack(
873        packed_message: &String,
874        key_manager: &(impl KeyManagerPacking + ?Sized),
875        options: UnpackOptions,
876    ) -> Result<UnpackedMessage> {
877        // First unpack to PlainMessage
878        let plain_message: PlainMessage =
879            String::unpack(packed_message, key_manager, options).await?;
880
881        // Then create UnpackedMessage which will try to parse the TAP message
882        Ok(UnpackedMessage::new(plain_message))
883    }
884}
885
886/// Implement Unpackable for JWS to UnpackedMessage
887#[async_trait]
888impl Unpackable<Jws, UnpackedMessage> for Jws {
889    async fn unpack(
890        packed_message: &Jws,
891        key_manager: &(impl KeyManagerPacking + ?Sized),
892        options: UnpackOptions,
893    ) -> Result<UnpackedMessage> {
894        // First unpack to PlainMessage
895        let plain_message: PlainMessage = Jws::unpack(packed_message, key_manager, options).await?;
896
897        // Then create UnpackedMessage which will try to parse the TAP message
898        Ok(UnpackedMessage::new(plain_message))
899    }
900}
901
902/// Implement Unpackable for JWE to UnpackedMessage
903#[async_trait]
904impl Unpackable<Jwe, UnpackedMessage> for Jwe {
905    async fn unpack(
906        packed_message: &Jwe,
907        key_manager: &(impl KeyManagerPacking + ?Sized),
908        options: UnpackOptions,
909    ) -> Result<UnpackedMessage> {
910        // First unpack to PlainMessage
911        let plain_message: PlainMessage = Jwe::unpack(packed_message, key_manager, options).await?;
912
913        // Then create UnpackedMessage which will try to parse the TAP message
914        Ok(UnpackedMessage::new(plain_message))
915    }
916}
917#[cfg(test)]
918mod tests {
919    use super::*;
920    use crate::agent_key_manager::AgentKeyManagerBuilder;
921    use crate::did::{DIDGenerationOptions, KeyType};
922    use crate::key_manager::KeyManager;
923    use std::sync::Arc;
924    use tap_msg::didcomm::PlainMessage;
925    use tap_msg::message::agent::TapParticipant;
926
927    #[tokio::test]
928    async fn test_plain_message_pack_unpack() {
929        // Create a key manager with a test key
930        let key_manager = Arc::new(AgentKeyManagerBuilder::new().build().unwrap());
931        let key = key_manager
932            .generate_key(DIDGenerationOptions {
933                key_type: KeyType::Ed25519,
934            })
935            .unwrap();
936
937        // Create a test message
938        let message = PlainMessage {
939            id: "test-message-1".to_string(),
940            typ: "application/didcomm-plain+json".to_string(),
941            type_: "https://example.org/test".to_string(),
942            body: serde_json::json!({
943                "content": "Hello, World!"
944            }),
945            from: key.did.clone(),
946            to: vec!["did:example:bob".to_string()],
947            thid: None,
948            pthid: None,
949            created_time: Some(1234567890),
950            expires_time: None,
951            from_prior: None,
952            attachments: None,
953            extra_headers: Default::default(),
954        };
955
956        // Pack in plain mode
957        let pack_options = PackOptions::new().with_plain();
958        let packed = message.pack(&*key_manager, pack_options).await.unwrap();
959
960        // Unpack
961        let unpack_options = UnpackOptions::new();
962        let unpacked: PlainMessage = String::unpack(&packed, &*key_manager, unpack_options)
963            .await
964            .unwrap();
965
966        // Verify
967        assert_eq!(unpacked.id, message.id);
968        assert_eq!(unpacked.type_, message.type_);
969        assert_eq!(unpacked.body, message.body);
970        assert_eq!(unpacked.from, message.from);
971        assert_eq!(unpacked.to, message.to);
972    }
973
974    #[tokio::test]
975    async fn test_jws_message_pack_unpack() {
976        // Create a key manager with a test key
977        let key_manager = Arc::new(AgentKeyManagerBuilder::new().build().unwrap());
978        let key = key_manager
979            .generate_key(DIDGenerationOptions {
980                key_type: KeyType::Ed25519,
981            })
982            .unwrap();
983
984        // Get the actual verification method ID from the DID document
985        let sender_kid = key.did_doc.verification_method[0].id.clone();
986
987        // Create a test message
988        let message = PlainMessage {
989            id: "test-message-2".to_string(),
990            typ: "application/didcomm-plain+json".to_string(),
991            type_: "https://example.org/test".to_string(),
992            body: serde_json::json!({
993                "content": "Signed message"
994            }),
995            from: key.did.clone(),
996            to: vec!["did:example:bob".to_string()],
997            thid: None,
998            pthid: None,
999            created_time: Some(1234567890),
1000            expires_time: None,
1001            from_prior: None,
1002            attachments: None,
1003            extra_headers: Default::default(),
1004        };
1005
1006        // Pack with signing
1007        let pack_options = PackOptions::new().with_sign(&sender_kid);
1008        let packed = message.pack(&*key_manager, pack_options).await.unwrap();
1009
1010        // Verify it's a JWS
1011        let jws: Jws = serde_json::from_str(&packed).unwrap();
1012        assert!(!jws.signatures.is_empty());
1013
1014        // Check the protected header has the correct kid
1015        let protected_header = jws.signatures[0].get_protected_header().unwrap();
1016        assert_eq!(protected_header.kid, sender_kid);
1017        assert_eq!(protected_header.typ, "application/didcomm-signed+json");
1018        assert_eq!(protected_header.alg, "EdDSA");
1019
1020        // Unpack
1021        let unpack_options = UnpackOptions::new();
1022        let unpacked: PlainMessage = String::unpack(&packed, &*key_manager, unpack_options)
1023            .await
1024            .unwrap();
1025
1026        // Verify
1027        assert_eq!(unpacked.id, message.id);
1028        assert_eq!(unpacked.type_, message.type_);
1029        assert_eq!(unpacked.body, message.body);
1030        assert_eq!(unpacked.from, message.from);
1031        assert_eq!(unpacked.to, message.to);
1032    }
1033
1034    #[tokio::test]
1035    async fn test_different_key_types_jws() {
1036        // Test with different key types
1037        let key_types = vec![KeyType::Ed25519];
1038
1039        for key_type in key_types {
1040            // Create a key manager with a test key
1041            let key_manager = Arc::new(AgentKeyManagerBuilder::new().build().unwrap());
1042            let key = key_manager
1043                .generate_key(DIDGenerationOptions { key_type })
1044                .unwrap();
1045
1046            // Get the actual verification method ID from the DID document
1047            let sender_kid = key.did_doc.verification_method[0].id.clone();
1048
1049            // Create a test message
1050            let message = PlainMessage {
1051                id: format!("test-{:?}", key_type),
1052                typ: "application/didcomm-plain+json".to_string(),
1053                type_: "https://example.org/test".to_string(),
1054                body: serde_json::json!({
1055                    "content": format!("Signed with {:?}", key_type)
1056                }),
1057                from: key.did.clone(),
1058                to: vec!["did:example:bob".to_string()],
1059                thid: None,
1060                pthid: None,
1061                created_time: Some(1234567890),
1062                expires_time: None,
1063                from_prior: None,
1064                attachments: None,
1065                extra_headers: Default::default(),
1066            };
1067
1068            // Pack with signing
1069            let pack_options = PackOptions::new().with_sign(&sender_kid);
1070            let packed = message.pack(&*key_manager, pack_options).await.unwrap();
1071
1072            // Verify it's a JWS
1073            let jws: Jws = serde_json::from_str(&packed).unwrap();
1074            assert!(!jws.signatures.is_empty());
1075
1076            // Check the protected header
1077            let protected_header = jws.signatures[0].get_protected_header().unwrap();
1078            assert_eq!(protected_header.kid, sender_kid);
1079
1080            // Check algorithm matches key type
1081            let expected_alg = match key_type {
1082                #[cfg(feature = "crypto-ed25519")]
1083                KeyType::Ed25519 => "EdDSA",
1084                #[cfg(feature = "crypto-p256")]
1085                KeyType::P256 => "ES256",
1086                #[cfg(feature = "crypto-secp256k1")]
1087                KeyType::Secp256k1 => "ES256K",
1088            };
1089            assert_eq!(protected_header.alg, expected_alg);
1090
1091            // Unpack and verify
1092            let unpack_options = UnpackOptions::new();
1093            let unpacked: PlainMessage = String::unpack(&packed, &*key_manager, unpack_options)
1094                .await
1095                .unwrap();
1096
1097            assert_eq!(unpacked.id, message.id);
1098            assert_eq!(unpacked.body, message.body);
1099        }
1100    }
1101
1102    #[tokio::test]
1103    async fn test_unpack_with_wrong_signature() {
1104        // Create a key manager and sign a message
1105        let key_manager1 = Arc::new(AgentKeyManagerBuilder::new().build().unwrap());
1106        let key1 = key_manager1
1107            .generate_key(DIDGenerationOptions {
1108                key_type: KeyType::Ed25519,
1109            })
1110            .unwrap();
1111
1112        let message = PlainMessage {
1113            id: "test-wrong-sig".to_string(),
1114            typ: "application/didcomm-plain+json".to_string(),
1115            type_: "https://example.org/test".to_string(),
1116            body: serde_json::json!({
1117                "content": "Test wrong signature"
1118            }),
1119            from: key1.did.clone(),
1120            to: vec!["did:example:bob".to_string()],
1121            thid: None,
1122            pthid: None,
1123            created_time: Some(1234567890),
1124            expires_time: None,
1125            from_prior: None,
1126            attachments: None,
1127            extra_headers: Default::default(),
1128        };
1129
1130        let sender_kid = key1.did_doc.verification_method[0].id.clone();
1131        let pack_options = PackOptions::new().with_sign(&sender_kid);
1132        let packed = message.pack(&*key_manager1, pack_options).await.unwrap();
1133
1134        // Tamper with the signature to make it invalid
1135        let mut jws: crate::message::Jws = serde_json::from_str(&packed).unwrap();
1136        // Corrupt the signature bytes
1137        jws.signatures[0].signature = "AAAA_invalid_signature_AAAA".to_string();
1138        let tampered = serde_json::to_string(&jws).unwrap();
1139
1140        // Try to unpack tampered message (should fail verification)
1141        let unpack_options = UnpackOptions::new();
1142        let result: Result<PlainMessage> =
1143            String::unpack(&tampered, &*key_manager1, unpack_options).await;
1144
1145        assert!(result.is_err());
1146    }
1147
1148    #[tokio::test]
1149    async fn test_unpack_cross_agent_with_did_key() {
1150        // Verify that did:key resolution allows cross-agent verification
1151        let key_manager1 = Arc::new(AgentKeyManagerBuilder::new().build().unwrap());
1152        let key1 = key_manager1
1153            .generate_key(DIDGenerationOptions {
1154                key_type: KeyType::Ed25519,
1155            })
1156            .unwrap();
1157
1158        let key_manager2 = Arc::new(AgentKeyManagerBuilder::new().build().unwrap());
1159        let _key2 = key_manager2
1160            .generate_key(DIDGenerationOptions {
1161                key_type: KeyType::Ed25519,
1162            })
1163            .unwrap();
1164
1165        let message = PlainMessage {
1166            id: "test-cross-agent".to_string(),
1167            typ: "application/didcomm-plain+json".to_string(),
1168            type_: "https://example.org/test".to_string(),
1169            body: serde_json::json!({
1170                "content": "Cross-agent verification"
1171            }),
1172            from: key1.did.clone(),
1173            to: vec!["did:example:bob".to_string()],
1174            thid: None,
1175            pthid: None,
1176            created_time: Some(1234567890),
1177            expires_time: None,
1178            from_prior: None,
1179            attachments: None,
1180            extra_headers: Default::default(),
1181        };
1182
1183        let sender_kid = key1.did_doc.verification_method[0].id.clone();
1184        let pack_options = PackOptions::new().with_sign(&sender_kid);
1185        let packed = message.pack(&*key_manager1, pack_options).await.unwrap();
1186
1187        // key_manager2 can verify because did:key embeds the public key
1188        let unpack_options = UnpackOptions::new();
1189        let result: PlainMessage = String::unpack(&packed, &*key_manager2, unpack_options)
1190            .await
1191            .unwrap();
1192
1193        assert_eq!(result.id, "test-cross-agent");
1194        assert_eq!(
1195            result.body,
1196            serde_json::json!({"content": "Cross-agent verification"})
1197        );
1198    }
1199
1200    #[tokio::test]
1201    async fn test_unpack_to_unpacked_message() {
1202        // Create a key manager with a test key
1203        let key_manager = Arc::new(AgentKeyManagerBuilder::new().build().unwrap());
1204        let key = key_manager
1205            .generate_key(DIDGenerationOptions {
1206                key_type: KeyType::Ed25519,
1207            })
1208            .unwrap();
1209
1210        // Create a TAP transfer message
1211        let message = PlainMessage {
1212            id: "test-transfer-1".to_string(),
1213            typ: "application/didcomm-plain+json".to_string(),
1214            type_: "https://tap.rsvp/schema/1.0#Transfer".to_string(),
1215            body: serde_json::json!({
1216                "@type": "https://tap.rsvp/schema/1.0#Transfer",
1217                "transaction_id": "test-tx-123",
1218                "asset": "eip155:1/slip44:60",
1219                "originator": {
1220                    "@id": key.did.clone()
1221                },
1222                "amount": "100",
1223                "agents": [],
1224                "memo": null,
1225                "beneficiary": {
1226                    "@id": "did:example:bob"
1227                },
1228                "settlement_id": null,
1229                "connection_id": null,
1230                "metadata": {}
1231            }),
1232            from: key.did.clone(),
1233            to: vec!["did:example:bob".to_string()],
1234            thid: None,
1235            pthid: None,
1236            created_time: Some(1234567890),
1237            expires_time: None,
1238            from_prior: None,
1239            attachments: None,
1240            extra_headers: Default::default(),
1241        };
1242
1243        // Pack in plain mode
1244        let pack_options = PackOptions::new().with_plain();
1245        let packed = message.pack(&*key_manager, pack_options).await.unwrap();
1246
1247        // Unpack to UnpackedMessage
1248        let unpack_options = UnpackOptions::new();
1249        let unpacked: UnpackedMessage = String::unpack(&packed, &*key_manager, unpack_options)
1250            .await
1251            .unwrap();
1252
1253        // Verify PlainMessage
1254        assert_eq!(unpacked.plain_message.id, message.id);
1255        assert_eq!(unpacked.plain_message.type_, message.type_);
1256
1257        // Verify TAP message was parsed
1258        if unpacked.tap_message.is_none() {
1259            println!(
1260                "TAP message parsing failed for body: {}",
1261                serde_json::to_string_pretty(&unpacked.plain_message.body).unwrap()
1262            );
1263        }
1264        assert!(unpacked.tap_message.is_some());
1265        match unpacked.tap_message.unwrap() {
1266            TapMessage::Transfer(transfer) => {
1267                assert_eq!(transfer.amount, "100");
1268                assert_eq!(transfer.originator.as_ref().unwrap().id(), key.did);
1269            }
1270            _ => panic!("Expected Transfer message"),
1271        }
1272    }
1273
1274    #[tokio::test]
1275    async fn test_unpack_invalid_tap_message() {
1276        // Create a key manager with a test key
1277        let key_manager = Arc::new(AgentKeyManagerBuilder::new().build().unwrap());
1278        let key = key_manager
1279            .generate_key(DIDGenerationOptions {
1280                key_type: KeyType::Ed25519,
1281            })
1282            .unwrap();
1283
1284        // Create a message with an unknown type
1285        let message = PlainMessage {
1286            id: "test-unknown-1".to_string(),
1287            typ: "application/didcomm-plain+json".to_string(),
1288            type_: "https://example.org/unknown#message".to_string(),
1289            body: serde_json::json!({
1290                "content": "Unknown message type"
1291            }),
1292            from: key.did.clone(),
1293            to: vec!["did:example:bob".to_string()],
1294            thid: None,
1295            pthid: None,
1296            created_time: Some(1234567890),
1297            expires_time: None,
1298            from_prior: None,
1299            attachments: None,
1300            extra_headers: Default::default(),
1301        };
1302
1303        // Pack in plain mode
1304        let pack_options = PackOptions::new().with_plain();
1305        let packed = message.pack(&*key_manager, pack_options).await.unwrap();
1306
1307        // Unpack to UnpackedMessage
1308        let unpack_options = UnpackOptions::new();
1309        let unpacked: UnpackedMessage = String::unpack(&packed, &*key_manager, unpack_options)
1310            .await
1311            .unwrap();
1312
1313        // Verify PlainMessage was unpacked
1314        assert_eq!(unpacked.plain_message.id, message.id);
1315
1316        // Verify TAP message parsing failed (unknown type)
1317        assert!(unpacked.tap_message.is_none());
1318    }
1319}