1use std::collections::{HashSet, VecDeque};
19use std::time::{SystemTime, UNIX_EPOCH};
20
21use ferogram_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
22use ferogram_tl_types::RemoteCall;
23
24const SEEN_MSG_IDS_MAX: usize = 500;
26
27#[derive(Debug)]
29pub enum DecryptError {
30 Crypto(ferogram_crypto::DecryptError),
32 FrameTooShort,
34 SessionMismatch,
36 MsgIdTimeWindow,
38 DuplicateMsgId,
40 InvalidMsgId,
42}
43
44impl std::fmt::Display for DecryptError {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 Self::Crypto(e) => write!(f, "crypto: {e}"),
48 Self::FrameTooShort => write!(f, "inner plaintext too short"),
49 Self::SessionMismatch => write!(f, "session_id mismatch"),
50 Self::MsgIdTimeWindow => write!(f, "server msg_id outside -300s/+30s time window"),
51 Self::DuplicateMsgId => write!(f, "duplicate server msg_id (replay)"),
52 Self::InvalidMsgId => write!(f, "server msg_id has even parity (must be odd)"),
53 }
54 }
55}
56impl std::error::Error for DecryptError {}
57
58pub struct DecryptedMessage {
60 pub salt: i64,
62 pub session_id: i64,
64 pub msg_id: i64,
66 pub seq_no: i32,
68 pub body: Vec<u8>,
70}
71
72pub type SeenMsgIds = std::sync::Arc<std::sync::Mutex<(VecDeque<i64>, HashSet<i64>)>>;
81
82pub fn new_seen_msg_ids() -> SeenMsgIds {
84 std::sync::Arc::new(std::sync::Mutex::new((
85 VecDeque::with_capacity(SEEN_MSG_IDS_MAX),
86 HashSet::with_capacity(SEEN_MSG_IDS_MAX),
87 )))
88}
89
90pub struct EncryptedSession {
92 auth_key: AuthKey,
93 session_id: i64,
94 sequence: i32,
95 last_msg_id: i64,
96 pub salt: i64,
98 pub time_offset: i32,
100 seen_msg_ids: SeenMsgIds,
103}
104
105impl EncryptedSession {
106 pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
112 Self::with_seen(auth_key, first_salt, time_offset, new_seen_msg_ids())
113 }
114
115 pub fn with_seen(
117 auth_key: [u8; 256],
118 first_salt: i64,
119 time_offset: i32,
120 seen_msg_ids: SeenMsgIds,
121 ) -> Self {
122 let mut rnd = [0u8; 8];
123 getrandom::getrandom(&mut rnd).expect("getrandom");
124 Self {
125 auth_key: AuthKey::from_bytes(auth_key),
126 session_id: i64::from_le_bytes(rnd),
127 sequence: 0,
128 last_msg_id: 0,
129 salt: first_salt,
130 time_offset,
131 seen_msg_ids,
132 }
133 }
134
135 pub fn seen_msg_ids(&self) -> SeenMsgIds {
138 std::sync::Arc::clone(&self.seen_msg_ids)
139 }
140
141 fn next_msg_id(&mut self) -> i64 {
143 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
144 let secs = now.as_secs().wrapping_add(self.time_offset as i64 as u64);
146 let nanos = now.subsec_nanos() as u64;
147 let mut id = ((secs << 32) | (nanos << 2)) as i64;
148 if self.last_msg_id >= id {
149 id = self.last_msg_id + 4;
150 }
151 self.last_msg_id = id;
152 id
153 }
154
155 fn next_seq_no(&mut self) -> i32 {
158 let n = self.sequence * 2 + 1;
159 self.sequence += 1;
160 n
161 }
162
163 pub fn next_seq_no_ncr(&self) -> i32 {
168 self.sequence * 2
169 }
170
171 pub fn correct_seq_no(&mut self, code: u32) {
176 match code {
177 32 => {
178 self.sequence += 64;
180 log::debug!(
181 "[ferogram] seq_no correction: code 32, bumped seq to {}",
182 self.sequence
183 );
184 }
185 33 => {
186 self.sequence = self.sequence.saturating_sub(16).max(1);
191 log::debug!(
192 "[ferogram] seq_no correction: code 33, lowered seq to {}",
193 self.sequence
194 );
195 }
196 _ => {}
197 }
198 }
199
200 pub fn undo_seq_no(&mut self) {
207 self.sequence = self.sequence.saturating_sub(1);
208 }
209
210 pub fn correct_time_offset(&mut self, server_msg_id: i64) {
217 let server_time = (server_msg_id >> 32) as i32;
219 let local_now = SystemTime::now()
220 .duration_since(UNIX_EPOCH)
221 .unwrap()
222 .as_secs() as i32;
223 let new_offset = server_time.wrapping_sub(local_now);
224 log::debug!(
225 "[ferogram] time_offset correction: {} → {} (server_time={server_time})",
226 self.time_offset,
227 new_offset
228 );
229 self.time_offset = new_offset;
230 self.last_msg_id = (server_msg_id & !0x3i64).max(self.last_msg_id);
233 }
234
235 pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
242 let msg_id = self.next_msg_id();
243 let seqno = if content_related {
244 self.next_seq_no()
245 } else {
246 self.next_seq_no_ncr()
247 };
248 (msg_id, seqno)
249 }
250
251 pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
259 let msg_id = self.next_msg_id();
260 let seq_no = if content_related {
261 self.next_seq_no()
262 } else {
263 self.next_seq_no_ncr()
264 };
265
266 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
267 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
268 buf.extend(self.salt.to_le_bytes());
269 buf.extend(self.session_id.to_le_bytes());
270 buf.extend(msg_id.to_le_bytes());
271 buf.extend(seq_no.to_le_bytes());
272 buf.extend((body.len() as u32).to_le_bytes());
273 buf.extend(body.iter().copied());
274
275 encrypt_data_v2(&mut buf, &self.auth_key);
276 (buf.as_ref().to_vec(), msg_id)
277 }
278
279 pub fn pack_container(&mut self, container_body: &[u8]) -> (Vec<u8>, i64) {
288 self.pack_body_with_msg_id(container_body, false)
289 }
290
291 pub fn pack_body_at_msg_id(&mut self, body: &[u8], msg_id: i64) -> Vec<u8> {
296 let seq_no = self.next_seq_no();
297 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
298 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
299 buf.extend(self.salt.to_le_bytes());
300 buf.extend(self.session_id.to_le_bytes());
301 buf.extend(msg_id.to_le_bytes());
302 buf.extend(seq_no.to_le_bytes());
303 buf.extend((body.len() as u32).to_le_bytes());
304 buf.extend(body.iter().copied());
305 encrypt_data_v2(&mut buf, &self.auth_key);
306 buf.as_ref().to_vec()
307 }
308
309 pub fn pack_serializable<S: ferogram_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
311 let body = call.to_bytes();
312 let msg_id = self.next_msg_id();
313 let seq_no = self.next_seq_no();
314
315 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
316 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
317 buf.extend(self.salt.to_le_bytes());
318 buf.extend(self.session_id.to_le_bytes());
319 buf.extend(msg_id.to_le_bytes());
320 buf.extend(seq_no.to_le_bytes());
321 buf.extend((body.len() as u32).to_le_bytes());
322 buf.extend(body.iter().copied());
323
324 encrypt_data_v2(&mut buf, &self.auth_key);
325 buf.as_ref().to_vec()
326 }
327
328 pub fn pack_serializable_with_msg_id<S: ferogram_tl_types::Serializable>(
330 &mut self,
331 call: &S,
332 ) -> (Vec<u8>, i64) {
333 let body = call.to_bytes();
334 let msg_id = self.next_msg_id();
335 let seq_no = self.next_seq_no();
336 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
337 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
338 buf.extend(self.salt.to_le_bytes());
339 buf.extend(self.session_id.to_le_bytes());
340 buf.extend(msg_id.to_le_bytes());
341 buf.extend(seq_no.to_le_bytes());
342 buf.extend((body.len() as u32).to_le_bytes());
343 buf.extend(body.iter().copied());
344 encrypt_data_v2(&mut buf, &self.auth_key);
345 (buf.as_ref().to_vec(), msg_id)
346 }
347
348 pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
350 let body = call.to_bytes();
351 let msg_id = self.next_msg_id();
352 let seq_no = self.next_seq_no();
353 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
354 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
355 buf.extend(self.salt.to_le_bytes());
356 buf.extend(self.session_id.to_le_bytes());
357 buf.extend(msg_id.to_le_bytes());
358 buf.extend(seq_no.to_le_bytes());
359 buf.extend((body.len() as u32).to_le_bytes());
360 buf.extend(body.iter().copied());
361 encrypt_data_v2(&mut buf, &self.auth_key);
362 (buf.as_ref().to_vec(), msg_id)
363 }
364
365 pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
367 let body = call.to_bytes();
368 let msg_id = self.next_msg_id();
369 let seq_no = self.next_seq_no();
370
371 let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
372 let mut buf = DequeBuffer::with_capacity(inner_len, 32);
373 buf.extend(self.salt.to_le_bytes());
374 buf.extend(self.session_id.to_le_bytes());
375 buf.extend(msg_id.to_le_bytes());
376 buf.extend(seq_no.to_le_bytes());
377 buf.extend((body.len() as u32).to_le_bytes());
378 buf.extend(body.iter().copied());
379
380 encrypt_data_v2(&mut buf, &self.auth_key);
381 buf.as_ref().to_vec()
382 }
383
384 pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
386 let plaintext = decrypt_data_v2(frame, &self.auth_key).map_err(DecryptError::Crypto)?;
387
388 if plaintext.len() < 32 {
389 return Err(DecryptError::FrameTooShort);
390 }
391
392 let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
393 let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
394 let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
395 let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
396 let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
397
398 if session_id != self.session_id {
399 return Err(DecryptError::SessionMismatch);
400 }
401
402 if msg_id & 1 == 0 {
404 return Err(DecryptError::InvalidMsgId);
405 }
406
407 let server_secs = (msg_id as u64 >> 32) as i64;
409 let now = SystemTime::now()
410 .duration_since(UNIX_EPOCH)
411 .unwrap()
412 .as_secs() as i64;
413 let corrected = now + self.time_offset as i64;
414 let skew = server_secs - corrected;
415 if !(-300..=30).contains(&skew) {
416 return Err(DecryptError::MsgIdTimeWindow);
417 }
418
419 {
421 let mut seen = self.seen_msg_ids.lock().unwrap();
422 if seen.1.contains(&msg_id) {
423 return Err(DecryptError::DuplicateMsgId);
424 }
425 seen.0.push_back(msg_id);
426 seen.1.insert(msg_id);
427 if seen.0.len() > SEEN_MSG_IDS_MAX
428 && let Some(old_id) = seen.0.pop_front()
429 {
430 seen.1.remove(&old_id);
431 }
432 }
433
434 if body_len > 16 * 1024 * 1024 {
436 return Err(DecryptError::FrameTooShort);
437 }
438 if 32 + body_len > plaintext.len() {
439 return Err(DecryptError::FrameTooShort);
440 }
441 if !body_len.is_multiple_of(4) {
443 return Err(DecryptError::FrameTooShort);
444 }
445 let padding = plaintext.len() - 32 - body_len;
447 if padding < 12 {
448 return Err(DecryptError::FrameTooShort);
449 }
450 let body = plaintext[32..32 + body_len].to_vec();
451
452 Ok(DecryptedMessage {
453 salt,
454 session_id,
455 msg_id,
456 seq_no,
457 body,
458 })
459 }
460
461 pub fn auth_key_bytes(&self) -> [u8; 256] {
463 self.auth_key.to_bytes()
464 }
465
466 pub fn session_id(&self) -> i64 {
468 self.session_id
469 }
470
471 pub fn reset_session(&mut self) {
477 let mut rnd = [0u8; 8];
478 getrandom::getrandom(&mut rnd).expect("getrandom");
479 let old_session = self.session_id;
480 self.session_id = i64::from_le_bytes(rnd);
481 self.sequence = 0;
482 self.last_msg_id = 0;
483 log::debug!(
486 "[ferogram] session reset: {:#018x} → {:#018x}",
487 old_session,
488 self.session_id
489 );
490 }
491}
492
493impl EncryptedSession {
494 pub fn decrypt_frame_dedup(
497 auth_key: &[u8; 256],
498 session_id: i64,
499 frame: &mut [u8],
500 seen: &SeenMsgIds,
501 ) -> Result<DecryptedMessage, DecryptError> {
502 let msg = Self::decrypt_frame_with_offset(auth_key, session_id, frame, 0)?;
503 {
504 let mut s = seen.lock().unwrap();
505 if s.1.contains(&msg.msg_id) {
506 return Err(DecryptError::DuplicateMsgId);
507 }
508 s.0.push_back(msg.msg_id);
509 s.1.insert(msg.msg_id);
510 if s.0.len() > SEEN_MSG_IDS_MAX
511 && let Some(old_id) = s.0.pop_front()
512 {
513 s.1.remove(&old_id);
514 }
515 }
516 Ok(msg)
517 }
518
519 pub fn decrypt_frame(
523 auth_key: &[u8; 256],
524 session_id: i64,
525 frame: &mut [u8],
526 ) -> Result<DecryptedMessage, DecryptError> {
527 Self::decrypt_frame_with_offset(auth_key, session_id, frame, 0)
528 }
529
530 pub fn decrypt_frame_with_offset(
533 auth_key: &[u8; 256],
534 session_id: i64,
535 frame: &mut [u8],
536 time_offset: i32,
537 ) -> Result<DecryptedMessage, DecryptError> {
538 let key = AuthKey::from_bytes(*auth_key);
539 let plaintext = decrypt_data_v2(frame, &key).map_err(DecryptError::Crypto)?;
540 if plaintext.len() < 32 {
541 return Err(DecryptError::FrameTooShort);
542 }
543 let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
544 let sid = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
545 let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
546 let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
547 let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
548 if sid != session_id {
549 return Err(DecryptError::SessionMismatch);
550 }
551 if msg_id & 1 == 0 {
553 return Err(DecryptError::InvalidMsgId);
554 }
555 let server_secs = (msg_id as u64 >> 32) as i64;
557 let now = SystemTime::now()
558 .duration_since(UNIX_EPOCH)
559 .unwrap()
560 .as_secs() as i64;
561 let corrected = now + time_offset as i64;
562 let skew = server_secs - corrected;
563 if !(-300..=30).contains(&skew) {
564 return Err(DecryptError::MsgIdTimeWindow);
565 }
566 if body_len > 16 * 1024 * 1024 {
568 return Err(DecryptError::FrameTooShort);
569 }
570 if 32 + body_len > plaintext.len() {
571 return Err(DecryptError::FrameTooShort);
572 }
573 if !body_len.is_multiple_of(4) {
575 return Err(DecryptError::FrameTooShort);
576 }
577 let padding = plaintext.len() - 32 - body_len;
579 if padding < 12 {
580 return Err(DecryptError::FrameTooShort);
581 }
582 let body = plaintext[32..32 + body_len].to_vec();
583 Ok(DecryptedMessage {
584 salt,
585 session_id: sid,
586 msg_id,
587 seq_no,
588 body,
589 })
590 }
591}