Skip to main content

ferogram_mtsender/
pool.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2//
3// ferogram: async Telegram MTProto client in Rust
4// https://github.com/ankit-chaubey/ferogram
5//
6// Licensed under either the MIT License or the Apache License 2.0.
7// See the LICENSE-MIT or LICENSE-APACHE file in this repository:
8// https://github.com/ankit-chaubey/ferogram
9//
10// Feel free to use, modify, and share this code.
11// Please keep this notice when redistributing.
12
13use crate::errors::InvocationError;
14use crate::sender::DcConnection;
15use ferogram_connect::{Socks5Config, TransportKind};
16use ferogram_session::DcEntry;
17use ferogram_tl_types::RemoteCall;
18use std::collections::HashMap;
19
20// Max simultaneous connections per DC.
21const MAX_CONNS_PER_DC: usize = 3;
22
23/// One slot in the per-DC connection pool.
24/// `in_flight` lets the pool pick the least-busy slot without locking it.
25pub struct ConnSlot {
26    pub conn: tokio::sync::Mutex<DcConnection>,
27    pub in_flight: std::sync::atomic::AtomicUsize,
28}
29
30/// Pool of per-DC authenticated connections.
31/// Each DC holds up to MAX_CONNS_PER_DC slots. The pool lock is dropped
32/// before any network I/O so concurrent callers don't serialize on it.
33pub struct DcPool {
34    /// Per-DC connection slots; inner Vec holds slot Arcs.
35    pub conns: HashMap<i32, Vec<std::sync::Arc<ConnSlot>>>,
36    addrs: HashMap<i32, String>,
37    #[allow(dead_code)]
38    home_dc_id: i32,
39    /// Proxy config forwarded to auto-reconnect.
40    socks5: Option<Socks5Config>,
41    /// Transport kind reused for secondary DC connections.
42    transport: TransportKind,
43    /// DCs that have already received `invokeWithLayer(initConnection(...))`.
44    init_done: std::collections::HashSet<i32>,
45}
46
47impl DcPool {
48    pub fn new(
49        home_dc_id: i32,
50        dc_entries: &[DcEntry],
51        socks5: Option<Socks5Config>,
52        transport: TransportKind,
53    ) -> Self {
54        let addrs = dc_entries
55            .iter()
56            .map(|e| (e.dc_id, e.addr.clone()))
57            .collect();
58        Self {
59            conns: HashMap::new(),
60            addrs,
61            home_dc_id,
62            socks5,
63            transport,
64            init_done: std::collections::HashSet::new(),
65        }
66    }
67
68    /// Returns true if at least one connection slot exists for `dc_id`.
69    pub fn has_connection(&self, dc_id: i32) -> bool {
70        self.conns.get(&dc_id).is_some_and(|v| !v.is_empty())
71    }
72
73    /// Insert a pre-built connection into the pool as a new slot.
74    pub fn insert(&mut self, dc_id: i32, conn: DcConnection) {
75        let slot = std::sync::Arc::new(ConnSlot {
76            conn: tokio::sync::Mutex::new(conn),
77            in_flight: std::sync::atomic::AtomicUsize::new(0),
78        });
79        self.conns.entry(dc_id).or_default().push(slot);
80        let total: usize = self.conns.values().map(|v| v.len()).sum();
81        metrics::gauge!("ferogram.connections_active").set(total as f64);
82    }
83
84    /// Returns the least-loaded slot for `dc_id`, creating one if needed.
85    /// Creates a new slot if all existing ones are busy and count < MAX_CONNS_PER_DC.
86    /// Drop the DcPool guard before locking the returned slot.
87    pub(crate) async fn get_or_create_slot(
88        &mut self,
89        dc_id: i32,
90        pfs: bool,
91        auth_key: Option<([u8; 256], i64, i32)>,
92    ) -> Result<std::sync::Arc<ConnSlot>, InvocationError> {
93        use std::sync::atomic::Ordering;
94
95        let addr = self.addrs.get(&dc_id).cloned().ok_or_else(|| {
96            InvocationError::Deserialize(format!("dc_pool: no address for DC{dc_id}"))
97        })?;
98
99        // Ensure at least one slot exists.
100        if !self.conns.contains_key(&dc_id) || self.conns[&dc_id].is_empty() {
101            tracing::info!("[dc_pool] auto-connecting DC{dc_id} ({addr})");
102            let conn = if let Some((key, salt, offset)) = auth_key {
103                DcConnection::connect_with_key(
104                    &addr,
105                    key,
106                    salt,
107                    offset,
108                    self.socks5.as_ref(),
109                    None,
110                    &self.transport,
111                    dc_id as i16,
112                    pfs,
113                )
114                .await?
115            } else {
116                DcConnection::connect_raw(
117                    &addr,
118                    self.socks5.as_ref(),
119                    &self.transport,
120                    dc_id as i16,
121                )
122                .await?
123            };
124            let slot = std::sync::Arc::new(ConnSlot {
125                conn: tokio::sync::Mutex::new(conn),
126                in_flight: std::sync::atomic::AtomicUsize::new(0),
127            });
128            self.conns.entry(dc_id).or_default().push(slot);
129            self.init_done.remove(&dc_id);
130            let total: usize = self.conns.values().map(|v| v.len()).sum();
131            metrics::gauge!("ferogram.connections_active").set(total as f64);
132        }
133
134        let slots = self
135            .conns
136            .get(&dc_id)
137            .expect("dc_id must be registered before use");
138
139        // pick least-busy slot
140        let best = slots
141            .iter()
142            .min_by_key(|s| s.in_flight.load(Ordering::Relaxed))
143            .expect("slots vec is non-empty")
144            .clone();
145        let min_inflight = best.in_flight.load(Ordering::Relaxed);
146
147        // Spawn a new slot if: all are busy AND we have room for more.
148        if min_inflight > 0 && slots.len() < MAX_CONNS_PER_DC {
149            tracing::debug!(
150                "[dc_pool] DC{dc_id}: all {} slots busy (min_inflight={}), opening new slot",
151                slots.len(),
152                min_inflight
153            );
154            let conn = if let Some((key, salt, offset)) = auth_key {
155                DcConnection::connect_with_key(
156                    &addr,
157                    key,
158                    salt,
159                    offset,
160                    self.socks5.as_ref(),
161                    None,
162                    &self.transport,
163                    dc_id as i16,
164                    pfs,
165                )
166                .await?
167            } else {
168                DcConnection::connect_raw(
169                    &addr,
170                    self.socks5.as_ref(),
171                    &self.transport,
172                    dc_id as i16,
173                )
174                .await?
175            };
176            let new_slot = std::sync::Arc::new(ConnSlot {
177                conn: tokio::sync::Mutex::new(conn),
178                in_flight: std::sync::atomic::AtomicUsize::new(0),
179            });
180            let arc = new_slot.clone();
181            self.conns
182                .get_mut(&dc_id)
183                .expect("dc_id must be registered")
184                .push(new_slot);
185            let total: usize = self.conns.values().map(|v| v.len()).sum();
186            metrics::gauge!("ferogram.connections_active").set(total as f64);
187            return Ok(arc);
188        }
189
190        Ok(best)
191    }
192
193    /// Evict all slots for a DC (called on IO error to force reconnection).
194    pub fn evict(&mut self, dc_id: i32) {
195        self.conns.remove(&dc_id);
196        self.init_done.remove(&dc_id);
197        let total: usize = self.conns.values().map(|v| v.len()).sum();
198        metrics::gauge!("ferogram.connections_active").set(total as f64);
199        tracing::debug!("[dc_pool] evicted all slots for DC{dc_id}");
200    }
201
202    /// Invoke a raw RPC call on the given DC.
203    /// Pool lock is released before the network round-trip begins.
204    pub async fn invoke_on_dc<R: RemoteCall>(
205        &mut self,
206        dc_id: i32,
207        _dc_entries: &[DcEntry],
208        req: &R,
209    ) -> Result<Vec<u8>, InvocationError> {
210        use std::sync::atomic::Ordering;
211        let slot = self.get_or_create_slot(dc_id, false, None).await?;
212        slot.in_flight.fetch_add(1, Ordering::Relaxed);
213        let result = slot.conn.lock().await.rpc_call(req).await;
214        slot.in_flight.fetch_sub(1, Ordering::Relaxed);
215        if let Err(ref e) = result {
216            let kind = match e {
217                InvocationError::Rpc(_) => "rpc",
218                InvocationError::Io(_) => "io",
219                _ => "other",
220            };
221            metrics::counter!("ferogram.rpc_errors_total", "kind" => kind).increment(1);
222        }
223        if matches!(result, Err(InvocationError::Io(_))) {
224            tracing::warn!("[dc_pool] IO error on DC{dc_id}, evicting all slots and retrying");
225            self.evict(dc_id);
226            let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
227            retry_slot.in_flight.fetch_add(1, Ordering::Relaxed);
228            let r = retry_slot.conn.lock().await.rpc_call(req).await;
229            retry_slot.in_flight.fetch_sub(1, Ordering::Relaxed);
230            return r;
231        }
232        result
233    }
234
235    /// Mark a DC as having completed initConnection.
236    pub fn mark_init_done(&mut self, dc_id: i32) {
237        self.init_done.insert(dc_id);
238    }
239
240    /// Returns true if this DC has already received initConnection this session.
241    pub fn is_init_done(&self, dc_id: i32) -> bool {
242        self.init_done.contains(&dc_id)
243    }
244
245    /// Like `invoke_on_dc` but accepts any `Serializable` type.
246    pub async fn invoke_on_dc_serializable<S: ferogram_tl_types::Serializable>(
247        &mut self,
248        dc_id: i32,
249        req: &S,
250    ) -> Result<Vec<u8>, InvocationError> {
251        use std::sync::atomic::Ordering;
252        let slot = self
253            .get_or_create_slot(dc_id, false, None)
254            .await
255            .map_err(|_| InvocationError::Deserialize(format!("no connection for DC{dc_id}")))?;
256        slot.in_flight.fetch_add(1, Ordering::Relaxed);
257        let result = slot.conn.lock().await.rpc_call_serializable(req).await;
258        slot.in_flight.fetch_sub(1, Ordering::Relaxed);
259        if matches!(result, Err(InvocationError::Io(_))) {
260            tracing::warn!("[dc_pool] serializable IO error on DC{dc_id}, evicting and retrying");
261            self.evict(dc_id);
262            let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
263            retry_slot.in_flight.fetch_add(1, Ordering::Relaxed);
264            let r = retry_slot
265                .conn
266                .lock()
267                .await
268                .rpc_call_serializable(req)
269                .await;
270            retry_slot.in_flight.fetch_sub(1, Ordering::Relaxed);
271            return r;
272        }
273        result
274    }
275
276    /// Update the address table (called after `initConnection`).
277    pub fn update_addrs(&mut self, entries: &[DcEntry]) {
278        for e in entries {
279            self.addrs.insert(e.dc_id, e.addr.clone());
280        }
281    }
282
283    /// Save the auth keys from pool connections back into the DC entry list.
284    /// Uses the first slot per DC (all slots share the same auth key).
285    pub fn collect_keys(&self, entries: &mut [DcEntry]) {
286        for e in entries.iter_mut() {
287            if let Some(slots) = self.conns.get(&e.dc_id)
288                && let Some(slot) = slots.first()
289                && let Ok(conn) = slot.conn.try_lock()
290            {
291                e.auth_key = Some(conn.auth_key_bytes());
292                e.first_salt = conn.first_salt();
293                e.time_offset = conn.time_offset();
294            }
295        }
296    }
297}
298
299/// Serialize a `msgs_ack#62d6b459 { msg_ids: Vector<long> }` TL body.
300///
301/// This is sent as a non-content-related encrypted frame (even seq_no)
302/// to acknowledge received server messages and prevent Telegram from
303/// closing the connection due to un-acked messages.
304pub(crate) fn build_msgs_ack_body(msg_ids: &[i64]) -> Vec<u8> {
305    let mut out = Vec::with_capacity(4 + 4 + 4 + msg_ids.len() * 8);
306    out.extend_from_slice(&0x62d6b459_u32.to_le_bytes()); // msgs_ack constructor
307    out.extend_from_slice(&0x1cb5c415_u32.to_le_bytes()); // Vector constructor
308    out.extend_from_slice(&(msg_ids.len() as u32).to_le_bytes());
309    for &id in msg_ids {
310        out.extend_from_slice(&id.to_le_bytes());
311    }
312    out
313}
314
315/// Serialize a `ping_delay_disconnect#f3427b8c { ping_id, disconnect_delay: 75 }` body.
316///
317/// Tells Telegram to close the connection after 75 seconds of silence.
318pub(crate) fn build_msgs_ack_ping_body(ping_id: i64) -> Vec<u8> {
319    // ping_delay_disconnect#f3427b8c ping_id:long disconnect_delay:int = Pong
320    let mut out = Vec::with_capacity(4 + 8 + 4);
321    out.extend_from_slice(&0xf3427b8c_u32.to_le_bytes()); // constructor
322    out.extend_from_slice(&ping_id.to_le_bytes());
323    out.extend_from_slice(&75_i32.to_le_bytes()); // disconnect_delay = 75 s
324    out
325}