Skip to main content

layer_mtproto/
encrypted.rs

1//! Encrypted MTProto 2.0 session (post auth-key).
2//!
3//! Once you have a `Finished` from [`crate::authentication`], construct an
4//! [`EncryptedSession`] and use it to serialize/deserialize all subsequent
5//! messages.
6
7use std::time::{SystemTime, UNIX_EPOCH};
8
9use layer_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
10use layer_tl_types::RemoteCall;
11
12/// Errors that can occur when decrypting a server message.
13#[derive(Debug)]
14pub enum DecryptError {
15    /// The underlying crypto layer rejected the message.
16    Crypto(layer_crypto::DecryptError),
17    /// The decrypted inner message was too short to contain a valid header.
18    FrameTooShort,
19    /// Session-ID mismatch (possible replay or wrong connection).
20    SessionMismatch,
21}
22
23impl std::fmt::Display for DecryptError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::Crypto(e) => write!(f, "crypto: {e}"),
27            Self::FrameTooShort => write!(f, "inner plaintext too short"),
28            Self::SessionMismatch => write!(f, "session_id mismatch"),
29        }
30    }
31}
32impl std::error::Error for DecryptError {}
33
34/// The inner payload extracted from a successfully decrypted server frame.
35pub struct DecryptedMessage {
36    /// `salt` sent by the server.
37    pub salt: i64,
38    /// The `session_id` from the frame.
39    pub session_id: i64,
40    /// The `msg_id` of the inner message.
41    pub msg_id: i64,
42    /// `seq_no` of the inner message.
43    pub seq_no: i32,
44    /// TL-serialized body of the inner message.
45    pub body: Vec<u8>,
46}
47
48/// MTProto 2.0 encrypted session state.
49///
50/// Wraps an `AuthKey` and tracks per-session counters (session_id, seq_no,
51/// last_msg_id, server salt).  Use [`EncryptedSession::pack`] to encrypt
52/// outgoing requests and [`EncryptedSession::unpack`] to decrypt incoming
53/// server frames.
54pub struct EncryptedSession {
55    auth_key: AuthKey,
56    session_id: i64,
57    sequence: i32,
58    last_msg_id: i64,
59    /// Current server salt to include in outgoing messages.
60    pub salt: i64,
61    /// Clock skew in seconds vs. server.
62    pub time_offset: i32,
63}
64
65impl EncryptedSession {
66    /// Create a new encrypted session from the output of `authentication::finish`.
67    pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
68        let mut rnd = [0u8; 8];
69        getrandom::getrandom(&mut rnd).expect("getrandom");
70        Self {
71            auth_key: AuthKey::from_bytes(auth_key),
72            session_id: i64::from_le_bytes(rnd),
73            sequence: 0,
74            last_msg_id: 0,
75            salt: first_salt,
76            time_offset,
77        }
78    }
79
80    /// Compute the next message ID (based on corrected server time).
81    fn next_msg_id(&mut self) -> i64 {
82        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
83        let secs = (now.as_secs() as i32).wrapping_add(self.time_offset) as u64;
84        let nanos = now.subsec_nanos() as u64;
85        let mut id = ((secs << 32) | (nanos << 2)) as i64;
86        if self.last_msg_id >= id {
87            id = self.last_msg_id + 4;
88        }
89        self.last_msg_id = id;
90        id
91    }
92
93    /// Next content-related seq_no (odd) and advance the counter.
94    /// Used for all regular RPC requests.
95    fn next_seq_no(&mut self) -> i32 {
96        let n = self.sequence * 2 + 1;
97        self.sequence += 1;
98        n
99    }
100
101    /// Return the current even seq_no WITHOUT advancing the counter.
102    ///
103    /// Service messages (MsgsAck, containers, etc.) MUST use an even seqno
104    /// per the MTProto spec so the server does not expect a reply.
105    pub fn next_seq_no_ncr(&self) -> i32 {
106        self.sequence * 2
107    }
108
109    /// Correct the outgoing sequence counter when the server reports a
110    /// `bad_msg_notification` with error codes 32 (seq_no too low) or
111    /// 33 (seq_no too high).
112    ///
113    pub fn correct_seq_no(&mut self, code: u32) {
114        match code {
115            32 => {
116                // seq_no too low: jump forward so next send is well above server expectation
117                self.sequence += 64;
118                log::debug!(
119                    "[layer] seq_no correction: code 32, bumped seq to {}",
120                    self.sequence
121                );
122            }
123            33 => {
124                // seq_no too high: step back, but never below 1 to avoid
125                // re-using seq_no=1 which was already sent this session.
126                // Zeroing would make the next content message get seq_no=1,
127                // which the server already saw and will reject again with code 32.
128                self.sequence = self.sequence.saturating_sub(16).max(1);
129                log::debug!(
130                    "[layer] seq_no correction: code 33, lowered seq to {}",
131                    self.sequence
132                );
133            }
134            _ => {}
135        }
136    }
137
138    /// Re-derive the clock skew from a server-provided `msg_id`.
139    ///
140    /// Called on `bad_msg_notification` error codes 16 (msg_id too low) and
141    /// 17 (msg_id too high) so clock drift is corrected at any point in the
142    /// session, not only at connect time.
143    ///
144    pub fn correct_time_offset(&mut self, server_msg_id: i64) {
145        // Upper 32 bits of msg_id = Unix seconds on the server
146        let server_time = (server_msg_id >> 32) as i32;
147        let local_now = SystemTime::now()
148            .duration_since(UNIX_EPOCH)
149            .unwrap()
150            .as_secs() as i32;
151        let new_offset = server_time.wrapping_sub(local_now);
152        log::debug!(
153            "[layer] time_offset correction: {} → {} (server_time={server_time})",
154            self.time_offset,
155            new_offset
156        );
157        self.time_offset = new_offset;
158        // Also reset last_msg_id so next_msg_id rebuilds from corrected clock
159        self.last_msg_id = 0;
160    }
161
162    /// Allocate a fresh `(msg_id, seqno)` pair for an inner container message
163    /// WITHOUT encrypting anything.
164    ///
165    /// `content_related = true`  → odd seqno, advances counter  (regular RPCs)
166    /// `content_related = false` → even seqno, no advance       (MsgsAck, container)
167    ///
168    pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
169        let msg_id = self.next_msg_id();
170        let seqno = if content_related {
171            self.next_seq_no()
172        } else {
173            self.next_seq_no_ncr()
174        };
175        (msg_id, seqno)
176    }
177
178    /// Encrypt a pre-serialized TL body into a wire-ready MTProto frame.
179    ///
180    /// `content_related` controls whether the seqno is odd (content, advances
181    /// the counter) or even (service, no advance).
182    ///
183    /// Returns `(encrypted_wire_bytes, msg_id)`.
184    /// Used for (bad_msg re-send) and (container inner messages).
185    pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
186        let msg_id = self.next_msg_id();
187        let seq_no = if content_related {
188            self.next_seq_no()
189        } else {
190            self.next_seq_no_ncr()
191        };
192
193        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
194        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
195        buf.extend(self.salt.to_le_bytes());
196        buf.extend(self.session_id.to_le_bytes());
197        buf.extend(msg_id.to_le_bytes());
198        buf.extend(seq_no.to_le_bytes());
199        buf.extend((body.len() as u32).to_le_bytes());
200        buf.extend(body.iter().copied());
201
202        encrypt_data_v2(&mut buf, &self.auth_key);
203        (buf.as_ref().to_vec(), msg_id)
204    }
205
206    /// Encrypt a pre-built `msg_container` body (the container itself is
207    /// a non-content-related message with an even seqno).
208    ///
209    /// Returns `(encrypted_wire_bytes, container_msg_id)`.
210    /// The container_msg_id is needed so callers can map it back to inner
211    /// requests when a bad_msg_notification or bad_server_salt arrives for
212    /// the container rather than the individual inner message.
213    ///
214    pub fn pack_container(&mut self, container_body: &[u8]) -> (Vec<u8>, i64) {
215        self.pack_body_with_msg_id(container_body, false)
216    }
217
218    // Original pack methods (unchanged)
219
220    /// Serialize and encrypt a TL function into a wire-ready byte vector.
221    pub fn pack_serializable<S: layer_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
222        let body = call.to_bytes();
223        let msg_id = self.next_msg_id();
224        let seq_no = self.next_seq_no();
225
226        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
227        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
228        buf.extend(self.salt.to_le_bytes());
229        buf.extend(self.session_id.to_le_bytes());
230        buf.extend(msg_id.to_le_bytes());
231        buf.extend(seq_no.to_le_bytes());
232        buf.extend((body.len() as u32).to_le_bytes());
233        buf.extend(body.iter().copied());
234
235        encrypt_data_v2(&mut buf, &self.auth_key);
236        buf.as_ref().to_vec()
237    }
238
239    /// Like `pack_serializable` but also returns the `msg_id`.
240    pub fn pack_serializable_with_msg_id<S: layer_tl_types::Serializable>(
241        &mut self,
242        call: &S,
243    ) -> (Vec<u8>, i64) {
244        let body = call.to_bytes();
245        let msg_id = self.next_msg_id();
246        let seq_no = self.next_seq_no();
247        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
248        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
249        buf.extend(self.salt.to_le_bytes());
250        buf.extend(self.session_id.to_le_bytes());
251        buf.extend(msg_id.to_le_bytes());
252        buf.extend(seq_no.to_le_bytes());
253        buf.extend((body.len() as u32).to_le_bytes());
254        buf.extend(body.iter().copied());
255        encrypt_data_v2(&mut buf, &self.auth_key);
256        (buf.as_ref().to_vec(), msg_id)
257    }
258
259    /// Like [`pack`] but also returns the `msg_id` allocated for this message.
260    pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
261        let body = call.to_bytes();
262        let msg_id = self.next_msg_id();
263        let seq_no = self.next_seq_no();
264        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
265        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
266        buf.extend(self.salt.to_le_bytes());
267        buf.extend(self.session_id.to_le_bytes());
268        buf.extend(msg_id.to_le_bytes());
269        buf.extend(seq_no.to_le_bytes());
270        buf.extend((body.len() as u32).to_le_bytes());
271        buf.extend(body.iter().copied());
272        encrypt_data_v2(&mut buf, &self.auth_key);
273        (buf.as_ref().to_vec(), msg_id)
274    }
275
276    /// Encrypt and frame a [`RemoteCall`] into a ready-to-send MTProto message.
277    pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
278        let body = call.to_bytes();
279        let msg_id = self.next_msg_id();
280        let seq_no = self.next_seq_no();
281
282        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
283        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
284        buf.extend(self.salt.to_le_bytes());
285        buf.extend(self.session_id.to_le_bytes());
286        buf.extend(msg_id.to_le_bytes());
287        buf.extend(seq_no.to_le_bytes());
288        buf.extend((body.len() as u32).to_le_bytes());
289        buf.extend(body.iter().copied());
290
291        encrypt_data_v2(&mut buf, &self.auth_key);
292        buf.as_ref().to_vec()
293    }
294
295    /// Decrypt an encrypted server frame.
296    pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
297        let plaintext = decrypt_data_v2(frame, &self.auth_key).map_err(DecryptError::Crypto)?;
298
299        if plaintext.len() < 32 {
300            return Err(DecryptError::FrameTooShort);
301        }
302
303        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
304        let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
305        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
306        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
307        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
308
309        if session_id != self.session_id {
310            return Err(DecryptError::SessionMismatch);
311        }
312
313        // Bounds-check BEFORE slicing. The old `.min()` silently truncated the body
314        // when body_len exceeded available plaintext (e.g. framing mismatch), which
315        // caused the TL deserializer downstream to hit "unexpected end of buffer".
316        if 32 + body_len > plaintext.len() {
317            return Err(DecryptError::FrameTooShort);
318        }
319        let body = plaintext[32..32 + body_len].to_vec();
320
321        Ok(DecryptedMessage {
322            salt,
323            session_id,
324            msg_id,
325            seq_no,
326            body,
327        })
328    }
329
330    /// Return the auth_key bytes (for persistence).
331    pub fn auth_key_bytes(&self) -> [u8; 256] {
332        self.auth_key.to_bytes()
333    }
334
335    /// Return the current session_id.
336    pub fn session_id(&self) -> i64 {
337        self.session_id
338    }
339}
340
341impl EncryptedSession {
342    /// Decrypt a frame using explicit key + session_id: no mutable state needed.
343    /// Used by the split-reader task so it can decrypt without locking the writer.
344    pub fn decrypt_frame(
345        auth_key: &[u8; 256],
346        session_id: i64,
347        frame: &mut [u8],
348    ) -> Result<DecryptedMessage, DecryptError> {
349        let key = AuthKey::from_bytes(*auth_key);
350        let plaintext = decrypt_data_v2(frame, &key).map_err(DecryptError::Crypto)?;
351        if plaintext.len() < 32 {
352            return Err(DecryptError::FrameTooShort);
353        }
354        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
355        let sid = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
356        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
357        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
358        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
359        if sid != session_id {
360            return Err(DecryptError::SessionMismatch);
361        }
362        if 32 + body_len > plaintext.len() {
363            return Err(DecryptError::FrameTooShort);
364        }
365        let body = plaintext[32..32 + body_len].to_vec();
366        Ok(DecryptedMessage {
367            salt,
368            session_id: sid,
369            msg_id,
370            seq_no,
371            body,
372        })
373    }
374}