celers_protocol/
extensions.rs

1//! Message extensions and utilities
2//!
3//! This module provides helper functions and extensions for working with
4//! Celery protocol messages, including signing, encryption, and validation.
5//!
6//! # Example
7//!
8//! ```
9//! use celers_protocol::extensions::MessageExt;
10//! use celers_protocol::{Message, TaskArgs};
11//! use uuid::Uuid;
12//!
13//! let task_id = Uuid::new_v4();
14//! let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
15//! let msg = Message::new("tasks.add".to_string(), task_id, body);
16//!
17//! // Validate the message
18//! assert!(msg.validate_basic().is_ok());
19//! ```
20
21use crate::Message;
22
23#[cfg(feature = "signing")]
24use crate::auth::{MessageSigner, SignatureError};
25
26#[cfg(feature = "encryption")]
27use crate::crypto::{EncryptionError, MessageEncryptor};
28
29use std::fmt;
30
31/// Error type for message extension operations
32#[derive(Debug)]
33pub enum ExtensionError {
34    /// Signature error
35    #[cfg(feature = "signing")]
36    Signature(SignatureError),
37    /// Encryption error
38    #[cfg(feature = "encryption")]
39    Encryption(EncryptionError),
40    /// Validation error
41    Validation(String),
42    /// Serialization error
43    Serialization(String),
44}
45
46impl fmt::Display for ExtensionError {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            #[cfg(feature = "signing")]
50            ExtensionError::Signature(e) => write!(f, "Signature error: {}", e),
51            #[cfg(feature = "encryption")]
52            ExtensionError::Encryption(e) => write!(f, "Encryption error: {}", e),
53            ExtensionError::Validation(msg) => write!(f, "Validation error: {}", msg),
54            ExtensionError::Serialization(msg) => write!(f, "Serialization error: {}", msg),
55        }
56    }
57}
58
59impl From<crate::ValidationError> for ExtensionError {
60    fn from(err: crate::ValidationError) -> Self {
61        ExtensionError::Validation(err.to_string())
62    }
63}
64
65impl std::error::Error for ExtensionError {}
66
67#[cfg(feature = "signing")]
68impl From<SignatureError> for ExtensionError {
69    fn from(e: SignatureError) -> Self {
70        ExtensionError::Signature(e)
71    }
72}
73
74#[cfg(feature = "encryption")]
75impl From<EncryptionError> for ExtensionError {
76    fn from(e: EncryptionError) -> Self {
77        ExtensionError::Encryption(e)
78    }
79}
80
81/// Extension trait for Message with additional utilities
82pub trait MessageExt {
83    /// Validate basic message structure
84    fn validate_basic(&self) -> Result<(), ExtensionError>;
85
86    /// Check if message is expired
87    fn is_expired(&self) -> bool;
88
89    /// Check if message is scheduled for future execution
90    fn is_scheduled(&self) -> bool;
91
92    /// Get message age in seconds
93    fn get_age_seconds(&self) -> Option<i64>;
94
95    /// Sign the message body
96    #[cfg(feature = "signing")]
97    fn sign_body(&self, signer: &MessageSigner) -> Vec<u8>;
98
99    /// Verify the message body signature
100    #[cfg(feature = "signing")]
101    fn verify_body(&self, signer: &MessageSigner, signature: &[u8]) -> Result<(), ExtensionError>;
102
103    /// Encrypt the message body
104    #[cfg(feature = "encryption")]
105    fn encrypt_body(&mut self, encryptor: &MessageEncryptor) -> Result<Vec<u8>, ExtensionError>;
106
107    /// Decrypt the message body
108    #[cfg(feature = "encryption")]
109    fn decrypt_body(
110        &self,
111        encryptor: &MessageEncryptor,
112        nonce: &[u8],
113    ) -> Result<Vec<u8>, ExtensionError>;
114}
115
116impl MessageExt for Message {
117    fn validate_basic(&self) -> Result<(), ExtensionError> {
118        self.validate().map_err(ExtensionError::from)
119    }
120
121    fn is_expired(&self) -> bool {
122        if let Some(expires) = self.headers.expires {
123            chrono::Utc::now() > expires
124        } else {
125            false
126        }
127    }
128
129    fn is_scheduled(&self) -> bool {
130        if let Some(eta) = self.headers.eta {
131            chrono::Utc::now() < eta
132        } else {
133            false
134        }
135    }
136
137    fn get_age_seconds(&self) -> Option<i64> {
138        // In a real implementation, you'd track message creation time
139        // For now, return None as we don't store creation timestamp
140        None
141    }
142
143    #[cfg(feature = "signing")]
144    fn sign_body(&self, signer: &MessageSigner) -> Vec<u8> {
145        signer.sign(&self.body)
146    }
147
148    #[cfg(feature = "signing")]
149    fn verify_body(&self, signer: &MessageSigner, signature: &[u8]) -> Result<(), ExtensionError> {
150        signer.verify(&self.body, signature)?;
151        Ok(())
152    }
153
154    #[cfg(feature = "encryption")]
155    fn encrypt_body(&mut self, encryptor: &MessageEncryptor) -> Result<Vec<u8>, ExtensionError> {
156        let (ciphertext, nonce) = encryptor.encrypt(&self.body)?;
157        self.body = ciphertext;
158        Ok(nonce)
159    }
160
161    #[cfg(feature = "encryption")]
162    fn decrypt_body(
163        &self,
164        encryptor: &MessageEncryptor,
165        nonce: &[u8],
166    ) -> Result<Vec<u8>, ExtensionError> {
167        let plaintext = encryptor.decrypt(&self.body, nonce)?;
168        Ok(plaintext)
169    }
170}
171
172/// Signed message wrapper
173#[cfg(feature = "signing")]
174#[derive(Debug, Clone)]
175pub struct SignedMessage {
176    /// The message
177    pub message: Message,
178    /// The signature
179    pub signature: Vec<u8>,
180}
181
182#[cfg(feature = "signing")]
183impl SignedMessage {
184    /// Create a new signed message
185    pub fn new(message: Message, signer: &MessageSigner) -> Self {
186        let signature = message.sign_body(signer);
187        Self { message, signature }
188    }
189
190    /// Verify the signature
191    pub fn verify(&self, signer: &MessageSigner) -> Result<(), ExtensionError> {
192        self.message.verify_body(signer, &self.signature)
193    }
194
195    /// Get the signature as hex string
196    pub fn signature_hex(&self) -> String {
197        hex::encode(&self.signature)
198    }
199}
200
201/// Encrypted message wrapper
202#[cfg(feature = "encryption")]
203#[derive(Debug, Clone)]
204pub struct EncryptedMessage {
205    /// The encrypted message
206    pub message: Message,
207    /// The nonce used for encryption
208    pub nonce: Vec<u8>,
209}
210
211#[cfg(feature = "encryption")]
212impl EncryptedMessage {
213    /// Create a new encrypted message
214    pub fn new(mut message: Message, encryptor: &MessageEncryptor) -> Result<Self, ExtensionError> {
215        let nonce = message.encrypt_body(encryptor)?;
216        Ok(Self { message, nonce })
217    }
218
219    /// Decrypt the message body
220    pub fn decrypt(&self, encryptor: &MessageEncryptor) -> Result<Vec<u8>, ExtensionError> {
221        self.message.decrypt_body(encryptor, &self.nonce)
222    }
223
224    /// Get the nonce as hex string
225    pub fn nonce_hex(&self) -> String {
226        hex::encode(&self.nonce)
227    }
228}
229
230/// Result type for build_secure method
231#[cfg(all(feature = "signing", feature = "encryption"))]
232pub type SecureBuildResult = Result<(Message, Option<Vec<u8>>, Option<Vec<u8>>), ExtensionError>;
233
234/// Message builder with security features
235pub struct SecureMessageBuilder {
236    message: Message,
237    #[cfg(feature = "signing")]
238    signer: Option<MessageSigner>,
239    #[cfg(feature = "encryption")]
240    encryptor: Option<MessageEncryptor>,
241}
242
243impl SecureMessageBuilder {
244    /// Create a new secure message builder
245    pub fn new(task: String, id: uuid::Uuid, body: Vec<u8>) -> Self {
246        Self {
247            message: Message::new(task, id, body),
248            #[cfg(feature = "signing")]
249            signer: None,
250            #[cfg(feature = "encryption")]
251            encryptor: None,
252        }
253    }
254
255    /// Set the message signer
256    #[cfg(feature = "signing")]
257    pub fn with_signer(mut self, key: &[u8]) -> Self {
258        self.signer = Some(MessageSigner::new(key));
259        self
260    }
261
262    /// Set the message encryptor
263    #[cfg(feature = "encryption")]
264    pub fn with_encryptor(mut self, key: &[u8]) -> Result<Self, ExtensionError> {
265        self.encryptor = Some(MessageEncryptor::new(key)?);
266        Ok(self)
267    }
268
269    /// Set priority
270    pub fn with_priority(mut self, priority: u8) -> Self {
271        self.message = self.message.with_priority(priority);
272        self
273    }
274
275    /// Build the message with optional signing
276    #[cfg(feature = "signing")]
277    #[cfg(not(feature = "encryption"))]
278    pub fn build(self) -> Result<(Message, Option<Vec<u8>>), ExtensionError> {
279        let signature = self.signer.as_ref().map(|s| self.message.sign_body(s));
280        Ok((self.message, signature))
281    }
282
283    /// Build the message with optional encryption
284    #[cfg(feature = "encryption")]
285    #[cfg(not(feature = "signing"))]
286    pub fn build(mut self) -> Result<(Message, Option<Vec<u8>>), ExtensionError> {
287        let nonce = if let Some(enc) = self.encryptor.as_ref() {
288            Some(self.message.encrypt_body(enc)?)
289        } else {
290            None
291        };
292        Ok((self.message, nonce))
293    }
294
295    /// Build with both signing and encryption
296    #[cfg(all(feature = "signing", feature = "encryption"))]
297    pub fn build_secure(mut self) -> SecureBuildResult {
298        let signature = self.signer.as_ref().map(|s| self.message.sign_body(s));
299        let nonce = if let Some(enc) = self.encryptor.as_ref() {
300            Some(self.message.encrypt_body(enc)?)
301        } else {
302            None
303        };
304        Ok((self.message, signature, nonce))
305    }
306
307    /// Build without security features
308    #[cfg(not(any(feature = "signing", feature = "encryption")))]
309    pub fn build(self) -> Message {
310        self.message
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::TaskArgs;
318    use uuid::Uuid;
319
320    #[test]
321    fn test_message_validate_basic() {
322        let task_id = Uuid::new_v4();
323        let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
324        let msg = Message::new("tasks.add".to_string(), task_id, body);
325
326        assert!(msg.validate_basic().is_ok());
327    }
328
329    #[test]
330    fn test_message_is_expired() {
331        let task_id = Uuid::new_v4();
332        let body = vec![1, 2, 3];
333        let mut msg = Message::new("tasks.test".to_string(), task_id, body);
334
335        // Not expired initially
336        assert!(!msg.is_expired());
337
338        // Set expiration in the past
339        msg.headers.expires = Some(chrono::Utc::now() - chrono::Duration::hours(1));
340        assert!(msg.is_expired());
341
342        // Set expiration in the future
343        msg.headers.expires = Some(chrono::Utc::now() + chrono::Duration::hours(1));
344        assert!(!msg.is_expired());
345    }
346
347    #[test]
348    fn test_message_is_scheduled() {
349        let task_id = Uuid::new_v4();
350        let body = vec![1, 2, 3];
351        let mut msg = Message::new("tasks.test".to_string(), task_id, body);
352
353        // Not scheduled initially
354        assert!(!msg.is_scheduled());
355
356        // Set ETA in the future
357        msg.headers.eta = Some(chrono::Utc::now() + chrono::Duration::hours(1));
358        assert!(msg.is_scheduled());
359
360        // Set ETA in the past
361        msg.headers.eta = Some(chrono::Utc::now() - chrono::Duration::hours(1));
362        assert!(!msg.is_scheduled());
363    }
364
365    #[cfg(feature = "signing")]
366    #[test]
367    fn test_sign_and_verify_message() {
368        let task_id = Uuid::new_v4();
369        let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
370        let msg = Message::new("tasks.add".to_string(), task_id, body);
371
372        let signer = MessageSigner::new(b"secret-key");
373        let signature = msg.sign_body(&signer);
374
375        assert!(msg.verify_body(&signer, &signature).is_ok());
376    }
377
378    #[cfg(feature = "signing")]
379    #[test]
380    fn test_signed_message_wrapper() {
381        let task_id = Uuid::new_v4();
382        let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
383        let msg = Message::new("tasks.add".to_string(), task_id, body);
384
385        let signer = MessageSigner::new(b"secret-key");
386        let signed = SignedMessage::new(msg, &signer);
387
388        assert!(signed.verify(&signer).is_ok());
389        assert!(!signed.signature_hex().is_empty());
390    }
391
392    #[cfg(feature = "encryption")]
393    #[test]
394    fn test_encrypt_and_decrypt_message() {
395        let task_id = Uuid::new_v4();
396        let body = b"secret data".to_vec();
397        let mut msg = Message::new("tasks.add".to_string(), task_id, body.clone());
398
399        let encryptor = MessageEncryptor::new(b"32-byte-secret-key-for-aes-256!!").unwrap();
400        let nonce = msg.encrypt_body(&encryptor).unwrap();
401
402        // Body should be different after encryption
403        assert_ne!(msg.body, body);
404
405        // Decrypt should recover original
406        let decrypted = msg.decrypt_body(&encryptor, &nonce).unwrap();
407        assert_eq!(decrypted, body);
408    }
409
410    #[cfg(feature = "encryption")]
411    #[test]
412    fn test_encrypted_message_wrapper() {
413        let task_id = Uuid::new_v4();
414        let body = b"secret data".to_vec();
415        let msg = Message::new("tasks.add".to_string(), task_id, body.clone());
416
417        let encryptor = MessageEncryptor::new(b"32-byte-secret-key-for-aes-256!!").unwrap();
418        let encrypted = EncryptedMessage::new(msg, &encryptor).unwrap();
419
420        let decrypted = encrypted.decrypt(&encryptor).unwrap();
421        assert_eq!(decrypted, body);
422        assert!(!encrypted.nonce_hex().is_empty());
423    }
424
425    #[cfg(feature = "signing")]
426    #[test]
427    fn test_secure_message_builder_with_signing() {
428        let task_id = Uuid::new_v4();
429        let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
430
431        let builder = SecureMessageBuilder::new("tasks.add".to_string(), task_id, body)
432            .with_signer(b"secret-key")
433            .with_priority(5);
434
435        #[cfg(not(feature = "encryption"))]
436        {
437            let (msg, signature) = builder.build().unwrap();
438            assert_eq!(msg.properties.priority, Some(5));
439            assert!(signature.is_some());
440        }
441
442        #[cfg(feature = "encryption")]
443        {
444            let _ = builder; // Use builder to avoid warning
445        }
446    }
447
448    #[test]
449    fn test_extension_error_display() {
450        let err = ExtensionError::Validation("test error".to_string());
451        assert_eq!(err.to_string(), "Validation error: test error");
452
453        let err = ExtensionError::Serialization("parse failed".to_string());
454        assert_eq!(err.to_string(), "Serialization error: parse failed");
455    }
456}