Skip to main content

layer_mtproto/
encrypted.rs

1//! Encrypted MTProto 2.0 session (post auth-key).
2//!
3//! Once you have a `Finished` from [`crate::authentication`], construct an
4//! [`EncryptedSession`] and use it to serialize/deserialize all subsequent
5//! messages.
6
7use std::time::{SystemTime, UNIX_EPOCH};
8
9use layer_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
10use layer_tl_types::RemoteCall;
11
12/// Errors that can occur when decrypting a server message.
13#[derive(Debug)]
14pub enum DecryptError {
15    /// The underlying crypto layer rejected the message.
16    Crypto(layer_crypto::DecryptError),
17    /// The decrypted inner message was too short to contain a valid header.
18    FrameTooShort,
19    /// Session-ID mismatch (possible replay or wrong connection).
20    SessionMismatch,
21}
22
23impl std::fmt::Display for DecryptError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::Crypto(e) => write!(f, "crypto: {e}"),
27            Self::FrameTooShort => write!(f, "inner plaintext too short"),
28            Self::SessionMismatch => write!(f, "session_id mismatch"),
29        }
30    }
31}
32impl std::error::Error for DecryptError {}
33
34/// The inner payload extracted from a successfully decrypted server frame.
35pub struct DecryptedMessage {
36    /// `salt` sent by the server.
37    pub salt: i64,
38    /// The `session_id` from the frame.
39    pub session_id: i64,
40    /// The `msg_id` of the inner message.
41    pub msg_id: i64,
42    /// `seq_no` of the inner message.
43    pub seq_no: i32,
44    /// TL-serialized body of the inner message.
45    pub body: Vec<u8>,
46}
47
48/// MTProto 2.0 encrypted session state.
49///
50/// Wraps an `AuthKey` and tracks per-session counters (session_id, seq_no,
51/// last_msg_id, server salt).  Use [`EncryptedSession::pack`] to encrypt
52/// outgoing requests and [`EncryptedSession::unpack`] to decrypt incoming
53/// server frames.
54pub struct EncryptedSession {
55    auth_key: AuthKey,
56    session_id: i64,
57    sequence: i32,
58    last_msg_id: i64,
59    /// Current server salt to include in outgoing messages.
60    pub salt: i64,
61    /// Clock skew in seconds vs. server.
62    pub time_offset: i32,
63}
64
65impl EncryptedSession {
66    /// Create a new encrypted session from the output of `authentication::finish`.
67    pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
68        let mut rnd = [0u8; 8];
69        getrandom::getrandom(&mut rnd).expect("getrandom");
70        Self {
71            auth_key: AuthKey::from_bytes(auth_key),
72            session_id: i64::from_le_bytes(rnd),
73            sequence: 0,
74            last_msg_id: 0,
75            salt: first_salt,
76            time_offset,
77        }
78    }
79
80    /// Compute the next message ID (based on corrected server time).
81    fn next_msg_id(&mut self) -> i64 {
82        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
83        let secs = (now.as_secs() as i32).wrapping_add(self.time_offset) as u64;
84        let nanos = now.subsec_nanos() as u64;
85        let mut id = ((secs << 32) | (nanos << 2)) as i64;
86        if self.last_msg_id >= id {
87            id = self.last_msg_id + 4;
88        }
89        self.last_msg_id = id;
90        id
91    }
92
93    /// Next content-related seq_no (odd) and advance the counter.
94    /// Used for all regular RPC requests.
95    fn next_seq_no(&mut self) -> i32 {
96        let n = self.sequence * 2 + 1;
97        self.sequence += 1;
98        n
99    }
100
101    // ── G-05: non-content-related seq_no ─────────────────────────────────────
102    /// Return the current even seq_no WITHOUT advancing the counter.
103    ///
104    /// Service messages (MsgsAck, containers, etc.) MUST use an even seqno
105    /// per the MTProto spec so the server does not expect a reply.
106    /// grammers reference: `mtp/encrypted.rs: get_seq_no(content_related: bool)`
107    pub fn next_seq_no_ncr(&self) -> i32 {
108        self.sequence * 2
109    }
110
111    // ── G-03: seq_no correction on bad_msg codes 32/33 ───────────────────────
112    /// Correct the outgoing sequence counter when the server reports a
113    /// `bad_msg_notification` with error codes 32 (seq_no too low) or
114    /// 33 (seq_no too high).
115    ///
116    /// grammers reference: `mtp/encrypted.rs: handle_bad_notification() codes 32/33`
117    pub fn correct_seq_no(&mut self, code: u32) {
118        match code {
119            32 => {
120                // seq_no too low — jump forward so next send is well above server expectation
121                self.sequence += 64;
122                log::debug!(
123                    "[layer] G-03 seq_no correction: code 32, bumped seq to {}",
124                    self.sequence
125                );
126            }
127            33 => {
128                // seq_no too high — step back, but never below 1 to avoid
129                // re-using seq_no=1 which was already sent this session.
130                // Zeroing would make the next content message get seq_no=1,
131                // which the server already saw and will reject again with code 32.
132                self.sequence = self.sequence.saturating_sub(16).max(1);
133                log::debug!(
134                    "[layer] G-03 seq_no correction: code 33, lowered seq to {}",
135                    self.sequence
136                );
137            }
138            _ => {}
139        }
140    }
141
142    // ── G-12: dynamic time_offset correction ─────────────────────────────────
143    /// Re-derive the clock skew from a server-provided `msg_id`.
144    ///
145    /// Called on `bad_msg_notification` error codes 16 (msg_id too low) and
146    /// 17 (msg_id too high) so clock drift is corrected at any point in the
147    /// session, not only at connect time.
148    ///
149    /// grammers reference: `mtp/encrypted.rs: correct_time_offset(msg_id)`
150    pub fn correct_time_offset(&mut self, server_msg_id: i64) {
151        // Upper 32 bits of msg_id = Unix seconds on the server
152        let server_time = (server_msg_id >> 32) as i32;
153        let local_now = SystemTime::now()
154            .duration_since(UNIX_EPOCH)
155            .unwrap()
156            .as_secs() as i32;
157        let new_offset = server_time.wrapping_sub(local_now);
158        log::debug!(
159            "[layer] G-12 time_offset correction: {} → {} (server_time={server_time})",
160            self.time_offset,
161            new_offset
162        );
163        self.time_offset = new_offset;
164        // Also reset last_msg_id so next_msg_id rebuilds from corrected clock
165        self.last_msg_id = 0;
166    }
167
168    // ── G-02 / G-07 helpers ───────────────────────────────────────────────────
169
170    /// Allocate a fresh `(msg_id, seqno)` pair for an inner container message
171    /// WITHOUT encrypting anything.
172    ///
173    /// `content_related = true`  → odd seqno, advances counter  (regular RPCs)
174    /// `content_related = false` → even seqno, no advance       (MsgsAck, container)
175    ///
176    /// grammers reference: `mtp/encrypted.rs: get_seq_no(content_related)`
177    pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
178        let msg_id = self.next_msg_id();
179        let seqno = if content_related {
180            self.next_seq_no()
181        } else {
182            self.next_seq_no_ncr()
183        };
184        (msg_id, seqno)
185    }
186
187    /// Encrypt a pre-serialized TL body into a wire-ready MTProto frame.
188    ///
189    /// `content_related` controls whether the seqno is odd (content, advances
190    /// the counter) or even (service, no advance).
191    ///
192    /// Returns `(encrypted_wire_bytes, msg_id)`.
193    /// Used for G-02 (bad_msg re-send) and G-07 (container inner messages).
194    pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
195        let msg_id = self.next_msg_id();
196        let seq_no = if content_related {
197            self.next_seq_no()
198        } else {
199            self.next_seq_no_ncr()
200        };
201
202        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
203        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
204        buf.extend(self.salt.to_le_bytes());
205        buf.extend(self.session_id.to_le_bytes());
206        buf.extend(msg_id.to_le_bytes());
207        buf.extend(seq_no.to_le_bytes());
208        buf.extend((body.len() as u32).to_le_bytes());
209        buf.extend(body.iter().copied());
210
211        encrypt_data_v2(&mut buf, &self.auth_key);
212        (buf.as_ref().to_vec(), msg_id)
213    }
214
215    /// Encrypt a pre-built `msg_container` body (the container itself is
216    /// a non-content-related message with an even seqno).
217    ///
218    /// Returns `(encrypted_wire_bytes, container_msg_id)`.
219    /// The container_msg_id is needed so callers can map it back to inner
220    /// requests when a bad_msg_notification or bad_server_salt arrives for
221    /// the container rather than the individual inner message.
222    ///
223    /// grammers reference: `MsgIdPair { msg_id, container_msg_id }`
224    pub fn pack_container(&mut self, container_body: &[u8]) -> (Vec<u8>, i64) {
225        self.pack_body_with_msg_id(container_body, false)
226    }
227
228    // ── Original pack methods (unchanged) ────────────────────────────────────
229
230    /// Serialize and encrypt a TL function into a wire-ready byte vector.
231    pub fn pack_serializable<S: layer_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
232        let body = call.to_bytes();
233        let msg_id = self.next_msg_id();
234        let seq_no = self.next_seq_no();
235
236        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
237        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
238        buf.extend(self.salt.to_le_bytes());
239        buf.extend(self.session_id.to_le_bytes());
240        buf.extend(msg_id.to_le_bytes());
241        buf.extend(seq_no.to_le_bytes());
242        buf.extend((body.len() as u32).to_le_bytes());
243        buf.extend(body.iter().copied());
244
245        encrypt_data_v2(&mut buf, &self.auth_key);
246        buf.as_ref().to_vec()
247    }
248
249    /// Like `pack_serializable` but also returns the `msg_id`.
250    pub fn pack_serializable_with_msg_id<S: layer_tl_types::Serializable>(
251        &mut self,
252        call: &S,
253    ) -> (Vec<u8>, i64) {
254        let body = call.to_bytes();
255        let msg_id = self.next_msg_id();
256        let seq_no = self.next_seq_no();
257        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
258        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
259        buf.extend(self.salt.to_le_bytes());
260        buf.extend(self.session_id.to_le_bytes());
261        buf.extend(msg_id.to_le_bytes());
262        buf.extend(seq_no.to_le_bytes());
263        buf.extend((body.len() as u32).to_le_bytes());
264        buf.extend(body.iter().copied());
265        encrypt_data_v2(&mut buf, &self.auth_key);
266        (buf.as_ref().to_vec(), msg_id)
267    }
268
269    /// Like [`pack`] but also returns the `msg_id` allocated for this message.
270    pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
271        let body = call.to_bytes();
272        let msg_id = self.next_msg_id();
273        let seq_no = self.next_seq_no();
274        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
275        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
276        buf.extend(self.salt.to_le_bytes());
277        buf.extend(self.session_id.to_le_bytes());
278        buf.extend(msg_id.to_le_bytes());
279        buf.extend(seq_no.to_le_bytes());
280        buf.extend((body.len() as u32).to_le_bytes());
281        buf.extend(body.iter().copied());
282        encrypt_data_v2(&mut buf, &self.auth_key);
283        (buf.as_ref().to_vec(), msg_id)
284    }
285
286    /// Encrypt and frame a [`RemoteCall`] into a ready-to-send MTProto message.
287    pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
288        let body = call.to_bytes();
289        let msg_id = self.next_msg_id();
290        let seq_no = self.next_seq_no();
291
292        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
293        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
294        buf.extend(self.salt.to_le_bytes());
295        buf.extend(self.session_id.to_le_bytes());
296        buf.extend(msg_id.to_le_bytes());
297        buf.extend(seq_no.to_le_bytes());
298        buf.extend((body.len() as u32).to_le_bytes());
299        buf.extend(body.iter().copied());
300
301        encrypt_data_v2(&mut buf, &self.auth_key);
302        buf.as_ref().to_vec()
303    }
304
305    /// Decrypt an encrypted server frame.
306    pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
307        let plaintext = decrypt_data_v2(frame, &self.auth_key).map_err(DecryptError::Crypto)?;
308
309        if plaintext.len() < 32 {
310            return Err(DecryptError::FrameTooShort);
311        }
312
313        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
314        let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
315        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
316        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
317        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
318
319        if session_id != self.session_id {
320            return Err(DecryptError::SessionMismatch);
321        }
322
323        // Bounds-check BEFORE slicing. The old `.min()` silently truncated the body
324        // when body_len exceeded available plaintext (e.g. framing mismatch), which
325        // caused the TL deserializer downstream to hit "unexpected end of buffer".
326        if 32 + body_len > plaintext.len() {
327            return Err(DecryptError::FrameTooShort);
328        }
329        let body = plaintext[32..32 + body_len].to_vec();
330
331        Ok(DecryptedMessage {
332            salt,
333            session_id,
334            msg_id,
335            seq_no,
336            body,
337        })
338    }
339
340    /// Return the auth_key bytes (for persistence).
341    pub fn auth_key_bytes(&self) -> [u8; 256] {
342        self.auth_key.to_bytes()
343    }
344
345    /// Return the current session_id.
346    pub fn session_id(&self) -> i64 {
347        self.session_id
348    }
349}
350
351impl EncryptedSession {
352    /// Decrypt a frame using explicit key + session_id — no mutable state needed.
353    /// Used by the split-reader task so it can decrypt without locking the writer.
354    pub fn decrypt_frame(
355        auth_key: &[u8; 256],
356        session_id: i64,
357        frame: &mut [u8],
358    ) -> Result<DecryptedMessage, DecryptError> {
359        let key = AuthKey::from_bytes(*auth_key);
360        let plaintext = decrypt_data_v2(frame, &key).map_err(DecryptError::Crypto)?;
361        if plaintext.len() < 32 {
362            return Err(DecryptError::FrameTooShort);
363        }
364        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
365        let sid = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
366        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
367        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
368        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
369        if sid != session_id {
370            return Err(DecryptError::SessionMismatch);
371        }
372        if 32 + body_len > plaintext.len() {
373            return Err(DecryptError::FrameTooShort);
374        }
375        let body = plaintext[32..32 + body_len].to_vec();
376        Ok(DecryptedMessage {
377            salt,
378            session_id: sid,
379            msg_id,
380            seq_no,
381            body,
382        })
383    }
384}