Skip to main content

layer_mtproto/
encrypted.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4// NOTE:
5// The "Layer" project is no longer maintained or supported.
6// Its original purpose for personal SDK/APK experimentation and learning
7// has been fulfilled.
8//
9// Please use Ferogram instead:
10// https://github.com/ankit-chaubey/ferogram
11// Ferogram will receive future updates and development, although progress
12// may be slower.
13//
14// Ferogram is an async Telegram MTProto client library written in Rust.
15// Its implementation follows the behaviour of the official Telegram clients,
16// particularly Telegram Desktop and TDLib, and aims to provide a clean and
17// modern async interface for building Telegram clients and tools.
18
19//! Encrypted MTProto 2.0 session (post auth-key).
20//!
21//! Once you have a `Finished` from [`crate::authentication`], construct an
22//! [`EncryptedSession`] and use it to serialize/deserialize all subsequent
23//! messages.
24
25use std::collections::VecDeque;
26use std::time::{SystemTime, UNIX_EPOCH};
27
28use layer_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
29use layer_tl_types::RemoteCall;
30
31/// Rolling deduplication buffer – mirrors tDesktop's 500-entry seen-msg_id set.
32const SEEN_MSG_IDS_MAX: usize = 500;
33/// Maximum clock skew between client and server before a message is rejected.
34const MSG_ID_TIME_WINDOW_SECS: i64 = 300;
35
36/// Errors that can occur when decrypting a server message.
37#[derive(Debug)]
38pub enum DecryptError {
39    /// The underlying crypto layer rejected the message.
40    Crypto(layer_crypto::DecryptError),
41    /// The decrypted inner message was too short to contain a valid header.
42    FrameTooShort,
43    /// Session-ID mismatch (possible replay or wrong connection).
44    SessionMismatch,
45    /// Server msg_id is outside the ±300 s window of corrected local time.
46    MsgIdTimeWindow,
47    /// This msg_id was already seen in the rolling 500-entry buffer.
48    DuplicateMsgId,
49}
50
51impl std::fmt::Display for DecryptError {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            Self::Crypto(e) => write!(f, "crypto: {e}"),
55            Self::FrameTooShort => write!(f, "inner plaintext too short"),
56            Self::SessionMismatch => write!(f, "session_id mismatch"),
57            Self::MsgIdTimeWindow => write!(f, "server msg_id outside ±300 s time window"),
58            Self::DuplicateMsgId => write!(f, "duplicate server msg_id (replay)"),
59        }
60    }
61}
62impl std::error::Error for DecryptError {}
63
64/// The inner payload extracted from a successfully decrypted server frame.
65pub struct DecryptedMessage {
66    /// `salt` sent by the server.
67    pub salt: i64,
68    /// The `session_id` from the frame.
69    pub session_id: i64,
70    /// The `msg_id` of the inner message.
71    pub msg_id: i64,
72    /// `seq_no` of the inner message.
73    pub seq_no: i32,
74    /// TL-serialized body of the inner message.
75    pub body: Vec<u8>,
76    /// Set when the server's `msg_id` timestamp fell outside the ±300 s window
77    /// - the tDesktop equivalent of `badTime = true`.
78    ///
79    /// Instead of rejecting the frame (which causes a reconnect loop on a
80    /// clock-drifted machine), we return the frame with this flag set.  The
81    /// reader corrects `time_offset` from `msg_id` and routes the frame
82    /// normally - matching tDesktop's per-handler `requestsFixTimeSalt` /
83    /// `correctUnixtimeWithBadLocal` self-healing path.
84    pub bad_time: bool,
85}
86
87/// MTProto 2.0 encrypted session state.
88pub struct EncryptedSession {
89    auth_key: AuthKey,
90    session_id: i64,
91    sequence: i32,
92    last_msg_id: i64,
93    /// Current server salt to include in outgoing messages.
94    pub salt: i64,
95    /// Clock skew in seconds vs. server.
96    pub time_offset: i32,
97    /// Rolling 500-entry dedup buffer of seen server msg_ids.
98    seen_msg_ids: std::sync::Mutex<VecDeque<i64>>,
99}
100
101impl EncryptedSession {
102    /// Create a new encrypted session from the output of `authentication::finish`.
103    pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
104        let mut rnd = [0u8; 8];
105        getrandom::getrandom(&mut rnd).expect("getrandom");
106        Self {
107            auth_key: AuthKey::from_bytes(auth_key),
108            session_id: i64::from_le_bytes(rnd),
109            sequence: 0,
110            last_msg_id: 0,
111            salt: first_salt,
112            time_offset,
113            seen_msg_ids: std::sync::Mutex::new(VecDeque::with_capacity(SEEN_MSG_IDS_MAX)),
114        }
115    }
116
117    /// Compute the next message ID (based on corrected server time).
118    fn next_msg_id(&mut self) -> i64 {
119        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
120        let secs = (now.as_secs() as i32).wrapping_add(self.time_offset) as u64;
121        let nanos = now.subsec_nanos() as u64;
122        let mut id = ((secs << 32) | (nanos << 2)) as i64;
123        if self.last_msg_id >= id {
124            id = self.last_msg_id + 4;
125        }
126        self.last_msg_id = id;
127        id
128    }
129
130    /// Next content-related seq_no (odd) and advance the counter.
131    /// Used for all regular RPC requests.
132    fn next_seq_no(&mut self) -> i32 {
133        let n = self.sequence * 2 + 1;
134        self.sequence += 1;
135        n
136    }
137
138    /// Return the current even seq_no WITHOUT advancing the counter.
139    ///
140    /// Service messages (MsgsAck, containers, etc.) MUST use an even seqno
141    /// per the MTProto spec so the server does not expect a reply.
142    pub fn next_seq_no_ncr(&self) -> i32 {
143        self.sequence * 2
144    }
145
146    /// Correct the outgoing sequence counter when the server reports a
147    /// `bad_msg_notification` with error codes 32 (seq_no too low) or
148    /// 33 (seq_no too high).
149    ///
150    pub fn correct_seq_no(&mut self, code: u32) {
151        match code {
152            32 => {
153                // seq_no too low: jump forward so next send is well above server expectation
154                self.sequence += 64;
155                log::debug!(
156                    "[layer] seq_no correction: code 32, bumped seq to {}",
157                    self.sequence
158                );
159            }
160            33 => {
161                // seq_no too high: step back, but never below 1 to avoid
162                // re-using seq_no=1 which was already sent this session.
163                // Zeroing would make the next content message get seq_no=1,
164                // which the server already saw and will reject again with code 32.
165                self.sequence = self.sequence.saturating_sub(16).max(1);
166                log::debug!(
167                    "[layer] seq_no correction: code 33, lowered seq to {}",
168                    self.sequence
169                );
170            }
171            _ => {}
172        }
173    }
174
175    /// Re-derive the clock skew from a server-provided `msg_id`.
176    ///
177    /// Called on `bad_msg_notification` error codes 16 (msg_id too low) and
178    /// 17 (msg_id too high) so clock drift is corrected at any point in the
179    /// session, not only at connect time.
180    ///
181    pub fn correct_time_offset(&mut self, server_msg_id: i64) {
182        // Upper 32 bits of msg_id = Unix seconds on the server
183        let server_time = (server_msg_id >> 32) as i32;
184        let local_now = SystemTime::now()
185            .duration_since(UNIX_EPOCH)
186            .unwrap()
187            .as_secs() as i32;
188        let new_offset = server_time.wrapping_sub(local_now);
189        log::debug!(
190            "[layer] time_offset correction: {} → {} (server_time={server_time})",
191            self.time_offset,
192            new_offset
193        );
194        self.time_offset = new_offset;
195        // Also reset last_msg_id so next_msg_id rebuilds from corrected clock
196        self.last_msg_id = 0;
197    }
198
199    /// Allocate a fresh `(msg_id, seqno)` pair for an inner container message
200    /// WITHOUT encrypting anything.
201    ///
202    /// `content_related = true`  → odd seqno, advances counter  (regular RPCs)
203    /// `content_related = false` → even seqno, no advance       (MsgsAck, container)
204    ///
205    pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
206        let msg_id = self.next_msg_id();
207        let seqno = if content_related {
208            self.next_seq_no()
209        } else {
210            self.next_seq_no_ncr()
211        };
212        (msg_id, seqno)
213    }
214
215    /// Encrypt a pre-serialized TL body into a wire-ready MTProto frame.
216    ///
217    /// `content_related` controls whether the seqno is odd (content, advances
218    /// the counter) or even (service, no advance).
219    ///
220    /// Returns `(encrypted_wire_bytes, msg_id)`.
221    /// Used for (bad_msg re-send) and (container inner messages).
222    pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
223        let msg_id = self.next_msg_id();
224        let seq_no = if content_related {
225            self.next_seq_no()
226        } else {
227            self.next_seq_no_ncr()
228        };
229
230        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
231        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
232        buf.extend(self.salt.to_le_bytes());
233        buf.extend(self.session_id.to_le_bytes());
234        buf.extend(msg_id.to_le_bytes());
235        buf.extend(seq_no.to_le_bytes());
236        buf.extend((body.len() as u32).to_le_bytes());
237        buf.extend(body.iter().copied());
238
239        encrypt_data_v2(&mut buf, &self.auth_key);
240        (buf.as_ref().to_vec(), msg_id)
241    }
242
243    /// Encrypt a pre-built `msg_container` body (the container itself is
244    /// a non-content-related message with an even seqno).
245    ///
246    /// Returns `(encrypted_wire_bytes, container_msg_id)`.
247    /// The container_msg_id is needed so callers can map it back to inner
248    /// requests when a bad_msg_notification or bad_server_salt arrives for
249    /// the container rather than the individual inner message.
250    ///
251    pub fn pack_container(&mut self, container_body: &[u8]) -> (Vec<u8>, i64) {
252        self.pack_body_with_msg_id(container_body, false)
253    }
254
255    // Original pack methods (unchanged)
256
257    /// Serialize and encrypt a TL function into a wire-ready byte vector.
258    pub fn pack_serializable<S: layer_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
259        let body = call.to_bytes();
260        let msg_id = self.next_msg_id();
261        let seq_no = self.next_seq_no();
262
263        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
264        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
265        buf.extend(self.salt.to_le_bytes());
266        buf.extend(self.session_id.to_le_bytes());
267        buf.extend(msg_id.to_le_bytes());
268        buf.extend(seq_no.to_le_bytes());
269        buf.extend((body.len() as u32).to_le_bytes());
270        buf.extend(body.iter().copied());
271
272        encrypt_data_v2(&mut buf, &self.auth_key);
273        buf.as_ref().to_vec()
274    }
275
276    /// Like `pack_serializable` but also returns the `msg_id`.
277    pub fn pack_serializable_with_msg_id<S: layer_tl_types::Serializable>(
278        &mut self,
279        call: &S,
280    ) -> (Vec<u8>, i64) {
281        let body = call.to_bytes();
282        let msg_id = self.next_msg_id();
283        let seq_no = self.next_seq_no();
284        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
285        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
286        buf.extend(self.salt.to_le_bytes());
287        buf.extend(self.session_id.to_le_bytes());
288        buf.extend(msg_id.to_le_bytes());
289        buf.extend(seq_no.to_le_bytes());
290        buf.extend((body.len() as u32).to_le_bytes());
291        buf.extend(body.iter().copied());
292        encrypt_data_v2(&mut buf, &self.auth_key);
293        (buf.as_ref().to_vec(), msg_id)
294    }
295
296    /// Like [`pack`] but also returns the `msg_id` allocated for this message.
297    pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
298        let body = call.to_bytes();
299        let msg_id = self.next_msg_id();
300        let seq_no = self.next_seq_no();
301        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
302        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
303        buf.extend(self.salt.to_le_bytes());
304        buf.extend(self.session_id.to_le_bytes());
305        buf.extend(msg_id.to_le_bytes());
306        buf.extend(seq_no.to_le_bytes());
307        buf.extend((body.len() as u32).to_le_bytes());
308        buf.extend(body.iter().copied());
309        encrypt_data_v2(&mut buf, &self.auth_key);
310        (buf.as_ref().to_vec(), msg_id)
311    }
312
313    /// Encrypt and frame a [`RemoteCall`] into a ready-to-send MTProto message.
314    pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
315        let body = call.to_bytes();
316        let msg_id = self.next_msg_id();
317        let seq_no = self.next_seq_no();
318
319        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
320        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
321        buf.extend(self.salt.to_le_bytes());
322        buf.extend(self.session_id.to_le_bytes());
323        buf.extend(msg_id.to_le_bytes());
324        buf.extend(seq_no.to_le_bytes());
325        buf.extend((body.len() as u32).to_le_bytes());
326        buf.extend(body.iter().copied());
327
328        encrypt_data_v2(&mut buf, &self.auth_key);
329        buf.as_ref().to_vec()
330    }
331
332    /// Decrypt an encrypted server frame.
333    pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
334        let plaintext = decrypt_data_v2(frame, &self.auth_key).map_err(DecryptError::Crypto)?;
335
336        if plaintext.len() < 32 {
337            return Err(DecryptError::FrameTooShort);
338        }
339
340        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
341        let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
342        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
343        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
344        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
345
346        if session_id != self.session_id {
347            return Err(DecryptError::SessionMismatch);
348        }
349
350        // #3 check if server time (upper 32 bits of msg_id) deviates > 300 s.
351        // tDesktop: sets badTime=true and continues; self-healing happens
352        // per-handler via requestsFixTimeSalt / correctUnixtimeWithBadLocal.
353        // Layer: same - set bad_time flag, caller corrects time_offset.
354        let server_secs = (msg_id as u64 >> 32) as i64;
355        let now = SystemTime::now()
356            .duration_since(UNIX_EPOCH)
357            .unwrap()
358            .as_secs() as i64;
359        let corrected = now + self.time_offset as i64;
360        let bad_time = (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS;
361
362        // #2 rolling 500-entry dedup.
363        {
364            let mut seen = self.seen_msg_ids.lock().unwrap();
365            if seen.contains(&msg_id) {
366                return Err(DecryptError::DuplicateMsgId);
367            }
368            seen.push_back(msg_id);
369            if seen.len() > SEEN_MSG_IDS_MAX {
370                seen.pop_front();
371            }
372        }
373
374        if 32 + body_len > plaintext.len() {
375            return Err(DecryptError::FrameTooShort);
376        }
377        let body = plaintext[32..32 + body_len].to_vec();
378
379        Ok(DecryptedMessage {
380            salt,
381            session_id,
382            msg_id,
383            seq_no,
384            body,
385            bad_time,
386        })
387    }
388
389    /// Return the auth_key bytes (for persistence).
390    pub fn auth_key_bytes(&self) -> [u8; 256] {
391        self.auth_key.to_bytes()
392    }
393
394    /// Return the current session_id.
395    pub fn session_id(&self) -> i64 {
396        self.session_id
397    }
398}
399
400impl EncryptedSession {
401    /// Decrypt a frame using explicit key + session_id: no mutable state needed.
402    /// Used by the split-reader task so it can decrypt without locking the writer.
403    /// `time_offset` is the session's current clock skew (seconds); pass 0 if unknown.
404    pub fn decrypt_frame(
405        auth_key: &[u8; 256],
406        session_id: i64,
407        frame: &mut [u8],
408    ) -> Result<DecryptedMessage, DecryptError> {
409        Self::decrypt_frame_with_offset(auth_key, session_id, frame, 0)
410    }
411
412    /// Like [`decrypt_frame`] but applies the time-window check with the given
413    /// `time_offset` (seconds, server_time − local_time).
414    pub fn decrypt_frame_with_offset(
415        auth_key: &[u8; 256],
416        session_id: i64,
417        frame: &mut [u8],
418        time_offset: i32,
419    ) -> Result<DecryptedMessage, DecryptError> {
420        let key = AuthKey::from_bytes(*auth_key);
421        let plaintext = decrypt_data_v2(frame, &key).map_err(DecryptError::Crypto)?;
422        if plaintext.len() < 32 {
423            return Err(DecryptError::FrameTooShort);
424        }
425        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
426        let sid = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
427        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
428        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
429        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
430        if sid != session_id {
431            return Err(DecryptError::SessionMismatch);
432        }
433        // Time-window check - flag only, not an error (see unpack() comment).
434        let server_secs = (msg_id as u64 >> 32) as i64;
435        let now = SystemTime::now()
436            .duration_since(UNIX_EPOCH)
437            .unwrap()
438            .as_secs() as i64;
439        let corrected = now + time_offset as i64;
440        let bad_time = (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS;
441        if 32 + body_len > plaintext.len() {
442            return Err(DecryptError::FrameTooShort);
443        }
444        let body = plaintext[32..32 + body_len].to_vec();
445        Ok(DecryptedMessage {
446            salt,
447            session_id: sid,
448            msg_id,
449            seq_no,
450            body,
451            bad_time,
452        })
453    }
454}