1use crate::{
2 chain_type::ChainType,
3 session_record::{SessionRecord, SessionEntry, ChainInfo, ChainKey},
4 crypto,
5 curve,
6 errors::{SessionError, MessageCounterError},
7 protocol_address::ProtocolAddress,
8 queue_job::queue_job,
9 session_builder::SessionStorage,
10 protos::{WhisperMessage},
11};
12use std::sync::Arc;
13
14pub struct CiphertextMessage {
15 pub message_type: u8,
16 pub body: Vec<u8>,
17}
18
19pub struct SessionCipher<T: SessionStorage> {
20 storage: Arc<T>,
21 addr: ProtocolAddress,
22}
23
24impl<T: SessionStorage + 'static> SessionCipher<T> {
25 pub fn new(storage: Arc<T>, addr: ProtocolAddress) -> Self {
26 Self { storage, addr }
27 }
28
29
30 pub async fn encrypt(&self, plaintext: &[u8]) -> Result<CiphertextMessage, Box<dyn std::error::Error + Send + Sync>> {
31 let storage = self.storage.clone();
32 let addr = self.addr.clone();
33 let plaintext = plaintext.to_vec();
34
35 queue_job(addr.to_string(), async move {
36 let mut record = storage.load_session(&addr.to_string()).await.ok_or("No session record found")?;
37 let mut session = record.get_open_session().ok_or("No open session")?.clone();
38
39 session.index_info.used = chrono::Utc::now().timestamp() as u64;
40
41 let chain_key = session.current_ratchet.ephemeral_key_pair.pub_key.clone();
42 let chain = session.get_chain_mut(&chain_key).ok_or("Chain not found")?;
43
44 let counter = (chain.chain_key.counter + 1) as u32;
45 let message_keys = Self::static_fill_message_keys(chain, counter)?;
46 let ciphertext = Self::static_encrypt_message(&message_keys, &plaintext)?;
47
48 let whisper_message = WhisperMessage {
49 ephemeral_key: session.current_ratchet.ephemeral_key_pair.pub_key.clone(),
50 counter: message_keys.counter,
51 previous_counter: session.current_ratchet.previous_counter,
52 ciphertext,
53 };
54
55 let body = Self::static_serialize_whisper_message(&whisper_message)?;
56
57 record.set_session(session);
58 storage.store_session(&addr.to_string(), record).await;
59
60 Ok(CiphertextMessage {
61 message_type: 1,
62 body,
63 })
64 }).await
65 }
66
67 pub async fn decrypt(&self, ciphertext_message: &CiphertextMessage) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
68 let storage = self.storage.clone();
69 let addr = self.addr.clone();
70 let message_type = ciphertext_message.message_type;
71 let body = ciphertext_message.body.clone();
72
73 queue_job(addr.to_string(), async move {
74 let mut record = storage.load_session(&addr.to_string()).await.ok_or("No session record found")?;
75
76 let plaintext = match message_type {
77 1 => Self::static_decrypt_whisper_message(&mut record, &body).await,
78 3 => Self::decrypt_pre_key_whisper_message(storage.clone(), addr.clone(), &mut record, &body).await,
79 _ => Err("Unknown message type".into()),
80 }?;
81
82 storage.store_session(&addr.to_string(), record).await;
83
84 Ok(plaintext)
85 }).await
86 }
87
88 async fn static_decrypt_whisper_message(record: &mut SessionRecord, message_bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
89 let message = Self::static_deserialize_whisper_message(message_bytes)?;
90 let mut session = record.get_open_session().ok_or("No open session")?.clone();
91
92 session.index_info.used = chrono::Utc::now().timestamp() as u64;
93
94 Self::static_maybe_step_ratchet(&mut session, &message.ephemeral_key, message.previous_counter)?;
95
96 let chain = session.get_chain_mut(&message.ephemeral_key)
97 .ok_or("Chain not found")?;
98
99 let message_keys = Self::static_fill_message_keys(chain, message.counter)?;
100 let plaintext = Self::static_decrypt_message(&message_keys, &message.ciphertext)?;
101
102 record.set_session(session);
103 Ok(plaintext)
107 }
108
109 async fn decrypt_pre_key_whisper_message<S: SessionStorage + 'static>(storage: Arc<S>, addr: ProtocolAddress, record: &mut SessionRecord, message_bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
110 use prost::Message;
111 let prekey_message = crate::protos::PreKeyWhisperMessage::decode(message_bytes)
112 .map_err(|e| format!("Failed to decode PreKeyWhisperMessage: {}", e))?;
113
114 let whisper_message = crate::protos::WhisperMessage::decode(&prekey_message.message[..])
116 .map_err(|e| format!("Failed to decode WhisperMessage: {}", e))?;
117
118 if let Some(session) = record.get_open_session() {
120 let mut session_clone = session.clone();
122
123 if let Some(pending) = &session_clone.pending_pre_key {
125 if pending.signed_key_id == prekey_message.signed_pre_key_id &&
126 pending.base_key == prekey_message.base_key {
127 session_clone.index_info.used = chrono::Utc::now().timestamp() as u64;
129
130 Self::static_maybe_step_ratchet(&mut session_clone, &whisper_message.ephemeral_key, whisper_message.previous_counter)?;
131
132 let chain = session_clone.get_chain_mut(&whisper_message.ephemeral_key)
133 .ok_or("Chain not found")?;
134
135 let message_keys = Self::static_fill_message_keys(chain, whisper_message.counter)?;
136 let plaintext = Self::static_decrypt_message(&message_keys, &whisper_message.ciphertext)?;
137
138 session_clone.pending_pre_key = None;
140 record.set_session(session_clone);
141
142 return Ok(plaintext);
143 }
144 }
145
146 session_clone.index_info.used = chrono::Utc::now().timestamp() as u64;
148
149 Self::static_maybe_step_ratchet(&mut session_clone, &whisper_message.ephemeral_key, whisper_message.previous_counter)?;
150
151 let chain = session_clone.get_chain_mut(&whisper_message.ephemeral_key)
152 .ok_or("Chain not found")?;
153
154 let message_keys = Self::static_fill_message_keys(chain, whisper_message.counter)?;
155 let plaintext = Self::static_decrypt_message(&message_keys, &whisper_message.ciphertext)?;
156
157 record.set_session(session_clone);
158 Ok(plaintext)
159 } else {
160 use crate::session_builder::{SessionBuilder, PreKeyWhisperMessage as BuilderPreKeyMessage};
164
165 let builder_message = BuilderPreKeyMessage {
167 registration_id: prekey_message.registration_id,
168 pre_key_id: Some(prekey_message.pre_key_id),
169 signed_pre_key_id: prekey_message.signed_pre_key_id,
170 base_key: prekey_message.base_key.clone(),
171 identity_key: prekey_message.identity_key.clone(),
172 message: prekey_message.message.clone(),
173 };
174
175 let session_builder = SessionBuilder::new(storage, addr);
177 let _pre_key_id = session_builder.init_incoming(record, &builder_message).await?;
178
179 if let Some(session) = record.get_open_session() {
181 let mut session_clone = session.clone();
182 session_clone.index_info.used = chrono::Utc::now().timestamp() as u64;
183
184 Self::static_maybe_step_ratchet(&mut session_clone, &whisper_message.ephemeral_key, whisper_message.previous_counter)?;
185
186 let chain = session_clone.get_chain_mut(&whisper_message.ephemeral_key)
187 .ok_or("Chain not found")?;
188
189 let message_keys = Self::static_fill_message_keys(chain, whisper_message.counter)?;
190 let plaintext = Self::static_decrypt_message(&message_keys, &whisper_message.ciphertext)?;
191
192 record.set_session(session_clone);
193 Ok(plaintext)
194 } else {
195 Err("Failed to create session from PreKey message".into())
196 }
197 }
198 }
199
200 #[allow(dead_code)]
201 fn static_get_message_keys(session: &SessionEntry, chain_key: &[u8]) -> Result<MessageKeys, Box<dyn std::error::Error + Send + Sync>> {
202 let chain = session.get_chain(chain_key).ok_or("Chain not found")?;
203
204 if chain.chain_key.key.is_none() {
205 return Err("Chain closed".into());
206 }
207
208 let key = chain.chain_key.key.as_ref().unwrap();
209 let counter = chain.chain_key.counter + 1;
210
211 let derived_keys = crypto::derive_secrets(key, &[0u8; 32], b"WhisperMessageKeys", Some(3))
213 .map_err(|e| format!("Key derivation error: {}", e))?;
214
215 let cipher_key = derived_keys[0].clone(); let mac_key = derived_keys[1].clone(); let iv = derived_keys[2][..16].to_vec(); Ok(MessageKeys {
221 cipher_key,
222 mac_key,
223 iv,
224 counter: counter as u32,
225 })
226 }
227
228 fn static_encrypt_message(keys: &MessageKeys, plaintext: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
229 let ciphertext = crypto::encrypt(&keys.cipher_key, plaintext, &keys.iv)?;
230 let mac = crypto::calculate_mac(&keys.mac_key, &ciphertext);
231
232 let mut result = ciphertext;
233 result.extend_from_slice(&mac[..8]);
234 Ok(result)
235 }
236
237 fn static_decrypt_message(keys: &MessageKeys, ciphertext: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
238 if ciphertext.len() < 8 {
239 return Err("Ciphertext too short".into());
240 }
241
242 let (message_data, mac) = ciphertext.split_at(ciphertext.len() - 8);
243 crypto::verify_mac(message_data, &keys.mac_key, mac, 8)?;
244
245 crypto::decrypt(&keys.cipher_key, message_data, &keys.iv)
246 }
247
248 fn static_fill_message_keys(chain: &mut ChainInfo, counter: u32) -> Result<MessageKeys, Box<dyn std::error::Error + Send + Sync>> {
249 if let Some(message_key) = chain.message_keys.get(&counter) {
250 let cipher_key = &message_key[..32];
251 let mac_key = &message_key[32..64];
252 let iv = &message_key[64..80];
253
254 return Ok(MessageKeys {
255 cipher_key: cipher_key.to_vec(),
256 mac_key: mac_key.to_vec(),
257 iv: iv.to_vec(),
258 counter,
259 });
260 }
261
262 if chain.chain_key.counter >= counter as i32 {
263 return Err(Box::new(MessageCounterError::new("Message counter too old")));
264 }
265
266 if counter as i32 - chain.chain_key.counter > 2000 {
267 return Err(Box::new(SessionError::new("Over 2000 messages into the future!")));
268 }
269
270 if chain.chain_key.key.is_none() {
271 return Err(Box::new(SessionError::new("Chain closed")));
272 }
273
274 let mut current_key = chain.chain_key.key.clone().unwrap();
275 let mut current_counter = chain.chain_key.counter;
276
277 while current_counter < counter as i32 {
278 let derived_keys = crypto::derive_secrets(¤t_key, &[0u8; 32], b"WhisperMessageKeys", Some(3))
280 .map_err(|e| format!("Key derivation error: {}", e))?;
281
282 let mut message_key = Vec::with_capacity(80);
284 message_key.extend_from_slice(&derived_keys[0]); message_key.extend_from_slice(&derived_keys[1]); message_key.extend_from_slice(&derived_keys[2][..16]); chain.message_keys.insert((current_counter + 1) as u32, message_key);
289 current_key = crypto::calculate_mac(¤t_key, &[2u8]);
290 current_counter += 1;
291 }
292
293 chain.chain_key.counter = current_counter;
294 chain.chain_key.key = Some(current_key);
295
296 Self::static_fill_message_keys(chain, counter)
297 }
298
299 fn static_maybe_step_ratchet(session: &mut SessionEntry, remote_key: &[u8], previous_counter: u32) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
300 if session.get_chain(remote_key).is_some() {
301 return Ok(());
302 }
303
304 let last_remote = session.current_ratchet.last_remote_ephemeral_key.clone();
306 if let Some(previous_ratchet) = session.get_chain_mut(&last_remote) {
307 Self::static_fill_message_keys(previous_ratchet, previous_counter)?;
308 previous_ratchet.chain_key.key = None; }
310
311 Self::static_calculate_ratchet(session, remote_key, false)?;
312
313 let cur_pub = session.current_ratchet.ephemeral_key_pair.pub_key.clone();
315 let prev_counter = session.get_chain(&cur_pub)
316 .map(|chain| chain.chain_key.counter as u32);
317
318 if let Some(counter) = prev_counter {
319 session.current_ratchet.previous_counter = counter;
320 session.delete_chain(&cur_pub)?;
321 }
322
323 session.current_ratchet.ephemeral_key_pair = curve::generate_key_pair();
324 Self::static_calculate_ratchet(session, remote_key, true)?;
325 session.current_ratchet.last_remote_ephemeral_key = remote_key.to_vec();
326
327 Ok(())
328 }
329
330 fn static_calculate_ratchet(session: &mut SessionEntry, remote_key: &[u8], sending: bool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
331 let root_key = session.current_ratchet.root_key.clone();
333 let priv_key = session.current_ratchet.ephemeral_key_pair.priv_key.clone();
334 let pub_key = session.current_ratchet.ephemeral_key_pair.pub_key.clone();
335
336 let shared_secret = curve::calculate_agreement(remote_key, &priv_key)?;
337 let master_key = crypto::derive_secrets(&shared_secret, &root_key, b"WhisperRatchet", Some(2))?;
338
339 let chain_key = if sending {
340 &pub_key
341 } else {
342 remote_key
343 };
344
345 session.add_chain(chain_key, ChainInfo {
346 message_keys: Default::default(),
347 chain_key: ChainKey {
348 counter: -1,
349 key: Some(master_key[1].clone()),
350 },
351 chain_type: if sending { ChainType::Sending } else { ChainType::Receiving },
352 })?;
353
354 session.current_ratchet.root_key = master_key[0].clone();
355 Ok(())
356 }
357
358 fn static_serialize_whisper_message(message: &WhisperMessage) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
359 use prost::Message;
360 Ok(message.encode_to_vec())
361 }
362
363 fn static_deserialize_whisper_message(data: &[u8]) -> Result<WhisperMessage, Box<dyn std::error::Error + Send + Sync>> {
364 use prost::Message;
365 WhisperMessage::decode(data)
366 .map_err(|e| format!("Failed to decode WhisperMessage: {}", e).into())
367 }
368
369 pub async fn has_open_session(&self) -> bool {
370 let storage = self.storage.clone();
371 let addr = self.addr.clone();
372
373 queue_job(addr.to_string(), async move {
374 if let Some(record) = storage.load_session(&addr.to_string()).await {
375 record.have_open_session()
376 } else {
377 false
378 }
379 }).await
380 }
381
382 pub async fn close_open_session(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
383 let storage = self.storage.clone();
384 let addr = self.addr.clone();
385
386 queue_job(addr.to_string(), async move {
387 if let Some(mut record) = storage.load_session(&addr.to_string()).await {
388 if let Some(open_session) = record.get_open_session() {
389 let base_key = open_session.index_info.base_key.clone();
390 record.close_session(&base_key);
391 storage.store_session(&addr.to_string(), record).await;
392 }
393 }
394 Ok(())
395 }).await
396 }
397}
398
399#[derive(Debug, Clone)]
400struct MessageKeys {
401 cipher_key: Vec<u8>,
402 mac_key: Vec<u8>,
403 iv: Vec<u8>,
404 counter: u32,
405}