Skip to main content

ferogram_mtsender/
pool.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2//
3// ferogram: async Telegram MTProto client in Rust
4// https://github.com/ankit-chaubey/ferogram
5//
6// Licensed under either the MIT License or the Apache License 2.0.
7// See the LICENSE-MIT or LICENSE-APACHE file in this repository:
8// https://github.com/ankit-chaubey/ferogram
9//
10// Feel free to use, modify, and share this code.
11// Please keep this notice when redistributing.
12
13use crate::errors::InvocationError;
14use crate::sender::DcConnection;
15use crate::sender_task::{FrameEvent, RpcEnqueue, spawn_sender_task};
16use ferogram_connect::util::maybe_gz_pack;
17use ferogram_connect::{Socks5Config, TransportKind};
18use ferogram_session::DcEntry;
19use ferogram_tl_types::{RemoteCall, Serializable};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
23use tokio::sync::{mpsc, oneshot};
24
25// Max simultaneous connections per DC.
26const MAX_CONNS_PER_DC: usize = 3;
27
28/// One slot in the per-DC connection pool.
29///
30/// Each slot is backed by a background sender task (see
31/// [`crate::sender_task::spawn_sender_task`]), not a locked `DcConnection`.
32/// Enqueueing a request just posts it to the task's mpsc channel and waits
33/// on a oneshot for the result: no lock is held across the network round
34/// trip, so any number of callers can have requests in flight on the same
35/// slot at once. The task itself batches whatever is pending into as few
36/// frames as possible and matches replies back to callers by msg_id,
37/// regardless of the order responses arrive in.
38///
39/// `in_flight` still lets the pool pick the least-busy slot without needing
40/// to touch the connection itself.
41pub struct ConnSlot {
42    rpc_tx: mpsc::Sender<RpcEnqueue>,
43    pub in_flight: AtomicUsize,
44    /// Set to `false` by the drain task below once the connection's sender
45    /// task reports an error. Callers check this after a failed call to
46    /// decide whether to evict the slot and retry on a fresh one, instead of
47    /// matching on the specific `InvocationError` variant (the sender task
48    /// always reports connection failures as `Deserialize`, since it has no
49    /// way to know whether a given caller still cares about the original
50    /// `Io`/etc. error kind once `fail_all` has fanned it out to everyone
51    /// waiting on this connection).
52    alive: Arc<AtomicBool>,
53    /// Snapshot of the auth key / salt / time offset taken when the slot was
54    /// created. Used by `collect_keys` to persist session info. The auth key
55    /// never changes for a slot's lifetime; salt and time offset can drift a
56    /// little as the connection runs (FutureSalts rotation), but a stale
57    /// value here only costs one bad_server_salt round trip the next time
58    /// this DC is reconnected, since the sender task self-corrects from
59    /// server-supplied corrections either way.
60    auth_key: [u8; 256],
61    first_salt: i64,
62    time_offset: i32,
63}
64
65/// Pool of per-DC authenticated connections.
66/// Each DC holds up to MAX_CONNS_PER_DC slots. The pool lock is dropped
67/// before any network I/O so concurrent callers don't serialize on it.
68pub struct DcPool {
69    /// Per-DC connection slots; inner Vec holds slot Arcs.
70    pub conns: HashMap<i32, Vec<Arc<ConnSlot>>>,
71    addrs: HashMap<i32, String>,
72    #[allow(dead_code)]
73    home_dc_id: i32,
74    /// Proxy config forwarded to auto-reconnect.
75    socks5: Option<Socks5Config>,
76    /// Transport kind reused for secondary DC connections.
77    transport: TransportKind,
78    /// DCs that have already received `invokeWithLayer(initConnection(...))`.
79    init_done: std::collections::HashSet<i32>,
80}
81
82impl DcPool {
83    /// Build an empty pool for `home_dc_id`, seeded with addresses for every
84    /// DC in `dc_entries`. No connections are opened yet; slots get created
85    /// lazily on first use of each DC.
86    pub fn new(
87        home_dc_id: i32,
88        dc_entries: &[DcEntry],
89        socks5: Option<Socks5Config>,
90        transport: TransportKind,
91    ) -> Self {
92        let addrs = dc_entries
93            .iter()
94            .map(|e| (e.dc_id, e.addr.clone()))
95            .collect();
96        Self {
97            conns: HashMap::new(),
98            addrs,
99            home_dc_id,
100            socks5,
101            transport,
102            init_done: std::collections::HashSet::new(),
103        }
104    }
105
106    /// Returns true if at least one connection slot exists for `dc_id`.
107    pub fn has_connection(&self, dc_id: i32) -> bool {
108        self.conns.get(&dc_id).is_some_and(|v| !v.is_empty())
109    }
110
111    /// Graduate an already set-up `DcConnection` into a pipelined slot.
112    ///
113    /// The connection has already done its DH / PFS bind / initConnection as
114    /// a plain `DcConnection`. From here on its socket is owned by a single
115    /// background task; this function just spawns that task and wraps the
116    /// resulting handle in a `ConnSlot`.
117    fn spawn_slot(conn: DcConnection) -> Arc<ConnSlot> {
118        let auth_key = conn.auth_key_bytes();
119        let first_salt = conn.first_salt();
120        let time_offset = conn.time_offset();
121        let (stream, frame_kind, enc) = conn.into_parts();
122
123        let (handle, mut frame_rx) = spawn_sender_task(stream, enc, frame_kind, None);
124
125        // Pool slots don't support reconnect: on failure the pool just
126        // evicts the whole DC and a fresh slot is opened from scratch on the
127        // next call. Dropping reconnect_tx here means the sender task's
128        // error branch sees its reconnect channel closed and shuts itself
129        // down cleanly instead of waiting for a reconnect that will never
130        // come.
131        drop(handle.reconnect_tx);
132
133        let alive = Arc::new(AtomicBool::new(true));
134        let alive_for_drain = alive.clone();
135        tokio::spawn(async move {
136            while let Some(event) = frame_rx.recv().await {
137                if let FrameEvent::Error(e) = event {
138                    tracing::warn!("[ferogram::pool] worker connection dropped: {e}");
139                    alive_for_drain.store(false, Ordering::Release);
140                    break;
141                }
142                // FrameEvent::Update / Connected: pool connections don't
143                // dispatch updates, nothing to do.
144            }
145        });
146
147        Arc::new(ConnSlot {
148            rpc_tx: handle.rpc_tx,
149            in_flight: AtomicUsize::new(0),
150            alive,
151            auth_key,
152            first_salt,
153            time_offset,
154        })
155    }
156
157    /// Insert a pre-built, already initialized connection into the pool as a
158    /// new slot.
159    pub fn insert(&mut self, dc_id: i32, conn: DcConnection) {
160        let slot = Self::spawn_slot(conn);
161        self.conns.entry(dc_id).or_default().push(slot);
162        let total: usize = self.conns.values().map(|v| v.len()).sum();
163        metrics::gauge!("ferogram.connections_active").set(total as f64);
164    }
165
166    /// Returns the least-loaded slot for `dc_id`, creating one if needed.
167    /// Creates a new slot if all existing ones are busy and count < MAX_CONNS_PER_DC.
168    /// Drop the DcPool guard before locking the returned slot.
169    pub(crate) async fn get_or_create_slot(
170        &mut self,
171        dc_id: i32,
172        pfs: bool,
173        auth_key: Option<([u8; 256], i64, i32)>,
174    ) -> Result<Arc<ConnSlot>, InvocationError> {
175        let addr = self.addrs.get(&dc_id).cloned().ok_or_else(|| {
176            InvocationError::Deserialize(format!("dc_pool: no address for DC{dc_id}"))
177        })?;
178
179        // Ensure at least one slot exists.
180        if !self.conns.contains_key(&dc_id) || self.conns[&dc_id].is_empty() {
181            tracing::debug!("[ferogram::pool] opening first connection to DC{dc_id} at {addr}");
182            let conn = if let Some((key, salt, offset)) = auth_key {
183                DcConnection::connect_with_key(
184                    &addr,
185                    key,
186                    salt,
187                    offset,
188                    self.socks5.as_ref(),
189                    None,
190                    &self.transport,
191                    dc_id as i16,
192                    pfs,
193                )
194                .await?
195            } else {
196                DcConnection::connect_raw(
197                    &addr,
198                    self.socks5.as_ref(),
199                    &self.transport,
200                    dc_id as i16,
201                )
202                .await?
203            };
204            let slot = Self::spawn_slot(conn);
205            self.conns.entry(dc_id).or_default().push(slot);
206            self.init_done.remove(&dc_id);
207            let total: usize = self.conns.values().map(|v| v.len()).sum();
208            metrics::gauge!("ferogram.connections_active").set(total as f64);
209        }
210
211        let slots = self
212            .conns
213            .get(&dc_id)
214            .expect("dc_id must be registered before use");
215
216        // pick least-busy slot
217        let best = slots
218            .iter()
219            .min_by_key(|s| s.in_flight.load(Ordering::Relaxed))
220            .expect("slots vec is non-empty")
221            .clone();
222        let min_inflight = best.in_flight.load(Ordering::Relaxed);
223
224        // Spawn a new slot if: all are busy AND we have room for more.
225        //
226        // With pipelined slots this matters less than it used to (a single
227        // slot can now happily carry many in-flight requests at once), but
228        // it's still worth spreading load across a few real TCP connections
229        // for very heavy transfers.
230        if min_inflight > 0 && slots.len() < MAX_CONNS_PER_DC {
231            tracing::debug!(
232                "[ferogram::pool] DC{dc_id}: all {} slots busy (min_inflight={min_inflight}), opening extra connection",
233                slots.len()
234            );
235            let conn = if let Some((key, salt, offset)) = auth_key {
236                DcConnection::connect_with_key(
237                    &addr,
238                    key,
239                    salt,
240                    offset,
241                    self.socks5.as_ref(),
242                    None,
243                    &self.transport,
244                    dc_id as i16,
245                    pfs,
246                )
247                .await?
248            } else {
249                DcConnection::connect_raw(
250                    &addr,
251                    self.socks5.as_ref(),
252                    &self.transport,
253                    dc_id as i16,
254                )
255                .await?
256            };
257            let new_slot = Self::spawn_slot(conn);
258            let arc = new_slot.clone();
259            self.conns
260                .get_mut(&dc_id)
261                .expect("dc_id must be registered")
262                .push(new_slot);
263            let total: usize = self.conns.values().map(|v| v.len()).sum();
264            metrics::gauge!("ferogram.connections_active").set(total as f64);
265            return Ok(arc);
266        }
267
268        Ok(best)
269    }
270
271    /// Evict all slots for a DC (called on connection failure to force
272    /// reconnection on the next call).
273    pub fn evict(&mut self, dc_id: i32) {
274        self.conns.remove(&dc_id);
275        self.init_done.remove(&dc_id);
276        let total: usize = self.conns.values().map(|v| v.len()).sum();
277        metrics::gauge!("ferogram.connections_active").set(total as f64);
278        tracing::debug!("[ferogram::pool] evicted all connections for DC{dc_id}");
279    }
280
281    /// Enqueue `body` on `slot` and await the result.
282    ///
283    /// This is the only place that touches `rpc_tx`/the oneshot: no mutex,
284    /// no blocking for the duration of the round trip. Multiple callers can
285    /// call this against the same slot concurrently and their requests will
286    /// pipeline on the wire instead of queueing behind each other.
287    async fn send_via_slot(
288        slot: &Arc<ConnSlot>,
289        body: Vec<u8>,
290    ) -> Result<Vec<u8>, InvocationError> {
291        slot.in_flight.fetch_add(1, Ordering::Relaxed);
292        let (tx, rx) = oneshot::channel();
293        let send_result = slot.rpc_tx.send(RpcEnqueue { body, tx }).await;
294        let result = if send_result.is_err() {
295            slot.alive.store(false, Ordering::Release);
296            Err(InvocationError::Deserialize(
297                "worker sender task shut down".into(),
298            ))
299        } else {
300            match rx.await {
301                Ok(r) => r,
302                Err(_) => {
303                    slot.alive.store(false, Ordering::Release);
304                    Err(InvocationError::Deserialize(
305                        "worker rpc channel closed".into(),
306                    ))
307                }
308            }
309        };
310        slot.in_flight.fetch_sub(1, Ordering::Relaxed);
311        result
312    }
313
314    /// Invoke a raw RPC call on the given DC.
315    /// Pool lock is released before the network round-trip begins.
316    pub async fn invoke_on_dc<R: RemoteCall>(
317        &mut self,
318        dc_id: i32,
319        _dc_entries: &[DcEntry],
320        req: &R,
321    ) -> Result<Vec<u8>, InvocationError> {
322        let slot = self.get_or_create_slot(dc_id, false, None).await?;
323        let body = maybe_gz_pack(&req.to_bytes());
324        let result = Self::send_via_slot(&slot, body.clone()).await;
325
326        if let Err(ref e) = result {
327            let kind = match e {
328                InvocationError::Rpc(_) => "rpc",
329                InvocationError::Io(_) => "io",
330                _ => "other",
331            };
332            metrics::counter!("ferogram.rpc_errors_total", "kind" => kind).increment(1);
333        }
334
335        if let Err(InvocationError::Rpc(ref e)) = result
336            && e.code == -404
337        {
338            // Telegram dropped the auth key (e.g. AndroidTV killed the socket during sleep).
339            // Evict and redo a full DH exchange; the login session is still valid server-side.
340            tracing::warn!(
341                "[ferogram::pool] DC{dc_id} returned -404 (auth key gone); evicting and redoing DH"
342            );
343            self.evict(dc_id);
344            let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
345            return Self::send_via_slot(&retry_slot, body).await;
346        }
347
348        if result.is_err() && !slot.alive.load(Ordering::Acquire) {
349            tracing::warn!(
350                "[ferogram::pool] DC{dc_id} connection died mid-request; evicting and retrying on a fresh connection"
351            );
352            self.evict(dc_id);
353            let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
354            return Self::send_via_slot(&retry_slot, body).await;
355        }
356        result
357    }
358
359    /// Mark a DC as having completed initConnection.
360    pub fn mark_init_done(&mut self, dc_id: i32) {
361        self.init_done.insert(dc_id);
362    }
363
364    /// Returns true if this DC has already received initConnection this session.
365    pub fn is_init_done(&self, dc_id: i32) -> bool {
366        self.init_done.contains(&dc_id)
367    }
368
369    /// Like `invoke_on_dc` but accepts any `Serializable` type.
370    pub async fn invoke_on_dc_serializable<S: Serializable>(
371        &mut self,
372        dc_id: i32,
373        req: &S,
374    ) -> Result<Vec<u8>, InvocationError> {
375        let slot = self
376            .get_or_create_slot(dc_id, false, None)
377            .await
378            .map_err(|_| InvocationError::Deserialize(format!("no connection for DC{dc_id}")))?;
379        let body = maybe_gz_pack(&req.to_bytes());
380        let result = Self::send_via_slot(&slot, body.clone()).await;
381
382        if let Err(InvocationError::Rpc(ref e)) = result
383            && e.code == -404
384        {
385            tracing::warn!(
386                "[ferogram::pool] DC{dc_id} returned -404 (serializable path); evicting and redoing DH"
387            );
388            self.evict(dc_id);
389            let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
390            return Self::send_via_slot(&retry_slot, body).await;
391        }
392
393        if result.is_err() && !slot.alive.load(Ordering::Acquire) {
394            tracing::warn!(
395                "[ferogram::pool] DC{dc_id} connection died mid-request (serializable path); evicting and retrying"
396            );
397            self.evict(dc_id);
398            let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
399            return Self::send_via_slot(&retry_slot, body).await;
400        }
401        result
402    }
403
404    /// Update the address table (called after `initConnection`).
405    pub fn update_addrs(&mut self, entries: &[DcEntry]) {
406        for e in entries {
407            self.addrs.insert(e.dc_id, e.addr.clone());
408        }
409    }
410
411    /// Save the auth keys from pool connections back into the DC entry list.
412    /// Uses the first slot per DC (all slots share the same auth key).
413    pub fn collect_keys(&self, entries: &mut [DcEntry]) {
414        for e in entries.iter_mut() {
415            if let Some(slots) = self.conns.get(&e.dc_id)
416                && let Some(slot) = slots.first()
417            {
418                e.auth_key = Some(slot.auth_key);
419                e.first_salt = slot.first_salt;
420                e.time_offset = slot.time_offset;
421            }
422        }
423    }
424}
425
426/// Serialize a `msgs_ack#62d6b459 { msg_ids: Vector<long> }` TL body.
427///
428/// This is sent as a non-content-related encrypted frame (even seq_no)
429/// to acknowledge received server messages and prevent Telegram from
430/// closing the connection due to un-acked messages.
431pub(crate) fn build_msgs_ack_body(msg_ids: &[i64]) -> Vec<u8> {
432    let mut out = Vec::with_capacity(4 + 4 + 4 + msg_ids.len() * 8);
433    out.extend_from_slice(&0x62d6b459_u32.to_le_bytes()); // msgs_ack constructor
434    out.extend_from_slice(&0x1cb5c415_u32.to_le_bytes()); // Vector constructor
435    out.extend_from_slice(&(msg_ids.len() as u32).to_le_bytes());
436    for &id in msg_ids {
437        out.extend_from_slice(&id.to_le_bytes());
438    }
439    out
440}
441
442/// Serialize a `ping_delay_disconnect#f3427b8c { ping_id, disconnect_delay: 75 }` body.
443///
444/// Tells Telegram to close the connection after 75 seconds of silence.
445pub(crate) fn build_msgs_ack_ping_body(ping_id: i64) -> Vec<u8> {
446    // ping_delay_disconnect#f3427b8c ping_id:long disconnect_delay:int = Pong
447    let mut out = Vec::with_capacity(4 + 8 + 4);
448    out.extend_from_slice(&0xf3427b8c_u32.to_le_bytes()); // constructor
449    out.extend_from_slice(&ping_id.to_le_bytes());
450    out.extend_from_slice(&75_i32.to_le_bytes()); // disconnect_delay = 75 s
451    out
452}