1use std::collections::HashMap;
8use layer_tl_types as tl;
9use layer_tl_types::{Cursor, Deserializable, RemoteCall};
10use layer_mtproto::{EncryptedSession, Session, authentication as auth};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13
14use crate::{InvocationError, TransportKind, session::DcEntry};
15
16pub struct DcConnection {
20 stream: TcpStream,
21 enc: EncryptedSession,
22}
23
24impl DcConnection {
25 pub async fn connect_raw(
27 addr: &str,
28 socks5: Option<&crate::socks5::Socks5Config>,
29 transport: &TransportKind,
30 ) -> Result<Self, InvocationError> {
31 log::info!("[dc_pool] Connecting to {addr} …");
32 let mut stream = Self::open_tcp(addr, socks5).await?;
33 Self::send_transport_init(&mut stream, transport).await?;
34
35 let mut plain = Session::new();
36
37 let (req1, s1) = auth::step1().map_err(|e| InvocationError::Deserialize(e.to_string()))?;
38 Self::send_plain_frame(&mut stream, &plain.pack(&req1).to_plaintext_bytes()).await?;
39 let res_pq: tl::enums::ResPq = Self::recv_plain_frame(&mut stream).await?;
40
41 let (req2, s2) = auth::step2(s1, res_pq).map_err(|e| InvocationError::Deserialize(e.to_string()))?;
42 Self::send_plain_frame(&mut stream, &plain.pack(&req2).to_plaintext_bytes()).await?;
43 let dh: tl::enums::ServerDhParams = Self::recv_plain_frame(&mut stream).await?;
44
45 let (req3, s3) = auth::step3(s2, dh).map_err(|e| InvocationError::Deserialize(e.to_string()))?;
46 Self::send_plain_frame(&mut stream, &plain.pack(&req3).to_plaintext_bytes()).await?;
47 let ans: tl::enums::SetClientDhParamsAnswer = Self::recv_plain_frame(&mut stream).await?;
48
49 let done = auth::finish(s3, ans).map_err(|e| InvocationError::Deserialize(e.to_string()))?;
50 log::info!("[dc_pool] DH complete ✓ for {addr}");
51
52 Ok(Self {
53 stream,
54 enc: EncryptedSession::new(done.auth_key, done.first_salt, done.time_offset),
55 })
56 }
57
58 pub async fn connect_with_key(
60 addr: &str,
61 auth_key: [u8; 256],
62 first_salt: i64,
63 time_offset: i32,
64 socks5: Option<&crate::socks5::Socks5Config>,
65 transport: &TransportKind,
66 ) -> Result<Self, InvocationError> {
67 let mut stream = Self::open_tcp(addr, socks5).await?;
68 Self::send_transport_init(&mut stream, transport).await?;
69 Ok(Self {
70 stream,
71 enc: EncryptedSession::new(auth_key, first_salt, time_offset),
72 })
73 }
74
75 async fn open_tcp(
76 addr: &str,
77 socks5: Option<&crate::socks5::Socks5Config>,
78 ) -> Result<TcpStream, InvocationError> {
79 match socks5 {
80 Some(proxy) => proxy.connect(addr).await,
81 None => Ok(TcpStream::connect(addr).await?),
82 }
83 }
84
85 async fn send_transport_init(
86 stream: &mut TcpStream,
87 transport: &TransportKind,
88 ) -> Result<(), InvocationError> {
89 match transport {
90 TransportKind::Abridged => { stream.write_all(&[0xef]).await?; }
91 TransportKind::Intermediate => { stream.write_all(&[0xee, 0xee, 0xee, 0xee]).await?; }
92 TransportKind::Full => {} TransportKind::Obfuscated { secret } => {
94 let mut nonce = [0u8; 64];
95 getrandom::getrandom(&mut nonce).map_err(|_| InvocationError::Deserialize("getrandom".into()))?;
96 nonce[56] = 0xef; nonce[57] = 0xef; nonce[58] = 0xef; nonce[59] = 0xef;
97 let (enc_key, enc_iv, _, _) = crate::transport_obfuscated::derive_keys(&nonce, secret.as_ref());
98 let mut enc = crate::transport_obfuscated::ObfCipher::new(enc_key, enc_iv);
99 let mut handshake = nonce;
100 enc.apply(&mut handshake[56..]);
101 stream.write_all(&handshake).await?;
102 }
103 }
104 Ok(())
105 }
106
107 pub fn auth_key_bytes(&self) -> [u8; 256] { self.enc.auth_key_bytes() }
108 pub fn first_salt(&self) -> i64 { self.enc.salt }
109 pub fn time_offset(&self) -> i32 { self.enc.time_offset }
110
111 pub async fn rpc_call<R: RemoteCall>(&mut self, req: &R) -> Result<Vec<u8>, InvocationError> {
112 let wire = self.enc.pack(req);
113 Self::send_abridged(&mut self.stream, &wire).await?;
114 self.recv_rpc().await
115 }
116
117 async fn recv_rpc(&mut self) -> Result<Vec<u8>, InvocationError> {
118 loop {
119 let mut raw = Self::recv_abridged(&mut self.stream).await?;
120 let msg = self.enc.unpack(&mut raw)
121 .map_err(|e| InvocationError::Deserialize(e.to_string()))?;
122 if msg.salt != 0 { self.enc.salt = msg.salt; }
123 if msg.body.len() < 4 { return Ok(msg.body); }
124 let cid = u32::from_le_bytes(msg.body[..4].try_into().unwrap());
125 match cid {
126 0xf35c6d01 => {
127 if msg.body.len() >= 12 { return Ok(msg.body[12..].to_vec()); }
128 return Ok(msg.body);
129 }
130 0x2144ca19 => {
131 if msg.body.len() < 8 {
132 return Err(InvocationError::Deserialize("rpc_error short".into()));
133 }
134 let code = i32::from_le_bytes(msg.body[4..8].try_into().unwrap());
135 let message = tl_read_string(&msg.body[8..]).unwrap_or_default();
136 return Err(InvocationError::Rpc(crate::RpcError::from_telegram(code, &message)));
137 }
138 0x347773c5 | 0x62d6b459 | 0x9ec20908 | 0xedab447b | 0xa7eff811 => continue,
139 _ => return Ok(msg.body),
140 }
141 }
142 }
143
144 async fn send_abridged(stream: &mut TcpStream, data: &[u8]) -> Result<(), InvocationError> {
145 let words = data.len() / 4;
146 if words < 0x7f {
147 stream.write_all(&[words as u8]).await?;
148 } else {
149 stream.write_all(&[0x7f, (words & 0xff) as u8, ((words >> 8) & 0xff) as u8, ((words >> 16) & 0xff) as u8]).await?;
150 }
151 stream.write_all(data).await?;
152 Ok(())
153 }
154
155 async fn recv_abridged(stream: &mut TcpStream) -> Result<Vec<u8>, InvocationError> {
156 let mut h = [0u8; 1];
157 stream.read_exact(&mut h).await?;
158 let words = if h[0] < 0x7f {
159 h[0] as usize
160 } else {
161 let mut b = [0u8; 3];
162 stream.read_exact(&mut b).await?;
163 b[0] as usize | (b[1] as usize) << 8 | (b[2] as usize) << 16
164 };
165 let mut buf = vec![0u8; words * 4];
166 stream.read_exact(&mut buf).await?;
167 Ok(buf)
168 }
169
170 async fn send_plain_frame(stream: &mut TcpStream, data: &[u8]) -> Result<(), InvocationError> {
171 Self::send_abridged(stream, data).await
172 }
173
174 async fn recv_plain_frame<T: Deserializable>(stream: &mut TcpStream) -> Result<T, InvocationError> {
175 let raw = Self::recv_abridged(stream).await?;
176 if raw.len() < 20 {
177 return Err(InvocationError::Deserialize("plain frame too short".into()));
178 }
179 if u64::from_le_bytes(raw[..8].try_into().unwrap()) != 0 {
180 return Err(InvocationError::Deserialize("expected auth_key_id=0 in plaintext".into()));
181 }
182 let body_len = u32::from_le_bytes(raw[16..20].try_into().unwrap()) as usize;
183 let mut cur = Cursor::from_slice(&raw[20..20 + body_len]);
184 T::deserialize(&mut cur).map_err(Into::into)
185 }
186}
187
188fn tl_read_bytes(data: &[u8]) -> Option<Vec<u8>> {
189 if data.is_empty() { return Some(vec![]); }
190 let (len, start) = if data[0] < 254 { (data[0] as usize, 1) }
191 else if data.len() >= 4 {
192 (data[1] as usize | (data[2] as usize) << 8 | (data[3] as usize) << 16, 4)
193 } else { return None; };
194 if data.len() < start + len { return None; }
195 Some(data[start..start + len].to_vec())
196}
197
198fn tl_read_string(data: &[u8]) -> Option<String> {
199 tl_read_bytes(data).map(|b| String::from_utf8_lossy(&b).into_owned())
200}
201
202pub struct DcPool {
206 conns: HashMap<i32, DcConnection>,
207 addrs: HashMap<i32, String>,
208 #[allow(dead_code)]
209 home_dc_id: i32,
210}
211
212impl DcPool {
213 pub fn new(home_dc_id: i32, dc_entries: &[DcEntry]) -> Self {
214 let addrs = dc_entries.iter().map(|e| (e.dc_id, e.addr.clone())).collect();
215 Self { conns: HashMap::new(), addrs, home_dc_id }
216 }
217
218 pub fn has_connection(&self, dc_id: i32) -> bool {
220 self.conns.contains_key(&dc_id)
221 }
222
223 pub fn insert(&mut self, dc_id: i32, conn: DcConnection) {
225 self.conns.insert(dc_id, conn);
226 }
227
228 pub async fn invoke_on_dc<R: RemoteCall>(
230 &mut self,
231 dc_id: i32,
232 _dc_entries: &[DcEntry],
233 req: &R,
234 ) -> Result<Vec<u8>, InvocationError> {
235 let conn = self.conns.get_mut(&dc_id)
236 .ok_or_else(|| InvocationError::Deserialize(format!("no connection for DC{dc_id}")))?;
237 conn.rpc_call(req).await
238 }
239
240 pub fn update_addrs(&mut self, entries: &[DcEntry]) {
242 for e in entries { self.addrs.insert(e.dc_id, e.addr.clone()); }
243 }
244
245 pub fn collect_keys(&self, entries: &mut Vec<DcEntry>) {
247 for e in entries.iter_mut() {
248 if let Some(conn) = self.conns.get(&e.dc_id) {
249 e.auth_key = Some(conn.auth_key_bytes());
250 e.first_salt = conn.first_salt();
251 e.time_offset = conn.time_offset();
252 }
253 }
254 }
255}