1use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
10use std::ops::{ControlFlow, Deref};
11use std::sync::Arc;
12use std::{fmt, panic};
13
14use grammers_mtproto::{mtp, transport};
15use grammers_session::Session;
16use grammers_session::types::DcOption;
17use grammers_session::updates::UpdatesLike;
18use grammers_tl_types::{self as tl, enums};
19use tokio::task::AbortHandle;
20use tokio::{
21 sync::{mpsc, oneshot},
22 task::JoinSet,
23};
24
25use crate::configuration::ConnectionParams;
26use crate::errors::ReadError;
27use crate::{InvocationError, Sender, ServerAddr, connect, connect_with_auth};
28
29pub(crate) type Transport = transport::Full;
30
31type InvokeResponse = Vec<u8>;
32
33enum Request {
34 Invoke {
35 dc_id: i32,
36 body: Vec<u8>,
37 tx: oneshot::Sender<Result<InvokeResponse, InvocationError>>,
38 },
39 Disconnect {
40 dc_id: i32,
41 },
42 Quit,
43}
44
45struct Rpc {
46 body: Vec<u8>,
47 tx: oneshot::Sender<Result<InvokeResponse, InvocationError>>,
48}
49
50struct ConnectionInfo {
51 dc_id: i32,
52 rpc_tx: mpsc::UnboundedSender<Rpc>,
53 abort_handle: AbortHandle,
54}
55
56#[derive(Clone)]
58pub struct SenderPoolFatHandle {
59 pub thin: SenderPoolHandle,
63 pub session: Arc<dyn Session>,
67 pub api_id: i32,
72}
73
74#[derive(Clone)]
76pub struct SenderPoolHandle(mpsc::UnboundedSender<Request>);
77
78pub struct SenderPool {
80 pub runner: SenderPoolRunner,
85 pub handle: SenderPoolFatHandle,
89 pub updates: mpsc::UnboundedReceiver<UpdatesLike>,
95}
96
97pub struct SenderPoolRunner {
101 session: Arc<dyn Session>,
102 api_id: i32,
103 connection_params: ConnectionParams,
104 request_rx: mpsc::UnboundedReceiver<Request>,
105 updates_tx: mpsc::UnboundedSender<UpdatesLike>,
106 connections: Vec<ConnectionInfo>,
107 connection_pool: JoinSet<Result<(), ReadError>>,
108}
109
110impl Deref for SenderPoolFatHandle {
111 type Target = SenderPoolHandle;
112
113 fn deref(&self) -> &Self::Target {
114 &self.thin
115 }
116}
117
118impl SenderPoolHandle {
119 pub async fn invoke_in_dc(
122 &self,
123 dc_id: i32,
124 body: Vec<u8>,
125 ) -> Result<InvokeResponse, InvocationError> {
126 let (tx, rx) = oneshot::channel();
127 self.0
128 .send(Request::Invoke { dc_id, body, tx })
129 .map_err(|_| InvocationError::Dropped)?;
130 rx.await.map_err(|_| InvocationError::Dropped)?
131 }
132
133 pub fn disconnect_from_dc(&self, dc_id: i32) -> bool {
141 self.0.send(Request::Disconnect { dc_id }).is_ok()
142 }
143
144 pub fn quit(&self) -> bool {
147 self.0.send(Request::Quit).is_ok()
148 }
149}
150
151impl SenderPool {
152 pub fn new<S: Session + 'static>(session: Arc<S>, api_id: i32) -> Self {
161 Self::with_configuration(session, api_id, Default::default())
162 }
163
164 pub fn with_configuration<S: Session + 'static>(
166 session: Arc<S>,
167 api_id: i32,
168 connection_params: ConnectionParams,
169 ) -> Self {
170 let (request_tx, request_rx) = mpsc::unbounded_channel();
171 let (updates_tx, updates_rx) = mpsc::unbounded_channel();
172 let session = session as Arc<dyn Session>;
173
174 Self {
175 runner: SenderPoolRunner {
176 session: Arc::clone(&session),
177 api_id,
178 connection_params,
179 request_rx,
180 updates_tx,
181 connections: Vec::new(),
182 connection_pool: JoinSet::new(),
183 },
184 handle: SenderPoolFatHandle {
185 thin: SenderPoolHandle(request_tx),
186 session,
187 api_id,
188 },
189 updates: updates_rx,
190 }
191 }
192}
193
194impl SenderPoolRunner {
195 pub async fn run(mut self) {
199 loop {
200 tokio::select! {
201 biased;
202 completion = self.connection_pool.join_next(), if !self.connection_pool.is_empty() => {
203 if let Err(err) = completion.unwrap() {
204 if let Ok(reason) = err.try_into_panic() {
205 panic::resume_unwind(reason);
206 }
207 }
208 self.connections
209 .retain(|connection| !connection.abort_handle.is_finished());
210 }
211 request = self.request_rx.recv() => {
212 let flow = if let Some(request) = request {
213 self.process_request(request).await
214 } else {
215 ControlFlow::Break(())
216 };
217 match flow {
218 ControlFlow::Continue(_) => continue,
219 ControlFlow::Break(_) => break,
220 }
221 }
222 }
223 }
224
225 self.connections.clear(); self.connection_pool.join_all().await;
227 }
228
229 async fn process_request(&mut self, request: Request) -> ControlFlow<()> {
230 match request {
231 Request::Invoke { dc_id, body, tx } => {
232 let Some(mut dc_option) = self.session.dc_option(dc_id) else {
233 let _ = tx.send(Err(InvocationError::InvalidDc));
234 return ControlFlow::Continue(());
235 };
236
237 let connection = match self
238 .connections
239 .iter()
240 .find(|connection| connection.dc_id == dc_id)
241 {
242 Some(connection) => connection,
243 None => {
244 let sender = match self.connect_sender(&dc_option).await {
245 Ok(t) => t,
246 Err(e) => {
247 let _ = tx.send(Err(e));
248 return ControlFlow::Continue(());
249 }
250 };
251
252 dc_option.auth_key = Some(sender.auth_key());
253 self.session.set_dc_option(&dc_option).await;
254
255 let (rpc_tx, rpc_rx) = mpsc::unbounded_channel();
256 let abort_handle = self.connection_pool.spawn(run_sender(
257 sender,
258 rpc_rx,
259 self.updates_tx.clone(),
260 dc_option.id == self.session.home_dc_id(),
261 ));
262 self.connections.push(ConnectionInfo {
263 dc_id,
264 rpc_tx,
265 abort_handle,
266 });
267 self.connections.last().unwrap()
268 }
269 };
270 let _ = connection.rpc_tx.send(Rpc { body, tx });
271 ControlFlow::Continue(())
272 }
273 Request::Disconnect { dc_id } => {
274 self.connections.retain(|connection| {
275 if connection.dc_id == dc_id {
276 connection.abort_handle.abort();
277 false
278 } else {
279 true
280 }
281 });
282 ControlFlow::Continue(())
283 }
284 Request::Quit => ControlFlow::Break(()),
285 }
286 }
287
288 async fn connect_sender(
289 &mut self,
290 dc_option: &DcOption,
291 ) -> Result<Sender<transport::Full, mtp::Encrypted>, InvocationError> {
292 let transport = transport::Full::new;
293
294 let address = if self.connection_params.use_ipv6 {
295 dc_option.ipv6.into()
296 } else {
297 dc_option.ipv4.into()
298 };
299
300 #[cfg(feature = "proxy")]
301 let addr = || {
302 if let Some(proxy) = self.connection_params.proxy_url.clone() {
303 ServerAddr::Proxied { address, proxy }
304 } else {
305 ServerAddr::Tcp { address }
306 }
307 };
308 #[cfg(not(feature = "proxy"))]
309 let addr = || ServerAddr::Tcp { address };
310
311 let init_connection = tl::functions::InvokeWithLayer {
312 layer: tl::LAYER,
313 query: tl::functions::InitConnection {
314 api_id: self.api_id,
315 device_model: self.connection_params.device_model.clone(),
316 system_version: self.connection_params.system_version.clone(),
317 app_version: self.connection_params.app_version.clone(),
318 system_lang_code: self.connection_params.system_lang_code.clone(),
319 lang_pack: "".into(),
320 lang_code: self.connection_params.lang_code.clone(),
321 proxy: None,
322 params: None,
323 query: tl::functions::help::GetConfig {},
324 },
325 };
326
327 let mut sender = if let Some(auth_key) = dc_option.auth_key {
328 connect_with_auth(transport(), addr(), auth_key).await?
329 } else {
330 connect(transport(), addr()).await?
331 };
332
333 let enums::Config::Config(remote_config) = match sender.invoke(&init_connection).await {
334 Ok(config) => config,
335 Err(InvocationError::Transport(transport::Error::BadStatus { status: 404 })) => {
336 sender = connect(transport(), addr()).await?;
337 sender.invoke(&init_connection).await?
338 }
339 Err(e) => return Err(dbg!(e).into()),
340 };
341
342 self.update_config(remote_config).await;
343
344 Ok(sender)
345 }
346
347 async fn update_config(&mut self, config: tl::types::Config) {
348 for option in config
349 .dc_options
350 .iter()
351 .map(|tl::enums::DcOption::Option(option)| option)
352 .filter(|option| !option.media_only && !option.tcpo_only && option.r#static)
353 {
354 let mut dc_option = self
355 .session
356 .dc_option(option.id)
357 .unwrap_or_else(|| DcOption {
358 id: option.id,
359 ipv4: SocketAddrV4::new(Ipv4Addr::from_bits(0), 0),
360 ipv6: SocketAddrV6::new(Ipv6Addr::from_bits(0), 0, 0, 0),
361 auth_key: None,
362 });
363 if option.ipv6 {
364 dc_option.ipv6 = SocketAddrV6::new(
365 option
366 .ip_address
367 .parse()
368 .expect("Telegram to return a valid IPv6 address"),
369 option.port as _,
370 0,
371 0,
372 );
373 } else {
374 dc_option.ipv4 = SocketAddrV4::new(
375 option
376 .ip_address
377 .parse()
378 .expect("Telegram to return a valid IPv4 address"),
379 option.port as _,
380 );
381 if dc_option.ipv6.ip().to_bits() == 0 {
382 dc_option.ipv6 = SocketAddrV6::new(
383 dc_option.ipv4.ip().to_ipv6_mapped(),
384 dc_option.ipv4.port(),
385 0,
386 0,
387 )
388 }
389 }
390 }
391 }
392}
393
394async fn run_sender(
395 mut sender: Sender<Transport, grammers_mtproto::mtp::Encrypted>,
396 mut rpc_rx: mpsc::UnboundedReceiver<Rpc>,
397 updates: mpsc::UnboundedSender<UpdatesLike>,
398 home_sender: bool,
399) -> Result<(), ReadError> {
400 loop {
401 tokio::select! {
402 step = sender.step() => match step {
403 Ok(all_new_updates) => all_new_updates.into_iter().for_each(|new_updates| {
404 let _ = updates.send(new_updates);
405 }),
406 Err(err) => {
407 if home_sender {
408 let _ = updates.send(UpdatesLike::ConnectionClosed);
409 }
410 break Err(err)
411 },
412 },
413 rpc = rpc_rx.recv() => match rpc {
414 Some(rpc) => sender.enqueue_body(rpc.body, rpc.tx),
415 None => break Ok(()),
416 },
417 }
418 }
419}
420
421impl fmt::Debug for Request {
422 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
423 match self {
424 Self::Invoke { dc_id, body, tx } => f
425 .debug_struct("Invoke")
426 .field("dc_id", dc_id)
427 .field(
428 "request",
429 &body[..4]
430 .try_into()
431 .map(|constructor_id| tl::name_for_id(u32::from_le_bytes(constructor_id)))
432 .unwrap_or("?"),
433 )
434 .field("tx", tx)
435 .finish(),
436 Self::Disconnect { dc_id } => {
437 f.debug_struct("Disconnect").field("dc_id", dc_id).finish()
438 }
439 Self::Quit => write!(f, "Quit"),
440 }
441 }
442}