Skip to main content

layer_client/
dc_pool.rs

1//! Multi-DC connection pool.
2//!
3//! Maintains one authenticated [`DcConnection`] per DC ID and routes RPC calls
4//! to the correct DC automatically.  Auth keys are shared from the home DC via
5//! `auth.exportAuthorization` / `auth.importAuthorization`.
6
7use layer_mtproto::{EncryptedSession, Session, authentication as auth};
8use layer_tl_types as tl;
9use layer_tl_types::{Cursor, Deserializable, RemoteCall};
10use std::collections::HashMap;
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13
14use crate::{InvocationError, TransportKind, session::DcEntry};
15
16// DcConnection
17
18/// A single encrypted connection to one Telegram DC.
19pub struct DcConnection {
20    stream: TcpStream,
21    enc: EncryptedSession,
22}
23
24impl DcConnection {
25    /// Connect and perform full DH handshake.
26    pub async fn connect_raw(
27        addr: &str,
28        socks5: Option<&crate::socks5::Socks5Config>,
29        transport: &TransportKind,
30        dc_id: i16,
31    ) -> Result<Self, InvocationError> {
32        tracing::debug!("[dc_pool] Connecting to {addr} …");
33        let mut stream = Self::open_tcp(addr, socks5).await?;
34        Self::send_transport_init(&mut stream, transport, dc_id).await?;
35
36        let mut plain = Session::new();
37
38        let (req1, s1) = auth::step1().map_err(|e| InvocationError::Deserialize(e.to_string()))?;
39        Self::send_plain_frame(&mut stream, &plain.pack(&req1).to_plaintext_bytes()).await?;
40        let res_pq: tl::enums::ResPq = Self::recv_plain_frame(&mut stream).await?;
41
42        let (req2, s2) = auth::step2(s1, res_pq, dc_id as i32)
43            .map_err(|e| InvocationError::Deserialize(e.to_string()))?;
44        Self::send_plain_frame(&mut stream, &plain.pack(&req2).to_plaintext_bytes()).await?;
45        let dh: tl::enums::ServerDhParams = Self::recv_plain_frame(&mut stream).await?;
46
47        let (req3, s3) =
48            auth::step3(s2, dh).map_err(|e| InvocationError::Deserialize(e.to_string()))?;
49        Self::send_plain_frame(&mut stream, &plain.pack(&req3).to_plaintext_bytes()).await?;
50        let ans: tl::enums::SetClientDhParamsAnswer = Self::recv_plain_frame(&mut stream).await?;
51
52        // Retry loop for dh_gen_retry (up to 5 attempts, mirroring tDesktop).
53        let done = {
54            let mut result =
55                auth::finish(s3, ans).map_err(|e| InvocationError::Deserialize(e.to_string()))?;
56            let mut attempts = 0u8;
57            loop {
58                match result {
59                    auth::FinishResult::Done(d) => break d,
60                    auth::FinishResult::Retry {
61                        retry_id,
62                        dh_params,
63                        nonce,
64                        server_nonce,
65                        new_nonce,
66                    } => {
67                        attempts += 1;
68                        if attempts >= 5 {
69                            return Err(InvocationError::Deserialize(
70                                "dh_gen_retry exceeded 5 attempts".into(),
71                            ));
72                        }
73                        let (req_retry, s3_retry) =
74                            auth::retry_step3(&dh_params, nonce, server_nonce, new_nonce, retry_id)
75                                .map_err(|e| InvocationError::Deserialize(e.to_string()))?;
76                        Self::send_plain_frame(
77                            &mut stream,
78                            &plain.pack(&req_retry).to_plaintext_bytes(),
79                        )
80                        .await?;
81                        let ans_retry: tl::enums::SetClientDhParamsAnswer =
82                            Self::recv_plain_frame(&mut stream).await?;
83                        result = auth::finish(s3_retry, ans_retry)
84                            .map_err(|e| InvocationError::Deserialize(e.to_string()))?;
85                    }
86                }
87            }
88        };
89        tracing::debug!("[dc_pool] DH complete ✓ for {addr}");
90
91        Ok(Self {
92            stream,
93            enc: EncryptedSession::new(done.auth_key, done.first_salt, done.time_offset),
94        })
95    }
96
97    /// Connect with an already-known auth key (no DH needed).
98    pub async fn connect_with_key(
99        addr: &str,
100        auth_key: [u8; 256],
101        first_salt: i64,
102        time_offset: i32,
103        socks5: Option<&crate::socks5::Socks5Config>,
104        transport: &TransportKind,
105        dc_id: i16,
106    ) -> Result<Self, InvocationError> {
107        let mut stream = Self::open_tcp(addr, socks5).await?;
108        Self::send_transport_init(&mut stream, transport, dc_id).await?;
109        Ok(Self {
110            stream,
111            enc: EncryptedSession::new(auth_key, first_salt, time_offset),
112        })
113    }
114
115    async fn open_tcp(
116        addr: &str,
117        socks5: Option<&crate::socks5::Socks5Config>,
118    ) -> Result<TcpStream, InvocationError> {
119        match socks5 {
120            Some(proxy) => proxy.connect(addr).await,
121            None => Ok(TcpStream::connect(addr).await?),
122        }
123    }
124
125    async fn send_transport_init(
126        stream: &mut TcpStream,
127        transport: &TransportKind,
128        dc_id: i16,
129    ) -> Result<(), InvocationError> {
130        match transport {
131            TransportKind::Abridged => {
132                stream.write_all(&[0xef]).await?;
133            }
134            TransportKind::Intermediate => {
135                stream.write_all(&[0xee, 0xee, 0xee, 0xee]).await?;
136            }
137            TransportKind::Full => {}
138            TransportKind::Obfuscated { secret } => {
139                use sha2::Digest;
140                let mut nonce = [0u8; 64];
141                loop {
142                    getrandom::getrandom(&mut nonce)
143                        .map_err(|_| InvocationError::Deserialize("getrandom".into()))?;
144                    let first = u32::from_le_bytes(nonce[0..4].try_into().unwrap());
145                    let second = u32::from_le_bytes(nonce[4..8].try_into().unwrap());
146                    let bad = nonce[0] == 0xEF
147                        || first == 0x44414548
148                        || first == 0x54534F50
149                        || first == 0x20544547
150                        || first == 0xEEEEEEEE
151                        || first == 0xDDDDDDDD
152                        || first == 0x02010316
153                        || second == 0x00000000;
154                    if !bad {
155                        break;
156                    }
157                }
158                let tx_raw: [u8; 32] = nonce[8..40].try_into().unwrap();
159                let tx_iv: [u8; 16] = nonce[40..56].try_into().unwrap();
160                let mut rev48 = nonce[8..56].to_vec();
161                rev48.reverse();
162                let rx_raw: [u8; 32] = rev48[0..32].try_into().unwrap();
163                let rx_iv: [u8; 16] = rev48[32..48].try_into().unwrap();
164                let (tx_key, rx_key): ([u8; 32], [u8; 32]) = if let Some(s) = secret {
165                    let mut h = sha2::Sha256::new();
166                    h.update(tx_raw);
167                    h.update(s.as_ref());
168                    let tx: [u8; 32] = h.finalize().into();
169                    let mut h = sha2::Sha256::new();
170                    h.update(rx_raw);
171                    h.update(s.as_ref());
172                    let rx: [u8; 32] = h.finalize().into();
173                    (tx, rx)
174                } else {
175                    (tx_raw, rx_raw)
176                };
177                nonce[56] = 0xef;
178                nonce[57] = 0xef;
179                nonce[58] = 0xef;
180                nonce[59] = 0xef;
181                let dc_bytes = dc_id.to_le_bytes();
182                nonce[60] = dc_bytes[0];
183                nonce[61] = dc_bytes[1];
184                {
185                    let mut enc =
186                        layer_crypto::ObfuscatedCipher::from_keys(&tx_key, &tx_iv, &rx_key, &rx_iv);
187                    let mut skip = [0u8; 56];
188                    enc.encrypt(&mut skip);
189                    enc.encrypt(&mut nonce[56..64]);
190                }
191                stream.write_all(&nonce).await?;
192            }
193            // PaddedIntermediate and FakeTls are handled by the main Connection path
194            // (lib.rs apply_transport_init).  DcPool connections always use the
195            // transport supplied by the caller if a 0xDD/0xEE proxy is used,
196            // the caller should open the stream through Connection::open_stream_mtproxy
197            // and not use DcPool::connect_raw.  Treat these as Abridged fallback so
198            // dc_pool.rs compiles cleanly for non-proxy aux-DC connections.
199            TransportKind::PaddedIntermediate { .. } | TransportKind::FakeTls { .. } => {
200                stream.write_all(&[0xef]).await?;
201            }
202        }
203        Ok(())
204    }
205
206    pub fn auth_key_bytes(&self) -> [u8; 256] {
207        self.enc.auth_key_bytes()
208    }
209    pub fn first_salt(&self) -> i64 {
210        self.enc.salt
211    }
212    pub fn time_offset(&self) -> i32 {
213        self.enc.time_offset
214    }
215
216    pub async fn rpc_call<R: RemoteCall>(&mut self, req: &R) -> Result<Vec<u8>, InvocationError> {
217        let wire = self.enc.pack(req);
218        Self::send_abridged(&mut self.stream, &wire).await?;
219        self.recv_rpc().await
220    }
221
222    async fn recv_rpc(&mut self) -> Result<Vec<u8>, InvocationError> {
223        loop {
224            let mut raw = Self::recv_abridged(&mut self.stream).await?;
225            let msg = self
226                .enc
227                .unpack(&mut raw)
228                .map_err(|e| InvocationError::Deserialize(e.to_string()))?;
229            if msg.salt != 0 {
230                self.enc.salt = msg.salt;
231            }
232            if msg.body.len() < 4 {
233                return Ok(msg.body);
234            }
235            let cid = u32::from_le_bytes(msg.body[..4].try_into().unwrap());
236            match cid {
237                0xf35c6d01 /* rpc_result */ => {
238                    if msg.body.len() >= 12 { return Ok(msg.body[12..].to_vec()); }
239                    return Ok(msg.body);
240                }
241                0x2144ca19 /* rpc_error */ => {
242                    if msg.body.len() < 8 {
243                        return Err(InvocationError::Deserialize("rpc_error short".into()));
244                    }
245                    let code = i32::from_le_bytes(msg.body[4..8].try_into().unwrap());
246                    let message = tl_read_string(&msg.body[8..]).unwrap_or_default();
247                    return Err(InvocationError::Rpc(crate::RpcError::from_telegram(code, &message)));
248                }
249                0x347773c5 | 0x62d6b459 | 0x9ec20908 | 0xedab447b | 0xa7eff811 => continue,
250                _ => return Ok(msg.body),
251            }
252        }
253    }
254
255    async fn send_abridged(stream: &mut TcpStream, data: &[u8]) -> Result<(), InvocationError> {
256        let words = data.len() / 4;
257        if words < 0x7f {
258            stream.write_all(&[words as u8]).await?;
259        } else {
260            stream
261                .write_all(&[
262                    0x7f,
263                    (words & 0xff) as u8,
264                    ((words >> 8) & 0xff) as u8,
265                    ((words >> 16) & 0xff) as u8,
266                ])
267                .await?;
268        }
269        stream.write_all(data).await?;
270        Ok(())
271    }
272
273    async fn recv_abridged(stream: &mut TcpStream) -> Result<Vec<u8>, InvocationError> {
274        let mut h = [0u8; 1];
275        stream.read_exact(&mut h).await?;
276        let words = if h[0] < 0x7f {
277            h[0] as usize
278        } else {
279            let mut b = [0u8; 3];
280            stream.read_exact(&mut b).await?;
281            b[0] as usize | (b[1] as usize) << 8 | (b[2] as usize) << 16
282        };
283        let mut buf = vec![0u8; words * 4];
284        stream.read_exact(&mut buf).await?;
285        Ok(buf)
286    }
287
288    async fn send_plain_frame(stream: &mut TcpStream, data: &[u8]) -> Result<(), InvocationError> {
289        Self::send_abridged(stream, data).await
290    }
291
292    async fn recv_plain_frame<T: Deserializable>(
293        stream: &mut TcpStream,
294    ) -> Result<T, InvocationError> {
295        let raw = Self::recv_abridged(stream).await?;
296        if raw.len() < 20 {
297            return Err(InvocationError::Deserialize("plain frame too short".into()));
298        }
299        if u64::from_le_bytes(raw[..8].try_into().unwrap()) != 0 {
300            return Err(InvocationError::Deserialize(
301                "expected auth_key_id=0 in plaintext".into(),
302            ));
303        }
304        let body_len = u32::from_le_bytes(raw[16..20].try_into().unwrap()) as usize;
305        let mut cur = Cursor::from_slice(&raw[20..20 + body_len]);
306        T::deserialize(&mut cur).map_err(Into::into)
307    }
308}
309
310fn tl_read_bytes(data: &[u8]) -> Option<Vec<u8>> {
311    if data.is_empty() {
312        return Some(vec![]);
313    }
314    let (len, start) = if data[0] < 254 {
315        (data[0] as usize, 1)
316    } else if data.len() >= 4 {
317        (
318            data[1] as usize | (data[2] as usize) << 8 | (data[3] as usize) << 16,
319            4,
320        )
321    } else {
322        return None;
323    };
324    if data.len() < start + len {
325        return None;
326    }
327    Some(data[start..start + len].to_vec())
328}
329
330fn tl_read_string(data: &[u8]) -> Option<String> {
331    tl_read_bytes(data).map(|b| String::from_utf8_lossy(&b).into_owned())
332}
333
334// DcPool
335
336/// Pool of per-DC authenticated connections.
337pub struct DcPool {
338    conns: HashMap<i32, DcConnection>,
339    addrs: HashMap<i32, String>,
340    #[allow(dead_code)]
341    home_dc_id: i32,
342}
343
344impl DcPool {
345    pub fn new(home_dc_id: i32, dc_entries: &[DcEntry]) -> Self {
346        let addrs = dc_entries
347            .iter()
348            .map(|e| (e.dc_id, e.addr.clone()))
349            .collect();
350        Self {
351            conns: HashMap::new(),
352            addrs,
353            home_dc_id,
354        }
355    }
356
357    /// Returns true if a connection for `dc_id` already exists in the pool.
358    pub fn has_connection(&self, dc_id: i32) -> bool {
359        self.conns.contains_key(&dc_id)
360    }
361
362    /// Insert a pre-built connection into the pool.
363    pub fn insert(&mut self, dc_id: i32, conn: DcConnection) {
364        self.conns.insert(dc_id, conn);
365    }
366
367    /// Invoke a raw RPC call on the given DC.
368    pub async fn invoke_on_dc<R: RemoteCall>(
369        &mut self,
370        dc_id: i32,
371        _dc_entries: &[DcEntry],
372        req: &R,
373    ) -> Result<Vec<u8>, InvocationError> {
374        let conn = self
375            .conns
376            .get_mut(&dc_id)
377            .ok_or_else(|| InvocationError::Deserialize(format!("no connection for DC{dc_id}")))?;
378        conn.rpc_call(req).await
379    }
380
381    /// Update the address table (called after `initConnection`).
382    pub fn update_addrs(&mut self, entries: &[DcEntry]) {
383        for e in entries {
384            self.addrs.insert(e.dc_id, e.addr.clone());
385        }
386    }
387
388    /// Save the auth keys from pool connections back into the DC entry list.
389    pub fn collect_keys(&self, entries: &mut [DcEntry]) {
390        for e in entries.iter_mut() {
391            if let Some(conn) = self.conns.get(&e.dc_id) {
392                e.auth_key = Some(conn.auth_key_bytes());
393                e.first_salt = conn.first_salt();
394                e.time_offset = conn.time_offset();
395            }
396        }
397    }
398}