1use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
10use std::ops::{ControlFlow, Deref};
11use std::sync::Arc;
12use std::{fmt, panic};
13
14use grammers_mtproto::{mtp, transport};
15use grammers_session::types::{DcOption, PeerId, PeerInfo, PeerRef, UpdateState, UpdatesState};
16use grammers_session::updates::UpdatesLike;
17use grammers_session::{BoxFuture, ErasedSession, Session};
18use grammers_tl_types::{self as tl, enums};
19use tokio::task::AbortHandle;
20use tokio::{
21 sync::{mpsc, oneshot},
22 task::JoinSet,
23};
24
25use crate::configuration::ConnectionParams;
26use crate::errors::ReadError;
27use crate::{InvocationError, Sender, ServerAddr, connect, connect_with_auth};
28
29pub(crate) type Transport = transport::Full;
30
31type InvokeResponse = Vec<u8>;
32
33enum Request {
34 Invoke {
35 dc_id: i32,
36 body: Vec<u8>,
37 tx: oneshot::Sender<Result<InvokeResponse, InvocationError>>,
38 },
39 Disconnect {
40 dc_id: i32,
41 },
42 Quit,
43}
44
45struct Rpc {
46 body: Vec<u8>,
47 tx: oneshot::Sender<Result<InvokeResponse, InvocationError>>,
48}
49
50struct ConnectionInfo {
51 dc_id: i32,
52 rpc_tx: mpsc::UnboundedSender<Rpc>,
53 abort_handle: AbortHandle,
54}
55
56#[derive(Clone)]
58pub struct SenderPoolFatHandle {
59 pub thin: SenderPoolHandle,
63 pub session: Arc<ErasedSession>,
67 pub api_id: i32,
72}
73
74#[derive(Clone)]
76pub struct SenderPoolHandle(mpsc::UnboundedSender<Request>);
77
78pub struct SenderPool {
80 pub runner: SenderPoolRunner,
85 pub handle: SenderPoolFatHandle,
89 pub updates: mpsc::UnboundedReceiver<UpdatesLike>,
95}
96
97pub struct SenderPoolRunner {
101 session: Arc<ErasedSession>,
102 api_id: i32,
103 connection_params: ConnectionParams,
104 request_rx: mpsc::UnboundedReceiver<Request>,
105 updates_tx: mpsc::UnboundedSender<UpdatesLike>,
106 connections: Vec<ConnectionInfo>,
107 connection_pool: JoinSet<Result<(), ReadError>>,
108}
109
110impl Deref for SenderPoolFatHandle {
111 type Target = SenderPoolHandle;
112
113 fn deref(&self) -> &Self::Target {
114 &self.thin
115 }
116}
117
118impl SenderPoolHandle {
119 pub async fn invoke_in_dc(
122 &self,
123 dc_id: i32,
124 body: Vec<u8>,
125 ) -> Result<InvokeResponse, InvocationError> {
126 let (tx, rx) = oneshot::channel();
127 self.0
128 .send(Request::Invoke { dc_id, body, tx })
129 .map_err(|_| InvocationError::Dropped)?;
130 rx.await.map_err(|_| InvocationError::Dropped)?
131 }
132
133 pub fn disconnect_from_dc(&self, dc_id: i32) -> bool {
141 self.0.send(Request::Disconnect { dc_id }).is_ok()
142 }
143
144 pub fn quit(&self) -> bool {
147 self.0.send(Request::Quit).is_ok()
148 }
149}
150
151impl SenderPool {
152 pub fn new<S>(session: Arc<S>, api_id: i32) -> Self
161 where
162 S: Session + Sized,
163 S::Error: std::error::Error + Send + Sync + 'static,
164 {
165 Self::with_configuration(session, api_id, Default::default())
166 }
167
168 pub fn with_configuration<S>(
170 session: Arc<S>,
171 api_id: i32,
172 connection_params: ConnectionParams,
173 ) -> Self
174 where
175 S: Session + Sized,
176 S::Error: std::error::Error + Send + Sync + 'static,
177 {
178 let session: Arc<ErasedSession> = Arc::new(Eraser(session));
179 let (request_tx, request_rx) = mpsc::unbounded_channel();
180 let (updates_tx, updates_rx) = mpsc::unbounded_channel();
181
182 Self {
183 runner: SenderPoolRunner {
184 session: Arc::clone(&session),
185 api_id,
186 connection_params,
187 request_rx,
188 updates_tx,
189 connections: Vec::new(),
190 connection_pool: JoinSet::new(),
191 },
192 handle: SenderPoolFatHandle {
193 thin: SenderPoolHandle(request_tx),
194 session,
195 api_id,
196 },
197 updates: updates_rx,
198 }
199 }
200}
201
202impl SenderPoolRunner {
203 pub async fn run(mut self) {
207 loop {
208 tokio::select! {
209 biased;
210 completion = self.connection_pool.join_next(), if !self.connection_pool.is_empty() => {
211 if let Err(err) = completion.unwrap() {
212 if let Ok(reason) = err.try_into_panic() {
213 panic::resume_unwind(reason);
214 }
215 }
216 self.connections
217 .retain(|connection| !connection.abort_handle.is_finished());
218 }
219 request = self.request_rx.recv() => {
220 let flow = if let Some(request) = request {
221 self.process_request(request).await
222 } else {
223 ControlFlow::Break(())
224 };
225 match flow {
226 ControlFlow::Continue(_) => continue,
227 ControlFlow::Break(_) => break,
228 }
229 }
230 }
231 }
232
233 self.connections.clear(); self.connection_pool.join_all().await;
235 }
236
237 async fn process_request(&mut self, request: Request) -> ControlFlow<()> {
238 match request {
239 Request::Invoke { dc_id, body, tx } => {
240 let connection = match self
241 .connections
242 .iter()
243 .find(|connection| connection.dc_id == dc_id)
244 {
245 Some(connection) => connection,
246 None => match self.create_connection(dc_id).await {
247 Ok(x) => x,
248 Err(e) => {
249 let _ = tx.send(Err(e));
250 return ControlFlow::Continue(());
251 }
252 },
253 };
254 let _ = connection.rpc_tx.send(Rpc { body, tx });
255 ControlFlow::Continue(())
256 }
257 Request::Disconnect { dc_id } => {
258 self.connections.retain(|connection| {
259 if connection.dc_id == dc_id {
260 connection.abort_handle.abort();
261 false
262 } else {
263 true
264 }
265 });
266 ControlFlow::Continue(())
267 }
268 Request::Quit => ControlFlow::Break(()),
269 }
270 }
271
272 async fn create_connection(&mut self, dc_id: i32) -> Result<&ConnectionInfo, InvocationError> {
273 let mut dc_option = match self.session.dc_option(dc_id)? {
274 Some(x) => x,
275 None => return Err(InvocationError::InvalidDc),
276 };
277
278 let sender = self.connect_sender(&dc_option).await?;
279
280 dc_option.auth_key = Some(sender.auth_key());
281 self.session.set_dc_option(&dc_option).await?;
282
283 let (rpc_tx, rpc_rx) = mpsc::unbounded_channel();
284 let abort_handle = self.connection_pool.spawn(run_sender(
285 sender,
286 rpc_rx,
287 self.updates_tx.clone(),
288 dc_option.id == self.session.home_dc_id()?,
289 ));
290 self.connections.push(ConnectionInfo {
291 dc_id,
292 rpc_tx,
293 abort_handle,
294 });
295 Ok(self.connections.last().unwrap())
296 }
297
298 async fn connect_sender(
299 &mut self,
300 dc_option: &DcOption,
301 ) -> Result<Sender<transport::Full, mtp::Encrypted>, InvocationError> {
302 let transport = transport::Full::new;
303
304 let address = if self.connection_params.use_ipv6 {
305 dc_option.ipv6.into()
306 } else {
307 dc_option.ipv4.into()
308 };
309
310 #[cfg(feature = "proxy")]
311 let addr = || {
312 if let Some(proxy) = self.connection_params.proxy_url.clone() {
313 ServerAddr::Proxied { address, proxy }
314 } else {
315 ServerAddr::Tcp { address }
316 }
317 };
318 #[cfg(not(feature = "proxy"))]
319 let addr = || ServerAddr::Tcp { address };
320
321 let init_connection = tl::functions::InvokeWithLayer {
322 layer: tl::LAYER,
323 query: tl::functions::InitConnection {
324 api_id: self.api_id,
325 device_model: self.connection_params.device_model.clone(),
326 system_version: self.connection_params.system_version.clone(),
327 app_version: self.connection_params.app_version.clone(),
328 system_lang_code: self.connection_params.system_lang_code.clone(),
329 lang_pack: "".into(),
330 lang_code: self.connection_params.lang_code.clone(),
331 proxy: None,
332 params: None,
333 query: tl::functions::help::GetConfig {},
334 },
335 };
336
337 let mut sender = if let Some(auth_key) = dc_option.auth_key {
338 connect_with_auth(transport(), addr(), auth_key)
339 .await
340 .map_err(InvocationError::Io)?
341 } else {
342 connect(transport(), addr()).await?
343 };
344
345 let enums::Config::Config(remote_config) = match sender.invoke(&init_connection).await {
346 Ok(config) => config,
347 Err(InvocationError::Transport(transport::Error::BadStatus { status: 404 })) => {
348 sender = connect(transport(), addr()).await?;
349 sender.invoke(&init_connection).await?
350 }
351 Err(e) => return Err(e),
352 };
353
354 self.update_config(remote_config).await?;
355
356 Ok(sender)
357 }
358
359 async fn update_config(&mut self, config: tl::types::Config) -> Result<(), InvocationError> {
360 for option in config
361 .dc_options
362 .iter()
363 .map(|tl::enums::DcOption::Option(option)| option)
364 .filter(|option| !option.media_only && !option.tcpo_only && option.r#static)
365 {
366 let mut dc_option = self
367 .session
368 .dc_option(option.id)?
369 .unwrap_or_else(|| DcOption {
370 id: option.id,
371 ipv4: SocketAddrV4::new(Ipv4Addr::from_bits(0), 0),
372 ipv6: SocketAddrV6::new(Ipv6Addr::from_bits(0), 0, 0, 0),
373 auth_key: None,
374 });
375 if option.ipv6 {
376 dc_option.ipv6 = SocketAddrV6::new(
377 option
378 .ip_address
379 .parse()
380 .expect("Telegram to return a valid IPv6 address"),
381 option.port as _,
382 0,
383 0,
384 );
385 } else {
386 dc_option.ipv4 = SocketAddrV4::new(
387 option
388 .ip_address
389 .parse()
390 .expect("Telegram to return a valid IPv4 address"),
391 option.port as _,
392 );
393 if dc_option.ipv6.ip().to_bits() == 0 {
394 dc_option.ipv6 = SocketAddrV6::new(
395 dc_option.ipv4.ip().to_ipv6_mapped(),
396 dc_option.ipv4.port(),
397 0,
398 0,
399 )
400 }
401 }
402 }
403 Ok(())
404 }
405}
406
407async fn run_sender(
408 mut sender: Sender<Transport, grammers_mtproto::mtp::Encrypted>,
409 mut rpc_rx: mpsc::UnboundedReceiver<Rpc>,
410 updates: mpsc::UnboundedSender<UpdatesLike>,
411 home_sender: bool,
412) -> Result<(), ReadError> {
413 loop {
414 tokio::select! {
415 step = sender.step() => match step {
416 Ok(all_new_updates) => all_new_updates.into_iter().for_each(|new_updates| {
417 let _ = updates.send(new_updates);
418 }),
419 Err(err) => {
420 if home_sender {
421 let _ = updates.send(UpdatesLike::ConnectionClosed);
422 }
423 break Err(err)
424 },
425 },
426 rpc = rpc_rx.recv() => match rpc {
427 Some(rpc) => sender.enqueue_body(rpc.body, rpc.tx),
428 None => break Ok(()),
429 },
430 }
431 }
432}
433
434impl fmt::Debug for Request {
435 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436 match self {
437 Self::Invoke { dc_id, body, tx } => f
438 .debug_struct("Invoke")
439 .field("dc_id", dc_id)
440 .field(
441 "request",
442 &body[..4]
443 .try_into()
444 .map(|constructor_id| tl::name_for_id(u32::from_le_bytes(constructor_id)))
445 .unwrap_or("?"),
446 )
447 .field("tx", tx)
448 .finish(),
449 Self::Disconnect { dc_id } => {
450 f.debug_struct("Disconnect").field("dc_id", dc_id).finish()
451 }
452 Self::Quit => write!(f, "Quit"),
453 }
454 }
455}
456
457struct Eraser<S: Session>(Arc<S>);
458
459impl<S> Session for Eraser<S>
460where
461 S: Session,
462 S::Error: std::error::Error + Send + Sync,
463{
464 type Error = Box<dyn std::error::Error + Send + Sync>;
465
466 fn home_dc_id(&self) -> Result<i32, Self::Error> {
467 Arc::clone(&self.0).home_dc_id().map_err(|e| e.into())
468 }
469
470 fn set_home_dc_id(&self, dc_id: i32) -> BoxFuture<'_, Result<(), Self::Error>> {
471 Box::pin(async move {
472 Arc::clone(&self.0)
473 .set_home_dc_id(dc_id)
474 .await
475 .map_err(|e| e.into())
476 })
477 }
478
479 fn dc_option(&self, dc_id: i32) -> Result<Option<DcOption>, Self::Error> {
480 Arc::clone(&self.0).dc_option(dc_id).map_err(|e| e.into())
481 }
482
483 fn set_dc_option(&self, dc_option: &DcOption) -> BoxFuture<'_, Result<(), Self::Error>> {
484 let dc_option = dc_option.clone();
485 Box::pin(async move {
486 Arc::clone(&self.0)
487 .set_dc_option(&dc_option)
488 .await
489 .map_err(|e| e.into())
490 })
491 }
492
493 fn peer(&self, peer: PeerId) -> BoxFuture<'_, Result<Option<PeerInfo>, Self::Error>> {
494 Box::pin(async move { Arc::clone(&self.0).peer(peer).await.map_err(|e| e.into()) })
495 }
496
497 fn peer_ref(&self, peer: PeerId) -> BoxFuture<'_, Result<Option<PeerRef>, Self::Error>> {
498 Box::pin(async move {
499 Arc::clone(&self.0)
500 .peer_ref(peer)
501 .await
502 .map_err(|e| e.into())
503 })
504 }
505
506 fn cache_peer(&self, peer: &PeerInfo) -> BoxFuture<'_, Result<(), Self::Error>> {
507 let peer = peer.clone();
508 Box::pin(async move {
509 Arc::clone(&self.0)
510 .cache_peer(&peer)
511 .await
512 .map_err(|e| e.into())
513 })
514 }
515
516 fn updates_state(&self) -> BoxFuture<'_, Result<UpdatesState, Self::Error>> {
517 Box::pin(async {
518 Arc::clone(&self.0)
519 .updates_state()
520 .await
521 .map_err(|e| e.into())
522 })
523 }
524
525 fn set_update_state(&self, update: UpdateState) -> BoxFuture<'_, Result<(), Self::Error>> {
526 Box::pin(async {
527 Arc::clone(&self.0)
528 .set_update_state(update)
529 .await
530 .map_err(|e| e.into())
531 })
532 }
533}