1use crate::errors::InvocationError;
14use crate::sender::DcConnection;
15use crate::sender_task::{FrameEvent, RpcEnqueue, spawn_sender_task};
16use ferogram_connect::util::maybe_gz_pack;
17use ferogram_connect::{Socks5Config, TransportKind};
18use ferogram_session::DcEntry;
19use ferogram_tl_types::{RemoteCall, Serializable};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
23use tokio::sync::{mpsc, oneshot};
24
25const MAX_CONNS_PER_DC: usize = 3;
27
28pub struct ConnSlot {
42 rpc_tx: mpsc::Sender<RpcEnqueue>,
43 pub in_flight: AtomicUsize,
44 alive: Arc<AtomicBool>,
53 auth_key: [u8; 256],
61 first_salt: i64,
62 time_offset: i32,
63}
64
65pub struct DcPool {
69 pub conns: HashMap<i32, Vec<Arc<ConnSlot>>>,
71 addrs: HashMap<i32, String>,
72 #[allow(dead_code)]
73 home_dc_id: i32,
74 socks5: Option<Socks5Config>,
76 transport: TransportKind,
78 init_done: std::collections::HashSet<i32>,
80}
81
82impl DcPool {
83 pub fn new(
84 home_dc_id: i32,
85 dc_entries: &[DcEntry],
86 socks5: Option<Socks5Config>,
87 transport: TransportKind,
88 ) -> Self {
89 let addrs = dc_entries
90 .iter()
91 .map(|e| (e.dc_id, e.addr.clone()))
92 .collect();
93 Self {
94 conns: HashMap::new(),
95 addrs,
96 home_dc_id,
97 socks5,
98 transport,
99 init_done: std::collections::HashSet::new(),
100 }
101 }
102
103 pub fn has_connection(&self, dc_id: i32) -> bool {
105 self.conns.get(&dc_id).is_some_and(|v| !v.is_empty())
106 }
107
108 fn spawn_slot(conn: DcConnection) -> Arc<ConnSlot> {
115 let auth_key = conn.auth_key_bytes();
116 let first_salt = conn.first_salt();
117 let time_offset = conn.time_offset();
118 let (stream, frame_kind, enc) = conn.into_parts();
119
120 let (handle, mut frame_rx) = spawn_sender_task(stream, enc, frame_kind, None);
121
122 drop(handle.reconnect_tx);
129
130 let alive = Arc::new(AtomicBool::new(true));
131 let alive_for_drain = alive.clone();
132 tokio::spawn(async move {
133 while let Some(event) = frame_rx.recv().await {
134 if let FrameEvent::Error(e) = event {
135 tracing::warn!("[ferogram::pool] worker connection dropped: {e}");
136 alive_for_drain.store(false, Ordering::Release);
137 break;
138 }
139 }
142 });
143
144 Arc::new(ConnSlot {
145 rpc_tx: handle.rpc_tx,
146 in_flight: AtomicUsize::new(0),
147 alive,
148 auth_key,
149 first_salt,
150 time_offset,
151 })
152 }
153
154 pub fn insert(&mut self, dc_id: i32, conn: DcConnection) {
157 let slot = Self::spawn_slot(conn);
158 self.conns.entry(dc_id).or_default().push(slot);
159 let total: usize = self.conns.values().map(|v| v.len()).sum();
160 metrics::gauge!("ferogram.connections_active").set(total as f64);
161 }
162
163 pub(crate) async fn get_or_create_slot(
167 &mut self,
168 dc_id: i32,
169 pfs: bool,
170 auth_key: Option<([u8; 256], i64, i32)>,
171 ) -> Result<Arc<ConnSlot>, InvocationError> {
172 let addr = self.addrs.get(&dc_id).cloned().ok_or_else(|| {
173 InvocationError::Deserialize(format!("dc_pool: no address for DC{dc_id}"))
174 })?;
175
176 if !self.conns.contains_key(&dc_id) || self.conns[&dc_id].is_empty() {
178 tracing::debug!("[ferogram::pool] opening first connection to DC{dc_id} at {addr}");
179 let conn = if let Some((key, salt, offset)) = auth_key {
180 DcConnection::connect_with_key(
181 &addr,
182 key,
183 salt,
184 offset,
185 self.socks5.as_ref(),
186 None,
187 &self.transport,
188 dc_id as i16,
189 pfs,
190 )
191 .await?
192 } else {
193 DcConnection::connect_raw(
194 &addr,
195 self.socks5.as_ref(),
196 &self.transport,
197 dc_id as i16,
198 )
199 .await?
200 };
201 let slot = Self::spawn_slot(conn);
202 self.conns.entry(dc_id).or_default().push(slot);
203 self.init_done.remove(&dc_id);
204 let total: usize = self.conns.values().map(|v| v.len()).sum();
205 metrics::gauge!("ferogram.connections_active").set(total as f64);
206 }
207
208 let slots = self
209 .conns
210 .get(&dc_id)
211 .expect("dc_id must be registered before use");
212
213 let best = slots
215 .iter()
216 .min_by_key(|s| s.in_flight.load(Ordering::Relaxed))
217 .expect("slots vec is non-empty")
218 .clone();
219 let min_inflight = best.in_flight.load(Ordering::Relaxed);
220
221 if min_inflight > 0 && slots.len() < MAX_CONNS_PER_DC {
228 tracing::debug!(
229 "[ferogram::pool] DC{dc_id}: all {} slots busy (min_inflight={min_inflight}), opening extra connection",
230 slots.len()
231 );
232 let conn = if let Some((key, salt, offset)) = auth_key {
233 DcConnection::connect_with_key(
234 &addr,
235 key,
236 salt,
237 offset,
238 self.socks5.as_ref(),
239 None,
240 &self.transport,
241 dc_id as i16,
242 pfs,
243 )
244 .await?
245 } else {
246 DcConnection::connect_raw(
247 &addr,
248 self.socks5.as_ref(),
249 &self.transport,
250 dc_id as i16,
251 )
252 .await?
253 };
254 let new_slot = Self::spawn_slot(conn);
255 let arc = new_slot.clone();
256 self.conns
257 .get_mut(&dc_id)
258 .expect("dc_id must be registered")
259 .push(new_slot);
260 let total: usize = self.conns.values().map(|v| v.len()).sum();
261 metrics::gauge!("ferogram.connections_active").set(total as f64);
262 return Ok(arc);
263 }
264
265 Ok(best)
266 }
267
268 pub fn evict(&mut self, dc_id: i32) {
271 self.conns.remove(&dc_id);
272 self.init_done.remove(&dc_id);
273 let total: usize = self.conns.values().map(|v| v.len()).sum();
274 metrics::gauge!("ferogram.connections_active").set(total as f64);
275 tracing::debug!("[ferogram::pool] evicted all connections for DC{dc_id}");
276 }
277
278 async fn send_via_slot(
285 slot: &Arc<ConnSlot>,
286 body: Vec<u8>,
287 ) -> Result<Vec<u8>, InvocationError> {
288 slot.in_flight.fetch_add(1, Ordering::Relaxed);
289 let (tx, rx) = oneshot::channel();
290 let send_result = slot.rpc_tx.send(RpcEnqueue { body, tx }).await;
291 let result = if send_result.is_err() {
292 slot.alive.store(false, Ordering::Release);
293 Err(InvocationError::Deserialize(
294 "worker sender task shut down".into(),
295 ))
296 } else {
297 match rx.await {
298 Ok(r) => r,
299 Err(_) => {
300 slot.alive.store(false, Ordering::Release);
301 Err(InvocationError::Deserialize(
302 "worker rpc channel closed".into(),
303 ))
304 }
305 }
306 };
307 slot.in_flight.fetch_sub(1, Ordering::Relaxed);
308 result
309 }
310
311 pub async fn invoke_on_dc<R: RemoteCall>(
314 &mut self,
315 dc_id: i32,
316 _dc_entries: &[DcEntry],
317 req: &R,
318 ) -> Result<Vec<u8>, InvocationError> {
319 let slot = self.get_or_create_slot(dc_id, false, None).await?;
320 let body = maybe_gz_pack(&req.to_bytes());
321 let result = Self::send_via_slot(&slot, body.clone()).await;
322
323 if let Err(ref e) = result {
324 let kind = match e {
325 InvocationError::Rpc(_) => "rpc",
326 InvocationError::Io(_) => "io",
327 _ => "other",
328 };
329 metrics::counter!("ferogram.rpc_errors_total", "kind" => kind).increment(1);
330 }
331
332 if result.is_err() && !slot.alive.load(Ordering::Acquire) {
333 tracing::warn!(
334 "[ferogram::pool] DC{dc_id} connection died mid-request; evicting and retrying on a fresh connection"
335 );
336 self.evict(dc_id);
337 let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
338 return Self::send_via_slot(&retry_slot, body).await;
339 }
340 result
341 }
342
343 pub fn mark_init_done(&mut self, dc_id: i32) {
345 self.init_done.insert(dc_id);
346 }
347
348 pub fn is_init_done(&self, dc_id: i32) -> bool {
350 self.init_done.contains(&dc_id)
351 }
352
353 pub async fn invoke_on_dc_serializable<S: Serializable>(
355 &mut self,
356 dc_id: i32,
357 req: &S,
358 ) -> Result<Vec<u8>, InvocationError> {
359 let slot = self
360 .get_or_create_slot(dc_id, false, None)
361 .await
362 .map_err(|_| InvocationError::Deserialize(format!("no connection for DC{dc_id}")))?;
363 let body = maybe_gz_pack(&req.to_bytes());
364 let result = Self::send_via_slot(&slot, body.clone()).await;
365
366 if result.is_err() && !slot.alive.load(Ordering::Acquire) {
367 tracing::warn!(
368 "[ferogram::pool] DC{dc_id} connection died mid-request (serializable path); evicting and retrying"
369 );
370 self.evict(dc_id);
371 let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
372 return Self::send_via_slot(&retry_slot, body).await;
373 }
374 result
375 }
376
377 pub fn update_addrs(&mut self, entries: &[DcEntry]) {
379 for e in entries {
380 self.addrs.insert(e.dc_id, e.addr.clone());
381 }
382 }
383
384 pub fn collect_keys(&self, entries: &mut [DcEntry]) {
387 for e in entries.iter_mut() {
388 if let Some(slots) = self.conns.get(&e.dc_id)
389 && let Some(slot) = slots.first()
390 {
391 e.auth_key = Some(slot.auth_key);
392 e.first_salt = slot.first_salt;
393 e.time_offset = slot.time_offset;
394 }
395 }
396 }
397}
398
399pub(crate) fn build_msgs_ack_body(msg_ids: &[i64]) -> Vec<u8> {
405 let mut out = Vec::with_capacity(4 + 4 + 4 + msg_ids.len() * 8);
406 out.extend_from_slice(&0x62d6b459_u32.to_le_bytes()); out.extend_from_slice(&0x1cb5c415_u32.to_le_bytes()); out.extend_from_slice(&(msg_ids.len() as u32).to_le_bytes());
409 for &id in msg_ids {
410 out.extend_from_slice(&id.to_le_bytes());
411 }
412 out
413}
414
415pub(crate) fn build_msgs_ack_ping_body(ping_id: i64) -> Vec<u8> {
419 let mut out = Vec::with_capacity(4 + 8 + 4);
421 out.extend_from_slice(&0xf3427b8c_u32.to_le_bytes()); out.extend_from_slice(&ping_id.to_le_bytes());
423 out.extend_from_slice(&75_i32.to_le_bytes()); out
425}