Skip to main content

ferogram_mtproto/
encrypted.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3//
4// ferogram: async Telegram MTProto client in Rust
5// https://github.com/ankit-chaubey/ferogram
6//
7//
8// If you use or modify this code, keep this notice at the top of your file
9// and include the LICENSE-MIT or LICENSE-APACHE file from this repository:
10// https://github.com/ankit-chaubey/ferogram
11
12//! Encrypted MTProto 2.0 session (post auth-key).
13//!
14//! Once you have a `Finished` from [`crate::authentication`], construct an
15//! [`EncryptedSession`] and use it to serialize/deserialize all subsequent
16//! messages.
17
18use std::collections::VecDeque;
19use std::time::{SystemTime, UNIX_EPOCH};
20
21use ferogram_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
22use ferogram_tl_types::RemoteCall;
23
24/// Rolling deduplication buffer for server msg_ids.
25const SEEN_MSG_IDS_MAX: usize = 500;
26/// Maximum clock skew between client and server before a message is rejected.
27const MSG_ID_TIME_WINDOW_SECS: i64 = 300;
28
29/// Errors that can occur when decrypting a server message.
30#[derive(Debug)]
31pub enum DecryptError {
32    /// The underlying crypto layer rejected the message.
33    Crypto(ferogram_crypto::DecryptError),
34    /// The decrypted inner message was too short to contain a valid header.
35    FrameTooShort,
36    /// Session-ID mismatch (possible replay or wrong connection).
37    SessionMismatch,
38    /// Server msg_id is outside the ±300 s window of corrected local time.
39    MsgIdTimeWindow,
40    /// This msg_id was already seen in the rolling 500-entry buffer.
41    DuplicateMsgId,
42}
43
44impl std::fmt::Display for DecryptError {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            Self::Crypto(e) => write!(f, "crypto: {e}"),
48            Self::FrameTooShort => write!(f, "inner plaintext too short"),
49            Self::SessionMismatch => write!(f, "session_id mismatch"),
50            Self::MsgIdTimeWindow => write!(f, "server msg_id outside ±300 s time window"),
51            Self::DuplicateMsgId => write!(f, "duplicate server msg_id (replay)"),
52        }
53    }
54}
55impl std::error::Error for DecryptError {}
56
57/// The inner payload extracted from a successfully decrypted server frame.
58pub struct DecryptedMessage {
59    /// `salt` sent by the server.
60    pub salt: i64,
61    /// The `session_id` from the frame.
62    pub session_id: i64,
63    /// The `msg_id` of the inner message.
64    pub msg_id: i64,
65    /// `seq_no` of the inner message.
66    pub seq_no: i32,
67    /// TL-serialized body of the inner message.
68    pub body: Vec<u8>,
69}
70
71/// Shared, persistent dedup ring for server msg_ids.
72///
73/// Outlives individual `EncryptedSession` objects so that replayed frames
74/// from a prior connection cycle are still rejected after reconnect.
75pub type SeenMsgIds = std::sync::Arc<std::sync::Mutex<VecDeque<i64>>>;
76
77/// Allocate a fresh seen-msg_id ring.
78pub fn new_seen_msg_ids() -> SeenMsgIds {
79    std::sync::Arc::new(std::sync::Mutex::new(VecDeque::with_capacity(
80        SEEN_MSG_IDS_MAX,
81    )))
82}
83
84/// MTProto 2.0 encrypted session state.
85pub struct EncryptedSession {
86    auth_key: AuthKey,
87    session_id: i64,
88    sequence: i32,
89    last_msg_id: i64,
90    /// Current server salt to include in outgoing messages.
91    pub salt: i64,
92    /// Clock skew in seconds vs. server.
93    pub time_offset: i32,
94    /// Rolling 500-entry dedup buffer of seen server msg_ids.
95    /// Shared with the owning DcConnection so it survives reconnects.
96    seen_msg_ids: SeenMsgIds,
97}
98
99impl EncryptedSession {
100    /// Create a new encrypted session from the output of `authentication::finish`.
101    ///
102    /// `seen_msg_ids` should be the persistent ring owned by the `DcConnection`
103    /// (or any other owner that outlives individual sessions).  Pass
104    /// `new_seen_msg_ids()` for the very first connection on a slot.
105    pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
106        Self::with_seen(auth_key, first_salt, time_offset, new_seen_msg_ids())
107    }
108
109    /// Like `new` but reuses an existing seen-msg_id ring (reconnect path).
110    pub fn with_seen(
111        auth_key: [u8; 256],
112        first_salt: i64,
113        time_offset: i32,
114        seen_msg_ids: SeenMsgIds,
115    ) -> Self {
116        let mut rnd = [0u8; 8];
117        getrandom::getrandom(&mut rnd).expect("getrandom");
118        Self {
119            auth_key: AuthKey::from_bytes(auth_key),
120            session_id: i64::from_le_bytes(rnd),
121            sequence: 0,
122            last_msg_id: 0,
123            salt: first_salt,
124            time_offset,
125            seen_msg_ids,
126        }
127    }
128
129    /// Return a clone of the shared seen-msg_id ring for passing to a
130    /// replacement session on reconnect.
131    pub fn seen_msg_ids(&self) -> SeenMsgIds {
132        std::sync::Arc::clone(&self.seen_msg_ids)
133    }
134
135    /// Compute the next message ID (based on corrected server time).
136    fn next_msg_id(&mut self) -> i64 {
137        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
138        // Keep arithmetic in u64: seconds since epoch with time_offset applied.
139        let secs = now.as_secs().wrapping_add(self.time_offset as i64 as u64);
140        let nanos = now.subsec_nanos() as u64;
141        let mut id = ((secs << 32) | (nanos << 2)) as i64;
142        if self.last_msg_id >= id {
143            id = self.last_msg_id + 4;
144        }
145        self.last_msg_id = id;
146        id
147    }
148
149    /// Next content-related seq_no (odd) and advance the counter.
150    /// Used for all regular RPC requests.
151    fn next_seq_no(&mut self) -> i32 {
152        let n = self.sequence * 2 + 1;
153        self.sequence += 1;
154        n
155    }
156
157    /// Return the current even seq_no WITHOUT advancing the counter.
158    ///
159    /// Service messages (MsgsAck, containers, etc.) MUST use an even seqno
160    /// per the MTProto spec so the server does not expect a reply.
161    pub fn next_seq_no_ncr(&self) -> i32 {
162        self.sequence * 2
163    }
164
165    /// Correct the outgoing sequence counter when the server reports a
166    /// `bad_msg_notification` with error codes 32 (seq_no too low) or
167    /// 33 (seq_no too high).
168    ///
169    pub fn correct_seq_no(&mut self, code: u32) {
170        match code {
171            32 => {
172                // seq_no too low: jump forward so next send is well above server expectation
173                self.sequence += 64;
174                log::debug!(
175                    "[ferogram] seq_no correction: code 32, bumped seq to {}",
176                    self.sequence
177                );
178            }
179            33 => {
180                // seq_no too high: step back, but never below 1 to avoid
181                // re-using seq_no=1 which was already sent this session.
182                // Zeroing would make the next content message get seq_no=1,
183                // which the server already saw and will reject again with code 32.
184                self.sequence = self.sequence.saturating_sub(16).max(1);
185                log::debug!(
186                    "[ferogram] seq_no correction: code 33, lowered seq to {}",
187                    self.sequence
188                );
189            }
190            _ => {}
191        }
192    }
193
194    /// Undo the last `next_seq_no` increment.
195    ///
196    /// Called before retrying a request after `bad_server_salt` so the resent
197    /// message uses the same seq_no slot rather than advancing the counter a
198    /// second time (which would produce seq_no too high → bad_msg_notification
199    /// code 33 → server closes TCP → early eof).
200    pub fn undo_seq_no(&mut self) {
201        self.sequence = self.sequence.saturating_sub(1);
202    }
203
204    /// Re-derive the clock skew from a server-provided `msg_id`.
205    ///
206    /// Called on `bad_msg_notification` error codes 16 (msg_id too low) and
207    /// 17 (msg_id too high) so clock drift is corrected at any point in the
208    /// session, not only at connect time.
209    ///
210    pub fn correct_time_offset(&mut self, server_msg_id: i64) {
211        // Upper 32 bits of msg_id = Unix seconds on the server
212        let server_time = (server_msg_id >> 32) as i32;
213        let local_now = SystemTime::now()
214            .duration_since(UNIX_EPOCH)
215            .unwrap()
216            .as_secs() as i32;
217        let new_offset = server_time.wrapping_sub(local_now);
218        log::debug!(
219            "[ferogram] time_offset correction: {} → {} (server_time={server_time})",
220            self.time_offset,
221            new_offset
222        );
223        self.time_offset = new_offset;
224        // Seed last_msg_id from the server's msg_id (bits 1-0 cleared to 0b00)
225        // so the next next_msg_id() call produces a strictly larger value.
226        self.last_msg_id = (server_msg_id & !0x3i64).max(self.last_msg_id);
227    }
228
229    /// Allocate a fresh `(msg_id, seqno)` pair for an inner container message
230    /// WITHOUT encrypting anything.
231    ///
232    /// `content_related = true`  → odd seqno, advances counter  (regular RPCs)
233    /// `content_related = false` → even seqno, no advance       (MsgsAck, container)
234    ///
235    pub fn alloc_msg_seqno(&mut self, content_related: bool) -> (i64, i32) {
236        let msg_id = self.next_msg_id();
237        let seqno = if content_related {
238            self.next_seq_no()
239        } else {
240            self.next_seq_no_ncr()
241        };
242        (msg_id, seqno)
243    }
244
245    /// Encrypt a pre-serialized TL body into a wire-ready MTProto frame.
246    ///
247    /// `content_related` controls whether the seqno is odd (content, advances
248    /// the counter) or even (service, no advance).
249    ///
250    /// Returns `(encrypted_wire_bytes, msg_id)`.
251    /// Used for (bad_msg re-send) and (container inner messages).
252    pub fn pack_body_with_msg_id(&mut self, body: &[u8], content_related: bool) -> (Vec<u8>, i64) {
253        let msg_id = self.next_msg_id();
254        let seq_no = if content_related {
255            self.next_seq_no()
256        } else {
257            self.next_seq_no_ncr()
258        };
259
260        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
261        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
262        buf.extend(self.salt.to_le_bytes());
263        buf.extend(self.session_id.to_le_bytes());
264        buf.extend(msg_id.to_le_bytes());
265        buf.extend(seq_no.to_le_bytes());
266        buf.extend((body.len() as u32).to_le_bytes());
267        buf.extend(body.iter().copied());
268
269        encrypt_data_v2(&mut buf, &self.auth_key);
270        (buf.as_ref().to_vec(), msg_id)
271    }
272
273    /// Encrypt a pre-built `msg_container` body (the container itself is
274    /// a non-content-related message with an even seqno).
275    ///
276    /// Returns `(encrypted_wire_bytes, container_msg_id)`.
277    /// The container_msg_id is needed so callers can map it back to inner
278    /// requests when a bad_msg_notification or bad_server_salt arrives for
279    /// the container rather than the individual inner message.
280    ///
281    pub fn pack_container(&mut self, container_body: &[u8]) -> (Vec<u8>, i64) {
282        self.pack_body_with_msg_id(container_body, false)
283    }
284
285    /// Encrypt `body` using a **caller-supplied** `msg_id` instead of generating one.
286    ///
287    /// Required by `auth.bindTempAuthKey`, which must use the same `msg_id`
288    /// in both the outer MTProto envelope and the inner `bind_auth_key_inner`.
289    pub fn pack_body_at_msg_id(&mut self, body: &[u8], msg_id: i64) -> Vec<u8> {
290        let seq_no = self.next_seq_no();
291        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
292        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
293        buf.extend(self.salt.to_le_bytes());
294        buf.extend(self.session_id.to_le_bytes());
295        buf.extend(msg_id.to_le_bytes());
296        buf.extend(seq_no.to_le_bytes());
297        buf.extend((body.len() as u32).to_le_bytes());
298        buf.extend(body.iter().copied());
299        encrypt_data_v2(&mut buf, &self.auth_key);
300        buf.as_ref().to_vec()
301    }
302
303    /// Serialize and encrypt a TL function into a wire-ready byte vector.
304    pub fn pack_serializable<S: ferogram_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
305        let body = call.to_bytes();
306        let msg_id = self.next_msg_id();
307        let seq_no = self.next_seq_no();
308
309        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
310        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
311        buf.extend(self.salt.to_le_bytes());
312        buf.extend(self.session_id.to_le_bytes());
313        buf.extend(msg_id.to_le_bytes());
314        buf.extend(seq_no.to_le_bytes());
315        buf.extend((body.len() as u32).to_le_bytes());
316        buf.extend(body.iter().copied());
317
318        encrypt_data_v2(&mut buf, &self.auth_key);
319        buf.as_ref().to_vec()
320    }
321
322    /// Like `pack_serializable` but also returns the `msg_id`.
323    pub fn pack_serializable_with_msg_id<S: ferogram_tl_types::Serializable>(
324        &mut self,
325        call: &S,
326    ) -> (Vec<u8>, i64) {
327        let body = call.to_bytes();
328        let msg_id = self.next_msg_id();
329        let seq_no = self.next_seq_no();
330        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
331        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
332        buf.extend(self.salt.to_le_bytes());
333        buf.extend(self.session_id.to_le_bytes());
334        buf.extend(msg_id.to_le_bytes());
335        buf.extend(seq_no.to_le_bytes());
336        buf.extend((body.len() as u32).to_le_bytes());
337        buf.extend(body.iter().copied());
338        encrypt_data_v2(&mut buf, &self.auth_key);
339        (buf.as_ref().to_vec(), msg_id)
340    }
341
342    /// Like [`pack`] but also returns the `msg_id` allocated for this message.
343    pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
344        let body = call.to_bytes();
345        let msg_id = self.next_msg_id();
346        let seq_no = self.next_seq_no();
347        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
348        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
349        buf.extend(self.salt.to_le_bytes());
350        buf.extend(self.session_id.to_le_bytes());
351        buf.extend(msg_id.to_le_bytes());
352        buf.extend(seq_no.to_le_bytes());
353        buf.extend((body.len() as u32).to_le_bytes());
354        buf.extend(body.iter().copied());
355        encrypt_data_v2(&mut buf, &self.auth_key);
356        (buf.as_ref().to_vec(), msg_id)
357    }
358
359    /// Encrypt and frame a [`RemoteCall`] into a ready-to-send MTProto message.
360    pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
361        let body = call.to_bytes();
362        let msg_id = self.next_msg_id();
363        let seq_no = self.next_seq_no();
364
365        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
366        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
367        buf.extend(self.salt.to_le_bytes());
368        buf.extend(self.session_id.to_le_bytes());
369        buf.extend(msg_id.to_le_bytes());
370        buf.extend(seq_no.to_le_bytes());
371        buf.extend((body.len() as u32).to_le_bytes());
372        buf.extend(body.iter().copied());
373
374        encrypt_data_v2(&mut buf, &self.auth_key);
375        buf.as_ref().to_vec()
376    }
377
378    /// Decrypt an encrypted server frame.
379    pub fn unpack(&self, frame: &mut [u8]) -> Result<DecryptedMessage, DecryptError> {
380        let plaintext = decrypt_data_v2(frame, &self.auth_key).map_err(DecryptError::Crypto)?;
381
382        if plaintext.len() < 32 {
383            return Err(DecryptError::FrameTooShort);
384        }
385
386        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
387        let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
388        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
389        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
390        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
391
392        if session_id != self.session_id {
393            return Err(DecryptError::SessionMismatch);
394        }
395
396        // Check server time (upper 32 bits of msg_id) against ±300 s window.
397        // Warn and continue: clock self-corrects via bad_msg_notification or pong.
398        let server_secs = (msg_id as u64 >> 32) as i64;
399        let now = SystemTime::now()
400            .duration_since(UNIX_EPOCH)
401            .unwrap()
402            .as_secs() as i64;
403        let corrected = now + self.time_offset as i64;
404        if (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS {
405            log::warn!(
406                "[ferogram] msg_id time-window violation: server_secs={server_secs} \
407                 corrected_local={corrected} skew={}s: processing anyway, \
408                 clock will self-correct via bad_msg_notification/pong",
409                (server_secs - corrected).abs()
410            );
411        }
412
413        // Rolling 500-entry dedup.
414        {
415            let mut seen = self.seen_msg_ids.lock().unwrap();
416            if seen.contains(&msg_id) {
417                return Err(DecryptError::DuplicateMsgId);
418            }
419            seen.push_back(msg_id);
420            if seen.len() > SEEN_MSG_IDS_MAX {
421                seen.pop_front();
422            }
423        }
424
425        // Maximum body length: 16 MB.
426        if body_len > 16 * 1024 * 1024 {
427            return Err(DecryptError::FrameTooShort);
428        }
429        if 32 + body_len > plaintext.len() {
430            return Err(DecryptError::FrameTooShort);
431        }
432        // MTProto 2.0: minimum padding is 12 bytes; no upper bound.
433        let padding = plaintext.len() - 32 - body_len;
434        if padding < 12 {
435            return Err(DecryptError::FrameTooShort);
436        }
437        let body = plaintext[32..32 + body_len].to_vec();
438
439        Ok(DecryptedMessage {
440            salt,
441            session_id,
442            msg_id,
443            seq_no,
444            body,
445        })
446    }
447
448    /// Return the auth_key bytes (for persistence).
449    pub fn auth_key_bytes(&self) -> [u8; 256] {
450        self.auth_key.to_bytes()
451    }
452
453    /// Return the current session_id.
454    pub fn session_id(&self) -> i64 {
455        self.session_id
456    }
457
458    /// Reset session state: new random session_id, zeroed seq_no and last_msg_id,
459    /// cleared dedup buffer.
460    ///
461    /// Called on `bad_msg_notification` error codes 32/33 (seq_no mismatch).
462    /// Creates a new session_id and resets seq_no to avoid persistent desync.
463    pub fn reset_session(&mut self) {
464        let mut rnd = [0u8; 8];
465        getrandom::getrandom(&mut rnd).expect("getrandom");
466        let old_session = self.session_id;
467        self.session_id = i64::from_le_bytes(rnd);
468        self.sequence = 0;
469        self.last_msg_id = 0;
470        // Do not clear seen_msg_ids: the ring is shared with the owning
471        // DcConnection and must survive session resets to reject replayed frames.
472        log::debug!(
473            "[ferogram] session reset: {:#018x} → {:#018x}",
474            old_session,
475            self.session_id
476        );
477    }
478}
479
480impl EncryptedSession {
481    /// Like [`decrypt_frame`] but also performs seen-msg_id deduplication using the
482    /// supplied ring.  Pass `&self.inner.seen_msg_ids` from the client.
483    pub fn decrypt_frame_dedup(
484        auth_key: &[u8; 256],
485        session_id: i64,
486        frame: &mut [u8],
487        seen: &SeenMsgIds,
488    ) -> Result<DecryptedMessage, DecryptError> {
489        let msg = Self::decrypt_frame_with_offset(auth_key, session_id, frame, 0)?;
490        {
491            let mut s = seen.lock().unwrap();
492            if s.contains(&msg.msg_id) {
493                return Err(DecryptError::DuplicateMsgId);
494            }
495            s.push_back(msg.msg_id);
496            if s.len() > SEEN_MSG_IDS_MAX {
497                s.pop_front();
498            }
499        }
500        Ok(msg)
501    }
502
503    /// Decrypt a frame using explicit key + session_id: no mutable state needed.
504    /// Used by the split-reader task so it can decrypt without locking the writer.
505    /// `time_offset` is the session's current clock skew (seconds); pass 0 if unknown.
506    pub fn decrypt_frame(
507        auth_key: &[u8; 256],
508        session_id: i64,
509        frame: &mut [u8],
510    ) -> Result<DecryptedMessage, DecryptError> {
511        Self::decrypt_frame_with_offset(auth_key, session_id, frame, 0)
512    }
513
514    /// Like [`decrypt_frame`] but applies the time-window check with the given
515    /// `time_offset` (seconds, server_time − local_time).
516    pub fn decrypt_frame_with_offset(
517        auth_key: &[u8; 256],
518        session_id: i64,
519        frame: &mut [u8],
520        time_offset: i32,
521    ) -> Result<DecryptedMessage, DecryptError> {
522        let key = AuthKey::from_bytes(*auth_key);
523        let plaintext = decrypt_data_v2(frame, &key).map_err(DecryptError::Crypto)?;
524        if plaintext.len() < 32 {
525            return Err(DecryptError::FrameTooShort);
526        }
527        let salt = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
528        let sid = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
529        let msg_id = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
530        let seq_no = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
531        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
532        if sid != session_id {
533            return Err(DecryptError::SessionMismatch);
534        }
535        // Warn but continue: clock self-corrects via bad_msg_notification or pong.
536        let server_secs = (msg_id as u64 >> 32) as i64;
537        let now = SystemTime::now()
538            .duration_since(UNIX_EPOCH)
539            .unwrap()
540            .as_secs() as i64;
541        let corrected = now + time_offset as i64;
542        if (server_secs - corrected).abs() > MSG_ID_TIME_WINDOW_SECS {
543            log::warn!(
544                "[ferogram] msg_id time-window violation (split-reader): server_secs={server_secs} \
545                 corrected_local={corrected} skew={}s: processing anyway",
546                (server_secs - corrected).abs()
547            );
548        }
549        // Maximum body length: 16 MB.
550        if body_len > 16 * 1024 * 1024 {
551            return Err(DecryptError::FrameTooShort);
552        }
553        if 32 + body_len > plaintext.len() {
554            return Err(DecryptError::FrameTooShort);
555        }
556        // MTProto 2.0: minimum padding is 12 bytes; no upper bound.
557        let padding = plaintext.len() - 32 - body_len;
558        if padding < 12 {
559            return Err(DecryptError::FrameTooShort);
560        }
561        let body = plaintext[32..32 + body_len].to_vec();
562        Ok(DecryptedMessage {
563            salt,
564            session_id: sid,
565            msg_id,
566            seq_no,
567            body,
568        })
569    }
570}