Skip to main content

ferogram_mtproto/
encrypted.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3//
4// ferogram: async Telegram MTProto client in Rust
5// https://github.com/ankit-chaubey/ferogram
6//
7// Based on layer: https://github.com/ankit-chaubey/layer
8// Follows official Telegram client behaviour (tdesktop, TDLib).
9//
10// If you use or modify this code, keep this notice at the top of your file
11// and include the LICENSE-MIT or LICENSE-APACHE file from this repository:
12// https://github.com/ankit-chaubey/ferogram
13
14//! Encrypted MTProto 2.0 session (post auth-key).
15//!
16//! Once you have a `Finished` from [`crate::authentication`], construct an
17//! [`EncryptedSession`] and use it to serialize/deserialize all subsequent
18//! messages.
19
20use std::collections::VecDeque;
21use std::time::{SystemTime, UNIX_EPOCH};
22
23use ferogram_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
24use ferogram_tl_types::RemoteCall;
25
26/// Rolling deduplication buffer – mirrors tDesktop's 500-entry seen-msg_id set.
27const SEEN_MSG_IDS_MAX: usize = 500;
28/// Maximum clock skew between client and server before a message is rejected.
29const MSG_ID_TIME_WINDOW_SECS: i64 = 300;
30
31/// Errors that can occur when decrypting a server message.
32#[derive(Debug)]
33pub enum DecryptError {
34    /// The underlying crypto layer rejected the message.
35    Crypto(ferogram_crypto::DecryptError),
36    /// The decrypted inner message was too short to contain a valid header.
37    FrameTooShort,
38    /// Session-ID mismatch (possible replay or wrong connection).
39    SessionMismatch,
40    /// Server msg_id is outside the ±300 s window of corrected local time.
41    MsgIdTimeWindow,
42    /// This msg_id was already seen in the rolling 500-entry buffer.
43    DuplicateMsgId,
44}
45
46impl std::fmt::Display for DecryptError {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Self::Crypto(e) => write!(f, "crypto: {e}"),
50            Self::FrameTooShort => write!(f, "inner plaintext too short"),
51            Self::SessionMismatch => write!(f, "session_id mismatch"),
52            Self::MsgIdTimeWindow => write!(f, "server msg_id outside ±300 s time window"),
53            Self::DuplicateMsgId => write!(f, "duplicate server msg_id (replay)"),
54        }
55    }
56}
57impl std::error::Error for DecryptError {}
58
59/// The inner payload extracted from a successfully decrypted server frame.
60pub struct DecryptedMessage {
61    /// `salt` sent by the server.
62    pub salt: i64,
63    /// The `session_id` from the frame.
64    pub session_id: i64,
65    /// The `msg_id` of the inner message.
66    pub msg_id: i64,
67    /// `seq_no` of the inner message.
68    pub seq_no: i32,
69    /// TL-serialized body of the inner message.
70    pub body: Vec<u8>,
71}
72
73/// MTProto 2.0 encrypted session state.
74pub struct EncryptedSession {
75    auth_key: AuthKey,
76    session_id: i64,
77    sequence: i32,
78    last_msg_id: i64,
79    /// Current server salt to include in outgoing messages.
80    pub salt: i64,
81    /// Clock skew in seconds vs. server.
82    pub time_offset: i32,
83    /// Rolling 500-entry dedup buffer of seen server msg_ids.
84    seen_msg_ids: std::sync::Mutex<VecDeque<i64>>,
85}
86
87impl EncryptedSession {
88    /// Create a new encrypted session from the output of `authentication::finish`.
89    pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
90        let mut rnd = [0u8; 8];
91        getrandom::getrandom(&mut rnd).expect("getrandom");
92        Self {
93            auth_key: AuthKey::from_bytes(auth_key),
94            session_id: i64::from_le_bytes(rnd),
95            sequence: 0,
96            last_msg_id: 0,
97            salt: first_salt,
98            time_offset,
99            seen_msg_ids: std::sync::Mutex::new(VecDeque::with_capacity(SEEN_MSG_IDS_MAX)),
100        }
101    }
102
103    /// Compute the next message ID (based on corrected server time).
104    fn next_msg_id(&mut self) -> i64 {
105        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
106        let secs = (now.as_secs() as i32).wrapping_add(self.time_offset) as u64;
107        let nanos = now.subsec_nanos() as u64;
108        let mut id = ((secs << 32) | (nanos << 2)) as i64;
109        if self.last_msg_id >= id {
110            id = self.last_msg_id + 4;
111        }
112        self.last_msg_id = id;
113        id
114    }
115
116    /// Next content-related seq_no (odd) and advance the counter.
117    /// Used for all regular RPC requests.
118    fn next_seq_no(&mut self) -> i32 {
119        let n = self.sequence * 2 + 1;
120        self.sequence += 1;
121        n
122    }
123
124    /// Return the current even seq_no WITHOUT advancing the counter.
125    ///
126    /// Service messages (MsgsAck, containers, etc.) MUST use an even seqno
127    /// per the MTProto spec so the server does not expect a reply.
128    pub fn next_seq_no_ncr(&self) -> i32 {
129        self.sequence * 2
130    }
131
132    /// Correct the outgoing sequence counter when the server reports a
133    /// `bad_msg_notification` with error codes 32 (seq_no too low) or
134    /// 33 (seq_no too high).
135    ///
136    pub fn correct_seq_no(&mut self, code: u32) {
137        match code {
138            32 => {
139                // seq_no too low: jump forward so next send is well above server expectation
140                self.sequence += 64;
141                log::debug!(
142                    "[ferogram] seq_no correction: code 32, bumped seq to {}",
143                    self.sequence
144                );
145            }
146            33 => {
147                // seq_no too high: step back, but never below 1 to avoid
148                // re-using seq_no=1 which was already sent this session.
149                // Zeroing would make the next content message get seq_no=1,
150                // which the server already saw and will reject again with code 32.
151                self.sequence = self.sequence.saturating_sub(16).max(1);
152                log::debug!(
153                    "[ferogram] seq_no correction: code 33, lowered seq to {}",
154                    self.sequence
155                );
156            }
157            _ => {}
158        }
159    }
160
161    /// Re-derive the clock skew from a server-provided `msg_id`.
162    ///
163    /// Called on `bad_msg_notification` error codes 16 (msg_id too low) and
164    /// 17 (msg_id too high) so clock drift is corrected at any point in the
165    /// session, not only at connect time.
166    ///
167    pub fn correct_time_offset(&mut self, server_msg_id: i64) {
168        // Upper 32 bits of msg_id = Unix seconds on the server
169        let server_time = (server_msg_id >> 32) as i32;
170        let local_now = SystemTime::now()
171            .duration_since(UNIX_EPOCH)
172            .unwrap()
173            .as_secs() as i32;
174        let new_offset = server_time.wrapping_sub(local_now);
175        log::debug!(
176            "[ferogram] time_offset correction: {} → {} (server_time={server_time})",
177            self.time_offset,
178            new_offset
179        );
180        self.time_offset = new_offset;
181        // Also reset last_msg_id so next_msg_id rebuilds from corrected clock
182        self.last_msg_id = 0;
183    }
184
185    /// Allocate a fresh `(msg_id, seqno)` pair for an inner container message
186    /// WITHOUT encrypting anything.
187    ///
188    /// `content_related = true`  → odd seqno, advances counter  (regular RPCs)
189    /// `content_related = false` → even seqno, no advance       (MsgsAck, container)
190    ///
191    pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
192        let msg_id = self.next_msg_id();
193        let seqno = if content_related {
194            self.next_seq_no()
195        } else {
196            self.next_seq_no_ncr()
197        };
198        (msg_id, seqno)
199    }
200
201    /// Encrypt a pre-serialized TL body into a wire-ready MTProto frame.
202    ///
203    /// `content_related` controls whether the seqno is odd (content, advances
204    /// the counter) or even (service, no advance).
205    ///
206    /// Returns `(encrypted_wire_bytes, msg_id)`.
207    /// Used for (bad_msg re-send) and (container inner messages).
208    pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
209        let msg_id = self.next_msg_id();
210        let seq_no = if content_related {
211            self.next_seq_no()
212        } else {
213            self.next_seq_no_ncr()
214        };
215
216        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
217        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
218        buf.extend(self.salt.to_le_bytes());
219        buf.extend(self.session_id.to_le_bytes());
220        buf.extend(msg_id.to_le_bytes());
221        buf.extend(seq_no.to_le_bytes());
222        buf.extend((body.len() as u32).to_le_bytes());
223        buf.extend(body.iter().copied());
224
225        encrypt_data_v2(&mut buf, &self.auth_key);
226        (buf.as_ref().to_vec(), msg_id)
227    }
228
229    /// Encrypt a pre-built `msg_container` body (the container itself is
230    /// a non-content-related message with an even seqno).
231    ///
232    /// Returns `(encrypted_wire_bytes, container_msg_id)`.
233    /// The container_msg_id is needed so callers can map it back to inner
234    /// requests when a bad_msg_notification or bad_server_salt arrives for
235    /// the container rather than the individual inner message.
236    ///
237    pub fn pack_container(&mut self, container_body: &[u8]) -> (Vec<u8>, i64) {
238        self.pack_body_with_msg_id(container_body, false)
239    }
240
241    // Original pack methods (unchanged)
242
243    /// Serialize and encrypt a TL function into a wire-ready byte vector.
244    pub fn pack_serializable<S: ferogram_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
245        let body = call.to_bytes();
246        let msg_id = self.next_msg_id();
247        let seq_no = self.next_seq_no();
248
249        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
250        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
251        buf.extend(self.salt.to_le_bytes());
252        buf.extend(self.session_id.to_le_bytes());
253        buf.extend(msg_id.to_le_bytes());
254        buf.extend(seq_no.to_le_bytes());
255        buf.extend((body.len() as u32).to_le_bytes());
256        buf.extend(body.iter().copied());
257
258        encrypt_data_v2(&mut buf, &self.auth_key);
259        buf.as_ref().to_vec()
260    }
261
262    /// Like `pack_serializable` but also returns the `msg_id`.
263    pub fn pack_serializable_with_msg_id<S: ferogram_tl_types::Serializable>(
264        &mut self,
265        call: &S,
266    ) -> (Vec<u8>, i64) {
267        let body = call.to_bytes();
268        let msg_id = self.next_msg_id();
269        let seq_no = self.next_seq_no();
270        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
271        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
272        buf.extend(self.salt.to_le_bytes());
273        buf.extend(self.session_id.to_le_bytes());
274        buf.extend(msg_id.to_le_bytes());
275        buf.extend(seq_no.to_le_bytes());
276        buf.extend((body.len() as u32).to_le_bytes());
277        buf.extend(body.iter().copied());
278        encrypt_data_v2(&mut buf, &self.auth_key);
279        (buf.as_ref().to_vec(), msg_id)
280    }
281
282    /// Like [`pack`] but also returns the `msg_id` allocated for this message.
283    pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
284        let body = call.to_bytes();
285        let msg_id = self.next_msg_id();
286        let seq_no = self.next_seq_no();
287        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
288        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
289        buf.extend(self.salt.to_le_bytes());
290        buf.extend(self.session_id.to_le_bytes());
291        buf.extend(msg_id.to_le_bytes());
292        buf.extend(seq_no.to_le_bytes());
293        buf.extend((body.len() as u32).to_le_bytes());
294        buf.extend(body.iter().copied());
295        encrypt_data_v2(&mut buf, &self.auth_key);
296        (buf.as_ref().to_vec(), msg_id)
297    }
298
299    /// Encrypt and frame a [`RemoteCall`] into a ready-to-send MTProto message.
300    pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
301        let body = call.to_bytes();
302        let msg_id = self.next_msg_id();
303        let seq_no = self.next_seq_no();
304
305        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
306        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
307        buf.extend(self.salt.to_le_bytes());
308        buf.extend(self.session_id.to_le_bytes());
309        buf.extend(msg_id.to_le_bytes());
310        buf.extend(seq_no.to_le_bytes());
311        buf.extend((body.len() as u32).to_le_bytes());
312        buf.extend(body.iter().copied());
313
314        encrypt_data_v2(&mut buf, &self.auth_key);
315        buf.as_ref().to_vec()
316    }
317
318    /// Decrypt an encrypted server frame.
319    pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
320        let plaintext = decrypt_data_v2(frame, &self.auth_key).map_err(DecryptError::Crypto)?;
321
322        if plaintext.len() < 32 {
323            return Err(DecryptError::FrameTooShort);
324        }
325
326        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
327        let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
328        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
329        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
330        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
331
332        if session_id != self.session_id {
333            return Err(DecryptError::SessionMismatch);
334        }
335
336        // #3 reject if server time (upper 32 bits of msg_id) deviates > 300 s.
337        let server_secs = (msg_id as u64 >> 32) as i64;
338        let now = SystemTime::now()
339            .duration_since(UNIX_EPOCH)
340            .unwrap()
341            .as_secs() as i64;
342        let corrected = now + self.time_offset as i64;
343        if (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS {
344            return Err(DecryptError::MsgIdTimeWindow);
345        }
346
347        // #2 rolling 500-entry dedup.
348        {
349            let mut seen = self.seen_msg_ids.lock().unwrap();
350            if seen.contains(&msg_id) {
351                return Err(DecryptError::DuplicateMsgId);
352            }
353            seen.push_back(msg_id);
354            if seen.len() > SEEN_MSG_IDS_MAX {
355                seen.pop_front();
356            }
357        }
358
359        if 32 + body_len > plaintext.len() {
360            return Err(DecryptError::FrameTooShort);
361        }
362        let body = plaintext[32..32 + body_len].to_vec();
363
364        Ok(DecryptedMessage {
365            salt,
366            session_id,
367            msg_id,
368            seq_no,
369            body,
370        })
371    }
372
373    /// Return the auth_key bytes (for persistence).
374    pub fn auth_key_bytes(&self) -> [u8; 256] {
375        self.auth_key.to_bytes()
376    }
377
378    /// Return the current session_id.
379    pub fn session_id(&self) -> i64 {
380        self.session_id
381    }
382}
383
384impl EncryptedSession {
385    /// Decrypt a frame using explicit key + session_id: no mutable state needed.
386    /// Used by the split-reader task so it can decrypt without locking the writer.
387    /// `time_offset` is the session's current clock skew (seconds); pass 0 if unknown.
388    pub fn decrypt_frame(
389        auth_key: &[u8; 256],
390        session_id: i64,
391        frame: &mut [u8],
392    ) -> Result<DecryptedMessage, DecryptError> {
393        Self::decrypt_frame_with_offset(auth_key, session_id, frame, 0)
394    }
395
396    /// Like [`decrypt_frame`] but applies the time-window check with the given
397    /// `time_offset` (seconds, server_time − local_time).
398    pub fn decrypt_frame_with_offset(
399        auth_key: &[u8; 256],
400        session_id: i64,
401        frame: &mut [u8],
402        time_offset: i32,
403    ) -> Result<DecryptedMessage, DecryptError> {
404        let key = AuthKey::from_bytes(*auth_key);
405        let plaintext = decrypt_data_v2(frame, &key).map_err(DecryptError::Crypto)?;
406        if plaintext.len() < 32 {
407            return Err(DecryptError::FrameTooShort);
408        }
409        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
410        let sid = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
411        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
412        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
413        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
414        if sid != session_id {
415            return Err(DecryptError::SessionMismatch);
416        }
417        // Time-window check.
418        let server_secs = (msg_id as u64 >> 32) as i64;
419        let now = SystemTime::now()
420            .duration_since(UNIX_EPOCH)
421            .unwrap()
422            .as_secs() as i64;
423        let corrected = now + time_offset as i64;
424        if (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS {
425            return Err(DecryptError::MsgIdTimeWindow);
426        }
427        if 32 + body_len > plaintext.len() {
428            return Err(DecryptError::FrameTooShort);
429        }
430        let body = plaintext[32..32 + body_len].to_vec();
431        Ok(DecryptedMessage {
432            salt,
433            session_id: sid,
434            msg_id,
435            seq_no,
436            body,
437        })
438    }
439}