libsignal_rust/
session_cipher.rs

1use crate::{
2    chain_type::ChainType,
3    session_record::{SessionRecord, SessionEntry, ChainInfo, ChainKey},
4    crypto,
5    curve,
6    errors::{SessionError, MessageCounterError},
7    protocol_address::ProtocolAddress,
8    queue_job::queue_job,
9    session_builder::SessionStorage,
10    protos::{WhisperMessage},
11};
12use std::sync::Arc;
13
14pub struct CiphertextMessage {
15    pub message_type: u8,
16    pub body: Vec<u8>,
17}
18
19pub struct SessionCipher<T: SessionStorage> {
20    storage: Arc<T>,
21    addr: ProtocolAddress,
22}
23
24impl<T: SessionStorage + 'static> SessionCipher<T> {
25    pub fn new(storage: Arc<T>, addr: ProtocolAddress) -> Self {
26        Self { storage, addr }
27    }
28
29
30    pub async fn encrypt(&self, plaintext: &[u8]) -> Result<CiphertextMessage, Box<dyn std::error::Error + Send + Sync>> {
31        let storage = self.storage.clone();
32        let addr = self.addr.clone();
33        let plaintext = plaintext.to_vec();
34        
35        queue_job(addr.to_string(), async move {
36            let mut record = storage.load_session(&addr.to_string()).await.ok_or("No session record found")?;
37            let mut session = record.get_open_session().ok_or("No open session")?.clone();
38            
39            session.index_info.used = chrono::Utc::now().timestamp() as u64;
40            
41            let chain_key = session.current_ratchet.ephemeral_key_pair.pub_key.clone();
42            let chain = session.get_chain_mut(&chain_key).ok_or("Chain not found")?;
43            
44            let counter = (chain.chain_key.counter + 1) as u32;
45            let message_keys = Self::static_fill_message_keys(chain, counter)?;
46            let ciphertext = Self::static_encrypt_message(&message_keys, &plaintext)?;
47            
48            let whisper_message = WhisperMessage {
49                ephemeral_key: session.current_ratchet.ephemeral_key_pair.pub_key.clone(),
50                counter: message_keys.counter,
51                previous_counter: session.current_ratchet.previous_counter,
52                ciphertext,
53            };
54
55            let body = Self::static_serialize_whisper_message(&whisper_message)?;
56            
57            record.set_session(session);
58            storage.store_session(&addr.to_string(), record).await;
59            
60            Ok(CiphertextMessage {
61                message_type: 1,
62                body,
63            })
64        }).await
65    }
66
67    pub async fn decrypt(&self, ciphertext_message: &CiphertextMessage) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
68        let storage = self.storage.clone();
69        let addr = self.addr.clone();
70        let message_type = ciphertext_message.message_type;
71        let body = ciphertext_message.body.clone();
72        
73        queue_job(addr.to_string(), async move {
74            let mut record = storage.load_session(&addr.to_string()).await.ok_or("No session record found")?;
75            
76            let plaintext = match message_type {
77                1 => Self::static_decrypt_whisper_message(&mut record, &body).await,
78                3 => Self::decrypt_pre_key_whisper_message(storage.clone(), addr.clone(), &mut record, &body).await,
79                _ => Err("Unknown message type".into()),
80            }?;
81            
82            storage.store_session(&addr.to_string(), record).await;
83            
84            Ok(plaintext)
85        }).await
86    }
87
88    async fn static_decrypt_whisper_message(record: &mut SessionRecord, message_bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
89        let message = Self::static_deserialize_whisper_message(message_bytes)?;
90        let mut session = record.get_open_session().ok_or("No open session")?.clone();
91
92        session.index_info.used = chrono::Utc::now().timestamp() as u64;
93        
94        Self::static_maybe_step_ratchet(&mut session, &message.ephemeral_key, message.previous_counter)?;
95        
96        let chain = session.get_chain_mut(&message.ephemeral_key)
97            .ok_or("Chain not found")?;
98        
99        let message_keys = Self::static_fill_message_keys(chain, message.counter)?;
100        let plaintext = Self::static_decrypt_message(&message_keys, &message.ciphertext)?;
101        
102        record.set_session(session);
103        // Return the updated record so caller can store it
104        // Note (DitzDev): Storage should be handled by the calling function
105        
106        Ok(plaintext)
107    }
108
109    async fn decrypt_pre_key_whisper_message<S: SessionStorage + 'static>(storage: Arc<S>, addr: ProtocolAddress, record: &mut SessionRecord, message_bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
110        use prost::Message;
111        let prekey_message = crate::protos::PreKeyWhisperMessage::decode(message_bytes)
112            .map_err(|e| format!("Failed to decode PreKeyWhisperMessage: {}", e))?;
113        
114        // Extract the wrapped WhisperMessage
115        let whisper_message = crate::protos::WhisperMessage::decode(&prekey_message.message[..])
116            .map_err(|e| format!("Failed to decode WhisperMessage: {}", e))?;
117
118        // Check if we already have a session for this PreKey message
119        if let Some(session) = record.get_open_session() {
120            // Clone the session to avoid borrowing conflicts
121            let mut session_clone = session.clone();
122            
123            // Check if this message matches the pending PreKey
124            if let Some(pending) = &session_clone.pending_pre_key {
125                if pending.signed_key_id == prekey_message.signed_pre_key_id &&
126                   pending.base_key == prekey_message.base_key {
127                    // This matches our pending prekey, proceed with decryption
128                    session_clone.index_info.used = chrono::Utc::now().timestamp() as u64;
129                    
130                    Self::static_maybe_step_ratchet(&mut session_clone, &whisper_message.ephemeral_key, whisper_message.previous_counter)?;
131                    
132                    let chain = session_clone.get_chain_mut(&whisper_message.ephemeral_key)
133                        .ok_or("Chain not found")?;
134                    
135                    let message_keys = Self::static_fill_message_keys(chain, whisper_message.counter)?;
136                    let plaintext = Self::static_decrypt_message(&message_keys, &whisper_message.ciphertext)?;
137                    
138                    // Clear the pending prekey since we've successfully used it
139                    session_clone.pending_pre_key = None;
140                    record.set_session(session_clone);
141                    
142                    return Ok(plaintext);
143                }
144            }
145            
146            // If we have a session but it doesn't match the PreKey, treat as normal whisper message
147            session_clone.index_info.used = chrono::Utc::now().timestamp() as u64;
148            
149            Self::static_maybe_step_ratchet(&mut session_clone, &whisper_message.ephemeral_key, whisper_message.previous_counter)?;
150            
151            let chain = session_clone.get_chain_mut(&whisper_message.ephemeral_key)
152                .ok_or("Chain not found")?;
153            
154            let message_keys = Self::static_fill_message_keys(chain, whisper_message.counter)?;
155            let plaintext = Self::static_decrypt_message(&message_keys, &whisper_message.ciphertext)?;
156            
157            record.set_session(session_clone);
158            Ok(plaintext)
159        } else {
160            // No existing session
161            // We can use SessionBuilder to create 
162            // one from the PreKey message
163            use crate::session_builder::{SessionBuilder, PreKeyWhisperMessage as BuilderPreKeyMessage};
164            
165            // Convert protobuf PreKeyWhisperMessage to SessionBuilder PreKeyWhisperMessage
166            let builder_message = BuilderPreKeyMessage {
167                registration_id: prekey_message.registration_id,
168                pre_key_id: Some(prekey_message.pre_key_id),
169                signed_pre_key_id: prekey_message.signed_pre_key_id,
170                base_key: prekey_message.base_key.clone(),
171                identity_key: prekey_message.identity_key.clone(),
172                message: prekey_message.message.clone(),
173            };
174            
175            // Create SessionBuilder and initialize incoming session
176            let session_builder = SessionBuilder::new(storage, addr);
177            let _pre_key_id = session_builder.init_incoming(record, &builder_message).await?;
178            
179            // Now that we have a session, proceed with decryption
180            if let Some(session) = record.get_open_session() {
181                let mut session_clone = session.clone();
182                session_clone.index_info.used = chrono::Utc::now().timestamp() as u64;
183                
184                Self::static_maybe_step_ratchet(&mut session_clone, &whisper_message.ephemeral_key, whisper_message.previous_counter)?;
185                
186                let chain = session_clone.get_chain_mut(&whisper_message.ephemeral_key)
187                    .ok_or("Chain not found")?;
188                
189                let message_keys = Self::static_fill_message_keys(chain, whisper_message.counter)?;
190                let plaintext = Self::static_decrypt_message(&message_keys, &whisper_message.ciphertext)?;
191                
192                record.set_session(session_clone);
193                Ok(plaintext)
194            } else {
195                Err("Failed to create session from PreKey message".into())
196            }
197        }
198    }
199    
200    #[allow(dead_code)]
201    fn static_get_message_keys(session: &SessionEntry, chain_key: &[u8]) -> Result<MessageKeys, Box<dyn std::error::Error + Send + Sync>> {
202        let chain = session.get_chain(chain_key).ok_or("Chain not found")?;
203        
204        if chain.chain_key.key.is_none() {
205            return Err("Chain closed".into());
206        }
207        
208        let key = chain.chain_key.key.as_ref().unwrap();
209        let counter = chain.chain_key.counter + 1;
210        
211        // Use HKDF to derive 80 bytes (32 cipher + 32 mac + 16 IV) just like static_fill_message_keys
212        let derived_keys = crypto::derive_secrets(key, &[0u8; 32], b"WhisperMessageKeys", Some(3))
213            .map_err(|e| format!("Key derivation error: {}", e))?;
214        
215        // Concatenate the derived keys: 32 bytes cipher + 32 bytes mac + 16 bytes IV  
216        let cipher_key = derived_keys[0].clone();  // 32 bytes cipher key
217        let mac_key = derived_keys[1].clone();     // 32 bytes mac key
218        let iv = derived_keys[2][..16].to_vec();   // 16 bytes IV
219        
220        Ok(MessageKeys {
221            cipher_key,
222            mac_key,
223            iv,
224            counter: counter as u32,
225        })
226    }
227
228    fn static_encrypt_message(keys: &MessageKeys, plaintext: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
229        let ciphertext = crypto::encrypt(&keys.cipher_key, plaintext, &keys.iv)?;
230        let mac = crypto::calculate_mac(&keys.mac_key, &ciphertext);
231        
232        let mut result = ciphertext;
233        result.extend_from_slice(&mac[..8]);
234        Ok(result)
235    }
236
237    fn static_decrypt_message(keys: &MessageKeys, ciphertext: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
238        if ciphertext.len() < 8 {
239            return Err("Ciphertext too short".into());
240        }
241        
242        let (message_data, mac) = ciphertext.split_at(ciphertext.len() - 8);
243        crypto::verify_mac(message_data, &keys.mac_key, mac, 8)?;
244        
245        crypto::decrypt(&keys.cipher_key, message_data, &keys.iv)
246    }
247
248    fn static_fill_message_keys(chain: &mut ChainInfo, counter: u32) -> Result<MessageKeys, Box<dyn std::error::Error + Send + Sync>> {
249        if let Some(message_key) = chain.message_keys.get(&counter) {
250            let cipher_key = &message_key[..32];
251            let mac_key = &message_key[32..64];
252            let iv = &message_key[64..80];
253            
254            return Ok(MessageKeys {
255                cipher_key: cipher_key.to_vec(),
256                mac_key: mac_key.to_vec(),
257                iv: iv.to_vec(),
258                counter,
259            });
260        }
261
262        if chain.chain_key.counter >= counter as i32 {
263            return Err(Box::new(MessageCounterError::new("Message counter too old")));
264        }
265
266        if counter as i32 - chain.chain_key.counter > 2000 {
267            return Err(Box::new(SessionError::new("Over 2000 messages into the future!")));
268        }
269
270        if chain.chain_key.key.is_none() {
271            return Err(Box::new(SessionError::new("Chain closed")));
272        }
273
274        let mut current_key = chain.chain_key.key.clone().unwrap();
275        let mut current_counter = chain.chain_key.counter;
276
277        while current_counter < counter as i32 {
278            // Use HKDF to derive 80 bytes (32 cipher + 32 mac + 16 IV)
279            let derived_keys = crypto::derive_secrets(&current_key, &[0u8; 32], b"WhisperMessageKeys", Some(3))
280                .map_err(|e| format!("Key derivation error: {}", e))?;
281            
282            // Concatenate the derived keys: 32 bytes cipher + 32 bytes mac + 16 bytes IV
283            let mut message_key = Vec::with_capacity(80);
284            message_key.extend_from_slice(&derived_keys[0]);  // 32 bytes cipher key
285            message_key.extend_from_slice(&derived_keys[1]);  // 32 bytes mac key  
286            message_key.extend_from_slice(&derived_keys[2][..16]); // 16 bytes IV
287            
288            chain.message_keys.insert((current_counter + 1) as u32, message_key);
289            current_key = crypto::calculate_mac(&current_key, &[2u8]);
290            current_counter += 1;
291        }
292
293        chain.chain_key.counter = current_counter;
294        chain.chain_key.key = Some(current_key);
295
296        Self::static_fill_message_keys(chain, counter)
297    }
298
299    fn static_maybe_step_ratchet(session: &mut SessionEntry, remote_key: &[u8], previous_counter: u32) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
300        if session.get_chain(remote_key).is_some() {
301            return Ok(());
302        }
303
304        // Clone the key before mutation to avoid borrowing conflicts
305        let last_remote = session.current_ratchet.last_remote_ephemeral_key.clone();
306        if let Some(previous_ratchet) = session.get_chain_mut(&last_remote) {
307            Self::static_fill_message_keys(previous_ratchet, previous_counter)?;
308            previous_ratchet.chain_key.key = None; // Close chain
309        }
310
311        Self::static_calculate_ratchet(session, remote_key, false)?;
312
313        // Clone the pub key to avoid borrowing conflicts
314        let cur_pub = session.current_ratchet.ephemeral_key_pair.pub_key.clone();
315        let prev_counter = session.get_chain(&cur_pub)
316            .map(|chain| chain.chain_key.counter as u32);
317
318        if let Some(counter) = prev_counter {
319            session.current_ratchet.previous_counter = counter;
320            session.delete_chain(&cur_pub)?;
321        }
322
323        session.current_ratchet.ephemeral_key_pair = curve::generate_key_pair();
324        Self::static_calculate_ratchet(session, remote_key, true)?;
325        session.current_ratchet.last_remote_ephemeral_key = remote_key.to_vec();
326
327        Ok(())
328    }
329
330    fn static_calculate_ratchet(session: &mut SessionEntry, remote_key: &[u8], sending: bool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
331        // Clone needed values to avoid borrowing conflicts
332        let root_key = session.current_ratchet.root_key.clone();
333        let priv_key = session.current_ratchet.ephemeral_key_pair.priv_key.clone();
334        let pub_key = session.current_ratchet.ephemeral_key_pair.pub_key.clone();
335
336        let shared_secret = curve::calculate_agreement(remote_key, &priv_key)?;
337        let master_key = crypto::derive_secrets(&shared_secret, &root_key, b"WhisperRatchet", Some(2))?;
338
339        let chain_key = if sending {
340            &pub_key
341        } else {
342            remote_key
343        };
344
345        session.add_chain(chain_key, ChainInfo {
346            message_keys: Default::default(),
347            chain_key: ChainKey {
348                counter: -1,
349                key: Some(master_key[1].clone()),
350            },
351            chain_type: if sending { ChainType::Sending } else { ChainType::Receiving },
352        })?;
353
354        session.current_ratchet.root_key = master_key[0].clone();
355        Ok(())
356    }
357
358    fn static_serialize_whisper_message(message: &WhisperMessage) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
359        use prost::Message;
360        Ok(message.encode_to_vec())
361    }
362
363    fn static_deserialize_whisper_message(data: &[u8]) -> Result<WhisperMessage, Box<dyn std::error::Error + Send + Sync>> {
364        use prost::Message;
365        WhisperMessage::decode(data)
366            .map_err(|e| format!("Failed to decode WhisperMessage: {}", e).into())
367    }
368
369    pub async fn has_open_session(&self) -> bool {
370        let storage = self.storage.clone();
371        let addr = self.addr.clone();
372        
373        queue_job(addr.to_string(), async move {
374            if let Some(record) = storage.load_session(&addr.to_string()).await {
375                record.have_open_session()
376            } else {
377                false
378            }
379        }).await
380    }
381
382    pub async fn close_open_session(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
383        let storage = self.storage.clone();
384        let addr = self.addr.clone();
385        
386        queue_job(addr.to_string(), async move {
387            if let Some(mut record) = storage.load_session(&addr.to_string()).await {
388                if let Some(open_session) = record.get_open_session() {
389                    let base_key = open_session.index_info.base_key.clone();
390                    record.close_session(&base_key);
391                    storage.store_session(&addr.to_string(), record).await;
392                }
393            }
394            Ok(())
395        }).await
396    }
397}
398
399#[derive(Debug, Clone)]
400struct MessageKeys {
401    cipher_key: Vec<u8>,
402    mac_key: Vec<u8>,
403    iv: Vec<u8>,
404    counter: u32,
405}