Skip to main content

layer_mtproto/
encrypted.rs

1//! Encrypted MTProto 2.0 session (post auth-key).
2//!
3//! Once you have a `Finished` from [`crate::authentication`], construct an
4//! [`EncryptedSession`] and use it to serialize/deserialize all subsequent
5//! messages.
6
7use std::time::{SystemTime, UNIX_EPOCH};
8
9use layer_crypto::{AuthKey, DequeBuffer, decrypt_data_v2, encrypt_data_v2};
10use layer_tl_types::RemoteCall;
11
12
13/// Errors that can occur when decrypting a server message.
14#[derive(Debug)]
15pub enum DecryptError {
16    /// The underlying crypto layer rejected the message.
17    Crypto(layer_crypto::DecryptError),
18    /// The decrypted inner message was too short to contain a valid header.
19    FrameTooShort,
20    /// Session-ID mismatch (possible replay or wrong connection).
21    SessionMismatch,
22}
23
24impl std::fmt::Display for DecryptError {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            Self::Crypto(e) => write!(f, "crypto: {e}"),
28            Self::FrameTooShort => write!(f, "inner plaintext too short"),
29            Self::SessionMismatch => write!(f, "session_id mismatch"),
30        }
31    }
32}
33impl std::error::Error for DecryptError {}
34
35/// The inner payload extracted from a successfully decrypted server frame.
36pub struct DecryptedMessage {
37    /// `salt` sent by the server.
38    pub salt:       i64,
39    /// The `session_id` from the frame.
40    pub session_id: i64,
41    /// The `msg_id` of the inner message.
42    pub msg_id:     i64,
43    /// `seq_no` of the inner message.
44    pub seq_no:     i32,
45    /// TL-serialized body of the inner message.
46    pub body:       Vec<u8>,
47}
48
49/// MTProto 2.0 encrypted session state.
50///
51/// Wraps an `AuthKey` and tracks per-session counters (session_id, seq_no,
52/// last_msg_id, server salt).  Use [`EncryptedSession::pack`] to encrypt
53/// outgoing requests and [`EncryptedSession::unpack`] to decrypt incoming
54/// server frames.
55pub struct EncryptedSession {
56    auth_key:    AuthKey,
57    session_id:  i64,
58    sequence:    i32,
59    last_msg_id: i64,
60    /// Current server salt to include in outgoing messages.
61    pub salt:    i64,
62    /// Clock skew in seconds vs. server.
63    pub time_offset: i32,
64}
65
66impl EncryptedSession {
67    /// Create a new encrypted session from the output of `authentication::finish`.
68    pub fn new(auth_key: [u8; 256], first_salt: i64, time_offset: i32) -> Self {
69        let mut rnd = [0u8; 8];
70        getrandom::getrandom(&mut rnd).expect("getrandom");
71        Self {
72            auth_key: AuthKey::from_bytes(auth_key),
73            session_id: i64::from_le_bytes(rnd),
74            sequence: 0,
75            last_msg_id: 0,
76            salt: first_salt,
77            time_offset,
78        }
79    }
80
81    /// Compute the next message ID (based on corrected server time).
82    fn next_msg_id(&mut self) -> i64 {
83        let now = SystemTime::now()
84            .duration_since(UNIX_EPOCH).unwrap();
85        let secs = (now.as_secs() as i32).wrapping_add(self.time_offset) as u64;
86        let nanos = now.subsec_nanos() as u64;
87        let mut id = ((secs << 32) | (nanos << 2)) as i64;
88        if self.last_msg_id >= id { id = self.last_msg_id + 4; }
89        self.last_msg_id = id;
90        id
91    }
92
93    /// Next content-related seq_no (odd) and advance the counter.
94    fn next_seq_no(&mut self) -> i32 {
95        let n = self.sequence * 2 + 1;
96        self.sequence += 1;
97        n
98    }
99
100    /// Serialize and encrypt a TL function into a wire-ready byte vector.
101    ///
102    /// Layout of the plaintext before encryption:
103    /// ```text
104    /// salt:       i64
105    /// session_id: i64
106    /// msg_id:     i64
107    /// seq_no:     i32
108    /// body_len:   i32
109    /// body:       [u8; body_len]
110    /// ```
111    /// Like `pack` but only requires `Serializable` (not `RemoteCall`).
112    /// Useful for generic wrapper types like `InvokeWithLayer<InitConnection<X>>`
113    /// where the return type is determined by the inner call, not the wrapper.
114    pub fn pack_serializable<S: layer_tl_types::Serializable>(&mut self, call: &S) -> Vec<u8> {
115        let body = call.to_bytes();
116        let msg_id = self.next_msg_id();
117        let seq_no = self.next_seq_no();
118
119        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
120        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
121        buf.extend(self.salt.to_le_bytes());
122        buf.extend(self.session_id.to_le_bytes());
123        buf.extend(msg_id.to_le_bytes());
124        buf.extend(seq_no.to_le_bytes());
125        buf.extend((body.len() as u32).to_le_bytes());
126        buf.extend(body.iter().copied());
127
128        encrypt_data_v2(&mut buf, &self.auth_key);
129        buf.as_ref().to_vec()
130    }
131
132
133    /// Like [`pack_serializable`] but also returns the `msg_id`.
134    /// Used by the split-writer path for write RPCs (Serializable but not RemoteCall).
135    pub fn pack_serializable_with_msg_id<S: layer_tl_types::Serializable>(&mut self, call: &S) -> (Vec<u8>, i64) {
136        let body   = call.to_bytes();
137        let msg_id = self.next_msg_id();
138        let seq_no = self.next_seq_no();
139        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
140        let mut buf   = DequeBuffer::with_capacity(inner_len, 32);
141        buf.extend(self.salt.to_le_bytes());
142        buf.extend(self.session_id.to_le_bytes());
143        buf.extend(msg_id.to_le_bytes());
144        buf.extend(seq_no.to_le_bytes());
145        buf.extend((body.len() as u32).to_le_bytes());
146        buf.extend(body.iter().copied());
147        encrypt_data_v2(&mut buf, &self.auth_key);
148        (buf.as_ref().to_vec(), msg_id)
149    }
150
151    /// Like [`pack`] but also returns the `msg_id` allocated for this message.
152    ///
153    /// Used by the async client to register a pending RPC reply channel keyed
154    /// by `msg_id` *before* sending the packet.
155    pub fn pack_with_msg_id<R: RemoteCall>(&mut self, call: &R) -> (Vec<u8>, i64) {
156        let body   = call.to_bytes();
157        let msg_id = self.next_msg_id();
158        let seq_no = self.next_seq_no();
159        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
160        let mut buf   = DequeBuffer::with_capacity(inner_len, 32);
161        buf.extend(self.salt.to_le_bytes());
162        buf.extend(self.session_id.to_le_bytes());
163        buf.extend(msg_id.to_le_bytes());
164        buf.extend(seq_no.to_le_bytes());
165        buf.extend((body.len() as u32).to_le_bytes());
166        buf.extend(body.iter().copied());
167        encrypt_data_v2(&mut buf, &self.auth_key);
168        (buf.as_ref().to_vec(), msg_id)
169    }
170
171    /// Encrypt and frame a [`RemoteCall`] into a ready-to-send MTProto message.
172    ///
173    /// Returns the encrypted bytes to pass directly to the transport layer.
174    pub fn pack<R: RemoteCall>(&mut self, call: &R) -> Vec<u8> {
175        let body = call.to_bytes();
176        let msg_id = self.next_msg_id();
177        let seq_no = self.next_seq_no();
178
179        // Build plaintext inner payload
180        let inner_len = 8 + 8 + 8 + 4 + 4 + body.len();
181        // Front capacity = 32 for auth_key_id + msg_key
182        let mut buf = DequeBuffer::with_capacity(inner_len, 32);
183        buf.extend(self.salt.to_le_bytes());
184        buf.extend(self.session_id.to_le_bytes());
185        buf.extend(msg_id.to_le_bytes());
186        buf.extend(seq_no.to_le_bytes());
187        buf.extend((body.len() as u32).to_le_bytes());
188        buf.extend(body.iter().copied());
189
190        encrypt_data_v2(&mut buf, &self.auth_key);
191        buf.as_ref().to_vec()
192    }
193
194    /// Decrypt an encrypted server frame.
195    ///
196    /// `frame` should be a raw frame received from the transport (already
197    /// stripped of the abridged-length prefix).
198    pub fn unpack(&self, frame: &mut Vec<u8>) -> Result<DecryptedMessage, DecryptError> {
199        let plaintext = decrypt_data_v2(frame, &self.auth_key)
200            .map_err(DecryptError::Crypto)?;
201
202        // inner: salt(8) + session_id(8) + msg_id(8) + seq_no(4) + len(4) + body
203        if plaintext.len() < 32 {
204            return Err(DecryptError::FrameTooShort);
205        }
206
207        let salt       = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
208        let session_id = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
209        let msg_id     = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
210        let seq_no     = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
211        let body_len   = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
212
213        if session_id != self.session_id {
214            return Err(DecryptError::SessionMismatch);
215        }
216
217        let body = plaintext[32..32 + body_len.min(plaintext.len() - 32)].to_vec();
218
219        Ok(DecryptedMessage { salt, session_id, msg_id, seq_no, body })
220    }
221
222    /// Return the auth_key bytes (for persistence).
223    pub fn auth_key_bytes(&self) -> [u8; 256] { self.auth_key.to_bytes() }
224
225    /// Return the current session_id.
226    pub fn session_id(&self) -> i64 { self.session_id }
227}
228
229impl EncryptedSession {
230    /// Decrypt a frame using explicit key + session_id — no mutable state needed.
231    /// Used by the split-reader task so it can decrypt without locking the writer.
232    pub fn decrypt_frame(
233        auth_key:   &[u8; 256],
234        session_id: i64,
235        frame:      &mut Vec<u8>,
236    ) -> Result<DecryptedMessage, DecryptError> {
237        let key = AuthKey::from_bytes(*auth_key);
238        let plaintext = decrypt_data_v2(frame, &key)
239            .map_err(DecryptError::Crypto)?;
240        if plaintext.len() < 32 {
241            return Err(DecryptError::FrameTooShort);
242        }
243        let salt     = i64::from_le_bytes(plaintext[..8].try_into().unwrap());
244        let sid      = i64::from_le_bytes(plaintext[8..16].try_into().unwrap());
245        let msg_id   = i64::from_le_bytes(plaintext[16..24].try_into().unwrap());
246        let seq_no   = i32::from_le_bytes(plaintext[24..28].try_into().unwrap());
247        let body_len = u32::from_le_bytes(plaintext[28..32].try_into().unwrap()) as usize;
248        if sid != session_id {
249            return Err(DecryptError::SessionMismatch);
250        }
251        let body = plaintext[32..32 + body_len.min(plaintext.len() - 32)].to_vec();
252        Ok(DecryptedMessage { salt, session_id: sid, msg_id, seq_no, body })
253    }
254}