Skip to main content

layer_client/
dc_pool.rs

1//! Multi-DC connection pool.
2//!
3//! Maintains one authenticated [`DcConnection`] per DC ID and routes RPC calls
4//! to the correct DC automatically.  Auth keys are shared from the home DC via
5//! `auth.exportAuthorization` / `auth.importAuthorization`.
6
7use std::collections::HashMap;
8use layer_tl_types as tl;
9use layer_tl_types::{Cursor, Deserializable, RemoteCall};
10use layer_mtproto::{EncryptedSession, Session, authentication as auth};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13
14use crate::{InvocationError, TransportKind, session::DcEntry};
15
16// ─── DcConnection ─────────────────────────────────────────────────────────────
17
18/// A single encrypted connection to one Telegram DC.
19pub struct DcConnection {
20    stream: TcpStream,
21    enc:    EncryptedSession,
22}
23
24impl DcConnection {
25    /// Connect and perform full DH handshake.
26    pub async fn connect_raw(
27        addr:      &str,
28        socks5:    Option<&crate::socks5::Socks5Config>,
29        transport: &TransportKind,
30    ) -> Result<Self, InvocationError> {
31        log::info!("[dc_pool] Connecting to {addr} …");
32        let mut stream = Self::open_tcp(addr, socks5).await?;
33        Self::send_transport_init(&mut stream, transport).await?;
34
35        let mut plain = Session::new();
36
37        let (req1, s1) = auth::step1().map_err(|e| InvocationError::Deserialize(e.to_string()))?;
38        Self::send_plain_frame(&mut stream, &plain.pack(&req1).to_plaintext_bytes()).await?;
39        let res_pq: tl::enums::ResPq = Self::recv_plain_frame(&mut stream).await?;
40
41        let (req2, s2) = auth::step2(s1, res_pq).map_err(|e| InvocationError::Deserialize(e.to_string()))?;
42        Self::send_plain_frame(&mut stream, &plain.pack(&req2).to_plaintext_bytes()).await?;
43        let dh: tl::enums::ServerDhParams = Self::recv_plain_frame(&mut stream).await?;
44
45        let (req3, s3) = auth::step3(s2, dh).map_err(|e| InvocationError::Deserialize(e.to_string()))?;
46        Self::send_plain_frame(&mut stream, &plain.pack(&req3).to_plaintext_bytes()).await?;
47        let ans: tl::enums::SetClientDhParamsAnswer = Self::recv_plain_frame(&mut stream).await?;
48
49        let done = auth::finish(s3, ans).map_err(|e| InvocationError::Deserialize(e.to_string()))?;
50        log::info!("[dc_pool] DH complete ✓ for {addr}");
51
52        Ok(Self {
53            stream,
54            enc: EncryptedSession::new(done.auth_key, done.first_salt, done.time_offset),
55        })
56    }
57
58    /// Connect with an already-known auth key (no DH needed).
59    pub async fn connect_with_key(
60        addr:        &str,
61        auth_key:    [u8; 256],
62        first_salt:  i64,
63        time_offset: i32,
64        socks5:      Option<&crate::socks5::Socks5Config>,
65        transport:   &TransportKind,
66    ) -> Result<Self, InvocationError> {
67        let mut stream = Self::open_tcp(addr, socks5).await?;
68        Self::send_transport_init(&mut stream, transport).await?;
69        Ok(Self {
70            stream,
71            enc: EncryptedSession::new(auth_key, first_salt, time_offset),
72        })
73    }
74
75    async fn open_tcp(
76        addr:   &str,
77        socks5: Option<&crate::socks5::Socks5Config>,
78    ) -> Result<TcpStream, InvocationError> {
79        match socks5 {
80            Some(proxy) => proxy.connect(addr).await,
81            None        => Ok(TcpStream::connect(addr).await?),
82        }
83    }
84
85    async fn send_transport_init(
86        stream:    &mut TcpStream,
87        transport: &TransportKind,
88    ) -> Result<(), InvocationError> {
89        match transport {
90            TransportKind::Abridged       => { stream.write_all(&[0xef]).await?; }
91            TransportKind::Intermediate   => { stream.write_all(&[0xee, 0xee, 0xee, 0xee]).await?; }
92            TransportKind::Full           => {} // no init byte
93            TransportKind::Obfuscated { secret } => {
94                let mut nonce = [0u8; 64];
95                getrandom::getrandom(&mut nonce).map_err(|_| InvocationError::Deserialize("getrandom".into()))?;
96                nonce[56] = 0xef; nonce[57] = 0xef; nonce[58] = 0xef; nonce[59] = 0xef;
97                let (enc_key, enc_iv, _, _) = crate::transport_obfuscated::derive_keys(&nonce, secret.as_ref());
98                let mut enc = crate::transport_obfuscated::ObfCipher::new(enc_key, enc_iv);
99                let mut handshake = nonce;
100                enc.apply(&mut handshake[56..]);
101                stream.write_all(&handshake).await?;
102            }
103        }
104        Ok(())
105    }
106
107    pub fn auth_key_bytes(&self) -> [u8; 256] { self.enc.auth_key_bytes() }
108    pub fn first_salt(&self)     -> i64         { self.enc.salt }
109    pub fn time_offset(&self)    -> i32         { self.enc.time_offset }
110
111    pub async fn rpc_call<R: RemoteCall>(&mut self, req: &R) -> Result<Vec<u8>, InvocationError> {
112        let wire = self.enc.pack(req);
113        Self::send_abridged(&mut self.stream, &wire).await?;
114        self.recv_rpc().await
115    }
116
117    async fn recv_rpc(&mut self) -> Result<Vec<u8>, InvocationError> {
118        loop {
119            let mut raw = Self::recv_abridged(&mut self.stream).await?;
120            let msg = self.enc.unpack(&mut raw)
121                .map_err(|e| InvocationError::Deserialize(e.to_string()))?;
122            if msg.salt != 0 { self.enc.salt = msg.salt; }
123            if msg.body.len() < 4 { return Ok(msg.body); }
124            let cid = u32::from_le_bytes(msg.body[..4].try_into().unwrap());
125            match cid {
126                0xf35c6d01 /* rpc_result */ => {
127                    if msg.body.len() >= 12 { return Ok(msg.body[12..].to_vec()); }
128                    return Ok(msg.body);
129                }
130                0x2144ca19 /* rpc_error */ => {
131                    if msg.body.len() < 8 {
132                        return Err(InvocationError::Deserialize("rpc_error short".into()));
133                    }
134                    let code = i32::from_le_bytes(msg.body[4..8].try_into().unwrap());
135                    let message = tl_read_string(&msg.body[8..]).unwrap_or_default();
136                    return Err(InvocationError::Rpc(crate::RpcError::from_telegram(code, &message)));
137                }
138                0x347773c5 | 0x62d6b459 | 0x9ec20908 | 0xedab447b | 0xa7eff811 => continue,
139                _ => return Ok(msg.body),
140            }
141        }
142    }
143
144    async fn send_abridged(stream: &mut TcpStream, data: &[u8]) -> Result<(), InvocationError> {
145        let words = data.len() / 4;
146        if words < 0x7f {
147            stream.write_all(&[words as u8]).await?;
148        } else {
149            stream.write_all(&[0x7f, (words & 0xff) as u8, ((words >> 8) & 0xff) as u8, ((words >> 16) & 0xff) as u8]).await?;
150        }
151        stream.write_all(data).await?;
152        Ok(())
153    }
154
155    async fn recv_abridged(stream: &mut TcpStream) -> Result<Vec<u8>, InvocationError> {
156        let mut h = [0u8; 1];
157        stream.read_exact(&mut h).await?;
158        let words = if h[0] < 0x7f {
159            h[0] as usize
160        } else {
161            let mut b = [0u8; 3];
162            stream.read_exact(&mut b).await?;
163            b[0] as usize | (b[1] as usize) << 8 | (b[2] as usize) << 16
164        };
165        let mut buf = vec![0u8; words * 4];
166        stream.read_exact(&mut buf).await?;
167        Ok(buf)
168    }
169
170    async fn send_plain_frame(stream: &mut TcpStream, data: &[u8]) -> Result<(), InvocationError> {
171        Self::send_abridged(stream, data).await
172    }
173
174    async fn recv_plain_frame<T: Deserializable>(stream: &mut TcpStream) -> Result<T, InvocationError> {
175        let raw = Self::recv_abridged(stream).await?;
176        if raw.len() < 20 {
177            return Err(InvocationError::Deserialize("plain frame too short".into()));
178        }
179        if u64::from_le_bytes(raw[..8].try_into().unwrap()) != 0 {
180            return Err(InvocationError::Deserialize("expected auth_key_id=0 in plaintext".into()));
181        }
182        let body_len = u32::from_le_bytes(raw[16..20].try_into().unwrap()) as usize;
183        let mut cur = Cursor::from_slice(&raw[20..20 + body_len]);
184        T::deserialize(&mut cur).map_err(Into::into)
185    }
186}
187
188fn tl_read_bytes(data: &[u8]) -> Option<Vec<u8>> {
189    if data.is_empty() { return Some(vec![]); }
190    let (len, start) = if data[0] < 254 { (data[0] as usize, 1) }
191    else if data.len() >= 4 {
192        (data[1] as usize | (data[2] as usize) << 8 | (data[3] as usize) << 16, 4)
193    } else { return None; };
194    if data.len() < start + len { return None; }
195    Some(data[start..start + len].to_vec())
196}
197
198fn tl_read_string(data: &[u8]) -> Option<String> {
199    tl_read_bytes(data).map(|b| String::from_utf8_lossy(&b).into_owned())
200}
201
202// ─── DcPool ───────────────────────────────────────────────────────────────────
203
204/// Pool of per-DC authenticated connections.
205pub struct DcPool {
206    conns:      HashMap<i32, DcConnection>,
207    addrs:      HashMap<i32, String>,
208    #[allow(dead_code)]
209    home_dc_id: i32,
210}
211
212impl DcPool {
213    pub fn new(home_dc_id: i32, dc_entries: &[DcEntry]) -> Self {
214        let addrs = dc_entries.iter().map(|e| (e.dc_id, e.addr.clone())).collect();
215        Self { conns: HashMap::new(), addrs, home_dc_id }
216    }
217
218    /// Returns true if a connection for `dc_id` already exists in the pool.
219    pub fn has_connection(&self, dc_id: i32) -> bool {
220        self.conns.contains_key(&dc_id)
221    }
222
223    /// Insert a pre-built connection into the pool.
224    pub fn insert(&mut self, dc_id: i32, conn: DcConnection) {
225        self.conns.insert(dc_id, conn);
226    }
227
228    /// Invoke a raw RPC call on the given DC.
229    pub async fn invoke_on_dc<R: RemoteCall>(
230        &mut self,
231        dc_id:      i32,
232        _dc_entries: &[DcEntry],
233        req:        &R,
234    ) -> Result<Vec<u8>, InvocationError> {
235        let conn = self.conns.get_mut(&dc_id)
236            .ok_or_else(|| InvocationError::Deserialize(format!("no connection for DC{dc_id}")))?;
237        conn.rpc_call(req).await
238    }
239
240    /// Update the address table (called after `initConnection`).
241    pub fn update_addrs(&mut self, entries: &[DcEntry]) {
242        for e in entries { self.addrs.insert(e.dc_id, e.addr.clone()); }
243    }
244
245    /// Save the auth keys from pool connections back into the DC entry list.
246    pub fn collect_keys(&self, entries: &mut Vec<DcEntry>) {
247        for e in entries.iter_mut() {
248            if let Some(conn) = self.conns.get(&e.dc_id) {
249                e.auth_key    = Some(conn.auth_key_bytes());
250                e.first_salt  = conn.first_salt();
251                e.time_offset = conn.time_offset();
252            }
253        }
254    }
255}