Skip to main content

layer_mtproto/
encrypted.rs

1//! Encrypted MTProto 2.0 session (post auth-key).
2//!
3//! Once you have a `Finished` from [`crate::authentication`], construct an
4//! [`EncryptedSession`] and use it to serialize/deserialize all subsequent
5//! messages.
6
7use std::collections::VecDeque;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use layer_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
11use layer_tl_types::RemoteCall;
12
13/// Rolling deduplication buffer – mirrors tDesktop's 500-entry seen-msg_id set.
14const SEEN_MSG_IDS_MAX: usize = 500;
15/// Maximum clock skew between client and server before a message is rejected.
16const MSG_ID_TIME_WINDOW_SECS: i64 = 300;
17
18/// Errors that can occur when decrypting a server message.
19#[derive(Debug)]
20pub enum DecryptError {
21    /// The underlying crypto layer rejected the message.
22    Crypto(layer_crypto::DecryptError),
23    /// The decrypted inner message was too short to contain a valid header.
24    FrameTooShort,
25    /// Session-ID mismatch (possible replay or wrong connection).
26    SessionMismatch,
27    /// Server msg_id is outside the ±300 s window of corrected local time.
28    MsgIdTimeWindow,
29    /// This msg_id was already seen in the rolling 500-entry buffer.
30    DuplicateMsgId,
31}
32
33impl std::fmt::Display for DecryptError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            Self::Crypto(e) => write!(f, "crypto: {e}"),
37            Self::FrameTooShort => write!(f, "inner plaintext too short"),
38            Self::SessionMismatch => write!(f, "session_id mismatch"),
39            Self::MsgIdTimeWindow => write!(f, "server msg_id outside ±300 s time window"),
40            Self::DuplicateMsgId => write!(f, "duplicate server msg_id (replay)"),
41        }
42    }
43}
44impl std::error::Error for DecryptError {}
45
46/// The inner payload extracted from a successfully decrypted server frame.
47pub struct DecryptedMessage {
48    /// `salt` sent by the server.
49    pub salt: i64,
50    /// The `session_id` from the frame.
51    pub session_id: i64,
52    /// The `msg_id` of the inner message.
53    pub msg_id: i64,
54    /// `seq_no` of the inner message.
55    pub seq_no: i32,
56    /// TL-serialized body of the inner message.
57    pub body: Vec<u8>,
58}
59
60/// MTProto 2.0 encrypted session state.
61pub struct EncryptedSession {
62    auth_key: AuthKey,
63    session_id: i64,
64    sequence: i32,
65    last_msg_id: i64,
66    /// Current server salt to include in outgoing messages.
67    pub salt: i64,
68    /// Clock skew in seconds vs. server.
69    pub time_offset: i32,
70    /// Rolling 500-entry dedup buffer of seen server msg_ids.
71    seen_msg_ids: std::sync::Mutex<VecDeque<i64>>,
72}
73
74impl EncryptedSession {
75    /// Create a new encrypted session from the output of `authentication::finish`.
76    pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
77        let mut rnd = [0u8; 8];
78        getrandom::getrandom(&mut rnd).expect("getrandom");
79        Self {
80            auth_key: AuthKey::from_bytes(auth_key),
81            session_id: i64::from_le_bytes(rnd),
82            sequence: 0,
83            last_msg_id: 0,
84            salt: first_salt,
85            time_offset,
86            seen_msg_ids: std::sync::Mutex::new(VecDeque::with_capacity(SEEN_MSG_IDS_MAX)),
87        }
88    }
89
90    /// Compute the next message ID (based on corrected server time).
91    fn next_msg_id(&mut self) -> i64 {
92        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
93        let secs = (now.as_secs() as i32).wrapping_add(self.time_offset) as u64;
94        let nanos = now.subsec_nanos() as u64;
95        let mut id = ((secs << 32) | (nanos << 2)) as i64;
96        if self.last_msg_id >= id {
97            id = self.last_msg_id + 4;
98        }
99        self.last_msg_id = id;
100        id
101    }
102
103    /// Next content-related seq_no (odd) and advance the counter.
104    /// Used for all regular RPC requests.
105    fn next_seq_no(&mut self) -> i32 {
106        let n = self.sequence * 2 + 1;
107        self.sequence += 1;
108        n
109    }
110
111    /// Return the current even seq_no WITHOUT advancing the counter.
112    ///
113    /// Service messages (MsgsAck, containers, etc.) MUST use an even seqno
114    /// per the MTProto spec so the server does not expect a reply.
115    pub fn next_seq_no_ncr(&self) -> i32 {
116        self.sequence * 2
117    }
118
119    /// Correct the outgoing sequence counter when the server reports a
120    /// `bad_msg_notification` with error codes 32 (seq_no too low) or
121    /// 33 (seq_no too high).
122    ///
123    pub fn correct_seq_no(&mut self, code: u32) {
124        match code {
125            32 => {
126                // seq_no too low: jump forward so next send is well above server expectation
127                self.sequence += 64;
128                log::debug!(
129                    "[layer] seq_no correction: code 32, bumped seq to {}",
130                    self.sequence
131                );
132            }
133            33 => {
134                // seq_no too high: step back, but never below 1 to avoid
135                // re-using seq_no=1 which was already sent this session.
136                // Zeroing would make the next content message get seq_no=1,
137                // which the server already saw and will reject again with code 32.
138                self.sequence = self.sequence.saturating_sub(16).max(1);
139                log::debug!(
140                    "[layer] seq_no correction: code 33, lowered seq to {}",
141                    self.sequence
142                );
143            }
144            _ => {}
145        }
146    }
147
148    /// Re-derive the clock skew from a server-provided `msg_id`.
149    ///
150    /// Called on `bad_msg_notification` error codes 16 (msg_id too low) and
151    /// 17 (msg_id too high) so clock drift is corrected at any point in the
152    /// session, not only at connect time.
153    ///
154    pub fn correct_time_offset(&mut self, server_msg_id: i64) {
155        // Upper 32 bits of msg_id = Unix seconds on the server
156        let server_time = (server_msg_id >> 32) as i32;
157        let local_now = SystemTime::now()
158            .duration_since(UNIX_EPOCH)
159            .unwrap()
160            .as_secs() as i32;
161        let new_offset = server_time.wrapping_sub(local_now);
162        log::debug!(
163            "[layer] time_offset correction: {} → {} (server_time={server_time})",
164            self.time_offset,
165            new_offset
166        );
167        self.time_offset = new_offset;
168        // Also reset last_msg_id so next_msg_id rebuilds from corrected clock
169        self.last_msg_id = 0;
170    }
171
172    /// Allocate a fresh `(msg_id, seqno)` pair for an inner container message
173    /// WITHOUT encrypting anything.
174    ///
175    /// `content_related = true`  → odd seqno, advances counter  (regular RPCs)
176    /// `content_related = false` → even seqno, no advance       (MsgsAck, container)
177    ///
178    pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
179        let msg_id = self.next_msg_id();
180        let seqno = if content_related {
181            self.next_seq_no()
182        } else {
183            self.next_seq_no_ncr()
184        };
185        (msg_id, seqno)
186    }
187
188    /// Encrypt a pre-serialized TL body into a wire-ready MTProto frame.
189    ///
190    /// `content_related` controls whether the seqno is odd (content, advances
191    /// the counter) or even (service, no advance).
192    ///
193    /// Returns `(encrypted_wire_bytes, msg_id)`.
194    /// Used for (bad_msg re-send) and (container inner messages).
195    pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
196        let msg_id = self.next_msg_id();
197        let seq_no = if content_related {
198            self.next_seq_no()
199        } else {
200            self.next_seq_no_ncr()
201        };
202
203        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
204        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
205        buf.extend(self.salt.to_le_bytes());
206        buf.extend(self.session_id.to_le_bytes());
207        buf.extend(msg_id.to_le_bytes());
208        buf.extend(seq_no.to_le_bytes());
209        buf.extend((body.len() as u32).to_le_bytes());
210        buf.extend(body.iter().copied());
211
212        encrypt_data_v2(&mut buf, &self.auth_key);
213        (buf.as_ref().to_vec(), msg_id)
214    }
215
216    /// Encrypt a pre-built `msg_container` body (the container itself is
217    /// a non-content-related message with an even seqno).
218    ///
219    /// Returns `(encrypted_wire_bytes, container_msg_id)`.
220    /// The container_msg_id is needed so callers can map it back to inner
221    /// requests when a bad_msg_notification or bad_server_salt arrives for
222    /// the container rather than the individual inner message.
223    ///
224    pub fn pack_container(&mut self, container_body: &[u8]) -> (Vec<u8>, i64) {
225        self.pack_body_with_msg_id(container_body, false)
226    }
227
228    // Original pack methods (unchanged)
229
230    /// Serialize and encrypt a TL function into a wire-ready byte vector.
231    pub fn pack_serializable<S: layer_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
232        let body = call.to_bytes();
233        let msg_id = self.next_msg_id();
234        let seq_no = self.next_seq_no();
235
236        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
237        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
238        buf.extend(self.salt.to_le_bytes());
239        buf.extend(self.session_id.to_le_bytes());
240        buf.extend(msg_id.to_le_bytes());
241        buf.extend(seq_no.to_le_bytes());
242        buf.extend((body.len() as u32).to_le_bytes());
243        buf.extend(body.iter().copied());
244
245        encrypt_data_v2(&mut buf, &self.auth_key);
246        buf.as_ref().to_vec()
247    }
248
249    /// Like `pack_serializable` but also returns the `msg_id`.
250    pub fn pack_serializable_with_msg_id<S: layer_tl_types::Serializable>(
251        &mut self,
252        call: &S,
253    ) -> (Vec<u8>, i64) {
254        let body = call.to_bytes();
255        let msg_id = self.next_msg_id();
256        let seq_no = self.next_seq_no();
257        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
258        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
259        buf.extend(self.salt.to_le_bytes());
260        buf.extend(self.session_id.to_le_bytes());
261        buf.extend(msg_id.to_le_bytes());
262        buf.extend(seq_no.to_le_bytes());
263        buf.extend((body.len() as u32).to_le_bytes());
264        buf.extend(body.iter().copied());
265        encrypt_data_v2(&mut buf, &self.auth_key);
266        (buf.as_ref().to_vec(), msg_id)
267    }
268
269    /// Like [`pack`] but also returns the `msg_id` allocated for this message.
270    pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
271        let body = call.to_bytes();
272        let msg_id = self.next_msg_id();
273        let seq_no = self.next_seq_no();
274        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
275        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
276        buf.extend(self.salt.to_le_bytes());
277        buf.extend(self.session_id.to_le_bytes());
278        buf.extend(msg_id.to_le_bytes());
279        buf.extend(seq_no.to_le_bytes());
280        buf.extend((body.len() as u32).to_le_bytes());
281        buf.extend(body.iter().copied());
282        encrypt_data_v2(&mut buf, &self.auth_key);
283        (buf.as_ref().to_vec(), msg_id)
284    }
285
286    /// Encrypt and frame a [`RemoteCall`] into a ready-to-send MTProto message.
287    pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
288        let body = call.to_bytes();
289        let msg_id = self.next_msg_id();
290        let seq_no = self.next_seq_no();
291
292        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
293        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
294        buf.extend(self.salt.to_le_bytes());
295        buf.extend(self.session_id.to_le_bytes());
296        buf.extend(msg_id.to_le_bytes());
297        buf.extend(seq_no.to_le_bytes());
298        buf.extend((body.len() as u32).to_le_bytes());
299        buf.extend(body.iter().copied());
300
301        encrypt_data_v2(&mut buf, &self.auth_key);
302        buf.as_ref().to_vec()
303    }
304
305    /// Decrypt an encrypted server frame.
306    pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
307        let plaintext = decrypt_data_v2(frame, &self.auth_key).map_err(DecryptError::Crypto)?;
308
309        if plaintext.len() < 32 {
310            return Err(DecryptError::FrameTooShort);
311        }
312
313        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
314        let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
315        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
316        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
317        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
318
319        if session_id != self.session_id {
320            return Err(DecryptError::SessionMismatch);
321        }
322
323        // #3 reject if server time (upper 32 bits of msg_id) deviates > 300 s.
324        let server_secs = (msg_id as u64 >> 32) as i64;
325        let now = SystemTime::now()
326            .duration_since(UNIX_EPOCH)
327            .unwrap()
328            .as_secs() as i64;
329        let corrected = now + self.time_offset as i64;
330        if (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS {
331            return Err(DecryptError::MsgIdTimeWindow);
332        }
333
334        // #2 rolling 500-entry dedup.
335        {
336            let mut seen = self.seen_msg_ids.lock().unwrap();
337            if seen.contains(&msg_id) {
338                return Err(DecryptError::DuplicateMsgId);
339            }
340            seen.push_back(msg_id);
341            if seen.len() > SEEN_MSG_IDS_MAX {
342                seen.pop_front();
343            }
344        }
345
346        if 32 + body_len > plaintext.len() {
347            return Err(DecryptError::FrameTooShort);
348        }
349        let body = plaintext[32..32 + body_len].to_vec();
350
351        Ok(DecryptedMessage {
352            salt,
353            session_id,
354            msg_id,
355            seq_no,
356            body,
357        })
358    }
359
360    /// Return the auth_key bytes (for persistence).
361    pub fn auth_key_bytes(&self) -> [u8; 256] {
362        self.auth_key.to_bytes()
363    }
364
365    /// Return the current session_id.
366    pub fn session_id(&self) -> i64 {
367        self.session_id
368    }
369}
370
371impl EncryptedSession {
372    /// Decrypt a frame using explicit key + session_id: no mutable state needed.
373    /// Used by the split-reader task so it can decrypt without locking the writer.
374    /// `time_offset` is the session's current clock skew (seconds); pass 0 if unknown.
375    pub fn decrypt_frame(
376        auth_key: &[u8; 256],
377        session_id: i64,
378        frame: &mut [u8],
379    ) -> Result<DecryptedMessage, DecryptError> {
380        Self::decrypt_frame_with_offset(auth_key, session_id, frame, 0)
381    }
382
383    /// Like [`decrypt_frame`] but applies the time-window check with the given
384    /// `time_offset` (seconds, server_time − local_time).
385    pub fn decrypt_frame_with_offset(
386        auth_key: &[u8; 256],
387        session_id: i64,
388        frame: &mut [u8],
389        time_offset: i32,
390    ) -> Result<DecryptedMessage, DecryptError> {
391        let key = AuthKey::from_bytes(*auth_key);
392        let plaintext = decrypt_data_v2(frame, &key).map_err(DecryptError::Crypto)?;
393        if plaintext.len() < 32 {
394            return Err(DecryptError::FrameTooShort);
395        }
396        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
397        let sid = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
398        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
399        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
400        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
401        if sid != session_id {
402            return Err(DecryptError::SessionMismatch);
403        }
404        // Time-window check.
405        let server_secs = (msg_id as u64 >> 32) as i64;
406        let now = SystemTime::now()
407            .duration_since(UNIX_EPOCH)
408            .unwrap()
409            .as_secs() as i64;
410        let corrected = now + time_offset as i64;
411        if (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS {
412            return Err(DecryptError::MsgIdTimeWindow);
413        }
414        if 32 + body_len > plaintext.len() {
415            return Err(DecryptError::FrameTooShort);
416        }
417        let body = plaintext[32..32 + body_len].to_vec();
418        Ok(DecryptedMessage {
419            salt,
420            session_id: sid,
421            msg_id,
422            seq_no,
423            body,
424        })
425    }
426}