1use crate::errors::InvocationError;
14use crate::sender::DcConnection;
15use crate::sender_task::{FrameEvent, RpcEnqueue, spawn_sender_task};
16use ferogram_connect::util::maybe_gz_pack;
17use ferogram_connect::{Socks5Config, TransportKind};
18use ferogram_session::DcEntry;
19use ferogram_tl_types::{RemoteCall, Serializable};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
23use tokio::sync::{mpsc, oneshot};
24
25const MAX_CONNS_PER_DC: usize = 3;
27
28pub struct ConnSlot {
42 rpc_tx: mpsc::Sender<RpcEnqueue>,
43 pub in_flight: AtomicUsize,
44 alive: Arc<AtomicBool>,
53 auth_key: [u8; 256],
61 first_salt: i64,
62 time_offset: i32,
63}
64
65pub struct DcPool {
69 pub conns: HashMap<i32, Vec<Arc<ConnSlot>>>,
71 addrs: HashMap<i32, String>,
72 #[allow(dead_code)]
73 home_dc_id: i32,
74 socks5: Option<Socks5Config>,
76 transport: TransportKind,
78 init_done: std::collections::HashSet<i32>,
80}
81
82impl DcPool {
83 pub fn new(
87 home_dc_id: i32,
88 dc_entries: &[DcEntry],
89 socks5: Option<Socks5Config>,
90 transport: TransportKind,
91 ) -> Self {
92 let addrs = dc_entries
93 .iter()
94 .map(|e| (e.dc_id, e.addr.clone()))
95 .collect();
96 Self {
97 conns: HashMap::new(),
98 addrs,
99 home_dc_id,
100 socks5,
101 transport,
102 init_done: std::collections::HashSet::new(),
103 }
104 }
105
106 pub fn has_connection(&self, dc_id: i32) -> bool {
108 self.conns.get(&dc_id).is_some_and(|v| !v.is_empty())
109 }
110
111 fn spawn_slot(conn: DcConnection) -> Arc<ConnSlot> {
118 let auth_key = conn.auth_key_bytes();
119 let first_salt = conn.first_salt();
120 let time_offset = conn.time_offset();
121 let (stream, frame_kind, enc) = conn.into_parts();
122
123 let (handle, mut frame_rx) = spawn_sender_task(stream, enc, frame_kind, None);
124
125 drop(handle.reconnect_tx);
132
133 let alive = Arc::new(AtomicBool::new(true));
134 let alive_for_drain = alive.clone();
135 tokio::spawn(async move {
136 while let Some(event) = frame_rx.recv().await {
137 if let FrameEvent::Error(e) = event {
138 tracing::warn!("[ferogram::pool] worker connection dropped: {e}");
139 alive_for_drain.store(false, Ordering::Release);
140 break;
141 }
142 }
145 });
146
147 Arc::new(ConnSlot {
148 rpc_tx: handle.rpc_tx,
149 in_flight: AtomicUsize::new(0),
150 alive,
151 auth_key,
152 first_salt,
153 time_offset,
154 })
155 }
156
157 pub fn insert(&mut self, dc_id: i32, conn: DcConnection) {
160 let slot = Self::spawn_slot(conn);
161 self.conns.entry(dc_id).or_default().push(slot);
162 let total: usize = self.conns.values().map(|v| v.len()).sum();
163 metrics::gauge!("ferogram.connections_active").set(total as f64);
164 }
165
166 pub(crate) async fn get_or_create_slot(
170 &mut self,
171 dc_id: i32,
172 pfs: bool,
173 auth_key: Option<([u8; 256], i64, i32)>,
174 ) -> Result<Arc<ConnSlot>, InvocationError> {
175 let addr = self.addrs.get(&dc_id).cloned().ok_or_else(|| {
176 InvocationError::Deserialize(format!("dc_pool: no address for DC{dc_id}"))
177 })?;
178
179 if !self.conns.contains_key(&dc_id) || self.conns[&dc_id].is_empty() {
181 tracing::debug!("[ferogram::pool] opening first connection to DC{dc_id} at {addr}");
182 let conn = if let Some((key, salt, offset)) = auth_key {
183 DcConnection::connect_with_key(
184 &addr,
185 key,
186 salt,
187 offset,
188 self.socks5.as_ref(),
189 None,
190 &self.transport,
191 dc_id as i16,
192 pfs,
193 )
194 .await?
195 } else {
196 DcConnection::connect_raw(
197 &addr,
198 self.socks5.as_ref(),
199 &self.transport,
200 dc_id as i16,
201 )
202 .await?
203 };
204 let slot = Self::spawn_slot(conn);
205 self.conns.entry(dc_id).or_default().push(slot);
206 self.init_done.remove(&dc_id);
207 let total: usize = self.conns.values().map(|v| v.len()).sum();
208 metrics::gauge!("ferogram.connections_active").set(total as f64);
209 }
210
211 let slots = self
212 .conns
213 .get(&dc_id)
214 .expect("dc_id must be registered before use");
215
216 let best = slots
218 .iter()
219 .min_by_key(|s| s.in_flight.load(Ordering::Relaxed))
220 .expect("slots vec is non-empty")
221 .clone();
222 let min_inflight = best.in_flight.load(Ordering::Relaxed);
223
224 if min_inflight > 0 && slots.len() < MAX_CONNS_PER_DC {
231 tracing::debug!(
232 "[ferogram::pool] DC{dc_id}: all {} slots busy (min_inflight={min_inflight}), opening extra connection",
233 slots.len()
234 );
235 let conn = if let Some((key, salt, offset)) = auth_key {
236 DcConnection::connect_with_key(
237 &addr,
238 key,
239 salt,
240 offset,
241 self.socks5.as_ref(),
242 None,
243 &self.transport,
244 dc_id as i16,
245 pfs,
246 )
247 .await?
248 } else {
249 DcConnection::connect_raw(
250 &addr,
251 self.socks5.as_ref(),
252 &self.transport,
253 dc_id as i16,
254 )
255 .await?
256 };
257 let new_slot = Self::spawn_slot(conn);
258 let arc = new_slot.clone();
259 self.conns
260 .get_mut(&dc_id)
261 .expect("dc_id must be registered")
262 .push(new_slot);
263 let total: usize = self.conns.values().map(|v| v.len()).sum();
264 metrics::gauge!("ferogram.connections_active").set(total as f64);
265 return Ok(arc);
266 }
267
268 Ok(best)
269 }
270
271 pub fn evict(&mut self, dc_id: i32) {
274 self.conns.remove(&dc_id);
275 self.init_done.remove(&dc_id);
276 let total: usize = self.conns.values().map(|v| v.len()).sum();
277 metrics::gauge!("ferogram.connections_active").set(total as f64);
278 tracing::debug!("[ferogram::pool] evicted all connections for DC{dc_id}");
279 }
280
281 async fn send_via_slot(
288 slot: &Arc<ConnSlot>,
289 body: Vec<u8>,
290 ) -> Result<Vec<u8>, InvocationError> {
291 slot.in_flight.fetch_add(1, Ordering::Relaxed);
292 let (tx, rx) = oneshot::channel();
293 let send_result = slot.rpc_tx.send(RpcEnqueue { body, tx }).await;
294 let result = if send_result.is_err() {
295 slot.alive.store(false, Ordering::Release);
296 Err(InvocationError::Deserialize(
297 "worker sender task shut down".into(),
298 ))
299 } else {
300 match rx.await {
301 Ok(r) => r,
302 Err(_) => {
303 slot.alive.store(false, Ordering::Release);
304 Err(InvocationError::Deserialize(
305 "worker rpc channel closed".into(),
306 ))
307 }
308 }
309 };
310 slot.in_flight.fetch_sub(1, Ordering::Relaxed);
311 result
312 }
313
314 pub async fn invoke_on_dc<R: RemoteCall>(
317 &mut self,
318 dc_id: i32,
319 _dc_entries: &[DcEntry],
320 req: &R,
321 ) -> Result<Vec<u8>, InvocationError> {
322 let slot = self.get_or_create_slot(dc_id, false, None).await?;
323 let body = maybe_gz_pack(&req.to_bytes());
324 let result = Self::send_via_slot(&slot, body.clone()).await;
325
326 if let Err(ref e) = result {
327 let kind = match e {
328 InvocationError::Rpc(_) => "rpc",
329 InvocationError::Io(_) => "io",
330 _ => "other",
331 };
332 metrics::counter!("ferogram.rpc_errors_total", "kind" => kind).increment(1);
333 }
334
335 if let Err(InvocationError::Rpc(ref e)) = result
336 && e.code == -404
337 {
338 tracing::warn!(
341 "[ferogram::pool] DC{dc_id} returned -404 (auth key gone); evicting and redoing DH"
342 );
343 self.evict(dc_id);
344 let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
345 return Self::send_via_slot(&retry_slot, body).await;
346 }
347
348 if result.is_err() && !slot.alive.load(Ordering::Acquire) {
349 tracing::warn!(
350 "[ferogram::pool] DC{dc_id} connection died mid-request; evicting and retrying on a fresh connection"
351 );
352 self.evict(dc_id);
353 let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
354 return Self::send_via_slot(&retry_slot, body).await;
355 }
356 result
357 }
358
359 pub fn mark_init_done(&mut self, dc_id: i32) {
361 self.init_done.insert(dc_id);
362 }
363
364 pub fn is_init_done(&self, dc_id: i32) -> bool {
366 self.init_done.contains(&dc_id)
367 }
368
369 pub async fn invoke_on_dc_serializable<S: Serializable>(
371 &mut self,
372 dc_id: i32,
373 req: &S,
374 ) -> Result<Vec<u8>, InvocationError> {
375 let slot = self
376 .get_or_create_slot(dc_id, false, None)
377 .await
378 .map_err(|_| InvocationError::Deserialize(format!("no connection for DC{dc_id}")))?;
379 let body = maybe_gz_pack(&req.to_bytes());
380 let result = Self::send_via_slot(&slot, body.clone()).await;
381
382 if let Err(InvocationError::Rpc(ref e)) = result
383 && e.code == -404
384 {
385 tracing::warn!(
386 "[ferogram::pool] DC{dc_id} returned -404 (serializable path); evicting and redoing DH"
387 );
388 self.evict(dc_id);
389 let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
390 return Self::send_via_slot(&retry_slot, body).await;
391 }
392
393 if result.is_err() && !slot.alive.load(Ordering::Acquire) {
394 tracing::warn!(
395 "[ferogram::pool] DC{dc_id} connection died mid-request (serializable path); evicting and retrying"
396 );
397 self.evict(dc_id);
398 let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
399 return Self::send_via_slot(&retry_slot, body).await;
400 }
401 result
402 }
403
404 pub fn update_addrs(&mut self, entries: &[DcEntry]) {
406 for e in entries {
407 self.addrs.insert(e.dc_id, e.addr.clone());
408 }
409 }
410
411 pub fn collect_keys(&self, entries: &mut [DcEntry]) {
414 for e in entries.iter_mut() {
415 if let Some(slots) = self.conns.get(&e.dc_id)
416 && let Some(slot) = slots.first()
417 {
418 e.auth_key = Some(slot.auth_key);
419 e.first_salt = slot.first_salt;
420 e.time_offset = slot.time_offset;
421 }
422 }
423 }
424}
425
426pub(crate) fn build_msgs_ack_body(msg_ids: &[i64]) -> Vec<u8> {
432 let mut out = Vec::with_capacity(4 + 4 + 4 + msg_ids.len() * 8);
433 out.extend_from_slice(&0x62d6b459_u32.to_le_bytes()); out.extend_from_slice(&0x1cb5c415_u32.to_le_bytes()); out.extend_from_slice(&(msg_ids.len() as u32).to_le_bytes());
436 for &id in msg_ids {
437 out.extend_from_slice(&id.to_le_bytes());
438 }
439 out
440}
441
442pub(crate) fn build_msgs_ack_ping_body(ping_id: i64) -> Vec<u8> {
446 let mut out = Vec::with_capacity(4 + 8 + 4);
448 out.extend_from_slice(&0xf3427b8c_u32.to_le_bytes()); out.extend_from_slice(&ping_id.to_le_bytes());
450 out.extend_from_slice(&75_i32.to_le_bytes()); out
452}