Skip to main content

grammers_mtsender/
sender_pool.rs

1// Copyright 2020 - developers of the `grammers` project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
10use std::ops::{ControlFlow, Deref};
11use std::sync::Arc;
12use std::{fmt, panic};
13
14use grammers_mtproto::{mtp, transport};
15use grammers_session::Session;
16use grammers_session::types::DcOption;
17use grammers_session::updates::UpdatesLike;
18use grammers_tl_types::{self as tl, enums};
19use tokio::task::AbortHandle;
20use tokio::{
21    sync::{mpsc, oneshot},
22    task::JoinSet,
23};
24
25use crate::configuration::ConnectionParams;
26use crate::errors::ReadError;
27use crate::{InvocationError, Sender, ServerAddr, connect, connect_with_auth};
28
29pub(crate) type Transport = transport::Full;
30
31type InvokeResponse = Vec<u8>;
32
33enum Request {
34    Invoke {
35        dc_id: i32,
36        body: Vec<u8>,
37        tx: oneshot::Sender<Result<InvokeResponse, InvocationError>>,
38    },
39    Disconnect {
40        dc_id: i32,
41    },
42    Quit,
43}
44
45struct Rpc {
46    body: Vec<u8>,
47    tx: oneshot::Sender<Result<InvokeResponse, InvocationError>>,
48}
49
50struct ConnectionInfo {
51    dc_id: i32,
52    rpc_tx: mpsc::UnboundedSender<Rpc>,
53    abort_handle: AbortHandle,
54}
55
56/// A fat [`SenderPoolHandle`] with additional metadata from its attached [`SenderPoolRunner`].
57#[derive(Clone)]
58pub struct SenderPoolFatHandle {
59    /// The inner thin handle that self can be derefed into.
60    ///
61    /// The rest of fields can be dropped if they are no longer needed.
62    pub thin: SenderPoolHandle,
63    /// The session in use by the attached [`SenderPoolRunner`].
64    ///
65    /// The runner will read and persist datacenter options in it.
66    pub session: Arc<dyn Session>,
67    /// Developer's [Application Identifier](https://core.telegram.org/myapp).
68    ///
69    /// The [`SenderPoolRunner`] will make use of this value when it needs
70    /// to invoke [`tl::functions::InitConnection`] after creating a new connection.
71    pub api_id: i32,
72}
73
74/// Cheaply cloneable handle to interact with its [`SenderPoolRunner`].
75#[derive(Clone)]
76pub struct SenderPoolHandle(mpsc::UnboundedSender<Request>);
77
78/// Builder to configure the runner to drive I/O and linked handles.
79pub struct SenderPool {
80    /// The single mutable instance responsible for driving I/O.
81    ///
82    /// Connections are created on-demand, so any errors while the pool
83    /// is running can only be retrieved with one of the [`SenderPool::handle`]s.
84    pub runner: SenderPoolRunner,
85    /// Starting fat handle attached to the [`SenderPool::runner`].
86    ///
87    /// Handles are the only way to interact with the runner once it's running.
88    pub handle: SenderPoolFatHandle,
89    /// The single mutable channel through which updates received
90    /// from the network by the [`SenderPool::runner`] are delivered.
91    ///
92    /// Update handling must be processed in a sequential manner,
93    /// so this is a separate instance with no way to clone it.
94    pub updates: mpsc::UnboundedReceiver<UpdatesLike>,
95}
96
97/// Manages and runs a pool of zero or more [`Sender`]s.
98///
99/// Use [`SenderPool::new`] to create an instance of this type and associated channels.
100pub struct SenderPoolRunner {
101    session: Arc<dyn Session>,
102    api_id: i32,
103    connection_params: ConnectionParams,
104    request_rx: mpsc::UnboundedReceiver<Request>,
105    updates_tx: mpsc::UnboundedSender<UpdatesLike>,
106    connections: Vec<ConnectionInfo>,
107    connection_pool: JoinSet<Result<(), ReadError>>,
108}
109
110impl Deref for SenderPoolFatHandle {
111    type Target = SenderPoolHandle;
112
113    fn deref(&self) -> &Self::Target {
114        &self.thin
115    }
116}
117
118impl SenderPoolHandle {
119    /// Communicate with the running [`SenderPoolRunner`] instance
120    /// to invoke the serialized request body in the specified datacenter.
121    pub async fn invoke_in_dc(
122        &self,
123        dc_id: i32,
124        body: Vec<u8>,
125    ) -> Result<InvokeResponse, InvocationError> {
126        let (tx, rx) = oneshot::channel();
127        self.0
128            .send(Request::Invoke { dc_id, body, tx })
129            .map_err(|_| InvocationError::Dropped)?;
130        rx.await.map_err(|_| InvocationError::Dropped)?
131    }
132
133    /// Communicate with the running [`SenderPoolRunner`] instance
134    /// to drop any active connections to the given datacenter.
135    ///
136    /// Has no effect if there was no connection to the datacenter.
137    ///
138    /// This is useful after datacenter migrations during sign in,
139    /// when the old connection is known to not be needed anymore.
140    pub fn disconnect_from_dc(&self, dc_id: i32) -> bool {
141        self.0.send(Request::Disconnect { dc_id }).is_ok()
142    }
143
144    /// Communicate with the running [`SenderPoolRunner`] instance
145    /// to drop all active connections and gracefully stop running.
146    pub fn quit(&self) -> bool {
147        self.0.send(Request::Quit).is_ok()
148    }
149}
150
151impl SenderPool {
152    /// Creates a new sender pool instance with default configuration,
153    /// attached to the given session and using the provided
154    /// [Application Identifier](https://core.telegram.org/myapp)
155    /// belonging to the developer.
156    ///
157    /// Session instance **should not** be reused by multiple pools at the same time.
158    /// The session instance will only be used to query datacenter options and persist
159    /// any permanent Authorization Keys generated for previously-unconncected datacenters.
160    pub fn new<S: Session + 'static>(session: Arc<S>, api_id: i32) -> Self {
161        Self::with_configuration(session, api_id, Default::default())
162    }
163
164    /// Creates a new sender pool with non-[`ConnectionParams::default`] configuration.
165    pub fn with_configuration<S: Session + 'static>(
166        session: Arc<S>,
167        api_id: i32,
168        connection_params: ConnectionParams,
169    ) -> Self {
170        let (request_tx, request_rx) = mpsc::unbounded_channel();
171        let (updates_tx, updates_rx) = mpsc::unbounded_channel();
172        let session = session as Arc<dyn Session>;
173
174        Self {
175            runner: SenderPoolRunner {
176                session: Arc::clone(&session),
177                api_id,
178                connection_params,
179                request_rx,
180                updates_tx,
181                connections: Vec::new(),
182                connection_pool: JoinSet::new(),
183            },
184            handle: SenderPoolFatHandle {
185                thin: SenderPoolHandle(request_tx),
186                session,
187                api_id,
188            },
189            updates: updates_rx,
190        }
191    }
192}
193
194impl SenderPoolRunner {
195    /// Run the sender pool until [`SenderPoolHandle::quit`] is called or the returned future is dropped.
196    ///
197    /// Connections will be initiated on-demand whenever the first request to a datacenter is made.
198    pub async fn run(mut self) {
199        loop {
200            tokio::select! {
201                biased;
202                completion = self.connection_pool.join_next(), if !self.connection_pool.is_empty() => {
203                    if let Err(err) = completion.unwrap() {
204                        if let Ok(reason) = err.try_into_panic() {
205                            panic::resume_unwind(reason);
206                        }
207                    }
208                    self.connections
209                        .retain(|connection| !connection.abort_handle.is_finished());
210                }
211                request = self.request_rx.recv() => {
212                    let flow = if let Some(request) = request {
213                        self.process_request(request).await
214                    } else {
215                        ControlFlow::Break(())
216                    };
217                    match flow {
218                        ControlFlow::Continue(_) => continue,
219                        ControlFlow::Break(_) => break,
220                    }
221                }
222            }
223        }
224
225        self.connections.clear(); // drop all channels to cause the `run_sender` loops to stop
226        self.connection_pool.join_all().await;
227    }
228
229    async fn process_request(&mut self, request: Request) -> ControlFlow<()> {
230        match request {
231            Request::Invoke { dc_id, body, tx } => {
232                let Some(mut dc_option) = self.session.dc_option(dc_id) else {
233                    let _ = tx.send(Err(InvocationError::InvalidDc));
234                    return ControlFlow::Continue(());
235                };
236
237                let connection = match self
238                    .connections
239                    .iter()
240                    .find(|connection| connection.dc_id == dc_id)
241                {
242                    Some(connection) => connection,
243                    None => {
244                        let sender = match self.connect_sender(&dc_option).await {
245                            Ok(t) => t,
246                            Err(e) => {
247                                let _ = tx.send(Err(e));
248                                return ControlFlow::Continue(());
249                            }
250                        };
251
252                        dc_option.auth_key = Some(sender.auth_key());
253                        self.session.set_dc_option(&dc_option).await;
254
255                        let (rpc_tx, rpc_rx) = mpsc::unbounded_channel();
256                        let abort_handle = self.connection_pool.spawn(run_sender(
257                            sender,
258                            rpc_rx,
259                            self.updates_tx.clone(),
260                            dc_option.id == self.session.home_dc_id(),
261                        ));
262                        self.connections.push(ConnectionInfo {
263                            dc_id,
264                            rpc_tx,
265                            abort_handle,
266                        });
267                        self.connections.last().unwrap()
268                    }
269                };
270                let _ = connection.rpc_tx.send(Rpc { body, tx });
271                ControlFlow::Continue(())
272            }
273            Request::Disconnect { dc_id } => {
274                self.connections.retain(|connection| {
275                    if connection.dc_id == dc_id {
276                        connection.abort_handle.abort();
277                        false
278                    } else {
279                        true
280                    }
281                });
282                ControlFlow::Continue(())
283            }
284            Request::Quit => ControlFlow::Break(()),
285        }
286    }
287
288    async fn connect_sender(
289        &mut self,
290        dc_option: &DcOption,
291    ) -> Result<Sender<transport::Full, mtp::Encrypted>, InvocationError> {
292        let transport = transport::Full::new;
293
294        let address = if self.connection_params.use_ipv6 {
295            dc_option.ipv6.into()
296        } else {
297            dc_option.ipv4.into()
298        };
299
300        #[cfg(feature = "proxy")]
301        let addr = || {
302            if let Some(proxy) = self.connection_params.proxy_url.clone() {
303                ServerAddr::Proxied { address, proxy }
304            } else {
305                ServerAddr::Tcp { address }
306            }
307        };
308        #[cfg(not(feature = "proxy"))]
309        let addr = || ServerAddr::Tcp { address };
310
311        let init_connection = tl::functions::InvokeWithLayer {
312            layer: tl::LAYER,
313            query: tl::functions::InitConnection {
314                api_id: self.api_id,
315                device_model: self.connection_params.device_model.clone(),
316                system_version: self.connection_params.system_version.clone(),
317                app_version: self.connection_params.app_version.clone(),
318                system_lang_code: self.connection_params.system_lang_code.clone(),
319                lang_pack: "".into(),
320                lang_code: self.connection_params.lang_code.clone(),
321                proxy: None,
322                params: None,
323                query: tl::functions::help::GetConfig {},
324            },
325        };
326
327        let mut sender = if let Some(auth_key) = dc_option.auth_key {
328            connect_with_auth(transport(), addr(), auth_key).await?
329        } else {
330            connect(transport(), addr()).await?
331        };
332
333        let enums::Config::Config(remote_config) = match sender.invoke(&init_connection).await {
334            Ok(config) => config,
335            Err(InvocationError::Transport(transport::Error::BadStatus { status: 404 })) => {
336                sender = connect(transport(), addr()).await?;
337                sender.invoke(&init_connection).await?
338            }
339            Err(e) => return Err(dbg!(e).into()),
340        };
341
342        self.update_config(remote_config).await;
343
344        Ok(sender)
345    }
346
347    async fn update_config(&mut self, config: tl::types::Config) {
348        for option in config
349            .dc_options
350            .iter()
351            .map(|tl::enums::DcOption::Option(option)| option)
352            .filter(|option| !option.media_only && !option.tcpo_only && option.r#static)
353        {
354            let mut dc_option = self
355                .session
356                .dc_option(option.id)
357                .unwrap_or_else(|| DcOption {
358                    id: option.id,
359                    ipv4: SocketAddrV4::new(Ipv4Addr::from_bits(0), 0),
360                    ipv6: SocketAddrV6::new(Ipv6Addr::from_bits(0), 0, 0, 0),
361                    auth_key: None,
362                });
363            if option.ipv6 {
364                dc_option.ipv6 = SocketAddrV6::new(
365                    option
366                        .ip_address
367                        .parse()
368                        .expect("Telegram to return a valid IPv6 address"),
369                    option.port as _,
370                    0,
371                    0,
372                );
373            } else {
374                dc_option.ipv4 = SocketAddrV4::new(
375                    option
376                        .ip_address
377                        .parse()
378                        .expect("Telegram to return a valid IPv4 address"),
379                    option.port as _,
380                );
381                if dc_option.ipv6.ip().to_bits() == 0 {
382                    dc_option.ipv6 = SocketAddrV6::new(
383                        dc_option.ipv4.ip().to_ipv6_mapped(),
384                        dc_option.ipv4.port(),
385                        0,
386                        0,
387                    )
388                }
389            }
390        }
391    }
392}
393
394async fn run_sender(
395    mut sender: Sender<Transport, grammers_mtproto::mtp::Encrypted>,
396    mut rpc_rx: mpsc::UnboundedReceiver<Rpc>,
397    updates: mpsc::UnboundedSender<UpdatesLike>,
398    home_sender: bool,
399) -> Result<(), ReadError> {
400    loop {
401        tokio::select! {
402            step = sender.step() => match step {
403                Ok(all_new_updates) => all_new_updates.into_iter().for_each(|new_updates| {
404                    let _ = updates.send(new_updates);
405                }),
406                Err(err) => {
407                    if home_sender {
408                        let _ = updates.send(UpdatesLike::ConnectionClosed);
409                    }
410                    break Err(err)
411                },
412            },
413            rpc = rpc_rx.recv() => match rpc {
414                Some(rpc) => sender.enqueue_body(rpc.body, rpc.tx),
415                None => break Ok(()),
416            },
417        }
418    }
419}
420
421impl fmt::Debug for Request {
422    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
423        match self {
424            Self::Invoke { dc_id, body, tx } => f
425                .debug_struct("Invoke")
426                .field("dc_id", dc_id)
427                .field(
428                    "request",
429                    &body[..4]
430                        .try_into()
431                        .map(|constructor_id| tl::name_for_id(u32::from_le_bytes(constructor_id)))
432                        .unwrap_or("?"),
433                )
434                .field("tx", tx)
435                .finish(),
436            Self::Disconnect { dc_id } => {
437                f.debug_struct("Disconnect").field("dc_id", dc_id).finish()
438            }
439            Self::Quit => write!(f, "Quit"),
440        }
441    }
442}