Skip to main content

ferogram_mtproto/
encrypted.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3//
4// ferogram: async Telegram MTProto client in Rust
5// https://github.com/ankit-chaubey/ferogram
6//
7// Based on layer: https://github.com/ankit-chaubey/layer
8// Follows official Telegram client behaviour (tdesktop, TDLib).
9//
10// If you use or modify this code, keep this notice at the top of your file
11// and include the LICENSE-MIT or LICENSE-APACHE file from this repository:
12// https://github.com/ankit-chaubey/ferogram
13
14//! Encrypted MTProto 2.0 session (post auth-key).
15//!
16//! Once you have a `Finished` from [`crate::authentication`], construct an
17//! [`EncryptedSession`] and use it to serialize/deserialize all subsequent
18//! messages.
19
20use std::collections::VecDeque;
21use std::time::{SystemTime, UNIX_EPOCH};
22
23use ferogram_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
24use ferogram_tl_types::RemoteCall;
25
26/// Rolling deduplication buffer – mirrors tDesktop's 500-entry seen-msg_id set.
27const SEEN_MSG_IDS_MAX: usize = 500;
28/// Maximum clock skew between client and server before a message is rejected.
29const MSG_ID_TIME_WINDOW_SECS: i64 = 300;
30
31/// Errors that can occur when decrypting a server message.
32#[derive(Debug)]
33pub enum DecryptError {
34    /// The underlying crypto layer rejected the message.
35    Crypto(ferogram_crypto::DecryptError),
36    /// The decrypted inner message was too short to contain a valid header.
37    FrameTooShort,
38    /// Session-ID mismatch (possible replay or wrong connection).
39    SessionMismatch,
40    /// Server msg_id is outside the ±300 s window of corrected local time.
41    MsgIdTimeWindow,
42    /// This msg_id was already seen in the rolling 500-entry buffer.
43    DuplicateMsgId,
44}
45
46impl std::fmt::Display for DecryptError {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Self::Crypto(e) => write!(f, "crypto: {e}"),
50            Self::FrameTooShort => write!(f, "inner plaintext too short"),
51            Self::SessionMismatch => write!(f, "session_id mismatch"),
52            Self::MsgIdTimeWindow => write!(f, "server msg_id outside ±300 s time window"),
53            Self::DuplicateMsgId => write!(f, "duplicate server msg_id (replay)"),
54        }
55    }
56}
57impl std::error::Error for DecryptError {}
58
59/// The inner payload extracted from a successfully decrypted server frame.
60pub struct DecryptedMessage {
61    /// `salt` sent by the server.
62    pub salt: i64,
63    /// The `session_id` from the frame.
64    pub session_id: i64,
65    /// The `msg_id` of the inner message.
66    pub msg_id: i64,
67    /// `seq_no` of the inner message.
68    pub seq_no: i32,
69    /// TL-serialized body of the inner message.
70    pub body: Vec<u8>,
71}
72
73/// MTProto 2.0 encrypted session state.
74pub struct EncryptedSession {
75    auth_key: AuthKey,
76    session_id: i64,
77    sequence: i32,
78    last_msg_id: i64,
79    /// Current server salt to include in outgoing messages.
80    pub salt: i64,
81    /// Clock skew in seconds vs. server.
82    pub time_offset: i32,
83    /// Rolling 500-entry dedup buffer of seen server msg_ids.
84    seen_msg_ids: std::sync::Mutex<VecDeque<i64>>,
85}
86
87impl EncryptedSession {
88    /// Create a new encrypted session from the output of `authentication::finish`.
89    pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
90        let mut rnd = [0u8; 8];
91        getrandom::getrandom(&mut rnd).expect("getrandom");
92        Self {
93            auth_key: AuthKey::from_bytes(auth_key),
94            session_id: i64::from_le_bytes(rnd),
95            sequence: 0,
96            last_msg_id: 0,
97            salt: first_salt,
98            time_offset,
99            seen_msg_ids: std::sync::Mutex::new(VecDeque::with_capacity(SEEN_MSG_IDS_MAX)),
100        }
101    }
102
103    /// Compute the next message ID (based on corrected server time).
104    fn next_msg_id(&mut self) -> i64 {
105        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
106        let secs = (now.as_secs() as i32).wrapping_add(self.time_offset) as u64;
107        let nanos = now.subsec_nanos() as u64;
108        let mut id = ((secs << 32) | (nanos << 2)) as i64;
109        if self.last_msg_id >= id {
110            id = self.last_msg_id + 4;
111        }
112        self.last_msg_id = id;
113        id
114    }
115
116    /// Next content-related seq_no (odd) and advance the counter.
117    /// Used for all regular RPC requests.
118    fn next_seq_no(&mut self) -> i32 {
119        let n = self.sequence * 2 + 1;
120        self.sequence += 1;
121        n
122    }
123
124    /// Return the current even seq_no WITHOUT advancing the counter.
125    ///
126    /// Service messages (MsgsAck, containers, etc.) MUST use an even seqno
127    /// per the MTProto spec so the server does not expect a reply.
128    pub fn next_seq_no_ncr(&self) -> i32 {
129        self.sequence * 2
130    }
131
132    /// Correct the outgoing sequence counter when the server reports a
133    /// `bad_msg_notification` with error codes 32 (seq_no too low) or
134    /// 33 (seq_no too high).
135    ///
136    pub fn correct_seq_no(&mut self, code: u32) {
137        match code {
138            32 => {
139                // seq_no too low: jump forward so next send is well above server expectation
140                self.sequence += 64;
141                log::debug!(
142                    "[ferogram] seq_no correction: code 32, bumped seq to {}",
143                    self.sequence
144                );
145            }
146            33 => {
147                // seq_no too high: step back, but never below 1 to avoid
148                // re-using seq_no=1 which was already sent this session.
149                // Zeroing would make the next content message get seq_no=1,
150                // which the server already saw and will reject again with code 32.
151                self.sequence = self.sequence.saturating_sub(16).max(1);
152                log::debug!(
153                    "[ferogram] seq_no correction: code 33, lowered seq to {}",
154                    self.sequence
155                );
156            }
157            _ => {}
158        }
159    }
160
161    /// Re-derive the clock skew from a server-provided `msg_id`.
162    ///
163    /// Called on `bad_msg_notification` error codes 16 (msg_id too low) and
164    /// 17 (msg_id too high) so clock drift is corrected at any point in the
165    /// session, not only at connect time.
166    ///
167    pub fn correct_time_offset(&mut self, server_msg_id: i64) {
168        // Upper 32 bits of msg_id = Unix seconds on the server
169        let server_time = (server_msg_id >> 32) as i32;
170        let local_now = SystemTime::now()
171            .duration_since(UNIX_EPOCH)
172            .unwrap()
173            .as_secs() as i32;
174        let new_offset = server_time.wrapping_sub(local_now);
175        log::debug!(
176            "[ferogram] time_offset correction: {} → {} (server_time={server_time})",
177            self.time_offset,
178            new_offset
179        );
180        self.time_offset = new_offset;
181        // Also reset last_msg_id so next_msg_id rebuilds from corrected clock
182        self.last_msg_id = 0;
183    }
184
185    /// Allocate a fresh `(msg_id, seqno)` pair for an inner container message
186    /// WITHOUT encrypting anything.
187    ///
188    /// `content_related = true`  → odd seqno, advances counter  (regular RPCs)
189    /// `content_related = false` → even seqno, no advance       (MsgsAck, container)
190    ///
191    pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
192        let msg_id = self.next_msg_id();
193        let seqno = if content_related {
194            self.next_seq_no()
195        } else {
196            self.next_seq_no_ncr()
197        };
198        (msg_id, seqno)
199    }
200
201    /// Encrypt a pre-serialized TL body into a wire-ready MTProto frame.
202    ///
203    /// `content_related` controls whether the seqno is odd (content, advances
204    /// the counter) or even (service, no advance).
205    ///
206    /// Returns `(encrypted_wire_bytes, msg_id)`.
207    /// Used for (bad_msg re-send) and (container inner messages).
208    pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
209        let msg_id = self.next_msg_id();
210        let seq_no = if content_related {
211            self.next_seq_no()
212        } else {
213            self.next_seq_no_ncr()
214        };
215
216        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
217        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
218        buf.extend(self.salt.to_le_bytes());
219        buf.extend(self.session_id.to_le_bytes());
220        buf.extend(msg_id.to_le_bytes());
221        buf.extend(seq_no.to_le_bytes());
222        buf.extend((body.len() as u32).to_le_bytes());
223        buf.extend(body.iter().copied());
224
225        encrypt_data_v2(&mut buf, &self.auth_key);
226        (buf.as_ref().to_vec(), msg_id)
227    }
228
229    /// Encrypt a pre-built `msg_container` body (the container itself is
230    /// a non-content-related message with an even seqno).
231    ///
232    /// Returns `(encrypted_wire_bytes, container_msg_id)`.
233    /// The container_msg_id is needed so callers can map it back to inner
234    /// requests when a bad_msg_notification or bad_server_salt arrives for
235    /// the container rather than the individual inner message.
236    ///
237    pub fn pack_container(&mut self, container_body: &[u8]) -> (Vec<u8>, i64) {
238        self.pack_body_with_msg_id(container_body, false)
239    }
240
241    // Original pack methods (unchanged)
242
243    /// Serialize and encrypt a TL function into a wire-ready byte vector.
244    pub fn pack_serializable<S: ferogram_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
245        let body = call.to_bytes();
246        let msg_id = self.next_msg_id();
247        let seq_no = self.next_seq_no();
248
249        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
250        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
251        buf.extend(self.salt.to_le_bytes());
252        buf.extend(self.session_id.to_le_bytes());
253        buf.extend(msg_id.to_le_bytes());
254        buf.extend(seq_no.to_le_bytes());
255        buf.extend((body.len() as u32).to_le_bytes());
256        buf.extend(body.iter().copied());
257
258        encrypt_data_v2(&mut buf, &self.auth_key);
259        buf.as_ref().to_vec()
260    }
261
262    /// Like `pack_serializable` but also returns the `msg_id`.
263    pub fn pack_serializable_with_msg_id<S: ferogram_tl_types::Serializable>(
264        &mut self,
265        call: &S,
266    ) -> (Vec<u8>, i64) {
267        let body = call.to_bytes();
268        let msg_id = self.next_msg_id();
269        let seq_no = self.next_seq_no();
270        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
271        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
272        buf.extend(self.salt.to_le_bytes());
273        buf.extend(self.session_id.to_le_bytes());
274        buf.extend(msg_id.to_le_bytes());
275        buf.extend(seq_no.to_le_bytes());
276        buf.extend((body.len() as u32).to_le_bytes());
277        buf.extend(body.iter().copied());
278        encrypt_data_v2(&mut buf, &self.auth_key);
279        (buf.as_ref().to_vec(), msg_id)
280    }
281
282    /// Like [`pack`] but also returns the `msg_id` allocated for this message.
283    pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
284        let body = call.to_bytes();
285        let msg_id = self.next_msg_id();
286        let seq_no = self.next_seq_no();
287        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
288        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
289        buf.extend(self.salt.to_le_bytes());
290        buf.extend(self.session_id.to_le_bytes());
291        buf.extend(msg_id.to_le_bytes());
292        buf.extend(seq_no.to_le_bytes());
293        buf.extend((body.len() as u32).to_le_bytes());
294        buf.extend(body.iter().copied());
295        encrypt_data_v2(&mut buf, &self.auth_key);
296        (buf.as_ref().to_vec(), msg_id)
297    }
298
299    /// Encrypt and frame a [`RemoteCall`] into a ready-to-send MTProto message.
300    pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
301        let body = call.to_bytes();
302        let msg_id = self.next_msg_id();
303        let seq_no = self.next_seq_no();
304
305        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
306        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
307        buf.extend(self.salt.to_le_bytes());
308        buf.extend(self.session_id.to_le_bytes());
309        buf.extend(msg_id.to_le_bytes());
310        buf.extend(seq_no.to_le_bytes());
311        buf.extend((body.len() as u32).to_le_bytes());
312        buf.extend(body.iter().copied());
313
314        encrypt_data_v2(&mut buf, &self.auth_key);
315        buf.as_ref().to_vec()
316    }
317
318    /// Decrypt an encrypted server frame.
319    pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
320        let plaintext = decrypt_data_v2(frame, &self.auth_key).map_err(DecryptError::Crypto)?;
321
322        if plaintext.len() < 32 {
323            return Err(DecryptError::FrameTooShort);
324        }
325
326        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
327        let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
328        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
329        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
330        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
331
332        if session_id != self.session_id {
333            return Err(DecryptError::SessionMismatch);
334        }
335
336        // #3 Check server time (upper 32 bits of msg_id) against ±300 s window.
337        // tDesktop sets badTime=true and continues: msgs_ack / pong / bad_msg_notification
338        // can self-heal the clock without a reconnect. We warn and continue rather than
339        // hard-reject, so a drifted-clock client does not loop-reconnect on every message.
340        let server_secs = (msg_id as u64 >> 32) as i64;
341        let now = SystemTime::now()
342            .duration_since(UNIX_EPOCH)
343            .unwrap()
344            .as_secs() as i64;
345        let corrected = now + self.time_offset as i64;
346        if (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS {
347            log::warn!(
348                "[ferogram] msg_id time-window violation: server_secs={server_secs} \
349                 corrected_local={corrected} skew={}s: processing anyway, \
350                 clock will self-correct via bad_msg_notification/pong",
351                (server_secs - corrected).abs()
352            );
353        }
354
355        // #2 rolling 500-entry dedup.
356        {
357            let mut seen = self.seen_msg_ids.lock().unwrap();
358            if seen.contains(&msg_id) {
359                return Err(DecryptError::DuplicateMsgId);
360            }
361            seen.push_back(msg_id);
362            if seen.len() > SEEN_MSG_IDS_MAX {
363                seen.pop_front();
364            }
365        }
366
367        // body_len upper bound (tDesktop kMaxMessageLength = 16 MB).
368        if body_len > 16 * 1024 * 1024 {
369            return Err(DecryptError::FrameTooShort);
370        }
371        if 32 + body_len > plaintext.len() {
372            return Err(DecryptError::FrameTooShort);
373        }
374        // padding must be 12–1024 bytes (tDesktop kMinPaddingSize/kMaxPaddingSize).
375        let padding = plaintext.len() - 32 - body_len;
376        if !(12..=1024).contains(&padding) {
377            return Err(DecryptError::FrameTooShort);
378        }
379        let body = plaintext[32..32 + body_len].to_vec();
380
381        Ok(DecryptedMessage {
382            salt,
383            session_id,
384            msg_id,
385            seq_no,
386            body,
387        })
388    }
389
390    /// Return the auth_key bytes (for persistence).
391    pub fn auth_key_bytes(&self) -> [u8; 256] {
392        self.auth_key.to_bytes()
393    }
394
395    /// Return the current session_id.
396    pub fn session_id(&self) -> i64 {
397        self.session_id
398    }
399
400    /// Reset session state: new random session_id, zeroed seq_no and last_msg_id,
401    /// cleared dedup buffer.
402    ///
403    /// Called on `bad_msg_notification` error codes 32/33 (seq_no mismatch).
404    /// This matches tDesktop's `ResetSession` which creates a new session_id
405    /// rather than trying to correct the seq_no counter, which can loop forever
406    /// on a persistent desync.
407    pub fn reset_session(&mut self) {
408        let mut rnd = [0u8; 8];
409        getrandom::getrandom(&mut rnd).expect("getrandom");
410        let old_session = self.session_id;
411        self.session_id = i64::from_le_bytes(rnd);
412        self.sequence = 0;
413        self.last_msg_id = 0;
414        self.seen_msg_ids.lock().unwrap().clear();
415        log::debug!(
416            "[ferogram] session reset: {:#018x} → {:#018x}",
417            old_session,
418            self.session_id
419        );
420    }
421}
422
423impl EncryptedSession {
424    /// Decrypt a frame using explicit key + session_id: no mutable state needed.
425    /// Used by the split-reader task so it can decrypt without locking the writer.
426    /// `time_offset` is the session's current clock skew (seconds); pass 0 if unknown.
427    pub fn decrypt_frame(
428        auth_key: &[u8; 256],
429        session_id: i64,
430        frame: &mut [u8],
431    ) -> Result<DecryptedMessage, DecryptError> {
432        Self::decrypt_frame_with_offset(auth_key, session_id, frame, 0)
433    }
434
435    /// Like [`decrypt_frame`] but applies the time-window check with the given
436    /// `time_offset` (seconds, server_time − local_time).
437    pub fn decrypt_frame_with_offset(
438        auth_key: &[u8; 256],
439        session_id: i64,
440        frame: &mut [u8],
441        time_offset: i32,
442    ) -> Result<DecryptedMessage, DecryptError> {
443        let key = AuthKey::from_bytes(*auth_key);
444        let plaintext = decrypt_data_v2(frame, &key).map_err(DecryptError::Crypto)?;
445        if plaintext.len() < 32 {
446            return Err(DecryptError::FrameTooShort);
447        }
448        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
449        let sid = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
450        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
451        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
452        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
453        if sid != session_id {
454            return Err(DecryptError::SessionMismatch);
455        }
456        // Time-window check: warn but continue (matches tDesktop badTime=true behaviour).
457        // Clock self-corrects via bad_msg_notification or pong; hard-reject causes reconnect loops.
458        let server_secs = (msg_id as u64 >> 32) as i64;
459        let now = SystemTime::now()
460            .duration_since(UNIX_EPOCH)
461            .unwrap()
462            .as_secs() as i64;
463        let corrected = now + time_offset as i64;
464        if (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS {
465            log::warn!(
466                "[ferogram] msg_id time-window violation (split-reader): server_secs={server_secs} \
467                 corrected_local={corrected} skew={}s: processing anyway",
468                (server_secs - corrected).abs()
469            );
470        }
471        // body_len upper bound (tDesktop kMaxMessageLength = 16 MB).
472        if body_len > 16 * 1024 * 1024 {
473            return Err(DecryptError::FrameTooShort);
474        }
475        if 32 + body_len > plaintext.len() {
476            return Err(DecryptError::FrameTooShort);
477        }
478        // padding must be 12–1024 bytes.
479        let padding = plaintext.len() - 32 - body_len;
480        if !(12..=1024).contains(&padding) {
481            return Err(DecryptError::FrameTooShort);
482        }
483        let body = plaintext[32..32 + body_len].to_vec();
484        Ok(DecryptedMessage {
485            salt,
486            session_id: sid,
487            msg_id,
488            seq_no,
489            body,
490        })
491    }
492}