Skip to main content

layer_mtproto/
encrypted.rs

1//! Encrypted MTProto 2.0 session (post auth-key).
2//!
3//! Once you have a `Finished` from [`crate::authentication`], construct an
4//! [`EncryptedSession`] and use it to serialize/deserialize all subsequent
5//! messages.
6
7use std::time::{SystemTime, UNIX_EPOCH};
8
9use layer_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
10use layer_tl_types::RemoteCall;
11
12
13/// Errors that can occur when decrypting a server message.
14#[derive(Debug)]
15pub enum DecryptError {
16    /// The underlying crypto layer rejected the message.
17    Crypto(layer_crypto::DecryptError),
18    /// The decrypted inner message was too short to contain a valid header.
19    FrameTooShort,
20    /// Session-ID mismatch (possible replay or wrong connection).
21    SessionMismatch,
22}
23
24impl std::fmt::Display for DecryptError {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            Self::Crypto(e) => write!(f, "crypto: {e}"),
28            Self::FrameTooShort => write!(f, "inner plaintext too short"),
29            Self::SessionMismatch => write!(f, "session_id mismatch"),
30        }
31    }
32}
33impl std::error::Error for DecryptError {}
34
35/// The inner payload extracted from a successfully decrypted server frame.
36pub struct DecryptedMessage {
37    /// `salt` sent by the server.
38    pub salt:       i64,
39    /// The `session_id` from the frame.
40    pub session_id: i64,
41    /// The `msg_id` of the inner message.
42    pub msg_id:     i64,
43    /// `seq_no` of the inner message.
44    pub seq_no:     i32,
45    /// TL-serialized body of the inner message.
46    pub body:       Vec<u8>,
47}
48
49/// MTProto 2.0 encrypted session state.
50///
51/// Wraps an `AuthKey` and tracks per-session counters (session_id, seq_no,
52/// last_msg_id, server salt).  Use [`EncryptedSession::pack`] to encrypt
53/// outgoing requests and [`EncryptedSession::unpack`] to decrypt incoming
54/// server frames.
55pub struct EncryptedSession {
56    auth_key:    AuthKey,
57    session_id:  i64,
58    sequence:    i32,
59    last_msg_id: i64,
60    /// Current server salt to include in outgoing messages.
61    pub salt:    i64,
62    /// Clock skew in seconds vs. server.
63    pub time_offset: i32,
64}
65
66impl EncryptedSession {
67    /// Create a new encrypted session from the output of `authentication::finish`.
68    pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
69        let mut rnd = [0u8; 8];
70        getrandom::getrandom(&mut rnd).expect("getrandom");
71        Self {
72            auth_key: AuthKey::from_bytes(auth_key),
73            session_id: i64::from_le_bytes(rnd),
74            sequence: 0,
75            last_msg_id: 0,
76            salt: first_salt,
77            time_offset,
78        }
79    }
80
81    /// Compute the next message ID (based on corrected server time).
82    fn next_msg_id(&mut self) -> i64 {
83        let now = SystemTime::now()
84            .duration_since(UNIX_EPOCH).unwrap();
85        let secs = (now.as_secs() as i32).wrapping_add(self.time_offset) as u64;
86        let nanos = now.subsec_nanos() as u64;
87        let mut id = ((secs << 32) | (nanos << 2)) as i64;
88        if self.last_msg_id >= id { id = self.last_msg_id + 4; }
89        self.last_msg_id = id;
90        id
91    }
92
93    /// Next content-related seq_no (odd) and advance the counter.
94    /// Used for all regular RPC requests.
95    fn next_seq_no(&mut self) -> i32 {
96        let n = self.sequence * 2 + 1;
97        self.sequence += 1;
98        n
99    }
100
101    // ── G-05: non-content-related seq_no ─────────────────────────────────────
102    /// Return the current even seq_no WITHOUT advancing the counter.
103    ///
104    /// Service messages (MsgsAck, containers, etc.) MUST use an even seqno
105    /// per the MTProto spec so the server does not expect a reply.
106    /// grammers reference: `mtp/encrypted.rs: get_seq_no(content_related: bool)`
107    pub fn next_seq_no_ncr(&self) -> i32 {
108        self.sequence * 2
109    }
110
111    // ── G-03: seq_no correction on bad_msg codes 32/33 ───────────────────────
112    /// Correct the outgoing sequence counter when the server reports a
113    /// `bad_msg_notification` with error codes 32 (seq_no too low) or
114    /// 33 (seq_no too high).
115    ///
116    /// grammers reference: `mtp/encrypted.rs: handle_bad_notification() codes 32/33`
117    pub fn correct_seq_no(&mut self, code: u32) {
118        match code {
119            32 => {
120                // seq_no too low — jump forward so next send is well above server expectation
121                self.sequence += 64;
122                log::debug!("[layer] G-03 seq_no correction: code 32, bumped seq to {}", self.sequence);
123            }
124            33 => {
125                // seq_no too high — step back, but never below 1 to avoid
126                // re-using seq_no=1 which was already sent this session.
127                // Zeroing would make the next content message get seq_no=1,
128                // which the server already saw and will reject again with code 32.
129                self.sequence = self.sequence.saturating_sub(16).max(1);
130                log::debug!("[layer] G-03 seq_no correction: code 33, lowered seq to {}", self.sequence);
131            }
132            _ => {}
133        }
134    }
135
136    // ── G-12: dynamic time_offset correction ─────────────────────────────────
137    /// Re-derive the clock skew from a server-provided `msg_id`.
138    ///
139    /// Called on `bad_msg_notification` error codes 16 (msg_id too low) and
140    /// 17 (msg_id too high) so clock drift is corrected at any point in the
141    /// session, not only at connect time.
142    ///
143    /// grammers reference: `mtp/encrypted.rs: correct_time_offset(msg_id)`
144    pub fn correct_time_offset(&mut self, server_msg_id: i64) {
145        // Upper 32 bits of msg_id = Unix seconds on the server
146        let server_time = (server_msg_id >> 32) as i32;
147        let local_now   = SystemTime::now()
148            .duration_since(UNIX_EPOCH).unwrap()
149            .as_secs() as i32;
150        let new_offset = server_time.wrapping_sub(local_now);
151        log::debug!(
152            "[layer] G-12 time_offset correction: {} → {} (server_time={server_time})",
153            self.time_offset, new_offset
154        );
155        self.time_offset = new_offset;
156        // Also reset last_msg_id so next_msg_id rebuilds from corrected clock
157        self.last_msg_id = 0;
158    }
159
160    // ── G-02 / G-07 helpers ───────────────────────────────────────────────────
161
162    /// Allocate a fresh `(msg_id, seqno)` pair for an inner container message
163    /// WITHOUT encrypting anything.
164    ///
165    /// `content_related = true`  → odd seqno, advances counter  (regular RPCs)
166    /// `content_related = false` → even seqno, no advance       (MsgsAck, container)
167    ///
168    /// grammers reference: `mtp/encrypted.rs: get_seq_no(content_related)`
169    pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
170        let msg_id = self.next_msg_id();
171        let seqno  = if content_related { self.next_seq_no() } else { self.next_seq_no_ncr() };
172        (msg_id, seqno)
173    }
174
175    /// Encrypt a pre-serialized TL body into a wire-ready MTProto frame.
176    ///
177    /// `content_related` controls whether the seqno is odd (content, advances
178    /// the counter) or even (service, no advance).
179    ///
180    /// Returns `(encrypted_wire_bytes, msg_id)`.
181    /// Used for G-02 (bad_msg re-send) and G-07 (container inner messages).
182    pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
183        let msg_id = self.next_msg_id();
184        let seq_no = if content_related { self.next_seq_no() } else { self.next_seq_no_ncr() };
185
186        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
187        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
188        buf.extend(self.salt.to_le_bytes());
189        buf.extend(self.session_id.to_le_bytes());
190        buf.extend(msg_id.to_le_bytes());
191        buf.extend(seq_no.to_le_bytes());
192        buf.extend((body.len() as u32).to_le_bytes());
193        buf.extend(body.iter().copied());
194
195        encrypt_data_v2(&mut buf, &self.auth_key);
196        (buf.as_ref().to_vec(), msg_id)
197    }
198
199    /// Encrypt a pre-built `msg_container` body (the container itself is
200    /// a non-content-related message with an even seqno).
201    ///
202    /// Returns `encrypted_wire_bytes`.
203    /// Used for G-07 (message container batching).
204    pub fn pack_container(&mut self, container_body: &[u8]) -> Vec<u8> {
205        let (wire, _msg_id) = self.pack_body_with_msg_id(container_body, false);
206        wire
207    }
208
209    // ── Original pack methods (unchanged) ────────────────────────────────────
210
211    /// Serialize and encrypt a TL function into a wire-ready byte vector.
212    pub fn pack_serializable<S: layer_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
213        let body = call.to_bytes();
214        let msg_id = self.next_msg_id();
215        let seq_no = self.next_seq_no();
216
217        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
218        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
219        buf.extend(self.salt.to_le_bytes());
220        buf.extend(self.session_id.to_le_bytes());
221        buf.extend(msg_id.to_le_bytes());
222        buf.extend(seq_no.to_le_bytes());
223        buf.extend((body.len() as u32).to_le_bytes());
224        buf.extend(body.iter().copied());
225
226        encrypt_data_v2(&mut buf, &self.auth_key);
227        buf.as_ref().to_vec()
228    }
229
230    /// Like `pack_serializable` but also returns the `msg_id`.
231    pub fn pack_serializable_with_msg_id<S: layer_tl_types::Serializable>(&mut self, call: &S) -> (Vec<u8>, i64) {
232        let body   = call.to_bytes();
233        let msg_id = self.next_msg_id();
234        let seq_no = self.next_seq_no();
235        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
236        let mut buf   = DequeBuffer::with_capacity(inner_len, 32);
237        buf.extend(self.salt.to_le_bytes());
238        buf.extend(self.session_id.to_le_bytes());
239        buf.extend(msg_id.to_le_bytes());
240        buf.extend(seq_no.to_le_bytes());
241        buf.extend((body.len() as u32).to_le_bytes());
242        buf.extend(body.iter().copied());
243        encrypt_data_v2(&mut buf, &self.auth_key);
244        (buf.as_ref().to_vec(), msg_id)
245    }
246
247    /// Like [`pack`] but also returns the `msg_id` allocated for this message.
248    pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
249        let body   = call.to_bytes();
250        let msg_id = self.next_msg_id();
251        let seq_no = self.next_seq_no();
252        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
253        let mut buf   = DequeBuffer::with_capacity(inner_len, 32);
254        buf.extend(self.salt.to_le_bytes());
255        buf.extend(self.session_id.to_le_bytes());
256        buf.extend(msg_id.to_le_bytes());
257        buf.extend(seq_no.to_le_bytes());
258        buf.extend((body.len() as u32).to_le_bytes());
259        buf.extend(body.iter().copied());
260        encrypt_data_v2(&mut buf, &self.auth_key);
261        (buf.as_ref().to_vec(), msg_id)
262    }
263
264    /// Encrypt and frame a [`RemoteCall`] into a ready-to-send MTProto message.
265    pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
266        let body = call.to_bytes();
267        let msg_id = self.next_msg_id();
268        let seq_no = self.next_seq_no();
269
270        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
271        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
272        buf.extend(self.salt.to_le_bytes());
273        buf.extend(self.session_id.to_le_bytes());
274        buf.extend(msg_id.to_le_bytes());
275        buf.extend(seq_no.to_le_bytes());
276        buf.extend((body.len() as u32).to_le_bytes());
277        buf.extend(body.iter().copied());
278
279        encrypt_data_v2(&mut buf, &self.auth_key);
280        buf.as_ref().to_vec()
281    }
282
283    /// Decrypt an encrypted server frame.
284    pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
285        let plaintext = decrypt_data_v2(frame, &self.auth_key)
286            .map_err(DecryptError::Crypto)?;
287
288        if plaintext.len() < 32 {
289            return Err(DecryptError::FrameTooShort);
290        }
291
292        let salt       = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
293        let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
294        let msg_id     = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
295        let seq_no     = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
296        let body_len   = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
297
298        if session_id != self.session_id {
299            return Err(DecryptError::SessionMismatch);
300        }
301
302        let body = plaintext[32..32 + body_len.min(plaintext.len() - 32)].to_vec();
303
304        Ok(DecryptedMessage { salt, session_id, msg_id, seq_no, body })
305    }
306
307    /// Return the auth_key bytes (for persistence).
308    pub fn auth_key_bytes(&self) -> [u8; 256] { self.auth_key.to_bytes() }
309
310    /// Return the current session_id.
311    pub fn session_id(&self) -> i64 { self.session_id }
312}
313
314impl EncryptedSession {
315    /// Decrypt a frame using explicit key + session_id — no mutable state needed.
316    /// Used by the split-reader task so it can decrypt without locking the writer.
317    pub fn decrypt_frame(
318        auth_key:   &[u8; 256],
319        session_id: i64,
320        frame:      &mut [u8],
321    ) -> Result<DecryptedMessage, DecryptError> {
322        let key = AuthKey::from_bytes(*auth_key);
323        let plaintext = decrypt_data_v2(frame, &key)
324            .map_err(DecryptError::Crypto)?;
325        if plaintext.len() < 32 {
326            return Err(DecryptError::FrameTooShort);
327        }
328        let salt     = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
329        let sid      = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
330        let msg_id   = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
331        let seq_no   = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
332        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
333        if sid != session_id {
334            return Err(DecryptError::SessionMismatch);
335        }
336        let body = plaintext[32..32 + body_len.min(plaintext.len() - 32)].to_vec();
337        Ok(DecryptedMessage { salt, session_id: sid, msg_id, seq_no, body })
338    }
339}