Skip to main content

celers_protocol/
extensions.rs

1//! Message extensions and utilities
2//!
3//! This module provides helper functions and extensions for working with
4//! Celery protocol messages, including signing, encryption, and validation.
5//!
6//! # Example
7//!
8//! ```
9//! use celers_protocol::extensions::MessageExt;
10//! use celers_protocol::{Message, TaskArgs};
11//! use uuid::Uuid;
12//!
13//! let task_id = Uuid::new_v4();
14//! let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
15//! let msg = Message::new("tasks.add".to_string(), task_id, body);
16//!
17//! // Validate the message
18//! assert!(msg.validate_basic().is_ok());
19//! ```
20
21use crate::Message;
22
23#[cfg(feature = "signing")]
24use crate::auth::{MessageSigner, SignatureError};
25
26#[cfg(feature = "encryption")]
27use crate::crypto::{EncryptionError, MessageEncryptor};
28
29use std::fmt;
30
31/// Error type for message extension operations
32#[derive(Debug)]
33pub enum ExtensionError {
34    /// Signature error
35    #[cfg(feature = "signing")]
36    Signature(SignatureError),
37    /// Encryption error
38    #[cfg(feature = "encryption")]
39    Encryption(EncryptionError),
40    /// Validation error
41    Validation(String),
42    /// Serialization error
43    Serialization(String),
44}
45
46impl fmt::Display for ExtensionError {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            #[cfg(feature = "signing")]
50            ExtensionError::Signature(e) => write!(f, "Signature error: {}", e),
51            #[cfg(feature = "encryption")]
52            ExtensionError::Encryption(e) => write!(f, "Encryption error: {}", e),
53            ExtensionError::Validation(msg) => write!(f, "Validation error: {}", msg),
54            ExtensionError::Serialization(msg) => write!(f, "Serialization error: {}", msg),
55        }
56    }
57}
58
59impl From<crate::ValidationError> for ExtensionError {
60    fn from(err: crate::ValidationError) -> Self {
61        ExtensionError::Validation(err.to_string())
62    }
63}
64
65impl std::error::Error for ExtensionError {}
66
67#[cfg(feature = "signing")]
68impl From<SignatureError> for ExtensionError {
69    fn from(e: SignatureError) -> Self {
70        ExtensionError::Signature(e)
71    }
72}
73
74#[cfg(feature = "encryption")]
75impl From<EncryptionError> for ExtensionError {
76    fn from(e: EncryptionError) -> Self {
77        ExtensionError::Encryption(e)
78    }
79}
80
81/// Extension trait for Message with additional utilities
82pub trait MessageExt {
83    /// Validate basic message structure
84    fn validate_basic(&self) -> Result<(), ExtensionError>;
85
86    /// Check if message is expired
87    fn is_expired(&self) -> bool;
88
89    /// Check if message is scheduled for future execution
90    fn is_scheduled(&self) -> bool;
91
92    /// Get message age in seconds
93    fn get_age_seconds(&self) -> Option<i64>;
94
95    /// Sign the message body
96    #[cfg(feature = "signing")]
97    fn sign_body(&self, signer: &MessageSigner) -> Result<Vec<u8>, crate::auth::SignatureError>;
98
99    /// Verify the message body signature
100    #[cfg(feature = "signing")]
101    fn verify_body(&self, signer: &MessageSigner, signature: &[u8]) -> Result<(), ExtensionError>;
102
103    /// Encrypt the message body
104    #[cfg(feature = "encryption")]
105    fn encrypt_body(&mut self, encryptor: &MessageEncryptor) -> Result<Vec<u8>, ExtensionError>;
106
107    /// Decrypt the message body
108    #[cfg(feature = "encryption")]
109    fn decrypt_body(
110        &self,
111        encryptor: &MessageEncryptor,
112        nonce: &[u8],
113    ) -> Result<Vec<u8>, ExtensionError>;
114}
115
116impl MessageExt for Message {
117    fn validate_basic(&self) -> Result<(), ExtensionError> {
118        self.validate().map_err(ExtensionError::from)
119    }
120
121    fn is_expired(&self) -> bool {
122        if let Some(expires) = self.headers.expires {
123            chrono::Utc::now() > expires
124        } else {
125            false
126        }
127    }
128
129    fn is_scheduled(&self) -> bool {
130        if let Some(eta) = self.headers.eta {
131            chrono::Utc::now() < eta
132        } else {
133            false
134        }
135    }
136
137    fn get_age_seconds(&self) -> Option<i64> {
138        // In a real implementation, you'd track message creation time
139        // For now, return None as we don't store creation timestamp
140        None
141    }
142
143    #[cfg(feature = "signing")]
144    fn sign_body(&self, signer: &MessageSigner) -> Result<Vec<u8>, crate::auth::SignatureError> {
145        signer.sign(&self.body)
146    }
147
148    #[cfg(feature = "signing")]
149    fn verify_body(&self, signer: &MessageSigner, signature: &[u8]) -> Result<(), ExtensionError> {
150        signer.verify(&self.body, signature)?;
151        Ok(())
152    }
153
154    #[cfg(feature = "encryption")]
155    fn encrypt_body(&mut self, encryptor: &MessageEncryptor) -> Result<Vec<u8>, ExtensionError> {
156        let (ciphertext, nonce) = encryptor.encrypt(&self.body)?;
157        self.body = ciphertext;
158        Ok(nonce)
159    }
160
161    #[cfg(feature = "encryption")]
162    fn decrypt_body(
163        &self,
164        encryptor: &MessageEncryptor,
165        nonce: &[u8],
166    ) -> Result<Vec<u8>, ExtensionError> {
167        let plaintext = encryptor.decrypt(&self.body, nonce)?;
168        Ok(plaintext)
169    }
170}
171
172/// Signed message wrapper
173#[cfg(feature = "signing")]
174#[derive(Debug, Clone)]
175pub struct SignedMessage {
176    /// The message
177    pub message: Message,
178    /// The signature
179    pub signature: Vec<u8>,
180}
181
182#[cfg(feature = "signing")]
183impl SignedMessage {
184    /// Create a new signed message
185    pub fn new(
186        message: Message,
187        signer: &MessageSigner,
188    ) -> Result<Self, crate::auth::SignatureError> {
189        let signature = message.sign_body(signer)?;
190        Ok(Self { message, signature })
191    }
192
193    /// Verify the signature
194    pub fn verify(&self, signer: &MessageSigner) -> Result<(), ExtensionError> {
195        self.message.verify_body(signer, &self.signature)
196    }
197
198    /// Get the signature as hex string
199    pub fn signature_hex(&self) -> String {
200        hex::encode(&self.signature)
201    }
202}
203
204/// Encrypted message wrapper
205#[cfg(feature = "encryption")]
206#[derive(Debug, Clone)]
207pub struct EncryptedMessage {
208    /// The encrypted message
209    pub message: Message,
210    /// The nonce used for encryption
211    pub nonce: Vec<u8>,
212}
213
214#[cfg(feature = "encryption")]
215impl EncryptedMessage {
216    /// Create a new encrypted message
217    pub fn new(mut message: Message, encryptor: &MessageEncryptor) -> Result<Self, ExtensionError> {
218        let nonce = message.encrypt_body(encryptor)?;
219        Ok(Self { message, nonce })
220    }
221
222    /// Decrypt the message body
223    pub fn decrypt(&self, encryptor: &MessageEncryptor) -> Result<Vec<u8>, ExtensionError> {
224        self.message.decrypt_body(encryptor, &self.nonce)
225    }
226
227    /// Get the nonce as hex string
228    pub fn nonce_hex(&self) -> String {
229        hex::encode(&self.nonce)
230    }
231}
232
233/// Result type for build_secure method
234#[cfg(all(feature = "signing", feature = "encryption"))]
235pub type SecureBuildResult = Result<(Message, Option<Vec<u8>>, Option<Vec<u8>>), ExtensionError>;
236
237/// Message builder with security features
238pub struct SecureMessageBuilder {
239    message: Message,
240    #[cfg(feature = "signing")]
241    signer: Option<MessageSigner>,
242    #[cfg(feature = "encryption")]
243    encryptor: Option<MessageEncryptor>,
244}
245
246impl SecureMessageBuilder {
247    /// Create a new secure message builder
248    pub fn new(task: String, id: uuid::Uuid, body: Vec<u8>) -> Self {
249        Self {
250            message: Message::new(task, id, body),
251            #[cfg(feature = "signing")]
252            signer: None,
253            #[cfg(feature = "encryption")]
254            encryptor: None,
255        }
256    }
257
258    /// Set the message signer
259    #[cfg(feature = "signing")]
260    pub fn with_signer(mut self, key: &[u8]) -> Self {
261        self.signer = Some(MessageSigner::new(key));
262        self
263    }
264
265    /// Set the message encryptor
266    #[cfg(feature = "encryption")]
267    pub fn with_encryptor(mut self, key: &[u8]) -> Result<Self, ExtensionError> {
268        self.encryptor = Some(MessageEncryptor::new(key)?);
269        Ok(self)
270    }
271
272    /// Set priority
273    pub fn with_priority(mut self, priority: u8) -> Self {
274        self.message = self.message.with_priority(priority);
275        self
276    }
277
278    /// Build the message with optional signing
279    #[cfg(feature = "signing")]
280    #[cfg(not(feature = "encryption"))]
281    pub fn build(self) -> Result<(Message, Option<Vec<u8>>), ExtensionError> {
282        let signature = self.signer.as_ref().map(|s| self.message.sign_body(s));
283        Ok((self.message, signature))
284    }
285
286    /// Build the message with optional encryption
287    #[cfg(feature = "encryption")]
288    #[cfg(not(feature = "signing"))]
289    pub fn build(mut self) -> Result<(Message, Option<Vec<u8>>), ExtensionError> {
290        let nonce = if let Some(enc) = self.encryptor.as_ref() {
291            Some(self.message.encrypt_body(enc)?)
292        } else {
293            None
294        };
295        Ok((self.message, nonce))
296    }
297
298    /// Build with both signing and encryption
299    #[cfg(all(feature = "signing", feature = "encryption"))]
300    pub fn build_secure(mut self) -> SecureBuildResult {
301        let signature = self
302            .signer
303            .as_ref()
304            .map(|s| self.message.sign_body(s))
305            .transpose()?;
306        let nonce = if let Some(enc) = self.encryptor.as_ref() {
307            Some(self.message.encrypt_body(enc)?)
308        } else {
309            None
310        };
311        Ok((self.message, signature, nonce))
312    }
313
314    /// Build without security features
315    #[cfg(not(any(feature = "signing", feature = "encryption")))]
316    pub fn build(self) -> Message {
317        self.message
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::TaskArgs;
325    use uuid::Uuid;
326
327    #[test]
328    fn test_message_validate_basic() {
329        let task_id = Uuid::new_v4();
330        let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
331        let msg = Message::new("tasks.add".to_string(), task_id, body);
332
333        assert!(msg.validate_basic().is_ok());
334    }
335
336    #[test]
337    fn test_message_is_expired() {
338        let task_id = Uuid::new_v4();
339        let body = vec![1, 2, 3];
340        let mut msg = Message::new("tasks.test".to_string(), task_id, body);
341
342        // Not expired initially
343        assert!(!msg.is_expired());
344
345        // Set expiration in the past
346        msg.headers.expires = Some(chrono::Utc::now() - chrono::Duration::hours(1));
347        assert!(msg.is_expired());
348
349        // Set expiration in the future
350        msg.headers.expires = Some(chrono::Utc::now() + chrono::Duration::hours(1));
351        assert!(!msg.is_expired());
352    }
353
354    #[test]
355    fn test_message_is_scheduled() {
356        let task_id = Uuid::new_v4();
357        let body = vec![1, 2, 3];
358        let mut msg = Message::new("tasks.test".to_string(), task_id, body);
359
360        // Not scheduled initially
361        assert!(!msg.is_scheduled());
362
363        // Set ETA in the future
364        msg.headers.eta = Some(chrono::Utc::now() + chrono::Duration::hours(1));
365        assert!(msg.is_scheduled());
366
367        // Set ETA in the past
368        msg.headers.eta = Some(chrono::Utc::now() - chrono::Duration::hours(1));
369        assert!(!msg.is_scheduled());
370    }
371
372    #[cfg(feature = "signing")]
373    #[test]
374    fn test_sign_and_verify_message() {
375        let task_id = Uuid::new_v4();
376        let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
377        let msg = Message::new("tasks.add".to_string(), task_id, body);
378
379        let signer = MessageSigner::new(b"secret-key");
380        let signature = msg.sign_body(&signer).expect("signing failed in test");
381
382        assert!(msg.verify_body(&signer, &signature).is_ok());
383    }
384
385    #[cfg(feature = "signing")]
386    #[test]
387    fn test_signed_message_wrapper() {
388        let task_id = Uuid::new_v4();
389        let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
390        let msg = Message::new("tasks.add".to_string(), task_id, body);
391
392        let signer = MessageSigner::new(b"secret-key");
393        let signed = SignedMessage::new(msg, &signer).expect("signing should not fail");
394
395        assert!(signed.verify(&signer).is_ok());
396        assert!(!signed.signature_hex().is_empty());
397    }
398
399    #[cfg(feature = "encryption")]
400    #[test]
401    fn test_encrypt_and_decrypt_message() {
402        let task_id = Uuid::new_v4();
403        let body = b"secret data".to_vec();
404        let mut msg = Message::new("tasks.add".to_string(), task_id, body.clone());
405
406        let encryptor = MessageEncryptor::new(b"32-byte-secret-key-for-aes-256!!").unwrap();
407        let nonce = msg.encrypt_body(&encryptor).unwrap();
408
409        // Body should be different after encryption
410        assert_ne!(msg.body, body);
411
412        // Decrypt should recover original
413        let decrypted = msg.decrypt_body(&encryptor, &nonce).unwrap();
414        assert_eq!(decrypted, body);
415    }
416
417    #[cfg(feature = "encryption")]
418    #[test]
419    fn test_encrypted_message_wrapper() {
420        let task_id = Uuid::new_v4();
421        let body = b"secret data".to_vec();
422        let msg = Message::new("tasks.add".to_string(), task_id, body.clone());
423
424        let encryptor = MessageEncryptor::new(b"32-byte-secret-key-for-aes-256!!").unwrap();
425        let encrypted = EncryptedMessage::new(msg, &encryptor).unwrap();
426
427        let decrypted = encrypted.decrypt(&encryptor).unwrap();
428        assert_eq!(decrypted, body);
429        assert!(!encrypted.nonce_hex().is_empty());
430    }
431
432    #[cfg(feature = "signing")]
433    #[test]
434    fn test_secure_message_builder_with_signing() {
435        let task_id = Uuid::new_v4();
436        let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
437
438        let builder = SecureMessageBuilder::new("tasks.add".to_string(), task_id, body)
439            .with_signer(b"secret-key")
440            .with_priority(5);
441
442        #[cfg(not(feature = "encryption"))]
443        {
444            let (msg, signature) = builder.build().unwrap();
445            assert_eq!(msg.properties.priority, Some(5));
446            assert!(signature.is_some());
447        }
448
449        #[cfg(feature = "encryption")]
450        {
451            let _ = builder; // Use builder to avoid warning
452        }
453    }
454
455    #[test]
456    fn test_extension_error_display() {
457        let err = ExtensionError::Validation("test error".to_string());
458        assert_eq!(err.to_string(), "Validation error: test error");
459
460        let err = ExtensionError::Serialization("parse failed".to_string());
461        assert_eq!(err.to_string(), "Serialization error: parse failed");
462    }
463}