1use std::collections::VecDeque;
26use std::time::{SystemTime, UNIX_EPOCH};
27
28use layer_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
29use layer_tl_types::RemoteCall;
30
31const SEEN_MSG_IDS_MAX: usize = 500;
33const MSG_ID_TIME_WINDOW_SECS: i64 = 300;
35
36#[derive(Debug)]
38pub enum DecryptError {
39 Crypto(layer_crypto::DecryptError),
41 FrameTooShort,
43 SessionMismatch,
45 MsgIdTimeWindow,
47 DuplicateMsgId,
49}
50
51impl std::fmt::Display for DecryptError {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 Self::Crypto(e) => write!(f, "crypto: {e}"),
55 Self::FrameTooShort => write!(f, "inner plaintext too short"),
56 Self::SessionMismatch => write!(f, "session_id mismatch"),
57 Self::MsgIdTimeWindow => write!(f, "server msg_id outside ±300 s time window"),
58 Self::DuplicateMsgId => write!(f, "duplicate server msg_id (replay)"),
59 }
60 }
61}
62impl std::error::Error for DecryptError {}
63
64pub struct DecryptedMessage {
66 pub salt: i64,
68 pub session_id: i64,
70 pub msg_id: i64,
72 pub seq_no: i32,
74 pub body: Vec<u8>,
76 pub bad_time: bool,
85}
86
87pub struct EncryptedSession {
89 auth_key: AuthKey,
90 session_id: i64,
91 sequence: i32,
92 last_msg_id: i64,
93 pub salt: i64,
95 pub time_offset: i32,
97 seen_msg_ids: std::sync::Mutex<VecDeque<i64>>,
99}
100
101impl EncryptedSession {
102 pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
104 let mut rnd = [0u8; 8];
105 getrandom::getrandom(&mut rnd).expect("getrandom");
106 Self {
107 auth_key: AuthKey::from_bytes(auth_key),
108 session_id: i64::from_le_bytes(rnd),
109 sequence: 0,
110 last_msg_id: 0,
111 salt: first_salt,
112 time_offset,
113 seen_msg_ids: std::sync::Mutex::new(VecDeque::with_capacity(SEEN_MSG_IDS_MAX)),
114 }
115 }
116
117 fn next_msg_id(&mut self) -> i64 {
119 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
120 let secs = (now.as_secs() as i32).wrapping_add(self.time_offset) as u64;
121 let nanos = now.subsec_nanos() as u64;
122 let mut id = ((secs << 32) | (nanos << 2)) as i64;
123 if self.last_msg_id >= id {
124 id = self.last_msg_id + 4;
125 }
126 self.last_msg_id = id;
127 id
128 }
129
130 fn next_seq_no(&mut self) -> i32 {
133 let n = self.sequence * 2 + 1;
134 self.sequence += 1;
135 n
136 }
137
138 pub fn next_seq_no_ncr(&self) -> i32 {
143 self.sequence * 2
144 }
145
146 pub fn correct_seq_no(&mut self, code: u32) {
151 match code {
152 32 => {
153 self.sequence += 64;
155 log::debug!(
156 "[layer] seq_no correction: code 32, bumped seq to {}",
157 self.sequence
158 );
159 }
160 33 => {
161 self.sequence = self.sequence.saturating_sub(16).max(1);
166 log::debug!(
167 "[layer] seq_no correction: code 33, lowered seq to {}",
168 self.sequence
169 );
170 }
171 _ => {}
172 }
173 }
174
175 pub fn correct_time_offset(&mut self, server_msg_id: i64) {
182 let server_time = (server_msg_id >> 32) as i32;
184 let local_now = SystemTime::now()
185 .duration_since(UNIX_EPOCH)
186 .unwrap()
187 .as_secs() as i32;
188 let new_offset = server_time.wrapping_sub(local_now);
189 log::debug!(
190 "[layer] time_offset correction: {} → {} (server_time={server_time})",
191 self.time_offset,
192 new_offset
193 );
194 self.time_offset = new_offset;
195 self.last_msg_id = 0;
197 }
198
199 pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
206 let msg_id = self.next_msg_id();
207 let seqno = if content_related {
208 self.next_seq_no()
209 } else {
210 self.next_seq_no_ncr()
211 };
212 (msg_id, seqno)
213 }
214
215 pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
223 let msg_id = self.next_msg_id();
224 let seq_no = if content_related {
225 self.next_seq_no()
226 } else {
227 self.next_seq_no_ncr()
228 };
229
230 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
231 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
232 buf.extend(self.salt.to_le_bytes());
233 buf.extend(self.session_id.to_le_bytes());
234 buf.extend(msg_id.to_le_bytes());
235 buf.extend(seq_no.to_le_bytes());
236 buf.extend((body.len() as u32).to_le_bytes());
237 buf.extend(body.iter().copied());
238
239 encrypt_data_v2(&mut buf, &self.auth_key);
240 (buf.as_ref().to_vec(), msg_id)
241 }
242
243 pub fn pack_container(&mut self, container_body: &[u8]) -> (Vec<u8>, i64) {
252 self.pack_body_with_msg_id(container_body, false)
253 }
254
255 pub fn pack_serializable<S: layer_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
259 let body = call.to_bytes();
260 let msg_id = self.next_msg_id();
261 let seq_no = self.next_seq_no();
262
263 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
264 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
265 buf.extend(self.salt.to_le_bytes());
266 buf.extend(self.session_id.to_le_bytes());
267 buf.extend(msg_id.to_le_bytes());
268 buf.extend(seq_no.to_le_bytes());
269 buf.extend((body.len() as u32).to_le_bytes());
270 buf.extend(body.iter().copied());
271
272 encrypt_data_v2(&mut buf, &self.auth_key);
273 buf.as_ref().to_vec()
274 }
275
276 pub fn pack_serializable_with_msg_id<S: layer_tl_types::Serializable>(
278 &mut self,
279 call: &S,
280 ) -> (Vec<u8>, i64) {
281 let body = call.to_bytes();
282 let msg_id = self.next_msg_id();
283 let seq_no = self.next_seq_no();
284 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
285 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
286 buf.extend(self.salt.to_le_bytes());
287 buf.extend(self.session_id.to_le_bytes());
288 buf.extend(msg_id.to_le_bytes());
289 buf.extend(seq_no.to_le_bytes());
290 buf.extend((body.len() as u32).to_le_bytes());
291 buf.extend(body.iter().copied());
292 encrypt_data_v2(&mut buf, &self.auth_key);
293 (buf.as_ref().to_vec(), msg_id)
294 }
295
296 pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
298 let body = call.to_bytes();
299 let msg_id = self.next_msg_id();
300 let seq_no = self.next_seq_no();
301 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
302 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
303 buf.extend(self.salt.to_le_bytes());
304 buf.extend(self.session_id.to_le_bytes());
305 buf.extend(msg_id.to_le_bytes());
306 buf.extend(seq_no.to_le_bytes());
307 buf.extend((body.len() as u32).to_le_bytes());
308 buf.extend(body.iter().copied());
309 encrypt_data_v2(&mut buf, &self.auth_key);
310 (buf.as_ref().to_vec(), msg_id)
311 }
312
313 pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
315 let body = call.to_bytes();
316 let msg_id = self.next_msg_id();
317 let seq_no = self.next_seq_no();
318
319 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
320 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
321 buf.extend(self.salt.to_le_bytes());
322 buf.extend(self.session_id.to_le_bytes());
323 buf.extend(msg_id.to_le_bytes());
324 buf.extend(seq_no.to_le_bytes());
325 buf.extend((body.len() as u32).to_le_bytes());
326 buf.extend(body.iter().copied());
327
328 encrypt_data_v2(&mut buf, &self.auth_key);
329 buf.as_ref().to_vec()
330 }
331
332 pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
334 let plaintext = decrypt_data_v2(frame, &self.auth_key).map_err(DecryptError::Crypto)?;
335
336 if plaintext.len() < 32 {
337 return Err(DecryptError::FrameTooShort);
338 }
339
340 let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
341 let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
342 let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
343 let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
344 let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
345
346 if session_id != self.session_id {
347 return Err(DecryptError::SessionMismatch);
348 }
349
350 let server_secs = (msg_id as u64 >> 32) as i64;
355 let now = SystemTime::now()
356 .duration_since(UNIX_EPOCH)
357 .unwrap()
358 .as_secs() as i64;
359 let corrected = now + self.time_offset as i64;
360 let bad_time = (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS;
361
362 {
364 let mut seen = self.seen_msg_ids.lock().unwrap();
365 if seen.contains(&msg_id) {
366 return Err(DecryptError::DuplicateMsgId);
367 }
368 seen.push_back(msg_id);
369 if seen.len() > SEEN_MSG_IDS_MAX {
370 seen.pop_front();
371 }
372 }
373
374 if 32 + body_len > plaintext.len() {
375 return Err(DecryptError::FrameTooShort);
376 }
377 let body = plaintext[32..32 + body_len].to_vec();
378
379 Ok(DecryptedMessage {
380 salt,
381 session_id,
382 msg_id,
383 seq_no,
384 body,
385 bad_time,
386 })
387 }
388
389 pub fn auth_key_bytes(&self) -> [u8; 256] {
391 self.auth_key.to_bytes()
392 }
393
394 pub fn session_id(&self) -> i64 {
396 self.session_id
397 }
398}
399
400impl EncryptedSession {
401 pub fn decrypt_frame(
405 auth_key: &[u8; 256],
406 session_id: i64,
407 frame: &mut [u8],
408 ) -> Result<DecryptedMessage, DecryptError> {
409 Self::decrypt_frame_with_offset(auth_key, session_id, frame, 0)
410 }
411
412 pub fn decrypt_frame_with_offset(
415 auth_key: &[u8; 256],
416 session_id: i64,
417 frame: &mut [u8],
418 time_offset: i32,
419 ) -> Result<DecryptedMessage, DecryptError> {
420 let key = AuthKey::from_bytes(*auth_key);
421 let plaintext = decrypt_data_v2(frame, &key).map_err(DecryptError::Crypto)?;
422 if plaintext.len() < 32 {
423 return Err(DecryptError::FrameTooShort);
424 }
425 let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
426 let sid = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
427 let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
428 let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
429 let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
430 if sid != session_id {
431 return Err(DecryptError::SessionMismatch);
432 }
433 let server_secs = (msg_id as u64 >> 32) as i64;
435 let now = SystemTime::now()
436 .duration_since(UNIX_EPOCH)
437 .unwrap()
438 .as_secs() as i64;
439 let corrected = now + time_offset as i64;
440 let bad_time = (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS;
441 if 32 + body_len > plaintext.len() {
442 return Err(DecryptError::FrameTooShort);
443 }
444 let body = plaintext[32..32 + body_len].to_vec();
445 Ok(DecryptedMessage {
446 salt,
447 session_id: sid,
448 msg_id,
449 seq_no,
450 body,
451 bad_time,
452 })
453 }
454}