Skip to main content

ferogram_connect/
connection.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2//
3// ferogram: async Telegram MTProto client in Rust
4// https://github.com/ankit-chaubey/ferogram
5//
6// Licensed under either the MIT License or the Apache License 2.0.
7// See the LICENSE-MIT or LICENSE-APACHE file in this repository:
8// https://github.com/ankit-chaubey/ferogram
9//
10// Feel free to use, modify, and share this code.
11// Please keep this notice when redistributing.
12
13use std::sync::Arc;
14use std::time::Duration;
15
16use socket2::TcpKeepalive;
17use tokio::io::AsyncWriteExt;
18use tokio::net::TcpStream;
19use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
20
21use ferogram_mtproto::{EncryptedSession, Session, authentication as auth};
22use ferogram_tl_types as tl;
23
24use crate::envelope::decode_bind_response;
25use crate::error::ConnectError;
26use crate::frame::{recv_frame_plain, send_frame};
27use crate::transport::recv_raw_frame;
28use crate::transport_kind::TransportKind;
29
30pub const PING_DELAY_SECS: u64 = 60;
31pub const NO_PING_DISCONNECT: i32 = 75;
32
33const TCP_KEEPALIVE_IDLE_SECS: u64 = 10;
34const TCP_KEEPALIVE_INTERVAL_SECS: u64 = 5;
35#[cfg(not(target_os = "windows"))]
36const TCP_KEEPALIVE_PROBES: u32 = 3;
37
38/// How framing bytes are sent/received on a connection.
39///
40/// `Obfuscated` carries an `Arc<Mutex<ObfuscatedCipher>>` so the same cipher
41/// state is shared (safely) between the writer task (TX / `encrypt`) and the
42/// reader task (RX / `decrypt`).  The two directions are separate AES-CTR
43/// instances inside `ObfuscatedCipher`, so locking is only needed to prevent
44/// concurrent mutation of the struct, not to serialise TX vs RX.
45#[derive(Clone)]
46pub enum FrameKind {
47    Abridged,
48    Intermediate,
49    #[allow(dead_code)]
50    Full {
51        send_seqno: Arc<std::sync::atomic::AtomicU32>,
52        recv_seqno: Arc<std::sync::atomic::AtomicU32>,
53    },
54    /// Obfuscated2 over Abridged framing.
55    Obfuscated {
56        cipher: std::sync::Arc<tokio::sync::Mutex<ferogram_crypto::ObfuscatedCipher>>,
57    },
58    /// Obfuscated2 over Intermediate+padding framing (`0xDD` MTProxy).
59    PaddedIntermediate {
60        cipher: std::sync::Arc<tokio::sync::Mutex<ferogram_crypto::ObfuscatedCipher>>,
61    },
62    /// FakeTLS framing (`0xEE` MTProxy).
63    FakeTls {
64        cipher: std::sync::Arc<tokio::sync::Mutex<ferogram_crypto::ObfuscatedCipher>>,
65    },
66}
67
68/// Write half of a split connection.  Held under `Mutex` in `ClientInner`.
69/// A single server-provided salt with its validity window.
70///
71#[derive(Clone, Debug)]
72pub struct FutureSalt {
73    pub valid_since: i32,
74    pub valid_until: i32,
75    pub salt: i64,
76}
77
78/// Delay (seconds) before a salt is considered usable after its `valid_since`.
79///
80pub const SALT_USE_DELAY: i32 = 60;
81
82/// Owns the EncryptedSession (for packing) and the pending-RPC map.
83pub struct ConnectionWriter {
84    pub enc: EncryptedSession,
85    pub frame_kind: FrameKind,
86    /// PFS: permanent auth key to save in session. None when PFS is off.
87    pub perm_auth_key: Option<[u8; 256]>,
88    /// msg_ids of received content messages waiting to be acked.
89    /// Drained into a MsgsAck on every outgoing frame (bundled into container
90    /// when sending an RPC, or sent standalone after route_frame).
91    pub pending_ack: Vec<i64>,
92    /// raw TL body bytes of every sent request, keyed by msg_id.
93    /// On bad_msg_notification the matching body is re-encrypted with a fresh
94    /// msg_id and re-sent transparently.
95    pub sent_bodies: std::collections::HashMap<i64, Vec<u8>>,
96    /// maps container_msg_id -> inner request msg_id.
97    /// When bad_msg_notification / bad_server_salt arrives for a container
98    /// rather than the individual inner message, we look here to find the
99    /// inner request to retry.
100    ///
101    pub container_map: std::collections::HashMap<i64, i64>,
102    /// Stale-msg resends queued under the writer lock and drained after release
103    /// to avoid holding the lock across TCP I/O.
104    pub new_session_resend_queue: Vec<(i64, i64, Vec<u8>)>,
105    /// -style future salt pool.
106    /// Sorted by valid_since ascending so the newest salt is LAST
107    /// (.valid_since), which puts
108    /// the highest valid_since at the end in ascending-key order).
109    pub salts: Vec<FutureSalt>,
110    /// Server-time anchor received with the last GetFutureSalts response.
111    /// (server_now, local_instant) lets us approximate server time at any
112    /// moment so we can check whether a salt's valid_since window has opened.
113    ///
114    pub start_salt_time: Option<(i32, std::time::Instant)>,
115}
116
117impl ConnectionWriter {
118    pub fn auth_key_bytes(&self) -> [u8; 256] {
119        self.perm_auth_key
120            .unwrap_or_else(|| self.enc.auth_key_bytes())
121    }
122    pub fn first_salt(&self) -> i64 {
123        self.enc.salt
124    }
125    pub fn time_offset(&self) -> i32 {
126        self.enc.time_offset
127    }
128
129    /// Proactively advance the active salt and prune expired ones.
130    ///
131    /// Called at the top of every RPC send.
132    /// Salts are sorted ascending by `valid_since` (oldest=index 0, newest=last).
133    ///
134    /// Prunes expired salts, then advances `enc.salt` to the freshest usable one.
135    ///
136    /// Returns `true` when the pool has shrunk to a single entry: caller should
137    /// fire a proactive `GetFutureSalts` via `try_request_salts()`.
138    pub fn advance_salt_if_needed(&mut self) -> bool {
139        let Some((server_now, start_instant)) = self.start_salt_time else {
140            return self.salts.len() <= 1;
141        };
142
143        // Approximate current server time.
144        let now = server_now + start_instant.elapsed().as_secs() as i32;
145
146        // Prune expired salts.
147        while self.salts.len() > 1 && now > self.salts[0].valid_until {
148            let expired = self.salts.remove(0);
149            tracing::debug!(
150                "[ferogram] salt {:#x} expired (valid_until={}), pruned",
151                expired.salt,
152                expired.valid_until,
153            );
154        }
155
156        // Advance to the freshest salt whose use-delay has opened AND
157        // which has not yet expired.  The `valid_until > now` guard is the
158        // critical safety: without it we can advance enc.salt to an already-
159        // expired entry from a stale FutureSalts pool, triggering immediate
160        // bad_server_salt rejection and re-entering the fetch loop.
161        if self.salts.len() > 1 {
162            let best = self
163                .salts
164                .iter()
165                .rev()
166                .find(|s| s.valid_since + SALT_USE_DELAY <= now && s.valid_until > now)
167                .map(|s| s.salt);
168            if let Some(salt) = best
169                && salt != self.enc.salt
170            {
171                tracing::debug!(
172                    "[ferogram] proactive salt cycle: {:#x} -> {:#x}",
173                    self.enc.salt,
174                    salt
175                );
176                self.enc.salt = salt;
177                // Prune salts whose valid_until has passed.
178                self.salts.retain(|s| s.valid_until > now);
179                if self.salts.is_empty() {
180                    // Safety net: keep a sentinel so we never go saltless.
181                    self.salts.push(FutureSalt {
182                        valid_since: 0,
183                        valid_until: i32::MAX,
184                        salt,
185                    });
186                }
187            }
188        }
189
190        self.salts.len() <= 1
191    }
192}
193
194pub struct Connection {
195    pub stream: TcpStream,
196    pub enc: EncryptedSession,
197    pub frame_kind: FrameKind,
198    /// When PFS is active, the permanent auth key (stored in session).
199    /// `enc` holds the temp key; this field holds the perm key so
200    /// `auth_key_bytes()` returns the right value to persist.
201    pub perm_auth_key: Option<[u8; 256]>,
202}
203
204impl Connection {
205    /// Open a TCP stream, optionally via SOCKS5, and apply transport init bytes.
206    async fn open_stream(
207        addr: &str,
208        socks5: Option<&crate::socks5::Socks5Config>,
209        transport: &TransportKind,
210        dc_id: i16,
211    ) -> Result<(TcpStream, FrameKind), ConnectError> {
212        let stream = match socks5 {
213            Some(proxy) => proxy.connect(addr).await?,
214            None => {
215                let stream = TcpStream::connect(addr).await.map_err(ConnectError::Io)?;
216                stream.set_nodelay(true).ok();
217                {
218                    let sock = socket2::SockRef::from(&stream);
219                    let keepalive = TcpKeepalive::new()
220                        .with_time(Duration::from_secs(TCP_KEEPALIVE_IDLE_SECS))
221                        .with_interval(Duration::from_secs(TCP_KEEPALIVE_INTERVAL_SECS));
222                    #[cfg(not(target_os = "windows"))]
223                    let keepalive = keepalive.with_retries(TCP_KEEPALIVE_PROBES);
224                    sock.set_tcp_keepalive(&keepalive).ok();
225                }
226                stream
227            }
228        };
229        Self::apply_transport_init(stream, transport, dc_id).await
230    }
231
232    /// Open a stream routed through an MTProxy (connects to proxy host:port,
233    /// not to the Telegram DC address).
234    async fn open_stream_mtproxy(
235        mtproxy: &crate::proxy::MtProxyConfig,
236        dc_id: i16,
237    ) -> Result<(TcpStream, FrameKind), ConnectError> {
238        let stream = mtproxy.connect().await?;
239        stream.set_nodelay(true).ok();
240        Self::apply_transport_init(stream, &mtproxy.transport, dc_id).await
241    }
242
243    async fn apply_transport_init(
244        mut stream: TcpStream,
245        transport: &TransportKind,
246        dc_id: i16,
247    ) -> Result<(TcpStream, FrameKind), ConnectError> {
248        match transport {
249            TransportKind::Abridged => {
250                stream.write_all(&[0xef]).await?;
251                Ok((stream, FrameKind::Abridged))
252            }
253            TransportKind::Intermediate => {
254                stream.write_all(&[0xee, 0xee, 0xee, 0xee]).await?;
255                Ok((stream, FrameKind::Intermediate))
256            }
257            TransportKind::Full => {
258                // Full transport has no init byte.
259                Ok((
260                    stream,
261                    FrameKind::Full {
262                        send_seqno: Arc::new(std::sync::atomic::AtomicU32::new(0)),
263                        recv_seqno: Arc::new(std::sync::atomic::AtomicU32::new(0)),
264                    },
265                ))
266            }
267            TransportKind::Obfuscated { secret } => {
268                use sha2::Digest;
269
270                // Random 64-byte nonce: retry until it passes the reserved-pattern
271                // Reject reserved nonce patterns that could be misidentified as HTTP
272                // or another MTProto framing tag by a proxy or DPI filter.
273                let mut nonce = [0u8; 64];
274                loop {
275                    getrandom::getrandom(&mut nonce)
276                        .map_err(|_| ConnectError::other("getrandom"))?;
277                    let first = u32::from_le_bytes(nonce[0..4].try_into().expect("4-byte slice"));
278                    let second = u32::from_le_bytes(nonce[4..8].try_into().expect("4-byte slice"));
279                    let bad = nonce[0] == 0xEF
280                        || first == 0x44414548 // HEAD
281                        || first == 0x54534F50 // POST
282                        || first == 0x20544547 // GET
283                        || first == 0x4954504f // OPTIONS
284                        || first == 0xEEEEEEEE
285                        || first == 0xDDDDDDDD
286                        || first == 0x02010316
287                        || second == 0x00000000;
288                    if !bad {
289                        break;
290                    }
291                }
292
293                // Key derivation from nonce[8..56]:
294                //   TX: key=nonce[8..40]  iv=nonce[40..56]
295                //   RX: key=rev[0..32]    iv=rev[32..48]   (rev = nonce[8..56] reversed)
296                // When an MTProxy secret is present, each 32-byte key becomes
297                // SHA-256(raw_key_slice || secret) for MTProxy key derivation.
298                let tx_raw: [u8; 32] = nonce[8..40].try_into().expect("32-byte slice");
299                let tx_iv: [u8; 16] = nonce[40..56].try_into().expect("16-byte slice");
300                let mut rev48 = nonce[8..56].to_vec();
301                rev48.reverse();
302                let rx_raw: [u8; 32] = rev48[0..32].try_into().expect("32-byte slice");
303                let rx_iv: [u8; 16] = rev48[32..48].try_into().expect("16-byte slice");
304
305                let (tx_key, rx_key): ([u8; 32], [u8; 32]) = if let Some(s) = secret {
306                    let mut h = sha2::Sha256::new();
307                    h.update(tx_raw);
308                    h.update(s.as_ref());
309                    let tx: [u8; 32] = h.finalize().into();
310
311                    let mut h = sha2::Sha256::new();
312                    h.update(rx_raw);
313                    h.update(s.as_ref());
314                    let rx: [u8; 32] = h.finalize().into();
315                    (tx, rx)
316                } else {
317                    (tx_raw, rx_raw)
318                };
319
320                // Stamp protocol id (Abridged = 0xEFEFEFEF) at nonce[56..60]
321                // and DC id as little-endian i16 at nonce[60..62].
322                nonce[56] = 0xef;
323                nonce[57] = 0xef;
324                nonce[58] = 0xef;
325                nonce[59] = 0xef;
326                let dc_bytes = dc_id.to_le_bytes();
327                nonce[60] = dc_bytes[0];
328                nonce[61] = dc_bytes[1];
329
330                // Encrypt nonce[56..64] in-place using the TX cipher advanced
331                // past the first 56 bytes (which are sent as plaintext).
332                //
333                // The same cipher instance must be used for both the nonce tail
334                // encryption and all subsequent TX data: AES-CTR is a single continuous
335                // stream; the TX position after encrypting the full 64-byte nonce is 64.
336                let mut cipher =
337                    ferogram_crypto::ObfuscatedCipher::from_keys(&tx_key, &tx_iv, &rx_key, &rx_iv);
338                // Advance TX past nonce[0..56] (sent as plaintext, not encrypted).
339                let mut skip = [0u8; 56];
340                cipher.encrypt(&mut skip);
341                // Encrypt nonce[56..64] in-place; cipher TX is now at position 64.
342                cipher.encrypt(&mut nonce[56..64]);
343
344                stream.write_all(&nonce).await?;
345
346                let cipher_arc = std::sync::Arc::new(tokio::sync::Mutex::new(cipher));
347                Ok((stream, FrameKind::Obfuscated { cipher: cipher_arc }))
348            }
349            TransportKind::PaddedIntermediate { secret } => {
350                use sha2::Digest;
351                let mut nonce = [0u8; 64];
352                loop {
353                    getrandom::getrandom(&mut nonce)
354                        .map_err(|_| ConnectError::other("getrandom"))?;
355                    let first = u32::from_le_bytes(nonce[0..4].try_into().expect("4-byte slice"));
356                    let second = u32::from_le_bytes(nonce[4..8].try_into().expect("4-byte slice"));
357                    let bad = nonce[0] == 0xEF
358                        || first == 0x44414548
359                        || first == 0x54534F50
360                        || first == 0x20544547
361                        || first == 0x4954504f
362                        || first == 0xEEEEEEEE
363                        || first == 0xDDDDDDDD
364                        || first == 0x02010316
365                        || second == 0x00000000;
366                    if !bad {
367                        break;
368                    }
369                }
370                let tx_raw: [u8; 32] = nonce[8..40].try_into().expect("32-byte slice");
371                let tx_iv: [u8; 16] = nonce[40..56].try_into().expect("16-byte slice");
372                let mut rev48 = nonce[8..56].to_vec();
373                rev48.reverse();
374                let rx_raw: [u8; 32] = rev48[0..32].try_into().expect("32-byte slice");
375                let rx_iv: [u8; 16] = rev48[32..48].try_into().expect("16-byte slice");
376                let (tx_key, rx_key): ([u8; 32], [u8; 32]) = if let Some(s) = secret {
377                    let mut h = sha2::Sha256::new();
378                    h.update(tx_raw);
379                    h.update(s.as_ref());
380                    let tx: [u8; 32] = h.finalize().into();
381                    let mut h = sha2::Sha256::new();
382                    h.update(rx_raw);
383                    h.update(s.as_ref());
384                    let rx: [u8; 32] = h.finalize().into();
385                    (tx, rx)
386                } else {
387                    (tx_raw, rx_raw)
388                };
389                // PaddedIntermediate tag = 0xDDDDDDDD
390                nonce[56] = 0xdd;
391                nonce[57] = 0xdd;
392                nonce[58] = 0xdd;
393                nonce[59] = 0xdd;
394                let dc_bytes = dc_id.to_le_bytes();
395                nonce[60] = dc_bytes[0];
396                nonce[61] = dc_bytes[1];
397                let mut cipher =
398                    ferogram_crypto::ObfuscatedCipher::from_keys(&tx_key, &tx_iv, &rx_key, &rx_iv);
399                let mut skip = [0u8; 56];
400                cipher.encrypt(&mut skip);
401                cipher.encrypt(&mut nonce[56..64]);
402                stream.write_all(&nonce).await?;
403                let cipher_arc = std::sync::Arc::new(tokio::sync::Mutex::new(cipher));
404                Ok((stream, FrameKind::PaddedIntermediate { cipher: cipher_arc }))
405            }
406            TransportKind::FakeTls { secret, domain } => {
407                // Fake TLS 1.3 ClientHello with HMAC-SHA256 random field.
408                // After the handshake, data flows as TLS Application Data records
409                // over a shared Obfuscated2 cipher seeded from the secret+HMAC.
410                let domain_bytes = domain.as_bytes();
411                let mut session_id = [0u8; 32];
412                getrandom::getrandom(&mut session_id)
413                    .map_err(|_| ConnectError::other("getrandom"))?;
414
415                // Build ClientHello body (random placeholder = zeros)
416                let cipher_suites: &[u8] = &[0x00, 0x04, 0x13, 0x01, 0x13, 0x02];
417                let compression: &[u8] = &[0x01, 0x00];
418                let sni_name_len = domain_bytes.len() as u16;
419                let sni_list_len = sni_name_len + 3;
420                let sni_ext_len = sni_list_len + 2;
421                let mut sni_ext = Vec::new();
422                sni_ext.extend_from_slice(&[0x00, 0x00]);
423                sni_ext.extend_from_slice(&sni_ext_len.to_be_bytes());
424                sni_ext.extend_from_slice(&sni_list_len.to_be_bytes());
425                sni_ext.push(0x00);
426                sni_ext.extend_from_slice(&sni_name_len.to_be_bytes());
427                sni_ext.extend_from_slice(domain_bytes);
428                let sup_ver: &[u8] = &[0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04];
429                let sup_grp: &[u8] = &[0x00, 0x0a, 0x00, 0x04, 0x00, 0x02, 0x00, 0x1d];
430                let sess_tick: &[u8] = &[0x00, 0x23, 0x00, 0x00];
431                let ext_body_len = sni_ext.len() + sup_ver.len() + sup_grp.len() + sess_tick.len();
432                let mut extensions = Vec::new();
433                extensions.extend_from_slice(&(ext_body_len as u16).to_be_bytes());
434                extensions.extend_from_slice(&sni_ext);
435                extensions.extend_from_slice(sup_ver);
436                extensions.extend_from_slice(sup_grp);
437                extensions.extend_from_slice(sess_tick);
438
439                let mut hello_body = Vec::new();
440                hello_body.extend_from_slice(&[0x03, 0x03]);
441                hello_body.extend_from_slice(&[0u8; 32]); // random placeholder
442                hello_body.push(session_id.len() as u8);
443                hello_body.extend_from_slice(&session_id);
444                hello_body.extend_from_slice(cipher_suites);
445                hello_body.extend_from_slice(compression);
446                hello_body.extend_from_slice(&extensions);
447
448                let hs_len = hello_body.len() as u32;
449                let mut handshake = vec![
450                    0x01,
451                    ((hs_len >> 16) & 0xff) as u8,
452                    ((hs_len >> 8) & 0xff) as u8,
453                    (hs_len & 0xff) as u8,
454                ];
455                handshake.extend_from_slice(&hello_body);
456
457                let rec_len = handshake.len() as u16;
458                let mut record = Vec::new();
459                record.push(0x16);
460                record.extend_from_slice(&[0x03, 0x01]);
461                record.extend_from_slice(&rec_len.to_be_bytes());
462                record.extend_from_slice(&handshake);
463
464                // HMAC-SHA256(secret, record) -> fill random field at offset 11
465                use sha2::Digest;
466                let random_offset = 5 + 4 + 2; // TLS-rec(5) + HS-hdr(4) + version(2)
467                let hmac_result: [u8; 32] = {
468                    use hmac::{Hmac, Mac};
469                    type HmacSha256 = Hmac<sha2::Sha256>;
470                    let mut mac = HmacSha256::new_from_slice(secret)
471                        .map_err(|_| ConnectError::other("HMAC key error"))?;
472                    mac.update(&record);
473                    mac.finalize().into_bytes().into()
474                };
475                record[random_offset..random_offset + 32].copy_from_slice(&hmac_result);
476                stream.write_all(&record).await?;
477
478                // Derive Obfuscated2 key from secret + HMAC
479                let mut h = sha2::Sha256::new();
480                h.update(secret.as_ref());
481                h.update(hmac_result);
482                let derived: [u8; 32] = h.finalize().into();
483                let iv = [0u8; 16];
484                let cipher =
485                    ferogram_crypto::ObfuscatedCipher::from_keys(&derived, &iv, &derived, &iv);
486                let cipher_arc = std::sync::Arc::new(tokio::sync::Mutex::new(cipher));
487                Ok((stream, FrameKind::FakeTls { cipher: cipher_arc }))
488            }
489            TransportKind::Http => {
490                // HTTP transport is handled in dc_pool - fall back to Abridged framing.
491                stream.write_all(&[0xef]).await?;
492                Ok((stream, FrameKind::Abridged))
493            }
494        }
495    }
496
497    pub async fn connect_raw(
498        addr: &str,
499        socks5: Option<&crate::socks5::Socks5Config>,
500        mtproxy: Option<&crate::proxy::MtProxyConfig>,
501        transport: &TransportKind,
502        dc_id: i16,
503    ) -> Result<Self, ConnectError> {
504        let t_label = match transport {
505            TransportKind::Abridged => "Abridged",
506            TransportKind::Obfuscated { .. } => "Obfuscated",
507            TransportKind::PaddedIntermediate { .. } => "PaddedIntermediate",
508            TransportKind::Http => "Http",
509            TransportKind::Intermediate => "Intermediate",
510            TransportKind::Full => "Full",
511            TransportKind::FakeTls { .. } => "FakeTls",
512        };
513        tracing::debug!("[ferogram] Connecting to {addr} ({t_label}) DH …");
514
515        let addr2 = addr.to_string();
516        let socks5_c = socks5.cloned();
517        let mtproxy_c = mtproxy.cloned();
518        let transport_c = transport.clone();
519
520        let fut = async move {
521            let (mut stream, frame_kind) = if let Some(ref mp) = mtproxy_c {
522                Self::open_stream_mtproxy(mp, dc_id).await?
523            } else {
524                Self::open_stream(&addr2, socks5_c.as_ref(), &transport_c, dc_id).await?
525            };
526
527            let mut plain = Session::new();
528
529            let (req1, s1) = auth::step1().map_err(|e| ConnectError::other(e.to_string()))?;
530            send_frame(
531                &mut stream,
532                &plain.pack(&req1).to_plaintext_bytes(),
533                &frame_kind,
534            )
535            .await?;
536            let res_pq: tl::enums::ResPq = recv_frame_plain(&mut stream, &frame_kind).await?;
537
538            let (req2, s2) = auth::step2(s1, res_pq, dc_id as i32)
539                .map_err(|e| ConnectError::other(e.to_string()))?;
540            send_frame(
541                &mut stream,
542                &plain.pack(&req2).to_plaintext_bytes(),
543                &frame_kind,
544            )
545            .await?;
546            let dh: tl::enums::ServerDhParams = recv_frame_plain(&mut stream, &frame_kind).await?;
547
548            let (req3, s3) = auth::step3(s2, dh).map_err(|e| ConnectError::other(e.to_string()))?;
549            send_frame(
550                &mut stream,
551                &plain.pack(&req3).to_plaintext_bytes(),
552                &frame_kind,
553            )
554            .await?;
555            let ans: tl::enums::SetClientDhParamsAnswer =
556                recv_frame_plain(&mut stream, &frame_kind).await?;
557
558            // Retry loop for dh_gen_retry (up to 5 attempts).
559            let done = {
560                let mut result =
561                    auth::finish(s3, ans).map_err(|e| ConnectError::other(e.to_string()))?;
562                let mut attempts = 0u8;
563                loop {
564                    match result {
565                        auth::FinishResult::Done(d) => break d,
566                        auth::FinishResult::Retry {
567                            retry_id,
568                            dh_params,
569                            nonce,
570                            server_nonce,
571                            new_nonce,
572                        } => {
573                            attempts += 1;
574                            if attempts >= 5 {
575                                return Err(ConnectError::other(
576                                    "dh_gen_retry exceeded 5 attempts",
577                                ));
578                            }
579                            let (req_retry, s3_retry) = auth::retry_step3(
580                                &dh_params,
581                                nonce,
582                                server_nonce,
583                                new_nonce,
584                                retry_id,
585                            )
586                            .map_err(|e| ConnectError::other(e.to_string()))?;
587                            send_frame(
588                                &mut stream,
589                                &plain.pack(&req_retry).to_plaintext_bytes(),
590                                &frame_kind,
591                            )
592                            .await?;
593                            let ans_retry: tl::enums::SetClientDhParamsAnswer =
594                                recv_frame_plain(&mut stream, &frame_kind).await?;
595                            result = auth::finish(s3_retry, ans_retry)
596                                .map_err(|e| ConnectError::other(e.to_string()))?;
597                        }
598                    }
599                }
600            };
601            tracing::debug!("[ferogram] DH complete ✓");
602
603            Ok::<Self, ConnectError>(Self {
604                stream,
605                enc: EncryptedSession::new(done.auth_key, done.first_salt, done.time_offset),
606                frame_kind,
607                perm_auth_key: None, // connect_raw produces the perm key itself
608            })
609        };
610
611        tokio::time::timeout(Duration::from_secs(15), fut)
612            .await
613            .map_err(|_| {
614                ConnectError::other(format!("DH handshake with {addr} timed out after 15 s"))
615            })?
616    }
617
618    #[allow(clippy::too_many_arguments)]
619    pub async fn connect_with_key(
620        addr: &str,
621        auth_key: [u8; 256],
622        first_salt: i64,
623        time_offset: i32,
624        socks5: Option<&crate::socks5::Socks5Config>,
625        mtproxy: Option<&crate::proxy::MtProxyConfig>,
626        transport: &TransportKind,
627        dc_id: i16,
628        pfs: bool,
629    ) -> Result<Self, ConnectError> {
630        let addr2 = addr.to_string();
631        let socks5_c = socks5.cloned();
632        let mtproxy_c = mtproxy.cloned();
633        let transport_c = transport.clone();
634
635        let fut = async move {
636            let (mut stream, frame_kind) = if let Some(ref mp) = mtproxy_c {
637                Self::open_stream_mtproxy(mp, dc_id).await?
638            } else {
639                Self::open_stream(&addr2, socks5_c.as_ref(), &transport_c, dc_id).await?
640            };
641            if pfs {
642                tracing::debug!("[ferogram] PFS: temp DH bind for DC{dc_id}");
643                match Self::do_pfs_bind(&mut stream, &frame_kind, &auth_key, dc_id).await {
644                    Ok(temp_enc) => {
645                        tracing::info!("[ferogram] PFS bind complete DC{dc_id}");
646                        return Ok(Self {
647                            stream,
648                            enc: temp_enc,
649                            frame_kind,
650                            perm_auth_key: Some(auth_key),
651                        });
652                    }
653                    Err(e) => {
654                        tracing::warn!(
655                            "[ferogram] PFS bind failed DC{dc_id} ({e}); falling back to perm key"
656                        );
657                        // Graceful fallback: reconnect because DH frames left the stream dirty.
658                        // Return error and let the caller handle retry without PFS.
659                        return Err(e);
660                    }
661                }
662            }
663            Ok::<Self, ConnectError>(Self {
664                stream,
665                enc: EncryptedSession::new(auth_key, first_salt, time_offset),
666                frame_kind,
667                perm_auth_key: None,
668            })
669        };
670
671        tokio::time::timeout(Duration::from_secs(30), fut)
672            .await
673            .map_err(|_| {
674                ConnectError::other(format!("connect_with_key to {addr} timed out after 30 s"))
675            })?
676    }
677
678    /// Perform a fresh temp-key DH on an already-open stream, then
679    /// send `auth.bindTempAuthKey` encrypted with the temp key.
680    /// Returns an `EncryptedSession` keyed with the bound temp key.
681    async fn do_pfs_bind(
682        stream: &mut TcpStream,
683        frame_kind: &FrameKind,
684        perm_auth_key: &[u8; 256],
685        dc_id: i16,
686    ) -> Result<EncryptedSession, ConnectError> {
687        use ferogram_mtproto::{
688            auth_key_id_from_key, encrypt_bind_inner, gen_msg_id, new_seen_msg_ids,
689            serialize_bind_temp_auth_key,
690        };
691        const TEMP_EXPIRES: i32 = 86_400; // 24 h
692
693        // temp-key DH
694        let mut plain = Session::new();
695
696        let (req1, s1) = auth::step1().map_err(|e| ConnectError::other(e.to_string()))?;
697        send_frame(stream, &plain.pack(&req1).to_plaintext_bytes(), frame_kind).await?;
698        let res_pq: tl::enums::ResPq = recv_frame_plain(stream, frame_kind).await?;
699
700        let (req2, s2) = ferogram_mtproto::step2_temp(s1, res_pq, dc_id as i32, TEMP_EXPIRES)
701            .map_err(|e| ConnectError::other(e.to_string()))?;
702        send_frame(stream, &plain.pack(&req2).to_plaintext_bytes(), frame_kind).await?;
703        let dh: tl::enums::ServerDhParams = recv_frame_plain(stream, frame_kind).await?;
704
705        let (req3, s3) = auth::step3(s2, dh).map_err(|e| ConnectError::other(e.to_string()))?;
706        send_frame(stream, &plain.pack(&req3).to_plaintext_bytes(), frame_kind).await?;
707        let ans: tl::enums::SetClientDhParamsAnswer = recv_frame_plain(stream, frame_kind).await?;
708
709        let done = {
710            let mut result =
711                auth::finish(s3, ans).map_err(|e| ConnectError::other(e.to_string()))?;
712            let mut attempts = 0u8;
713            loop {
714                match result {
715                    ferogram_mtproto::FinishResult::Done(d) => break d,
716                    ferogram_mtproto::FinishResult::Retry {
717                        retry_id,
718                        dh_params,
719                        nonce,
720                        server_nonce,
721                        new_nonce,
722                    } => {
723                        attempts += 1;
724                        if attempts >= 5 {
725                            return Err(ConnectError::other(
726                                "PFS temp DH retry exceeded 5 attempts",
727                            ));
728                        }
729                        let (rr, s3r) = ferogram_mtproto::retry_step3(
730                            &dh_params,
731                            nonce,
732                            server_nonce,
733                            new_nonce,
734                            retry_id,
735                        )
736                        .map_err(|e| ConnectError::other(e.to_string()))?;
737                        send_frame(stream, &plain.pack(&rr).to_plaintext_bytes(), frame_kind)
738                            .await?;
739                        let ar: tl::enums::SetClientDhParamsAnswer =
740                            recv_frame_plain(stream, frame_kind).await?;
741                        result = auth::finish(s3r, ar)
742                            .map_err(|e| ConnectError::other(e.to_string()))?;
743                    }
744                }
745            }
746        };
747
748        let temp_key = done.auth_key;
749        let temp_salt = done.first_salt;
750        let temp_offset = done.time_offset;
751
752        // build bindTempAuthKey body
753        let temp_key_id = auth_key_id_from_key(&temp_key);
754        let perm_key_id = auth_key_id_from_key(perm_auth_key);
755
756        let mut nonce_buf = [0u8; 8];
757        getrandom::getrandom(&mut nonce_buf).map_err(|_| ConnectError::other("getrandom nonce"))?;
758        let nonce = i64::from_le_bytes(nonce_buf);
759
760        let server_now = std::time::SystemTime::now()
761            .duration_since(std::time::UNIX_EPOCH)
762            .unwrap()
763            .as_secs() as i32
764            + temp_offset;
765        let expires_at = server_now + TEMP_EXPIRES;
766
767        let seen = new_seen_msg_ids();
768        let mut temp_enc = EncryptedSession::with_seen(temp_key, temp_salt, temp_offset, seen);
769        let temp_session_id = temp_enc.session_id();
770
771        let msg_id = gen_msg_id();
772        let enc_msg = encrypt_bind_inner(
773            perm_auth_key,
774            msg_id,
775            nonce,
776            temp_key_id,
777            perm_key_id,
778            temp_session_id,
779            expires_at,
780        );
781        let bind_body = serialize_bind_temp_auth_key(perm_key_id, nonce, expires_at, &enc_msg);
782
783        // send encrypted bind request
784        let wire = temp_enc.pack_body_at_msg_id(&bind_body, msg_id);
785        send_frame(stream, &wire, frame_kind).await?;
786
787        // Receive and verify response.
788        // The server may send informational frames first (msgs_ack, new_session_created)
789        // before the actual rpc_result{boolTrue}, so we loop up to 5 frames.
790        for attempt in 0u8..5 {
791            let mut raw = recv_raw_frame(stream, frame_kind).await?;
792            let decrypted = temp_enc
793                .unpack(&mut raw)
794                .map_err(|e| ConnectError::other(format!("PFS bind decrypt: {e:?}")))?;
795            match decode_bind_response(&decrypted.body) {
796                Ok(()) => {
797                    // bindTempAuthKey succeeds under the temp key; keep the session
798                    // sequence as-is so subsequent RPCs continue from the same MTProto
799                    // message stream.
800                    return Ok(temp_enc);
801                }
802                Err(ref e) if e == "__need_more__" => {
803                    tracing::debug!(
804                        "[ferogram] PFS bind (DC{dc_id}): informational frame {attempt}, reading next"
805                    );
806                    continue;
807                }
808                Err(reason) => {
809                    tracing::error!("[ferogram] PFS bind server response (DC{dc_id}): {reason}");
810                    return Err(ConnectError::other(format!(
811                        "auth.bindTempAuthKey: {reason}"
812                    )));
813                }
814            }
815        }
816        Err(ConnectError::other(
817            "auth.bindTempAuthKey: no boolTrue after 5 frames",
818        ))
819    }
820
821    pub fn auth_key_bytes(&self) -> [u8; 256] {
822        // When PFS is active, perm_auth_key is the key to persist in the session.
823        // enc.auth_key_bytes() would return the short-lived temp key instead.
824        self.perm_auth_key
825            .unwrap_or_else(|| self.enc.auth_key_bytes())
826    }
827
828    /// Split into a write-only `ConnectionWriter` and the TCP read half.
829    pub fn into_writer(self) -> (ConnectionWriter, OwnedWriteHalf, OwnedReadHalf, FrameKind) {
830        let (read_half, write_half) = self.stream.into_split();
831        let writer = ConnectionWriter {
832            enc: self.enc,
833            frame_kind: self.frame_kind.clone(),
834            perm_auth_key: self.perm_auth_key,
835            pending_ack: Vec::new(),
836            new_session_resend_queue: Vec::new(),
837            sent_bodies: std::collections::HashMap::new(),
838            container_map: std::collections::HashMap::new(),
839            salts: Vec::new(),
840            start_salt_time: None,
841        };
842        (writer, write_half, read_half, self.frame_kind)
843    }
844}