Skip to main content

layer_mtproto/
encrypted.rs

1//! Encrypted MTProto 2.0 session (post auth-key).
2//!
3//! Once you have a `Finished` from [`crate::authentication`], construct an
4//! [`EncryptedSession`] and use it to serialize/deserialize all subsequent
5//! messages.
6
7use std::time::{SystemTime, UNIX_EPOCH};
8
9use layer_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
10use layer_tl_types::RemoteCall;
11
12/// Errors that can occur when decrypting a server message.
13#[derive(Debug)]
14pub enum DecryptError {
15    /// The underlying crypto layer rejected the message.
16    Crypto(layer_crypto::DecryptError),
17    /// The decrypted inner message was too short to contain a valid header.
18    FrameTooShort,
19    /// Session-ID mismatch (possible replay or wrong connection).
20    SessionMismatch,
21}
22
23impl std::fmt::Display for DecryptError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::Crypto(e) => write!(f, "crypto: {e}"),
27            Self::FrameTooShort => write!(f, "inner plaintext too short"),
28            Self::SessionMismatch => write!(f, "session_id mismatch"),
29        }
30    }
31}
32impl std::error::Error for DecryptError {}
33
34/// The inner payload extracted from a successfully decrypted server frame.
35pub struct DecryptedMessage {
36    /// `salt` sent by the server.
37    pub salt: i64,
38    /// The `session_id` from the frame.
39    pub session_id: i64,
40    /// The `msg_id` of the inner message.
41    pub msg_id: i64,
42    /// `seq_no` of the inner message.
43    pub seq_no: i32,
44    /// TL-serialized body of the inner message.
45    pub body: Vec<u8>,
46}
47
48/// MTProto 2.0 encrypted session state.
49///
50/// Wraps an `AuthKey` and tracks per-session counters (session_id, seq_no,
51/// last_msg_id, server salt).  Use [`EncryptedSession::pack`] to encrypt
52/// outgoing requests and [`EncryptedSession::unpack`] to decrypt incoming
53/// server frames.
54pub struct EncryptedSession {
55    auth_key: AuthKey,
56    session_id: i64,
57    sequence: i32,
58    last_msg_id: i64,
59    /// Current server salt to include in outgoing messages.
60    pub salt: i64,
61    /// Clock skew in seconds vs. server.
62    pub time_offset: i32,
63}
64
65impl EncryptedSession {
66    /// Create a new encrypted session from the output of `authentication::finish`.
67    pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
68        let mut rnd = [0u8; 8];
69        getrandom::getrandom(&mut rnd).expect("getrandom");
70        Self {
71            auth_key: AuthKey::from_bytes(auth_key),
72            session_id: i64::from_le_bytes(rnd),
73            sequence: 0,
74            last_msg_id: 0,
75            salt: first_salt,
76            time_offset,
77        }
78    }
79
80    /// Compute the next message ID (based on corrected server time).
81    fn next_msg_id(&mut self) -> i64 {
82        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
83        let secs = (now.as_secs() as i32).wrapping_add(self.time_offset) as u64;
84        let nanos = now.subsec_nanos() as u64;
85        let mut id = ((secs << 32) | (nanos << 2)) as i64;
86        if self.last_msg_id >= id {
87            id = self.last_msg_id + 4;
88        }
89        self.last_msg_id = id;
90        id
91    }
92
93    /// Next content-related seq_no (odd) and advance the counter.
94    /// Used for all regular RPC requests.
95    fn next_seq_no(&mut self) -> i32 {
96        let n = self.sequence * 2 + 1;
97        self.sequence += 1;
98        n
99    }
100
101    // ── G-05: non-content-related seq_no ─────────────────────────────────────
102    /// Return the current even seq_no WITHOUT advancing the counter.
103    ///
104    /// Service messages (MsgsAck, containers, etc.) MUST use an even seqno
105    /// per the MTProto spec so the server does not expect a reply.
106    /// grammers reference: `mtp/encrypted.rs: get_seq_no(content_related: bool)`
107    pub fn next_seq_no_ncr(&self) -> i32 {
108        self.sequence * 2
109    }
110
111    // ── G-03: seq_no correction on bad_msg codes 32/33 ───────────────────────
112    /// Correct the outgoing sequence counter when the server reports a
113    /// `bad_msg_notification` with error codes 32 (seq_no too low) or
114    /// 33 (seq_no too high).
115    ///
116    /// grammers reference: `mtp/encrypted.rs: handle_bad_notification() codes 32/33`
117    pub fn correct_seq_no(&mut self, code: u32) {
118        match code {
119            32 => {
120                // seq_no too low — jump forward so next send is well above server expectation
121                self.sequence += 64;
122                log::debug!(
123                    "[layer] G-03 seq_no correction: code 32, bumped seq to {}",
124                    self.sequence
125                );
126            }
127            33 => {
128                // seq_no too high — step back, but never below 1 to avoid
129                // re-using seq_no=1 which was already sent this session.
130                // Zeroing would make the next content message get seq_no=1,
131                // which the server already saw and will reject again with code 32.
132                self.sequence = self.sequence.saturating_sub(16).max(1);
133                log::debug!(
134                    "[layer] G-03 seq_no correction: code 33, lowered seq to {}",
135                    self.sequence
136                );
137            }
138            _ => {}
139        }
140    }
141
142    // ── G-12: dynamic time_offset correction ─────────────────────────────────
143    /// Re-derive the clock skew from a server-provided `msg_id`.
144    ///
145    /// Called on `bad_msg_notification` error codes 16 (msg_id too low) and
146    /// 17 (msg_id too high) so clock drift is corrected at any point in the
147    /// session, not only at connect time.
148    ///
149    /// grammers reference: `mtp/encrypted.rs: correct_time_offset(msg_id)`
150    pub fn correct_time_offset(&mut self, server_msg_id: i64) {
151        // Upper 32 bits of msg_id = Unix seconds on the server
152        let server_time = (server_msg_id >> 32) as i32;
153        let local_now = SystemTime::now()
154            .duration_since(UNIX_EPOCH)
155            .unwrap()
156            .as_secs() as i32;
157        let new_offset = server_time.wrapping_sub(local_now);
158        log::debug!(
159            "[layer] G-12 time_offset correction: {} → {} (server_time={server_time})",
160            self.time_offset,
161            new_offset
162        );
163        self.time_offset = new_offset;
164        // Also reset last_msg_id so next_msg_id rebuilds from corrected clock
165        self.last_msg_id = 0;
166    }
167
168    // ── G-02 / G-07 helpers ───────────────────────────────────────────────────
169
170    /// Allocate a fresh `(msg_id, seqno)` pair for an inner container message
171    /// WITHOUT encrypting anything.
172    ///
173    /// `content_related = true`  → odd seqno, advances counter  (regular RPCs)
174    /// `content_related = false` → even seqno, no advance       (MsgsAck, container)
175    ///
176    /// grammers reference: `mtp/encrypted.rs: get_seq_no(content_related)`
177    pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
178        let msg_id = self.next_msg_id();
179        let seqno = if content_related {
180            self.next_seq_no()
181        } else {
182            self.next_seq_no_ncr()
183        };
184        (msg_id, seqno)
185    }
186
187    /// Encrypt a pre-serialized TL body into a wire-ready MTProto frame.
188    ///
189    /// `content_related` controls whether the seqno is odd (content, advances
190    /// the counter) or even (service, no advance).
191    ///
192    /// Returns `(encrypted_wire_bytes, msg_id)`.
193    /// Used for G-02 (bad_msg re-send) and G-07 (container inner messages).
194    pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
195        let msg_id = self.next_msg_id();
196        let seq_no = if content_related {
197            self.next_seq_no()
198        } else {
199            self.next_seq_no_ncr()
200        };
201
202        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
203        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
204        buf.extend(self.salt.to_le_bytes());
205        buf.extend(self.session_id.to_le_bytes());
206        buf.extend(msg_id.to_le_bytes());
207        buf.extend(seq_no.to_le_bytes());
208        buf.extend((body.len() as u32).to_le_bytes());
209        buf.extend(body.iter().copied());
210
211        encrypt_data_v2(&mut buf, &self.auth_key);
212        (buf.as_ref().to_vec(), msg_id)
213    }
214
215    /// Encrypt a pre-built `msg_container` body (the container itself is
216    /// a non-content-related message with an even seqno).
217    ///
218    /// Returns `encrypted_wire_bytes`.
219    /// Used for G-07 (message container batching).
220    pub fn pack_container(&mut self, container_body: &[u8]) -> Vec<u8> {
221        let (wire, _msg_id) = self.pack_body_with_msg_id(container_body, false);
222        wire
223    }
224
225    // ── Original pack methods (unchanged) ────────────────────────────────────
226
227    /// Serialize and encrypt a TL function into a wire-ready byte vector.
228    pub fn pack_serializable<S: layer_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
229        let body = call.to_bytes();
230        let msg_id = self.next_msg_id();
231        let seq_no = self.next_seq_no();
232
233        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
234        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
235        buf.extend(self.salt.to_le_bytes());
236        buf.extend(self.session_id.to_le_bytes());
237        buf.extend(msg_id.to_le_bytes());
238        buf.extend(seq_no.to_le_bytes());
239        buf.extend((body.len() as u32).to_le_bytes());
240        buf.extend(body.iter().copied());
241
242        encrypt_data_v2(&mut buf, &self.auth_key);
243        buf.as_ref().to_vec()
244    }
245
246    /// Like `pack_serializable` but also returns the `msg_id`.
247    pub fn pack_serializable_with_msg_id<S: layer_tl_types::Serializable>(
248        &mut self,
249        call: &S,
250    ) -> (Vec<u8>, i64) {
251        let body = call.to_bytes();
252        let msg_id = self.next_msg_id();
253        let seq_no = self.next_seq_no();
254        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
255        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
256        buf.extend(self.salt.to_le_bytes());
257        buf.extend(self.session_id.to_le_bytes());
258        buf.extend(msg_id.to_le_bytes());
259        buf.extend(seq_no.to_le_bytes());
260        buf.extend((body.len() as u32).to_le_bytes());
261        buf.extend(body.iter().copied());
262        encrypt_data_v2(&mut buf, &self.auth_key);
263        (buf.as_ref().to_vec(), msg_id)
264    }
265
266    /// Like [`pack`] but also returns the `msg_id` allocated for this message.
267    pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
268        let body = call.to_bytes();
269        let msg_id = self.next_msg_id();
270        let seq_no = self.next_seq_no();
271        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
272        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
273        buf.extend(self.salt.to_le_bytes());
274        buf.extend(self.session_id.to_le_bytes());
275        buf.extend(msg_id.to_le_bytes());
276        buf.extend(seq_no.to_le_bytes());
277        buf.extend((body.len() as u32).to_le_bytes());
278        buf.extend(body.iter().copied());
279        encrypt_data_v2(&mut buf, &self.auth_key);
280        (buf.as_ref().to_vec(), msg_id)
281    }
282
283    /// Encrypt and frame a [`RemoteCall`] into a ready-to-send MTProto message.
284    pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
285        let body = call.to_bytes();
286        let msg_id = self.next_msg_id();
287        let seq_no = self.next_seq_no();
288
289        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
290        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
291        buf.extend(self.salt.to_le_bytes());
292        buf.extend(self.session_id.to_le_bytes());
293        buf.extend(msg_id.to_le_bytes());
294        buf.extend(seq_no.to_le_bytes());
295        buf.extend((body.len() as u32).to_le_bytes());
296        buf.extend(body.iter().copied());
297
298        encrypt_data_v2(&mut buf, &self.auth_key);
299        buf.as_ref().to_vec()
300    }
301
302    /// Decrypt an encrypted server frame.
303    pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
304        let plaintext = decrypt_data_v2(frame, &self.auth_key).map_err(DecryptError::Crypto)?;
305
306        if plaintext.len() < 32 {
307            return Err(DecryptError::FrameTooShort);
308        }
309
310        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
311        let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
312        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
313        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
314        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
315
316        if session_id != self.session_id {
317            return Err(DecryptError::SessionMismatch);
318        }
319
320        let body = plaintext[32..32 + body_len.min(plaintext.len() - 32)].to_vec();
321
322        Ok(DecryptedMessage {
323            salt,
324            session_id,
325            msg_id,
326            seq_no,
327            body,
328        })
329    }
330
331    /// Return the auth_key bytes (for persistence).
332    pub fn auth_key_bytes(&self) -> [u8; 256] {
333        self.auth_key.to_bytes()
334    }
335
336    /// Return the current session_id.
337    pub fn session_id(&self) -> i64 {
338        self.session_id
339    }
340}
341
342impl EncryptedSession {
343    /// Decrypt a frame using explicit key + session_id — no mutable state needed.
344    /// Used by the split-reader task so it can decrypt without locking the writer.
345    pub fn decrypt_frame(
346        auth_key: &[u8; 256],
347        session_id: i64,
348        frame: &mut [u8],
349    ) -> Result<DecryptedMessage, DecryptError> {
350        let key = AuthKey::from_bytes(*auth_key);
351        let plaintext = decrypt_data_v2(frame, &key).map_err(DecryptError::Crypto)?;
352        if plaintext.len() < 32 {
353            return Err(DecryptError::FrameTooShort);
354        }
355        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
356        let sid = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
357        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
358        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
359        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
360        if sid != session_id {
361            return Err(DecryptError::SessionMismatch);
362        }
363        let body = plaintext[32..32 + body_len.min(plaintext.len() - 32)].to_vec();
364        Ok(DecryptedMessage {
365            salt,
366            session_id: sid,
367            msg_id,
368            seq_no,
369            body,
370        })
371    }
372}