Skip to main content

grammers_mtsender/
sender_pool.rs

1// Copyright 2020 - developers of the `grammers` project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
10use std::ops::{ControlFlow, Deref};
11use std::sync::Arc;
12use std::{fmt, panic};
13
14use grammers_mtproto::{mtp, transport};
15use grammers_session::types::{DcOption, PeerId, PeerInfo, PeerRef, UpdateState, UpdatesState};
16use grammers_session::updates::UpdatesLike;
17use grammers_session::{BoxFuture, ErasedSession, Session};
18use grammers_tl_types::{self as tl, enums};
19use tokio::task::AbortHandle;
20use tokio::{
21    sync::{mpsc, oneshot},
22    task::JoinSet,
23};
24
25use crate::configuration::ConnectionParams;
26use crate::errors::ReadError;
27use crate::{InvocationError, Sender, ServerAddr, connect, connect_with_auth};
28
29pub(crate) type Transport = transport::Full;
30
31type InvokeResponse = Vec<u8>;
32
33enum Request {
34    Invoke {
35        dc_id: i32,
36        body: Vec<u8>,
37        tx: oneshot::Sender<Result<InvokeResponse, InvocationError>>,
38    },
39    Disconnect {
40        dc_id: i32,
41    },
42    Quit,
43}
44
45struct Rpc {
46    body: Vec<u8>,
47    tx: oneshot::Sender<Result<InvokeResponse, InvocationError>>,
48}
49
50struct ConnectionInfo {
51    dc_id: i32,
52    rpc_tx: mpsc::UnboundedSender<Rpc>,
53    abort_handle: AbortHandle,
54}
55
56/// A fat [`SenderPoolHandle`] with additional metadata from its attached [`SenderPoolRunner`].
57#[derive(Clone)]
58pub struct SenderPoolFatHandle {
59    /// The inner thin handle that self can be derefed into.
60    ///
61    /// The rest of fields can be dropped if they are no longer needed.
62    pub thin: SenderPoolHandle,
63    /// The session in use by the attached [`SenderPoolRunner`].
64    ///
65    /// The runner will read and persist datacenter options in it.
66    pub session: Arc<ErasedSession>,
67    /// Developer's [Application Identifier](https://core.telegram.org/myapp).
68    ///
69    /// The [`SenderPoolRunner`] will make use of this value when it needs
70    /// to invoke [`tl::functions::InitConnection`] after creating a new connection.
71    pub api_id: i32,
72}
73
74/// Cheaply cloneable handle to interact with its [`SenderPoolRunner`].
75#[derive(Clone)]
76pub struct SenderPoolHandle(mpsc::UnboundedSender<Request>);
77
78/// Builder to configure the runner to drive I/O and linked handles.
79pub struct SenderPool {
80    /// The single mutable instance responsible for driving I/O.
81    ///
82    /// Connections are created on-demand, so any errors while the pool
83    /// is running can only be retrieved with one of the [`SenderPool::handle`]s.
84    pub runner: SenderPoolRunner,
85    /// Starting fat handle attached to the [`SenderPool::runner`].
86    ///
87    /// Handles are the only way to interact with the runner once it's running.
88    pub handle: SenderPoolFatHandle,
89    /// The single mutable channel through which updates received
90    /// from the network by the [`SenderPool::runner`] are delivered.
91    ///
92    /// Update handling must be processed in a sequential manner,
93    /// so this is a separate instance with no way to clone it.
94    pub updates: mpsc::UnboundedReceiver<UpdatesLike>,
95}
96
97/// Manages and runs a pool of zero or more [`Sender`]s.
98///
99/// Use [`SenderPool::new`] to create an instance of this type and associated channels.
100pub struct SenderPoolRunner {
101    session: Arc<ErasedSession>,
102    api_id: i32,
103    connection_params: ConnectionParams,
104    request_rx: mpsc::UnboundedReceiver<Request>,
105    updates_tx: mpsc::UnboundedSender<UpdatesLike>,
106    connections: Vec<ConnectionInfo>,
107    connection_pool: JoinSet<Result<(), ReadError>>,
108}
109
110impl Deref for SenderPoolFatHandle {
111    type Target = SenderPoolHandle;
112
113    fn deref(&self) -> &Self::Target {
114        &self.thin
115    }
116}
117
118impl SenderPoolHandle {
119    /// Communicate with the running [`SenderPoolRunner`] instance
120    /// to invoke the serialized request body in the specified datacenter.
121    pub async fn invoke_in_dc(
122        &self,
123        dc_id: i32,
124        body: Vec<u8>,
125    ) -> Result<InvokeResponse, InvocationError> {
126        let (tx, rx) = oneshot::channel();
127        self.0
128            .send(Request::Invoke { dc_id, body, tx })
129            .map_err(|_| InvocationError::Dropped)?;
130        rx.await.map_err(|_| InvocationError::Dropped)?
131    }
132
133    /// Communicate with the running [`SenderPoolRunner`] instance
134    /// to drop any active connections to the given datacenter.
135    ///
136    /// Has no effect if there was no connection to the datacenter.
137    ///
138    /// This is useful after datacenter migrations during sign in,
139    /// when the old connection is known to not be needed anymore.
140    pub fn disconnect_from_dc(&self, dc_id: i32) -> bool {
141        self.0.send(Request::Disconnect { dc_id }).is_ok()
142    }
143
144    /// Communicate with the running [`SenderPoolRunner`] instance
145    /// to drop all active connections and gracefully stop running.
146    pub fn quit(&self) -> bool {
147        self.0.send(Request::Quit).is_ok()
148    }
149}
150
151impl SenderPool {
152    /// Creates a new sender pool instance with default configuration,
153    /// attached to the given session and using the provided
154    /// [Application Identifier](https://core.telegram.org/myapp)
155    /// belonging to the developer.
156    ///
157    /// Session instance **should not** be reused by multiple pools at the same time.
158    /// The session instance will only be used to query datacenter options and persist
159    /// any permanent Authorization Keys generated for previously-unconncected datacenters.
160    pub fn new<S>(session: Arc<S>, api_id: i32) -> Self
161    where
162        S: Session + Sized,
163        S::Error: std::error::Error + Send + Sync + 'static,
164    {
165        Self::with_configuration(session, api_id, Default::default())
166    }
167
168    /// Creates a new sender pool with non-[`ConnectionParams::default`] configuration.
169    pub fn with_configuration<S>(
170        session: Arc<S>,
171        api_id: i32,
172        connection_params: ConnectionParams,
173    ) -> Self
174    where
175        S: Session + Sized,
176        S::Error: std::error::Error + Send + Sync + 'static,
177    {
178        let session: Arc<ErasedSession> = Arc::new(Eraser(session));
179        let (request_tx, request_rx) = mpsc::unbounded_channel();
180        let (updates_tx, updates_rx) = mpsc::unbounded_channel();
181
182        Self {
183            runner: SenderPoolRunner {
184                session: Arc::clone(&session),
185                api_id,
186                connection_params,
187                request_rx,
188                updates_tx,
189                connections: Vec::new(),
190                connection_pool: JoinSet::new(),
191            },
192            handle: SenderPoolFatHandle {
193                thin: SenderPoolHandle(request_tx),
194                session,
195                api_id,
196            },
197            updates: updates_rx,
198        }
199    }
200}
201
202impl SenderPoolRunner {
203    /// Run the sender pool until [`SenderPoolHandle::quit`] is called or the returned future is dropped.
204    ///
205    /// Connections will be initiated on-demand whenever the first request to a datacenter is made.
206    pub async fn run(mut self) {
207        loop {
208            tokio::select! {
209                biased;
210                completion = self.connection_pool.join_next(), if !self.connection_pool.is_empty() => {
211                    if let Err(err) = completion.unwrap() {
212                        if let Ok(reason) = err.try_into_panic() {
213                            panic::resume_unwind(reason);
214                        }
215                    }
216                    self.connections
217                        .retain(|connection| !connection.abort_handle.is_finished());
218                }
219                request = self.request_rx.recv() => {
220                    let flow = if let Some(request) = request {
221                        self.process_request(request).await
222                    } else {
223                        ControlFlow::Break(())
224                    };
225                    match flow {
226                        ControlFlow::Continue(_) => continue,
227                        ControlFlow::Break(_) => break,
228                    }
229                }
230            }
231        }
232
233        self.connections.clear(); // drop all channels to cause the `run_sender` loops to stop
234        self.connection_pool.join_all().await;
235    }
236
237    async fn process_request(&mut self, request: Request) -> ControlFlow<()> {
238        match request {
239            Request::Invoke { dc_id, body, tx } => {
240                let connection = match self
241                    .connections
242                    .iter()
243                    .find(|connection| connection.dc_id == dc_id)
244                {
245                    Some(connection) => connection,
246                    None => match self.create_connection(dc_id).await {
247                        Ok(x) => x,
248                        Err(e) => {
249                            let _ = tx.send(Err(e));
250                            return ControlFlow::Continue(());
251                        }
252                    },
253                };
254                let _ = connection.rpc_tx.send(Rpc { body, tx });
255                ControlFlow::Continue(())
256            }
257            Request::Disconnect { dc_id } => {
258                self.connections.retain(|connection| {
259                    if connection.dc_id == dc_id {
260                        connection.abort_handle.abort();
261                        false
262                    } else {
263                        true
264                    }
265                });
266                ControlFlow::Continue(())
267            }
268            Request::Quit => ControlFlow::Break(()),
269        }
270    }
271
272    async fn create_connection(&mut self, dc_id: i32) -> Result<&ConnectionInfo, InvocationError> {
273        let mut dc_option = match self.session.dc_option(dc_id)? {
274            Some(x) => x,
275            None => return Err(InvocationError::InvalidDc),
276        };
277
278        let sender = self.connect_sender(&dc_option).await?;
279
280        dc_option.auth_key = Some(sender.auth_key());
281        self.session.set_dc_option(&dc_option).await?;
282
283        let (rpc_tx, rpc_rx) = mpsc::unbounded_channel();
284        let abort_handle = self.connection_pool.spawn(run_sender(
285            sender,
286            rpc_rx,
287            self.updates_tx.clone(),
288            dc_option.id == self.session.home_dc_id()?,
289        ));
290        self.connections.push(ConnectionInfo {
291            dc_id,
292            rpc_tx,
293            abort_handle,
294        });
295        Ok(self.connections.last().unwrap())
296    }
297
298    async fn connect_sender(
299        &mut self,
300        dc_option: &DcOption,
301    ) -> Result<Sender<transport::Full, mtp::Encrypted>, InvocationError> {
302        let transport = transport::Full::new;
303
304        let address = if self.connection_params.use_ipv6 {
305            dc_option.ipv6.into()
306        } else {
307            dc_option.ipv4.into()
308        };
309
310        #[cfg(feature = "proxy")]
311        let addr = || {
312            if let Some(proxy) = self.connection_params.proxy_url.clone() {
313                ServerAddr::Proxied { address, proxy }
314            } else {
315                ServerAddr::Tcp { address }
316            }
317        };
318        #[cfg(not(feature = "proxy"))]
319        let addr = || ServerAddr::Tcp { address };
320
321        let init_connection = tl::functions::InvokeWithLayer {
322            layer: tl::LAYER,
323            query: tl::functions::InitConnection {
324                api_id: self.api_id,
325                device_model: self.connection_params.device_model.clone(),
326                system_version: self.connection_params.system_version.clone(),
327                app_version: self.connection_params.app_version.clone(),
328                system_lang_code: self.connection_params.system_lang_code.clone(),
329                lang_pack: "".into(),
330                lang_code: self.connection_params.lang_code.clone(),
331                proxy: None,
332                params: None,
333                query: tl::functions::help::GetConfig {},
334            },
335        };
336
337        let mut sender = if let Some(auth_key) = dc_option.auth_key {
338            connect_with_auth(transport(), addr(), auth_key)
339                .await
340                .map_err(InvocationError::Io)?
341        } else {
342            connect(transport(), addr()).await?
343        };
344
345        let enums::Config::Config(remote_config) = match sender.invoke(&init_connection).await {
346            Ok(config) => config,
347            Err(InvocationError::Transport(transport::Error::BadStatus { status: 404 })) => {
348                sender = connect(transport(), addr()).await?;
349                sender.invoke(&init_connection).await?
350            }
351            Err(e) => return Err(e),
352        };
353
354        self.update_config(remote_config).await?;
355
356        Ok(sender)
357    }
358
359    async fn update_config(&mut self, config: tl::types::Config) -> Result<(), InvocationError> {
360        for option in config
361            .dc_options
362            .iter()
363            .map(|tl::enums::DcOption::Option(option)| option)
364            .filter(|option| !option.media_only && !option.tcpo_only && option.r#static)
365        {
366            let mut dc_option = self
367                .session
368                .dc_option(option.id)?
369                .unwrap_or_else(|| DcOption {
370                    id: option.id,
371                    ipv4: SocketAddrV4::new(Ipv4Addr::from_bits(0), 0),
372                    ipv6: SocketAddrV6::new(Ipv6Addr::from_bits(0), 0, 0, 0),
373                    auth_key: None,
374                });
375            if option.ipv6 {
376                dc_option.ipv6 = SocketAddrV6::new(
377                    option
378                        .ip_address
379                        .parse()
380                        .expect("Telegram to return a valid IPv6 address"),
381                    option.port as _,
382                    0,
383                    0,
384                );
385            } else {
386                dc_option.ipv4 = SocketAddrV4::new(
387                    option
388                        .ip_address
389                        .parse()
390                        .expect("Telegram to return a valid IPv4 address"),
391                    option.port as _,
392                );
393                if dc_option.ipv6.ip().to_bits() == 0 {
394                    dc_option.ipv6 = SocketAddrV6::new(
395                        dc_option.ipv4.ip().to_ipv6_mapped(),
396                        dc_option.ipv4.port(),
397                        0,
398                        0,
399                    )
400                }
401            }
402        }
403        Ok(())
404    }
405}
406
407async fn run_sender(
408    mut sender: Sender<Transport, grammers_mtproto::mtp::Encrypted>,
409    mut rpc_rx: mpsc::UnboundedReceiver<Rpc>,
410    updates: mpsc::UnboundedSender<UpdatesLike>,
411    home_sender: bool,
412) -> Result<(), ReadError> {
413    loop {
414        tokio::select! {
415            step = sender.step() => match step {
416                Ok(all_new_updates) => all_new_updates.into_iter().for_each(|new_updates| {
417                    let _ = updates.send(new_updates);
418                }),
419                Err(err) => {
420                    if home_sender {
421                        let _ = updates.send(UpdatesLike::ConnectionClosed);
422                    }
423                    break Err(err)
424                },
425            },
426            rpc = rpc_rx.recv() => match rpc {
427                Some(rpc) => sender.enqueue_body(rpc.body, rpc.tx),
428                None => break Ok(()),
429            },
430        }
431    }
432}
433
434impl fmt::Debug for Request {
435    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436        match self {
437            Self::Invoke { dc_id, body, tx } => f
438                .debug_struct("Invoke")
439                .field("dc_id", dc_id)
440                .field(
441                    "request",
442                    &body[..4]
443                        .try_into()
444                        .map(|constructor_id| tl::name_for_id(u32::from_le_bytes(constructor_id)))
445                        .unwrap_or("?"),
446                )
447                .field("tx", tx)
448                .finish(),
449            Self::Disconnect { dc_id } => {
450                f.debug_struct("Disconnect").field("dc_id", dc_id).finish()
451            }
452            Self::Quit => write!(f, "Quit"),
453        }
454    }
455}
456
457struct Eraser<S: Session>(Arc<S>);
458
459impl<S> Session for Eraser<S>
460where
461    S: Session,
462    S::Error: std::error::Error + Send + Sync,
463{
464    type Error = Box<dyn std::error::Error + Send + Sync>;
465
466    fn home_dc_id(&self) -> Result<i32, Self::Error> {
467        Arc::clone(&self.0).home_dc_id().map_err(|e| e.into())
468    }
469
470    fn set_home_dc_id(&self, dc_id: i32) -> BoxFuture<'_, Result<(), Self::Error>> {
471        Box::pin(async move {
472            Arc::clone(&self.0)
473                .set_home_dc_id(dc_id)
474                .await
475                .map_err(|e| e.into())
476        })
477    }
478
479    fn dc_option(&self, dc_id: i32) -> Result<Option<DcOption>, Self::Error> {
480        Arc::clone(&self.0).dc_option(dc_id).map_err(|e| e.into())
481    }
482
483    fn set_dc_option(&self, dc_option: &DcOption) -> BoxFuture<'_, Result<(), Self::Error>> {
484        let dc_option = dc_option.clone();
485        Box::pin(async move {
486            Arc::clone(&self.0)
487                .set_dc_option(&dc_option)
488                .await
489                .map_err(|e| e.into())
490        })
491    }
492
493    fn peer(&self, peer: PeerId) -> BoxFuture<'_, Result<Option<PeerInfo>, Self::Error>> {
494        Box::pin(async move { Arc::clone(&self.0).peer(peer).await.map_err(|e| e.into()) })
495    }
496
497    fn peer_ref(&self, peer: PeerId) -> BoxFuture<'_, Result<Option<PeerRef>, Self::Error>> {
498        Box::pin(async move {
499            Arc::clone(&self.0)
500                .peer_ref(peer)
501                .await
502                .map_err(|e| e.into())
503        })
504    }
505
506    fn cache_peer(&self, peer: &PeerInfo) -> BoxFuture<'_, Result<(), Self::Error>> {
507        let peer = peer.clone();
508        Box::pin(async move {
509            Arc::clone(&self.0)
510                .cache_peer(&peer)
511                .await
512                .map_err(|e| e.into())
513        })
514    }
515
516    fn updates_state(&self) -> BoxFuture<'_, Result<UpdatesState, Self::Error>> {
517        Box::pin(async {
518            Arc::clone(&self.0)
519                .updates_state()
520                .await
521                .map_err(|e| e.into())
522        })
523    }
524
525    fn set_update_state(&self, update: UpdateState) -> BoxFuture<'_, Result<(), Self::Error>> {
526        Box::pin(async {
527            Arc::clone(&self.0)
528                .set_update_state(update)
529                .await
530                .map_err(|e| e.into())
531        })
532    }
533}