Skip to main content

ferogram_connect/
connection.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2//
3// ferogram: async Telegram MTProto client in Rust
4// https://github.com/ankit-chaubey/ferogram
5//
6// Licensed under either the MIT License or the Apache License 2.0.
7// See the LICENSE-MIT or LICENSE-APACHE file in this repository:
8// https://github.com/ankit-chaubey/ferogram
9//
10// Feel free to use, modify, and share this code.
11// Please keep this notice when redistributing.
12
13use std::sync::Arc;
14use std::time::Duration;
15
16use socket2::TcpKeepalive;
17use tokio::io::AsyncWriteExt;
18use tokio::net::TcpStream;
19use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
20
21use ferogram_mtproto::{EncryptedSession, Session, authentication as auth};
22use ferogram_tl_types as tl;
23
24use crate::envelope::decode_bind_response;
25use crate::error::ConnectError;
26use crate::frame::{recv_frame_plain, send_frame};
27use crate::transport::recv_raw_frame;
28use crate::transport_kind::TransportKind;
29
30pub const PING_DELAY_SECS: u64 = 60;
31pub const NO_PING_DISCONNECT: i32 = 75;
32
33const TCP_KEEPALIVE_IDLE_SECS: u64 = 10;
34const TCP_KEEPALIVE_INTERVAL_SECS: u64 = 5;
35#[cfg(not(target_os = "windows"))]
36const TCP_KEEPALIVE_PROBES: u32 = 3;
37
38/// How framing bytes are sent/received on a connection.
39///
40/// `Obfuscated` carries an `Arc<Mutex<ObfuscatedCipher>>` so the same cipher
41/// state is shared (safely) between the writer task (TX / `encrypt`) and the
42/// reader task (RX / `decrypt`).  The two directions are separate AES-CTR
43/// instances inside `ObfuscatedCipher`, so locking is only needed to prevent
44/// concurrent mutation of the struct, not to serialise TX vs RX.
45#[derive(Clone)]
46pub enum FrameKind {
47    Abridged,
48    Intermediate,
49    #[allow(dead_code)]
50    Full {
51        send_seqno: Arc<std::sync::atomic::AtomicU32>,
52        recv_seqno: Arc<std::sync::atomic::AtomicU32>,
53    },
54    /// Obfuscated2 over Abridged framing.
55    Obfuscated {
56        cipher: std::sync::Arc<tokio::sync::Mutex<ferogram_crypto::ObfuscatedCipher>>,
57    },
58    /// Obfuscated2 over Intermediate+padding framing (`0xDD` MTProxy).
59    PaddedIntermediate {
60        cipher: std::sync::Arc<tokio::sync::Mutex<ferogram_crypto::ObfuscatedCipher>>,
61    },
62    /// FakeTLS framing (`0xEE` MTProxy).
63    FakeTls {
64        cipher: std::sync::Arc<tokio::sync::Mutex<ferogram_crypto::ObfuscatedCipher>>,
65    },
66}
67
68/// Write half of a split connection.  Held under `Mutex` in `ClientInner`.
69/// A single server-provided salt with its validity window.
70///
71#[derive(Clone, Debug)]
72pub struct FutureSalt {
73    pub valid_since: i32,
74    /// Stored as `u32` because Telegram sends validity windows that extend
75    /// past 2038 (e.g. valid_until ≈ 2_751_656_413, year 2057).  Those values
76    /// overflow `i32` and wrap negative, making every salt look expired when
77    /// compared against the current server time with a signed comparison.
78    pub valid_until: u32,
79    pub salt: i64,
80}
81
82/// Delay (seconds) before a salt is considered usable after its `valid_since`.
83///
84pub const SALT_USE_DELAY: i32 = 60;
85
86/// Owns the EncryptedSession (for packing) and the pending-RPC map.
87pub struct ConnectionWriter {
88    pub enc: EncryptedSession,
89    pub frame_kind: FrameKind,
90    /// PFS: permanent auth key to save in session. None when PFS is off.
91    pub perm_auth_key: Option<[u8; 256]>,
92    /// msg_ids of received content messages waiting to be acked.
93    /// Drained into a MsgsAck on every outgoing frame (bundled into container
94    /// when sending an RPC, or sent standalone after route_frame).
95    pub pending_ack: Vec<i64>,
96    /// raw TL body bytes of every sent request, keyed by msg_id.
97    /// On bad_msg_notification the matching body is re-encrypted with a fresh
98    /// msg_id and re-sent transparently.
99    pub sent_bodies: std::collections::HashMap<i64, Vec<u8>>,
100    /// maps container_msg_id -> inner request msg_id.
101    /// When bad_msg_notification / bad_server_salt arrives for a container
102    /// rather than the individual inner message, we look here to find the
103    /// inner request to retry.
104    ///
105    pub container_map: std::collections::HashMap<i64, i64>,
106    /// Stale-msg resends queued under the writer lock and drained after release
107    /// to avoid holding the lock across TCP I/O.
108    pub new_session_resend_queue: Vec<(i64, i64, Vec<u8>)>,
109    /// -style future salt pool.
110    /// Sorted by valid_since ascending so the newest salt is LAST
111    /// (.valid_since), which puts
112    /// the highest valid_since at the end in ascending-key order).
113    pub salts: Vec<FutureSalt>,
114    /// Server-time anchor received with the last GetFutureSalts response.
115    /// (server_now, local_instant) lets us approximate server time at any
116    /// moment so we can check whether a salt's valid_since window has opened.
117    ///
118    pub start_salt_time: Option<(i32, std::time::Instant)>,
119}
120
121impl ConnectionWriter {
122    pub fn auth_key_bytes(&self) -> [u8; 256] {
123        self.perm_auth_key
124            .unwrap_or_else(|| self.enc.auth_key_bytes())
125    }
126    pub fn first_salt(&self) -> i64 {
127        self.enc.salt
128    }
129    pub fn time_offset(&self) -> i32 {
130        self.enc.time_offset
131    }
132
133    /// Proactively advance the active salt and prune expired ones.
134    ///
135    /// Called at the top of every RPC send.
136    /// Salts are sorted ascending by `valid_since` (oldest=index 0, newest=last).
137    ///
138    /// Prunes expired salts, then advances `enc.salt` to the freshest usable one.
139    ///
140    /// Returns `true` when the pool has shrunk to a single entry: caller should
141    /// fire a proactive `GetFutureSalts` via `try_request_salts()`.
142    pub fn advance_salt_if_needed(&mut self) -> bool {
143        let Some((server_now, start_instant)) = self.start_salt_time else {
144            return self.salts.len() <= 1;
145        };
146
147        // Approximate current server time.
148        let now = server_now + start_instant.elapsed().as_secs() as i32;
149
150        // Prune expired salts.
151        while self.salts.len() > 1 && (now as u32) > self.salts[0].valid_until {
152            let expired = self.salts.remove(0);
153            tracing::debug!(
154                "[ferogram] salt {:#x} expired (valid_until={}), pruned",
155                expired.salt,
156                expired.valid_until,
157            );
158        }
159
160        // Advance to the freshest salt whose use-delay has opened AND
161        // which has not yet expired.  The `valid_until > now` guard is the
162        // critical safety: without it we can advance enc.salt to an already-
163        // expired entry from a stale FutureSalts pool, triggering immediate
164        // bad_server_salt rejection and re-entering the fetch loop.
165        if self.salts.len() > 1 {
166            let best = self
167                .salts
168                .iter()
169                .rev()
170                .find(|s| s.valid_since + SALT_USE_DELAY <= now && s.valid_until > (now as u32))
171                .map(|s| s.salt);
172            if let Some(salt) = best
173                && salt != self.enc.salt
174            {
175                tracing::debug!(
176                    "[ferogram] proactive salt cycle: {:#x} -> {:#x}",
177                    self.enc.salt,
178                    salt
179                );
180                self.enc.salt = salt;
181                // Prune salts whose valid_until has passed.
182                self.salts.retain(|s| s.valid_until > (now as u32));
183                if self.salts.is_empty() {
184                    // Safety net: keep a sentinel so we never go saltless.
185                    self.salts.push(FutureSalt {
186                        valid_since: 0,
187                        valid_until: u32::MAX,
188                        salt,
189                    });
190                }
191            }
192        }
193
194        self.salts.len() <= 1
195    }
196}
197
198pub struct Connection {
199    pub stream: TcpStream,
200    pub enc: EncryptedSession,
201    pub frame_kind: FrameKind,
202    /// When PFS is active, the permanent auth key (stored in session).
203    /// `enc` holds the temp key; this field holds the perm key so
204    /// `auth_key_bytes()` returns the right value to persist.
205    pub perm_auth_key: Option<[u8; 256]>,
206}
207
208impl Connection {
209    /// Open a TCP stream, optionally via SOCKS5, and apply transport init bytes.
210    async fn open_stream(
211        addr: &str,
212        socks5: Option<&crate::socks5::Socks5Config>,
213        transport: &TransportKind,
214        dc_id: i16,
215    ) -> Result<(TcpStream, FrameKind), ConnectError> {
216        let stream = match socks5 {
217            Some(proxy) => proxy.connect(addr).await?,
218            None => {
219                let stream = TcpStream::connect(addr).await.map_err(ConnectError::Io)?;
220                stream.set_nodelay(true).ok();
221                {
222                    let sock = socket2::SockRef::from(&stream);
223                    let keepalive = TcpKeepalive::new()
224                        .with_time(Duration::from_secs(TCP_KEEPALIVE_IDLE_SECS))
225                        .with_interval(Duration::from_secs(TCP_KEEPALIVE_INTERVAL_SECS));
226                    #[cfg(not(target_os = "windows"))]
227                    let keepalive = keepalive.with_retries(TCP_KEEPALIVE_PROBES);
228                    sock.set_tcp_keepalive(&keepalive).ok();
229                }
230                stream
231            }
232        };
233        Self::apply_transport_init(stream, transport, dc_id).await
234    }
235
236    /// Open a stream routed through an MTProxy (connects to proxy host:port,
237    /// not to the Telegram DC address).
238    async fn open_stream_mtproxy(
239        mtproxy: &crate::proxy::MtProxyConfig,
240        dc_id: i16,
241    ) -> Result<(TcpStream, FrameKind), ConnectError> {
242        let stream = mtproxy.connect().await?;
243        stream.set_nodelay(true).ok();
244        Self::apply_transport_init(stream, &mtproxy.transport, dc_id).await
245    }
246
247    async fn apply_transport_init(
248        mut stream: TcpStream,
249        transport: &TransportKind,
250        dc_id: i16,
251    ) -> Result<(TcpStream, FrameKind), ConnectError> {
252        match transport {
253            TransportKind::Abridged => {
254                stream.write_all(&[0xef]).await?;
255                Ok((stream, FrameKind::Abridged))
256            }
257            TransportKind::Intermediate => {
258                stream.write_all(&[0xee, 0xee, 0xee, 0xee]).await?;
259                Ok((stream, FrameKind::Intermediate))
260            }
261            TransportKind::Full => {
262                // Full transport has no init byte.
263                Ok((
264                    stream,
265                    FrameKind::Full {
266                        send_seqno: Arc::new(std::sync::atomic::AtomicU32::new(0)),
267                        recv_seqno: Arc::new(std::sync::atomic::AtomicU32::new(0)),
268                    },
269                ))
270            }
271            TransportKind::Obfuscated { secret } => {
272                use sha2::Digest;
273
274                // Random 64-byte nonce: retry until it passes the reserved-pattern
275                // Reject reserved nonce patterns that could be misidentified as HTTP
276                // or another MTProto framing tag by a proxy or DPI filter.
277                let mut nonce = [0u8; 64];
278                loop {
279                    getrandom::getrandom(&mut nonce)
280                        .map_err(|_| ConnectError::other("getrandom"))?;
281                    let first = u32::from_le_bytes(nonce[0..4].try_into().expect("4-byte slice"));
282                    let second = u32::from_le_bytes(nonce[4..8].try_into().expect("4-byte slice"));
283                    let bad = nonce[0] == 0xEF
284                        || first == 0x44414548 // HEAD
285                        || first == 0x54534F50 // POST
286                        || first == 0x20544547 // GET
287                        || first == 0x4954504f // OPTIONS
288                        || first == 0xEEEEEEEE
289                        || first == 0xDDDDDDDD
290                        || first == 0x02010316
291                        || second == 0x00000000;
292                    if !bad {
293                        break;
294                    }
295                }
296
297                // Key derivation from nonce[8..56]:
298                //   TX: key=nonce[8..40]  iv=nonce[40..56]
299                //   RX: key=rev[0..32]    iv=rev[32..48]   (rev = nonce[8..56] reversed)
300                // When an MTProxy secret is present, each 32-byte key becomes
301                // SHA-256(raw_key_slice || secret) for MTProxy key derivation.
302                let tx_raw: [u8; 32] = nonce[8..40].try_into().expect("32-byte slice");
303                let tx_iv: [u8; 16] = nonce[40..56].try_into().expect("16-byte slice");
304                let mut rev48 = nonce[8..56].to_vec();
305                rev48.reverse();
306                let rx_raw: [u8; 32] = rev48[0..32].try_into().expect("32-byte slice");
307                let rx_iv: [u8; 16] = rev48[32..48].try_into().expect("16-byte slice");
308
309                let (tx_key, rx_key): ([u8; 32], [u8; 32]) = if let Some(s) = secret {
310                    let mut h = sha2::Sha256::new();
311                    h.update(tx_raw);
312                    h.update(s.as_ref());
313                    let tx: [u8; 32] = h.finalize().into();
314
315                    let mut h = sha2::Sha256::new();
316                    h.update(rx_raw);
317                    h.update(s.as_ref());
318                    let rx: [u8; 32] = h.finalize().into();
319                    (tx, rx)
320                } else {
321                    (tx_raw, rx_raw)
322                };
323
324                // Stamp protocol id (Abridged = 0xEFEFEFEF) at nonce[56..60]
325                // and DC id as little-endian i16 at nonce[60..62].
326                nonce[56] = 0xef;
327                nonce[57] = 0xef;
328                nonce[58] = 0xef;
329                nonce[59] = 0xef;
330                let dc_bytes = dc_id.to_le_bytes();
331                nonce[60] = dc_bytes[0];
332                nonce[61] = dc_bytes[1];
333
334                // Encrypt nonce[56..64] in-place using the TX cipher advanced
335                // past the first 56 bytes (which are sent as plaintext).
336                //
337                // The same cipher instance must be used for both the nonce tail
338                // encryption and all subsequent TX data: AES-CTR is a single continuous
339                // stream; the TX position after encrypting the full 64-byte nonce is 64.
340                let mut cipher =
341                    ferogram_crypto::ObfuscatedCipher::from_keys(&tx_key, &tx_iv, &rx_key, &rx_iv);
342                // Advance TX past nonce[0..56] (sent as plaintext, not encrypted).
343                let mut skip = [0u8; 56];
344                cipher.encrypt(&mut skip);
345                // Encrypt nonce[56..64] in-place; cipher TX is now at position 64.
346                cipher.encrypt(&mut nonce[56..64]);
347
348                stream.write_all(&nonce).await?;
349
350                let cipher_arc = std::sync::Arc::new(tokio::sync::Mutex::new(cipher));
351                Ok((stream, FrameKind::Obfuscated { cipher: cipher_arc }))
352            }
353            TransportKind::PaddedIntermediate { secret } => {
354                use sha2::Digest;
355                let mut nonce = [0u8; 64];
356                loop {
357                    getrandom::getrandom(&mut nonce)
358                        .map_err(|_| ConnectError::other("getrandom"))?;
359                    let first = u32::from_le_bytes(nonce[0..4].try_into().expect("4-byte slice"));
360                    let second = u32::from_le_bytes(nonce[4..8].try_into().expect("4-byte slice"));
361                    let bad = nonce[0] == 0xEF
362                        || first == 0x44414548
363                        || first == 0x54534F50
364                        || first == 0x20544547
365                        || first == 0x4954504f
366                        || first == 0xEEEEEEEE
367                        || first == 0xDDDDDDDD
368                        || first == 0x02010316
369                        || second == 0x00000000;
370                    if !bad {
371                        break;
372                    }
373                }
374                let tx_raw: [u8; 32] = nonce[8..40].try_into().expect("32-byte slice");
375                let tx_iv: [u8; 16] = nonce[40..56].try_into().expect("16-byte slice");
376                let mut rev48 = nonce[8..56].to_vec();
377                rev48.reverse();
378                let rx_raw: [u8; 32] = rev48[0..32].try_into().expect("32-byte slice");
379                let rx_iv: [u8; 16] = rev48[32..48].try_into().expect("16-byte slice");
380                let (tx_key, rx_key): ([u8; 32], [u8; 32]) = if let Some(s) = secret {
381                    let mut h = sha2::Sha256::new();
382                    h.update(tx_raw);
383                    h.update(s.as_ref());
384                    let tx: [u8; 32] = h.finalize().into();
385                    let mut h = sha2::Sha256::new();
386                    h.update(rx_raw);
387                    h.update(s.as_ref());
388                    let rx: [u8; 32] = h.finalize().into();
389                    (tx, rx)
390                } else {
391                    (tx_raw, rx_raw)
392                };
393                // PaddedIntermediate tag = 0xDDDDDDDD
394                nonce[56] = 0xdd;
395                nonce[57] = 0xdd;
396                nonce[58] = 0xdd;
397                nonce[59] = 0xdd;
398                let dc_bytes = dc_id.to_le_bytes();
399                nonce[60] = dc_bytes[0];
400                nonce[61] = dc_bytes[1];
401                let mut cipher =
402                    ferogram_crypto::ObfuscatedCipher::from_keys(&tx_key, &tx_iv, &rx_key, &rx_iv);
403                let mut skip = [0u8; 56];
404                cipher.encrypt(&mut skip);
405                cipher.encrypt(&mut nonce[56..64]);
406                stream.write_all(&nonce).await?;
407                let cipher_arc = std::sync::Arc::new(tokio::sync::Mutex::new(cipher));
408                Ok((stream, FrameKind::PaddedIntermediate { cipher: cipher_arc }))
409            }
410            TransportKind::FakeTls { secret, domain } => {
411                // Fake TLS 1.3 ClientHello with HMAC-SHA256 random field.
412                // After the handshake, data flows as TLS Application Data records
413                // over a shared Obfuscated2 cipher seeded from the secret+HMAC.
414                let domain_bytes = domain.as_bytes();
415                let mut session_id = [0u8; 32];
416                getrandom::getrandom(&mut session_id)
417                    .map_err(|_| ConnectError::other("getrandom"))?;
418
419                // Build ClientHello body (random placeholder = zeros)
420                let cipher_suites: &[u8] = &[0x00, 0x04, 0x13, 0x01, 0x13, 0x02];
421                let compression: &[u8] = &[0x01, 0x00];
422                let sni_name_len = domain_bytes.len() as u16;
423                let sni_list_len = sni_name_len + 3;
424                let sni_ext_len = sni_list_len + 2;
425                let mut sni_ext = Vec::new();
426                sni_ext.extend_from_slice(&[0x00, 0x00]);
427                sni_ext.extend_from_slice(&sni_ext_len.to_be_bytes());
428                sni_ext.extend_from_slice(&sni_list_len.to_be_bytes());
429                sni_ext.push(0x00);
430                sni_ext.extend_from_slice(&sni_name_len.to_be_bytes());
431                sni_ext.extend_from_slice(domain_bytes);
432                let sup_ver: &[u8] = &[0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04];
433                let sup_grp: &[u8] = &[0x00, 0x0a, 0x00, 0x04, 0x00, 0x02, 0x00, 0x1d];
434                let sess_tick: &[u8] = &[0x00, 0x23, 0x00, 0x00];
435                let ext_body_len = sni_ext.len() + sup_ver.len() + sup_grp.len() + sess_tick.len();
436                let mut extensions = Vec::new();
437                extensions.extend_from_slice(&(ext_body_len as u16).to_be_bytes());
438                extensions.extend_from_slice(&sni_ext);
439                extensions.extend_from_slice(sup_ver);
440                extensions.extend_from_slice(sup_grp);
441                extensions.extend_from_slice(sess_tick);
442
443                let mut hello_body = Vec::new();
444                hello_body.extend_from_slice(&[0x03, 0x03]);
445                hello_body.extend_from_slice(&[0u8; 32]); // random placeholder
446                hello_body.push(session_id.len() as u8);
447                hello_body.extend_from_slice(&session_id);
448                hello_body.extend_from_slice(cipher_suites);
449                hello_body.extend_from_slice(compression);
450                hello_body.extend_from_slice(&extensions);
451
452                let hs_len = hello_body.len() as u32;
453                let mut handshake = vec![
454                    0x01,
455                    ((hs_len >> 16) & 0xff) as u8,
456                    ((hs_len >> 8) & 0xff) as u8,
457                    (hs_len & 0xff) as u8,
458                ];
459                handshake.extend_from_slice(&hello_body);
460
461                let rec_len = handshake.len() as u16;
462                let mut record = Vec::new();
463                record.push(0x16);
464                record.extend_from_slice(&[0x03, 0x01]);
465                record.extend_from_slice(&rec_len.to_be_bytes());
466                record.extend_from_slice(&handshake);
467
468                // HMAC-SHA256(secret, record) -> fill random field at offset 11
469                use sha2::Digest;
470                let random_offset = 5 + 4 + 2; // TLS-rec(5) + HS-hdr(4) + version(2)
471                let hmac_result: [u8; 32] = {
472                    use hmac::{Hmac, Mac};
473                    type HmacSha256 = Hmac<sha2::Sha256>;
474                    let mut mac = HmacSha256::new_from_slice(secret)
475                        .map_err(|_| ConnectError::other("HMAC key error"))?;
476                    mac.update(&record);
477                    mac.finalize().into_bytes().into()
478                };
479                record[random_offset..random_offset + 32].copy_from_slice(&hmac_result);
480                stream.write_all(&record).await?;
481
482                // Derive Obfuscated2 key from secret + HMAC
483                let mut h = sha2::Sha256::new();
484                h.update(secret.as_ref());
485                h.update(hmac_result);
486                let derived: [u8; 32] = h.finalize().into();
487                let iv = [0u8; 16];
488                let cipher =
489                    ferogram_crypto::ObfuscatedCipher::from_keys(&derived, &iv, &derived, &iv);
490                let cipher_arc = std::sync::Arc::new(tokio::sync::Mutex::new(cipher));
491                Ok((stream, FrameKind::FakeTls { cipher: cipher_arc }))
492            }
493            TransportKind::Http => {
494                // HTTP transport is handled in dc_pool - fall back to Abridged framing.
495                stream.write_all(&[0xef]).await?;
496                Ok((stream, FrameKind::Abridged))
497            }
498        }
499    }
500
501    pub async fn connect_raw(
502        addr: &str,
503        socks5: Option<&crate::socks5::Socks5Config>,
504        mtproxy: Option<&crate::proxy::MtProxyConfig>,
505        transport: &TransportKind,
506        dc_id: i16,
507    ) -> Result<Self, ConnectError> {
508        let t_label = match transport {
509            TransportKind::Abridged => "Abridged",
510            TransportKind::Obfuscated { .. } => "Obfuscated",
511            TransportKind::PaddedIntermediate { .. } => "PaddedIntermediate",
512            TransportKind::Http => "Http",
513            TransportKind::Intermediate => "Intermediate",
514            TransportKind::Full => "Full",
515            TransportKind::FakeTls { .. } => "FakeTls",
516        };
517        tracing::debug!("[ferogram] Connecting to {addr} ({t_label}) DH …");
518
519        let addr2 = addr.to_string();
520        let socks5_c = socks5.cloned();
521        let mtproxy_c = mtproxy.cloned();
522        let transport_c = transport.clone();
523
524        let fut = async move {
525            let (mut stream, frame_kind) = if let Some(ref mp) = mtproxy_c {
526                Self::open_stream_mtproxy(mp, dc_id).await?
527            } else {
528                Self::open_stream(&addr2, socks5_c.as_ref(), &transport_c, dc_id).await?
529            };
530
531            let mut plain = Session::new();
532
533            let (req1, s1) = auth::step1().map_err(|e| ConnectError::other(e.to_string()))?;
534            send_frame(
535                &mut stream,
536                &plain.pack(&req1).to_plaintext_bytes(),
537                &frame_kind,
538            )
539            .await?;
540            let res_pq: tl::enums::ResPq = recv_frame_plain(&mut stream, &frame_kind).await?;
541
542            let (req2, s2) = auth::step2(s1, res_pq, dc_id as i32)
543                .map_err(|e| ConnectError::other(e.to_string()))?;
544            send_frame(
545                &mut stream,
546                &plain.pack(&req2).to_plaintext_bytes(),
547                &frame_kind,
548            )
549            .await?;
550            let dh: tl::enums::ServerDhParams = recv_frame_plain(&mut stream, &frame_kind).await?;
551
552            let (req3, s3) = auth::step3(s2, dh).map_err(|e| ConnectError::other(e.to_string()))?;
553            send_frame(
554                &mut stream,
555                &plain.pack(&req3).to_plaintext_bytes(),
556                &frame_kind,
557            )
558            .await?;
559            let ans: tl::enums::SetClientDhParamsAnswer =
560                recv_frame_plain(&mut stream, &frame_kind).await?;
561
562            // Retry loop for dh_gen_retry (up to 5 attempts).
563            let done = {
564                let mut result =
565                    auth::finish(s3, ans).map_err(|e| ConnectError::other(e.to_string()))?;
566                let mut attempts = 0u8;
567                loop {
568                    match result {
569                        auth::FinishResult::Done(d) => break d,
570                        auth::FinishResult::Retry {
571                            retry_id,
572                            dh_params,
573                            nonce,
574                            server_nonce,
575                            new_nonce,
576                        } => {
577                            attempts += 1;
578                            if attempts >= 5 {
579                                return Err(ConnectError::other(
580                                    "dh_gen_retry exceeded 5 attempts",
581                                ));
582                            }
583                            let (req_retry, s3_retry) = auth::retry_step3(
584                                &dh_params,
585                                nonce,
586                                server_nonce,
587                                new_nonce,
588                                retry_id,
589                            )
590                            .map_err(|e| ConnectError::other(e.to_string()))?;
591                            send_frame(
592                                &mut stream,
593                                &plain.pack(&req_retry).to_plaintext_bytes(),
594                                &frame_kind,
595                            )
596                            .await?;
597                            let ans_retry: tl::enums::SetClientDhParamsAnswer =
598                                recv_frame_plain(&mut stream, &frame_kind).await?;
599                            result = auth::finish(s3_retry, ans_retry)
600                                .map_err(|e| ConnectError::other(e.to_string()))?;
601                        }
602                    }
603                }
604            };
605            tracing::debug!("[ferogram] DH complete ✓");
606
607            Ok::<Self, ConnectError>(Self {
608                stream,
609                enc: EncryptedSession::new(done.auth_key, done.first_salt, done.time_offset),
610                frame_kind,
611                perm_auth_key: None, // connect_raw produces the perm key itself
612            })
613        };
614
615        tokio::time::timeout(Duration::from_secs(15), fut)
616            .await
617            .map_err(|_| {
618                ConnectError::other(format!("DH handshake with {addr} timed out after 15 s"))
619            })?
620    }
621
622    #[allow(clippy::too_many_arguments)]
623    pub async fn connect_with_key(
624        addr: &str,
625        auth_key: [u8; 256],
626        first_salt: i64,
627        time_offset: i32,
628        socks5: Option<&crate::socks5::Socks5Config>,
629        mtproxy: Option<&crate::proxy::MtProxyConfig>,
630        transport: &TransportKind,
631        dc_id: i16,
632        pfs: bool,
633    ) -> Result<Self, ConnectError> {
634        let addr2 = addr.to_string();
635        let socks5_c = socks5.cloned();
636        let mtproxy_c = mtproxy.cloned();
637        let transport_c = transport.clone();
638
639        let fut = async move {
640            let (mut stream, frame_kind) = if let Some(ref mp) = mtproxy_c {
641                Self::open_stream_mtproxy(mp, dc_id).await?
642            } else {
643                Self::open_stream(&addr2, socks5_c.as_ref(), &transport_c, dc_id).await?
644            };
645            if pfs {
646                tracing::debug!("[ferogram] PFS: temp DH bind for DC{dc_id}");
647                match Self::do_pfs_bind(&mut stream, &frame_kind, &auth_key, dc_id).await {
648                    Ok(temp_enc) => {
649                        tracing::info!("[ferogram] PFS bind complete DC{dc_id}");
650                        return Ok(Self {
651                            stream,
652                            enc: temp_enc,
653                            frame_kind,
654                            perm_auth_key: Some(auth_key),
655                        });
656                    }
657                    Err(e) => {
658                        tracing::warn!(
659                            "[ferogram] PFS bind failed DC{dc_id} ({e}); falling back to perm key"
660                        );
661                        // Graceful fallback: reconnect because DH frames left the stream dirty.
662                        // Return error and let the caller handle retry without PFS.
663                        return Err(e);
664                    }
665                }
666            }
667            Ok::<Self, ConnectError>(Self {
668                stream,
669                enc: EncryptedSession::new(auth_key, first_salt, time_offset),
670                frame_kind,
671                perm_auth_key: None,
672            })
673        };
674
675        tokio::time::timeout(Duration::from_secs(30), fut)
676            .await
677            .map_err(|_| {
678                ConnectError::other(format!("connect_with_key to {addr} timed out after 30 s"))
679            })?
680    }
681
682    /// Perform a fresh temp-key DH on an already-open stream, then
683    /// send `auth.bindTempAuthKey` encrypted with the temp key.
684    /// Returns an `EncryptedSession` keyed with the bound temp key.
685    async fn do_pfs_bind(
686        stream: &mut TcpStream,
687        frame_kind: &FrameKind,
688        perm_auth_key: &[u8; 256],
689        dc_id: i16,
690    ) -> Result<EncryptedSession, ConnectError> {
691        use ferogram_mtproto::{
692            auth_key_id_from_key, encrypt_bind_inner, gen_msg_id, new_seen_msg_ids,
693            serialize_bind_temp_auth_key,
694        };
695        const TEMP_EXPIRES: i32 = 86_400; // 24 h
696
697        // temp-key DH
698        let mut plain = Session::new();
699
700        let (req1, s1) = auth::step1().map_err(|e| ConnectError::other(e.to_string()))?;
701        send_frame(stream, &plain.pack(&req1).to_plaintext_bytes(), frame_kind).await?;
702        let res_pq: tl::enums::ResPq = recv_frame_plain(stream, frame_kind).await?;
703
704        let (req2, s2) = ferogram_mtproto::step2_temp(s1, res_pq, dc_id as i32, TEMP_EXPIRES)
705            .map_err(|e| ConnectError::other(e.to_string()))?;
706        send_frame(stream, &plain.pack(&req2).to_plaintext_bytes(), frame_kind).await?;
707        let dh: tl::enums::ServerDhParams = recv_frame_plain(stream, frame_kind).await?;
708
709        let (req3, s3) = auth::step3(s2, dh).map_err(|e| ConnectError::other(e.to_string()))?;
710        send_frame(stream, &plain.pack(&req3).to_plaintext_bytes(), frame_kind).await?;
711        let ans: tl::enums::SetClientDhParamsAnswer = recv_frame_plain(stream, frame_kind).await?;
712
713        let done = {
714            let mut result =
715                auth::finish(s3, ans).map_err(|e| ConnectError::other(e.to_string()))?;
716            let mut attempts = 0u8;
717            loop {
718                match result {
719                    ferogram_mtproto::FinishResult::Done(d) => break d,
720                    ferogram_mtproto::FinishResult::Retry {
721                        retry_id,
722                        dh_params,
723                        nonce,
724                        server_nonce,
725                        new_nonce,
726                    } => {
727                        attempts += 1;
728                        if attempts >= 5 {
729                            return Err(ConnectError::other(
730                                "PFS temp DH retry exceeded 5 attempts",
731                            ));
732                        }
733                        let (rr, s3r) = ferogram_mtproto::retry_step3(
734                            &dh_params,
735                            nonce,
736                            server_nonce,
737                            new_nonce,
738                            retry_id,
739                        )
740                        .map_err(|e| ConnectError::other(e.to_string()))?;
741                        send_frame(stream, &plain.pack(&rr).to_plaintext_bytes(), frame_kind)
742                            .await?;
743                        let ar: tl::enums::SetClientDhParamsAnswer =
744                            recv_frame_plain(stream, frame_kind).await?;
745                        result = auth::finish(s3r, ar)
746                            .map_err(|e| ConnectError::other(e.to_string()))?;
747                    }
748                }
749            }
750        };
751
752        let temp_key = done.auth_key;
753        let temp_salt = done.first_salt;
754        let temp_offset = done.time_offset;
755
756        // build bindTempAuthKey body
757        let temp_key_id = auth_key_id_from_key(&temp_key);
758        let perm_key_id = auth_key_id_from_key(perm_auth_key);
759
760        let mut nonce_buf = [0u8; 8];
761        getrandom::getrandom(&mut nonce_buf).map_err(|_| ConnectError::other("getrandom nonce"))?;
762        let nonce = i64::from_le_bytes(nonce_buf);
763
764        let server_now = std::time::SystemTime::now()
765            .duration_since(std::time::UNIX_EPOCH)
766            .unwrap()
767            .as_secs() as i32
768            + temp_offset;
769        let expires_at = server_now + TEMP_EXPIRES;
770
771        let seen = new_seen_msg_ids();
772        let mut temp_enc = EncryptedSession::with_seen(temp_key, temp_salt, temp_offset, seen);
773        let temp_session_id = temp_enc.session_id();
774
775        let msg_id = gen_msg_id();
776        let enc_msg = encrypt_bind_inner(
777            perm_auth_key,
778            msg_id,
779            nonce,
780            temp_key_id,
781            perm_key_id,
782            temp_session_id,
783            expires_at,
784        );
785        let bind_body = serialize_bind_temp_auth_key(perm_key_id, nonce, expires_at, &enc_msg);
786
787        // send encrypted bind request
788        let wire = temp_enc.pack_body_at_msg_id(&bind_body, msg_id);
789        send_frame(stream, &wire, frame_kind).await?;
790
791        // Receive and verify response.
792        // The server may send informational frames first (msgs_ack, new_session_created)
793        // before the actual rpc_result{boolTrue}, so we loop up to 5 frames.
794        for attempt in 0u8..5 {
795            let mut raw = recv_raw_frame(stream, frame_kind).await?;
796            let decrypted = temp_enc
797                .unpack(&mut raw)
798                .map_err(|e| ConnectError::other(format!("PFS bind decrypt: {e:?}")))?;
799            match decode_bind_response(&decrypted.body) {
800                Ok(()) => {
801                    // bindTempAuthKey succeeds under the temp key; keep the session
802                    // sequence as-is so subsequent RPCs continue from the same MTProto
803                    // message stream.
804                    return Ok(temp_enc);
805                }
806                Err(ref e) if e == "__need_more__" => {
807                    tracing::debug!(
808                        "[ferogram] PFS bind (DC{dc_id}): informational frame {attempt}, reading next"
809                    );
810                    continue;
811                }
812                Err(reason) => {
813                    tracing::error!("[ferogram] PFS bind server response (DC{dc_id}): {reason}");
814                    return Err(ConnectError::other(format!(
815                        "auth.bindTempAuthKey: {reason}"
816                    )));
817                }
818            }
819        }
820        Err(ConnectError::other(
821            "auth.bindTempAuthKey: no boolTrue after 5 frames",
822        ))
823    }
824
825    pub fn auth_key_bytes(&self) -> [u8; 256] {
826        // When PFS is active, perm_auth_key is the key to persist in the session.
827        // enc.auth_key_bytes() would return the short-lived temp key instead.
828        self.perm_auth_key
829            .unwrap_or_else(|| self.enc.auth_key_bytes())
830    }
831
832    /// Split into a write-only `ConnectionWriter` and the TCP read half.
833    pub fn into_writer(self) -> (ConnectionWriter, OwnedWriteHalf, OwnedReadHalf, FrameKind) {
834        let (read_half, write_half) = self.stream.into_split();
835        let writer = ConnectionWriter {
836            enc: self.enc,
837            frame_kind: self.frame_kind.clone(),
838            perm_auth_key: self.perm_auth_key,
839            pending_ack: Vec::new(),
840            new_session_resend_queue: Vec::new(),
841            sent_bodies: std::collections::HashMap::new(),
842            container_map: std::collections::HashMap::new(),
843            salts: Vec::new(),
844            start_salt_time: None,
845        };
846        (writer, write_half, read_half, self.frame_kind)
847    }
848}