Skip to main content

ferogram_mtsender/
pool.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2//
3// ferogram: async Telegram MTProto client in Rust
4// https://github.com/ankit-chaubey/ferogram
5//
6// Licensed under either the MIT License or the Apache License 2.0.
7// See the LICENSE-MIT or LICENSE-APACHE file in this repository:
8// https://github.com/ankit-chaubey/ferogram
9//
10// Feel free to use, modify, and share this code.
11// Please keep this notice when redistributing.
12
13use crate::errors::InvocationError;
14use crate::sender::DcConnection;
15use crate::sender_task::{FrameEvent, RpcEnqueue, spawn_sender_task};
16use ferogram_connect::util::maybe_gz_pack;
17use ferogram_connect::{Socks5Config, TransportKind};
18use ferogram_session::DcEntry;
19use ferogram_tl_types::{RemoteCall, Serializable};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
23use tokio::sync::{mpsc, oneshot};
24
25// Max simultaneous connections per DC.
26const MAX_CONNS_PER_DC: usize = 3;
27
28/// One slot in the per-DC connection pool.
29///
30/// Each slot is backed by a background sender task (see
31/// [`crate::sender_task::spawn_sender_task`]), not a locked `DcConnection`.
32/// Enqueueing a request just posts it to the task's mpsc channel and waits
33/// on a oneshot for the result: no lock is held across the network round
34/// trip, so any number of callers can have requests in flight on the same
35/// slot at once. The task itself batches whatever is pending into as few
36/// frames as possible and matches replies back to callers by msg_id,
37/// regardless of the order responses arrive in.
38///
39/// `in_flight` still lets the pool pick the least-busy slot without needing
40/// to touch the connection itself.
41pub struct ConnSlot {
42    rpc_tx: mpsc::Sender<RpcEnqueue>,
43    pub in_flight: AtomicUsize,
44    /// Set to `false` by the drain task below once the connection's sender
45    /// task reports an error. Callers check this after a failed call to
46    /// decide whether to evict the slot and retry on a fresh one, instead of
47    /// matching on the specific `InvocationError` variant (the sender task
48    /// always reports connection failures as `Deserialize`, since it has no
49    /// way to know whether a given caller still cares about the original
50    /// `Io`/etc. error kind once `fail_all` has fanned it out to everyone
51    /// waiting on this connection).
52    alive: Arc<AtomicBool>,
53    /// Snapshot of the auth key / salt / time offset taken when the slot was
54    /// created. Used by `collect_keys` to persist session info. The auth key
55    /// never changes for a slot's lifetime; salt and time offset can drift a
56    /// little as the connection runs (FutureSalts rotation), but a stale
57    /// value here only costs one bad_server_salt round trip the next time
58    /// this DC is reconnected, since the sender task self-corrects from
59    /// server-supplied corrections either way.
60    auth_key: [u8; 256],
61    first_salt: i64,
62    time_offset: i32,
63}
64
65/// Pool of per-DC authenticated connections.
66/// Each DC holds up to MAX_CONNS_PER_DC slots. The pool lock is dropped
67/// before any network I/O so concurrent callers don't serialize on it.
68pub struct DcPool {
69    /// Per-DC connection slots; inner Vec holds slot Arcs.
70    pub conns: HashMap<i32, Vec<Arc<ConnSlot>>>,
71    addrs: HashMap<i32, String>,
72    #[allow(dead_code)]
73    home_dc_id: i32,
74    /// Proxy config forwarded to auto-reconnect.
75    socks5: Option<Socks5Config>,
76    /// Transport kind reused for secondary DC connections.
77    transport: TransportKind,
78    /// DCs that have already received `invokeWithLayer(initConnection(...))`.
79    init_done: std::collections::HashSet<i32>,
80}
81
82impl DcPool {
83    pub fn new(
84        home_dc_id: i32,
85        dc_entries: &[DcEntry],
86        socks5: Option<Socks5Config>,
87        transport: TransportKind,
88    ) -> Self {
89        let addrs = dc_entries
90            .iter()
91            .map(|e| (e.dc_id, e.addr.clone()))
92            .collect();
93        Self {
94            conns: HashMap::new(),
95            addrs,
96            home_dc_id,
97            socks5,
98            transport,
99            init_done: std::collections::HashSet::new(),
100        }
101    }
102
103    /// Returns true if at least one connection slot exists for `dc_id`.
104    pub fn has_connection(&self, dc_id: i32) -> bool {
105        self.conns.get(&dc_id).is_some_and(|v| !v.is_empty())
106    }
107
108    /// Graduate an already set-up `DcConnection` into a pipelined slot.
109    ///
110    /// The connection has already done its DH / PFS bind / initConnection as
111    /// a plain `DcConnection`. From here on its socket is owned by a single
112    /// background task; this function just spawns that task and wraps the
113    /// resulting handle in a `ConnSlot`.
114    fn spawn_slot(conn: DcConnection) -> Arc<ConnSlot> {
115        let auth_key = conn.auth_key_bytes();
116        let first_salt = conn.first_salt();
117        let time_offset = conn.time_offset();
118        let (stream, frame_kind, enc) = conn.into_parts();
119
120        let (handle, mut frame_rx) = spawn_sender_task(stream, enc, frame_kind, None);
121
122        // Pool slots don't support reconnect: on failure the pool just
123        // evicts the whole DC and a fresh slot is opened from scratch on the
124        // next call. Dropping reconnect_tx here means the sender task's
125        // error branch sees its reconnect channel closed and shuts itself
126        // down cleanly instead of waiting for a reconnect that will never
127        // come.
128        drop(handle.reconnect_tx);
129
130        let alive = Arc::new(AtomicBool::new(true));
131        let alive_for_drain = alive.clone();
132        tokio::spawn(async move {
133            while let Some(event) = frame_rx.recv().await {
134                if let FrameEvent::Error(e) = event {
135                    tracing::warn!("[ferogram::pool] worker connection dropped: {e}");
136                    alive_for_drain.store(false, Ordering::Release);
137                    break;
138                }
139                // FrameEvent::Update / Connected: pool connections don't
140                // dispatch updates, nothing to do.
141            }
142        });
143
144        Arc::new(ConnSlot {
145            rpc_tx: handle.rpc_tx,
146            in_flight: AtomicUsize::new(0),
147            alive,
148            auth_key,
149            first_salt,
150            time_offset,
151        })
152    }
153
154    /// Insert a pre-built, already initialized connection into the pool as a
155    /// new slot.
156    pub fn insert(&mut self, dc_id: i32, conn: DcConnection) {
157        let slot = Self::spawn_slot(conn);
158        self.conns.entry(dc_id).or_default().push(slot);
159        let total: usize = self.conns.values().map(|v| v.len()).sum();
160        metrics::gauge!("ferogram.connections_active").set(total as f64);
161    }
162
163    /// Returns the least-loaded slot for `dc_id`, creating one if needed.
164    /// Creates a new slot if all existing ones are busy and count < MAX_CONNS_PER_DC.
165    /// Drop the DcPool guard before locking the returned slot.
166    pub(crate) async fn get_or_create_slot(
167        &mut self,
168        dc_id: i32,
169        pfs: bool,
170        auth_key: Option<([u8; 256], i64, i32)>,
171    ) -> Result<Arc<ConnSlot>, InvocationError> {
172        let addr = self.addrs.get(&dc_id).cloned().ok_or_else(|| {
173            InvocationError::Deserialize(format!("dc_pool: no address for DC{dc_id}"))
174        })?;
175
176        // Ensure at least one slot exists.
177        if !self.conns.contains_key(&dc_id) || self.conns[&dc_id].is_empty() {
178            tracing::debug!("[ferogram::pool] opening first connection to DC{dc_id} at {addr}");
179            let conn = if let Some((key, salt, offset)) = auth_key {
180                DcConnection::connect_with_key(
181                    &addr,
182                    key,
183                    salt,
184                    offset,
185                    self.socks5.as_ref(),
186                    None,
187                    &self.transport,
188                    dc_id as i16,
189                    pfs,
190                )
191                .await?
192            } else {
193                DcConnection::connect_raw(
194                    &addr,
195                    self.socks5.as_ref(),
196                    &self.transport,
197                    dc_id as i16,
198                )
199                .await?
200            };
201            let slot = Self::spawn_slot(conn);
202            self.conns.entry(dc_id).or_default().push(slot);
203            self.init_done.remove(&dc_id);
204            let total: usize = self.conns.values().map(|v| v.len()).sum();
205            metrics::gauge!("ferogram.connections_active").set(total as f64);
206        }
207
208        let slots = self
209            .conns
210            .get(&dc_id)
211            .expect("dc_id must be registered before use");
212
213        // pick least-busy slot
214        let best = slots
215            .iter()
216            .min_by_key(|s| s.in_flight.load(Ordering::Relaxed))
217            .expect("slots vec is non-empty")
218            .clone();
219        let min_inflight = best.in_flight.load(Ordering::Relaxed);
220
221        // Spawn a new slot if: all are busy AND we have room for more.
222        //
223        // With pipelined slots this matters less than it used to (a single
224        // slot can now happily carry many in-flight requests at once), but
225        // it's still worth spreading load across a few real TCP connections
226        // for very heavy transfers.
227        if min_inflight > 0 && slots.len() < MAX_CONNS_PER_DC {
228            tracing::debug!(
229                "[ferogram::pool] DC{dc_id}: all {} slots busy (min_inflight={min_inflight}), opening extra connection",
230                slots.len()
231            );
232            let conn = if let Some((key, salt, offset)) = auth_key {
233                DcConnection::connect_with_key(
234                    &addr,
235                    key,
236                    salt,
237                    offset,
238                    self.socks5.as_ref(),
239                    None,
240                    &self.transport,
241                    dc_id as i16,
242                    pfs,
243                )
244                .await?
245            } else {
246                DcConnection::connect_raw(
247                    &addr,
248                    self.socks5.as_ref(),
249                    &self.transport,
250                    dc_id as i16,
251                )
252                .await?
253            };
254            let new_slot = Self::spawn_slot(conn);
255            let arc = new_slot.clone();
256            self.conns
257                .get_mut(&dc_id)
258                .expect("dc_id must be registered")
259                .push(new_slot);
260            let total: usize = self.conns.values().map(|v| v.len()).sum();
261            metrics::gauge!("ferogram.connections_active").set(total as f64);
262            return Ok(arc);
263        }
264
265        Ok(best)
266    }
267
268    /// Evict all slots for a DC (called on connection failure to force
269    /// reconnection on the next call).
270    pub fn evict(&mut self, dc_id: i32) {
271        self.conns.remove(&dc_id);
272        self.init_done.remove(&dc_id);
273        let total: usize = self.conns.values().map(|v| v.len()).sum();
274        metrics::gauge!("ferogram.connections_active").set(total as f64);
275        tracing::debug!("[ferogram::pool] evicted all connections for DC{dc_id}");
276    }
277
278    /// Enqueue `body` on `slot` and await the result.
279    ///
280    /// This is the only place that touches `rpc_tx`/the oneshot: no mutex,
281    /// no blocking for the duration of the round trip. Multiple callers can
282    /// call this against the same slot concurrently and their requests will
283    /// pipeline on the wire instead of queueing behind each other.
284    async fn send_via_slot(
285        slot: &Arc<ConnSlot>,
286        body: Vec<u8>,
287    ) -> Result<Vec<u8>, InvocationError> {
288        slot.in_flight.fetch_add(1, Ordering::Relaxed);
289        let (tx, rx) = oneshot::channel();
290        let send_result = slot.rpc_tx.send(RpcEnqueue { body, tx }).await;
291        let result = if send_result.is_err() {
292            slot.alive.store(false, Ordering::Release);
293            Err(InvocationError::Deserialize(
294                "worker sender task shut down".into(),
295            ))
296        } else {
297            match rx.await {
298                Ok(r) => r,
299                Err(_) => {
300                    slot.alive.store(false, Ordering::Release);
301                    Err(InvocationError::Deserialize(
302                        "worker rpc channel closed".into(),
303                    ))
304                }
305            }
306        };
307        slot.in_flight.fetch_sub(1, Ordering::Relaxed);
308        result
309    }
310
311    /// Invoke a raw RPC call on the given DC.
312    /// Pool lock is released before the network round-trip begins.
313    pub async fn invoke_on_dc<R: RemoteCall>(
314        &mut self,
315        dc_id: i32,
316        _dc_entries: &[DcEntry],
317        req: &R,
318    ) -> Result<Vec<u8>, InvocationError> {
319        let slot = self.get_or_create_slot(dc_id, false, None).await?;
320        let body = maybe_gz_pack(&req.to_bytes());
321        let result = Self::send_via_slot(&slot, body.clone()).await;
322
323        if let Err(ref e) = result {
324            let kind = match e {
325                InvocationError::Rpc(_) => "rpc",
326                InvocationError::Io(_) => "io",
327                _ => "other",
328            };
329            metrics::counter!("ferogram.rpc_errors_total", "kind" => kind).increment(1);
330        }
331
332        if result.is_err() && !slot.alive.load(Ordering::Acquire) {
333            tracing::warn!(
334                "[ferogram::pool] DC{dc_id} connection died mid-request; evicting and retrying on a fresh connection"
335            );
336            self.evict(dc_id);
337            let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
338            return Self::send_via_slot(&retry_slot, body).await;
339        }
340        result
341    }
342
343    /// Mark a DC as having completed initConnection.
344    pub fn mark_init_done(&mut self, dc_id: i32) {
345        self.init_done.insert(dc_id);
346    }
347
348    /// Returns true if this DC has already received initConnection this session.
349    pub fn is_init_done(&self, dc_id: i32) -> bool {
350        self.init_done.contains(&dc_id)
351    }
352
353    /// Like `invoke_on_dc` but accepts any `Serializable` type.
354    pub async fn invoke_on_dc_serializable<S: Serializable>(
355        &mut self,
356        dc_id: i32,
357        req: &S,
358    ) -> Result<Vec<u8>, InvocationError> {
359        let slot = self
360            .get_or_create_slot(dc_id, false, None)
361            .await
362            .map_err(|_| InvocationError::Deserialize(format!("no connection for DC{dc_id}")))?;
363        let body = maybe_gz_pack(&req.to_bytes());
364        let result = Self::send_via_slot(&slot, body.clone()).await;
365
366        if result.is_err() && !slot.alive.load(Ordering::Acquire) {
367            tracing::warn!(
368                "[ferogram::pool] DC{dc_id} connection died mid-request (serializable path); evicting and retrying"
369            );
370            self.evict(dc_id);
371            let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
372            return Self::send_via_slot(&retry_slot, body).await;
373        }
374        result
375    }
376
377    /// Update the address table (called after `initConnection`).
378    pub fn update_addrs(&mut self, entries: &[DcEntry]) {
379        for e in entries {
380            self.addrs.insert(e.dc_id, e.addr.clone());
381        }
382    }
383
384    /// Save the auth keys from pool connections back into the DC entry list.
385    /// Uses the first slot per DC (all slots share the same auth key).
386    pub fn collect_keys(&self, entries: &mut [DcEntry]) {
387        for e in entries.iter_mut() {
388            if let Some(slots) = self.conns.get(&e.dc_id)
389                && let Some(slot) = slots.first()
390            {
391                e.auth_key = Some(slot.auth_key);
392                e.first_salt = slot.first_salt;
393                e.time_offset = slot.time_offset;
394            }
395        }
396    }
397}
398
399/// Serialize a `msgs_ack#62d6b459 { msg_ids: Vector<long> }` TL body.
400///
401/// This is sent as a non-content-related encrypted frame (even seq_no)
402/// to acknowledge received server messages and prevent Telegram from
403/// closing the connection due to un-acked messages.
404pub(crate) fn build_msgs_ack_body(msg_ids: &[i64]) -> Vec<u8> {
405    let mut out = Vec::with_capacity(4 + 4 + 4 + msg_ids.len() * 8);
406    out.extend_from_slice(&0x62d6b459_u32.to_le_bytes()); // msgs_ack constructor
407    out.extend_from_slice(&0x1cb5c415_u32.to_le_bytes()); // Vector constructor
408    out.extend_from_slice(&(msg_ids.len() as u32).to_le_bytes());
409    for &id in msg_ids {
410        out.extend_from_slice(&id.to_le_bytes());
411    }
412    out
413}
414
415/// Serialize a `ping_delay_disconnect#f3427b8c { ping_id, disconnect_delay: 75 }` body.
416///
417/// Tells Telegram to close the connection after 75 seconds of silence.
418pub(crate) fn build_msgs_ack_ping_body(ping_id: i64) -> Vec<u8> {
419    // ping_delay_disconnect#f3427b8c ping_id:long disconnect_delay:int = Pong
420    let mut out = Vec::with_capacity(4 + 8 + 4);
421    out.extend_from_slice(&0xf3427b8c_u32.to_le_bytes()); // constructor
422    out.extend_from_slice(&ping_id.to_le_bytes());
423    out.extend_from_slice(&75_i32.to_le_bytes()); // disconnect_delay = 75 s
424    out
425}