1use std::collections::{HashSet, VecDeque};
14use std::time::{SystemTime, UNIX_EPOCH};
15
16use ferogram_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
17use ferogram_tl_types::RemoteCall;
18
19const SEEN_MSG_IDS_MAX: usize = 500;
21
22#[derive(Debug)]
24pub enum DecryptError {
25 Crypto(ferogram_crypto::DecryptError),
27 FrameTooShort,
29 SessionMismatch,
31 MsgIdTimeWindow,
33 DuplicateMsgId,
35 InvalidMsgId,
37}
38
39impl std::fmt::Display for DecryptError {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 Self::Crypto(e) => write!(f, "crypto: {e}"),
43 Self::FrameTooShort => write!(f, "inner plaintext too short"),
44 Self::SessionMismatch => write!(f, "session_id mismatch"),
45 Self::MsgIdTimeWindow => write!(f, "server msg_id outside -300s/+30s time window"),
46 Self::DuplicateMsgId => write!(f, "duplicate server msg_id (replay)"),
47 Self::InvalidMsgId => write!(f, "server msg_id has even parity (must be odd)"),
48 }
49 }
50}
51impl std::error::Error for DecryptError {}
52
53pub struct DecryptedMessage {
55 pub salt: i64,
57 pub session_id: i64,
59 pub msg_id: i64,
61 pub seq_no: i32,
63 pub body: Vec<u8>,
65}
66
67pub type SeenMsgIds = std::sync::Arc<std::sync::Mutex<(VecDeque<i64>, HashSet<i64>)>>;
76
77pub fn new_seen_msg_ids() -> SeenMsgIds {
79 std::sync::Arc::new(std::sync::Mutex::new((
80 VecDeque::with_capacity(SEEN_MSG_IDS_MAX),
81 HashSet::with_capacity(SEEN_MSG_IDS_MAX),
82 )))
83}
84
85pub struct EncryptedSession {
87 auth_key: AuthKey,
88 session_id: i64,
89 sequence: i32,
90 last_msg_id: i64,
91 pub salt: i64,
93 pub time_offset: i32,
95 seen_msg_ids: SeenMsgIds,
98}
99
100impl EncryptedSession {
101 pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
107 Self::with_seen(auth_key, first_salt, time_offset, new_seen_msg_ids())
108 }
109
110 pub fn with_seen(
112 auth_key: [u8; 256],
113 first_salt: i64,
114 time_offset: i32,
115 seen_msg_ids: SeenMsgIds,
116 ) -> Self {
117 let mut rnd = [0u8; 8];
118 ferogram_crypto::fill_random(&mut rnd);
119 Self {
120 auth_key: AuthKey::from_bytes(auth_key),
121 session_id: i64::from_le_bytes(rnd),
122 sequence: 0,
123 last_msg_id: 0,
124 salt: first_salt,
125 time_offset,
126 seen_msg_ids,
127 }
128 }
129
130 pub fn seen_msg_ids(&self) -> SeenMsgIds {
133 std::sync::Arc::clone(&self.seen_msg_ids)
134 }
135
136 fn next_msg_id(&mut self) -> i64 {
138 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
139 let secs = now.as_secs().wrapping_add(self.time_offset as i64 as u64);
141 let nanos = now.subsec_nanos() as u64;
142 let mut id = ((secs << 32) | (nanos << 2)) as i64;
143 if (id as u64 & 0xFFFF_FFFF) == 0 {
148 id |= 4;
149 }
150 if self.last_msg_id >= id {
151 id = self.last_msg_id + 4;
152 }
153 self.last_msg_id = id;
154 id
155 }
156
157 fn next_seq_no(&mut self) -> i32 {
160 let n = self.sequence * 2 + 1;
161 self.sequence += 1;
162 n
163 }
164
165 pub fn next_seq_no_ncr(&self) -> i32 {
170 self.sequence * 2
171 }
172
173 pub fn correct_seq_no(&mut self, _code: u32) {
184 self.reset_session();
187 tracing::debug!(
188 code = _code,
189 "[ferogram::mtproto] seq_no desync: full session reset (new session_id, seq_no=0)"
190 );
191 }
192
193 pub fn undo_seq_no(&mut self) {
200 self.sequence = self.sequence.saturating_sub(1);
201 }
202
203 pub fn correct_time_offset(&mut self, server_msg_id: i64) {
210 let server_time = (server_msg_id >> 32) as i32;
212 let local_now = SystemTime::now()
213 .duration_since(UNIX_EPOCH)
214 .unwrap()
215 .as_secs() as i32;
216 let new_offset = server_time.wrapping_sub(local_now);
217 tracing::debug!(
218 old_offset = self.time_offset,
219 new_offset,
220 server_time,
221 "[ferogram::mtproto] clock skew corrected from bad_msg_notification"
222 );
223 self.time_offset = new_offset;
224 self.last_msg_id = (server_msg_id & !0x3i64).max(self.last_msg_id);
227 }
228
229 pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
236 let msg_id = self.next_msg_id();
237 let seqno = if content_related {
238 self.next_seq_no()
239 } else {
240 self.next_seq_no_ncr()
241 };
242 (msg_id, seqno)
243 }
244
245 pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
253 let msg_id = self.next_msg_id();
254 let seq_no = if content_related {
255 self.next_seq_no()
256 } else {
257 self.next_seq_no_ncr()
258 };
259
260 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
261 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
262 buf.extend(self.salt.to_le_bytes());
263 buf.extend(self.session_id.to_le_bytes());
264 buf.extend(msg_id.to_le_bytes());
265 buf.extend(seq_no.to_le_bytes());
266 buf.extend((body.len() as u32).to_le_bytes());
267 buf.extend(body.iter().copied());
268
269 encrypt_data_v2(&mut buf, &self.auth_key);
270 (buf.as_ref().to_vec(), msg_id)
271 }
272
273 pub fn pack_container(&mut self, container_body: &[u8]) -> (Vec<u8>, i64) {
282 self.pack_body_with_msg_id(container_body, false)
283 }
284
285 pub fn pack_body_at_msg_id(&mut self, body: &[u8], msg_id: i64) -> Vec<u8> {
290 let seq_no = self.next_seq_no();
291 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
292 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
293 buf.extend(self.salt.to_le_bytes());
294 buf.extend(self.session_id.to_le_bytes());
295 buf.extend(msg_id.to_le_bytes());
296 buf.extend(seq_no.to_le_bytes());
297 buf.extend((body.len() as u32).to_le_bytes());
298 buf.extend(body.iter().copied());
299 encrypt_data_v2(&mut buf, &self.auth_key);
300 buf.as_ref().to_vec()
301 }
302
303 pub fn pack_serializable<S: ferogram_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
305 let body = call.to_bytes();
306 let msg_id = self.next_msg_id();
307 let seq_no = self.next_seq_no();
308
309 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
310 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
311 buf.extend(self.salt.to_le_bytes());
312 buf.extend(self.session_id.to_le_bytes());
313 buf.extend(msg_id.to_le_bytes());
314 buf.extend(seq_no.to_le_bytes());
315 buf.extend((body.len() as u32).to_le_bytes());
316 buf.extend(body.iter().copied());
317
318 encrypt_data_v2(&mut buf, &self.auth_key);
319 buf.as_ref().to_vec()
320 }
321
322 pub fn pack_serializable_with_msg_id<S: ferogram_tl_types::Serializable>(
324 &mut self,
325 call: &S,
326 ) -> (Vec<u8>, i64) {
327 let body = call.to_bytes();
328 let msg_id = self.next_msg_id();
329 let seq_no = self.next_seq_no();
330 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
331 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
332 buf.extend(self.salt.to_le_bytes());
333 buf.extend(self.session_id.to_le_bytes());
334 buf.extend(msg_id.to_le_bytes());
335 buf.extend(seq_no.to_le_bytes());
336 buf.extend((body.len() as u32).to_le_bytes());
337 buf.extend(body.iter().copied());
338 encrypt_data_v2(&mut buf, &self.auth_key);
339 (buf.as_ref().to_vec(), msg_id)
340 }
341
342 pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
344 let body = call.to_bytes();
345 let msg_id = self.next_msg_id();
346 let seq_no = self.next_seq_no();
347 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
348 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
349 buf.extend(self.salt.to_le_bytes());
350 buf.extend(self.session_id.to_le_bytes());
351 buf.extend(msg_id.to_le_bytes());
352 buf.extend(seq_no.to_le_bytes());
353 buf.extend((body.len() as u32).to_le_bytes());
354 buf.extend(body.iter().copied());
355 encrypt_data_v2(&mut buf, &self.auth_key);
356 (buf.as_ref().to_vec(), msg_id)
357 }
358
359 pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
361 let body = call.to_bytes();
362 let msg_id = self.next_msg_id();
363 let seq_no = self.next_seq_no();
364
365 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
366 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
367 buf.extend(self.salt.to_le_bytes());
368 buf.extend(self.session_id.to_le_bytes());
369 buf.extend(msg_id.to_le_bytes());
370 buf.extend(seq_no.to_le_bytes());
371 buf.extend((body.len() as u32).to_le_bytes());
372 buf.extend(body.iter().copied());
373
374 encrypt_data_v2(&mut buf, &self.auth_key);
375 buf.as_ref().to_vec()
376 }
377
378 pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
380 let plaintext = decrypt_data_v2(frame, &self.auth_key).map_err(DecryptError::Crypto)?;
381
382 if plaintext.len() < 32 {
383 return Err(DecryptError::FrameTooShort);
384 }
385
386 let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
387 let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
388 let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
389 let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
390 let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
391
392 if session_id != self.session_id {
393 return Err(DecryptError::SessionMismatch);
394 }
395
396 if msg_id & 1 == 0 {
398 return Err(DecryptError::InvalidMsgId);
399 }
400
401 let server_secs = (msg_id as u64 >> 32) as i64;
403 let now = SystemTime::now()
404 .duration_since(UNIX_EPOCH)
405 .unwrap()
406 .as_secs() as i64;
407 let corrected = now + self.time_offset as i64;
408 let skew = server_secs - corrected;
409 if !(-300..=30).contains(&skew) {
410 return Err(DecryptError::MsgIdTimeWindow);
411 }
412
413 {
415 let mut seen = self.seen_msg_ids.lock().unwrap();
416 if seen.1.contains(&msg_id) {
417 return Err(DecryptError::DuplicateMsgId);
418 }
419 seen.0.push_back(msg_id);
420 seen.1.insert(msg_id);
421 if seen.0.len() > SEEN_MSG_IDS_MAX
422 && let Some(old_id) = seen.0.pop_front()
423 {
424 seen.1.remove(&old_id);
425 }
426 }
427
428 if body_len > 16 * 1024 * 1024 {
430 return Err(DecryptError::FrameTooShort);
431 }
432 if 32 + body_len > plaintext.len() {
433 return Err(DecryptError::FrameTooShort);
434 }
435 if !body_len.is_multiple_of(4) {
437 return Err(DecryptError::FrameTooShort);
438 }
439 let padding = plaintext.len() - 32 - body_len;
441 if !(12..=1024).contains(&padding) {
442 return Err(DecryptError::FrameTooShort);
443 }
444 let body = plaintext[32..32 + body_len].to_vec();
445
446 Ok(DecryptedMessage {
447 salt,
448 session_id,
449 msg_id,
450 seq_no,
451 body,
452 })
453 }
454
455 pub fn auth_key_bytes(&self) -> [u8; 256] {
457 self.auth_key.to_bytes()
458 }
459
460 pub fn session_id(&self) -> i64 {
462 self.session_id
463 }
464
465 pub fn reset_session(&mut self) {
473 let mut rnd = [0u8; 8];
474 ferogram_crypto::fill_random(&mut rnd);
475 let old_session = self.session_id;
476 self.session_id = i64::from_le_bytes(rnd);
477 self.sequence = 0;
478 self.last_msg_id = 0;
479 tracing::debug!(
482 old_session = format_args!("{old_session:#018x}"),
483 new_session = format_args!("{:#018x}", self.session_id),
484 "[ferogram::mtproto] session reset: new session_id assigned, seq_no zeroed"
485 );
486 }
487
488 pub fn reset_seq_no_only(&mut self) {
499 self.sequence = 0;
500 self.last_msg_id = 0;
501 tracing::debug!(
502 session_id = format_args!("{:#018x}", self.session_id),
503 "[ferogram::mtproto] seq_no reset after new_session_created (session_id unchanged)"
504 );
505 }
506}
507
508impl EncryptedSession {
509 pub fn decrypt_frame_dedup(
517 auth_key: &[u8; 256],
518 session_id: i64,
519 frame: &mut [u8],
520 seen: &SeenMsgIds,
521 ) -> Result<DecryptedMessage, DecryptError> {
522 Self::decrypt_frame_dedup_with_offset(auth_key, session_id, frame, seen, 0)
523 }
524
525 pub fn decrypt_frame_dedup_with_offset(
532 auth_key: &[u8; 256],
533 session_id: i64,
534 frame: &mut [u8],
535 seen: &SeenMsgIds,
536 time_offset: i32,
537 ) -> Result<DecryptedMessage, DecryptError> {
538 let msg = Self::decrypt_frame_with_offset(auth_key, session_id, frame, time_offset)?;
539 {
540 let mut s = seen.lock().unwrap();
541 if s.1.contains(&msg.msg_id) {
542 return Err(DecryptError::DuplicateMsgId);
543 }
544 s.0.push_back(msg.msg_id);
545 s.1.insert(msg.msg_id);
546 if s.0.len() > SEEN_MSG_IDS_MAX
547 && let Some(old_id) = s.0.pop_front()
548 {
549 s.1.remove(&old_id);
550 }
551 }
552 Ok(msg)
553 }
554
555 pub fn decrypt_frame(
559 auth_key: &[u8; 256],
560 session_id: i64,
561 frame: &mut [u8],
562 ) -> Result<DecryptedMessage, DecryptError> {
563 Self::decrypt_frame_with_offset(auth_key, session_id, frame, 0)
564 }
565
566 pub fn decrypt_frame_with_offset(
569 auth_key: &[u8; 256],
570 session_id: i64,
571 frame: &mut [u8],
572 time_offset: i32,
573 ) -> Result<DecryptedMessage, DecryptError> {
574 let key = AuthKey::from_bytes(*auth_key);
575 let plaintext = decrypt_data_v2(frame, &key).map_err(DecryptError::Crypto)?;
576 if plaintext.len() < 32 {
577 return Err(DecryptError::FrameTooShort);
578 }
579 let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
580 let sid = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
581 let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
582 let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
583 let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
584 if sid != session_id {
585 return Err(DecryptError::SessionMismatch);
586 }
587 if msg_id & 1 == 0 {
589 return Err(DecryptError::InvalidMsgId);
590 }
591 let server_secs = (msg_id as u64 >> 32) as i64;
593 let now = SystemTime::now()
594 .duration_since(UNIX_EPOCH)
595 .unwrap()
596 .as_secs() as i64;
597 let corrected = now + time_offset as i64;
598 let skew = server_secs - corrected;
599 if !(-300..=30).contains(&skew) {
600 return Err(DecryptError::MsgIdTimeWindow);
601 }
602 if body_len > 16 * 1024 * 1024 {
604 return Err(DecryptError::FrameTooShort);
605 }
606 if 32 + body_len > plaintext.len() {
607 return Err(DecryptError::FrameTooShort);
608 }
609 if !body_len.is_multiple_of(4) {
611 return Err(DecryptError::FrameTooShort);
612 }
613 let padding = plaintext.len() - 32 - body_len;
615 if !(12..=1024).contains(&padding) {
616 return Err(DecryptError::FrameTooShort);
617 }
618 let body = plaintext[32..32 + body_len].to_vec();
619 Ok(DecryptedMessage {
620 salt,
621 session_id: sid,
622 msg_id,
623 seq_no,
624 body,
625 })
626 }
627}