use crate::errors::InvocationError;
use crate::sender::DcConnection;
use ferogram_connect::{Socks5Config, TransportKind};
use ferogram_session::DcEntry;
use ferogram_tl_types::RemoteCall;
use std::collections::HashMap;
const MAX_CONNS_PER_DC: usize = 3;
pub struct ConnSlot {
pub conn: tokio::sync::Mutex<DcConnection>,
pub in_flight: std::sync::atomic::AtomicUsize,
}
pub struct DcPool {
pub conns: HashMap<i32, Vec<std::sync::Arc<ConnSlot>>>,
addrs: HashMap<i32, String>,
#[allow(dead_code)]
home_dc_id: i32,
socks5: Option<Socks5Config>,
transport: TransportKind,
init_done: std::collections::HashSet<i32>,
}
impl DcPool {
pub fn new(
home_dc_id: i32,
dc_entries: &[DcEntry],
socks5: Option<Socks5Config>,
transport: TransportKind,
) -> Self {
let addrs = dc_entries
.iter()
.map(|e| (e.dc_id, e.addr.clone()))
.collect();
Self {
conns: HashMap::new(),
addrs,
home_dc_id,
socks5,
transport,
init_done: std::collections::HashSet::new(),
}
}
pub fn has_connection(&self, dc_id: i32) -> bool {
self.conns.get(&dc_id).is_some_and(|v| !v.is_empty())
}
pub fn insert(&mut self, dc_id: i32, conn: DcConnection) {
let slot = std::sync::Arc::new(ConnSlot {
conn: tokio::sync::Mutex::new(conn),
in_flight: std::sync::atomic::AtomicUsize::new(0),
});
self.conns.entry(dc_id).or_default().push(slot);
let total: usize = self.conns.values().map(|v| v.len()).sum();
metrics::gauge!("ferogram.connections_active").set(total as f64);
}
pub(crate) async fn get_or_create_slot(
&mut self,
dc_id: i32,
pfs: bool,
auth_key: Option<([u8; 256], i64, i32)>,
) -> Result<std::sync::Arc<ConnSlot>, InvocationError> {
use std::sync::atomic::Ordering;
let addr = self.addrs.get(&dc_id).cloned().ok_or_else(|| {
InvocationError::Deserialize(format!("dc_pool: no address for DC{dc_id}"))
})?;
if !self.conns.contains_key(&dc_id) || self.conns[&dc_id].is_empty() {
tracing::info!("[dc_pool] auto-connecting DC{dc_id} ({addr})");
let conn = if let Some((key, salt, offset)) = auth_key {
DcConnection::connect_with_key(
&addr,
key,
salt,
offset,
self.socks5.as_ref(),
None,
&self.transport,
dc_id as i16,
pfs,
)
.await?
} else {
DcConnection::connect_raw(
&addr,
self.socks5.as_ref(),
&self.transport,
dc_id as i16,
)
.await?
};
let slot = std::sync::Arc::new(ConnSlot {
conn: tokio::sync::Mutex::new(conn),
in_flight: std::sync::atomic::AtomicUsize::new(0),
});
self.conns.entry(dc_id).or_default().push(slot);
self.init_done.remove(&dc_id);
let total: usize = self.conns.values().map(|v| v.len()).sum();
metrics::gauge!("ferogram.connections_active").set(total as f64);
}
let slots = self
.conns
.get(&dc_id)
.expect("dc_id must be registered before use");
let best = slots
.iter()
.min_by_key(|s| s.in_flight.load(Ordering::Relaxed))
.expect("slots vec is non-empty")
.clone();
let min_inflight = best.in_flight.load(Ordering::Relaxed);
if min_inflight > 0 && slots.len() < MAX_CONNS_PER_DC {
tracing::debug!(
"[dc_pool] DC{dc_id}: all {} slots busy (min_inflight={}), opening new slot",
slots.len(),
min_inflight
);
let conn = if let Some((key, salt, offset)) = auth_key {
DcConnection::connect_with_key(
&addr,
key,
salt,
offset,
self.socks5.as_ref(),
None,
&self.transport,
dc_id as i16,
pfs,
)
.await?
} else {
DcConnection::connect_raw(
&addr,
self.socks5.as_ref(),
&self.transport,
dc_id as i16,
)
.await?
};
let new_slot = std::sync::Arc::new(ConnSlot {
conn: tokio::sync::Mutex::new(conn),
in_flight: std::sync::atomic::AtomicUsize::new(0),
});
let arc = new_slot.clone();
self.conns
.get_mut(&dc_id)
.expect("dc_id must be registered")
.push(new_slot);
let total: usize = self.conns.values().map(|v| v.len()).sum();
metrics::gauge!("ferogram.connections_active").set(total as f64);
return Ok(arc);
}
Ok(best)
}
pub fn evict(&mut self, dc_id: i32) {
self.conns.remove(&dc_id);
self.init_done.remove(&dc_id);
let total: usize = self.conns.values().map(|v| v.len()).sum();
metrics::gauge!("ferogram.connections_active").set(total as f64);
tracing::debug!("[dc_pool] evicted all slots for DC{dc_id}");
}
pub async fn invoke_on_dc<R: RemoteCall>(
&mut self,
dc_id: i32,
_dc_entries: &[DcEntry],
req: &R,
) -> Result<Vec<u8>, InvocationError> {
use std::sync::atomic::Ordering;
let slot = self.get_or_create_slot(dc_id, false, None).await?;
slot.in_flight.fetch_add(1, Ordering::Relaxed);
let result = slot.conn.lock().await.rpc_call(req).await;
slot.in_flight.fetch_sub(1, Ordering::Relaxed);
if let Err(ref e) = result {
let kind = match e {
InvocationError::Rpc(_) => "rpc",
InvocationError::Io(_) => "io",
_ => "other",
};
metrics::counter!("ferogram.rpc_errors_total", "kind" => kind).increment(1);
}
if matches!(result, Err(InvocationError::Io(_))) {
tracing::warn!("[dc_pool] IO error on DC{dc_id}, evicting all slots and retrying");
self.evict(dc_id);
let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
retry_slot.in_flight.fetch_add(1, Ordering::Relaxed);
let r = retry_slot.conn.lock().await.rpc_call(req).await;
retry_slot.in_flight.fetch_sub(1, Ordering::Relaxed);
return r;
}
result
}
pub fn mark_init_done(&mut self, dc_id: i32) {
self.init_done.insert(dc_id);
}
pub fn is_init_done(&self, dc_id: i32) -> bool {
self.init_done.contains(&dc_id)
}
pub async fn invoke_on_dc_serializable<S: ferogram_tl_types::Serializable>(
&mut self,
dc_id: i32,
req: &S,
) -> Result<Vec<u8>, InvocationError> {
use std::sync::atomic::Ordering;
let slot = self
.get_or_create_slot(dc_id, false, None)
.await
.map_err(|_| InvocationError::Deserialize(format!("no connection for DC{dc_id}")))?;
slot.in_flight.fetch_add(1, Ordering::Relaxed);
let result = slot.conn.lock().await.rpc_call_serializable(req).await;
slot.in_flight.fetch_sub(1, Ordering::Relaxed);
if matches!(result, Err(InvocationError::Io(_))) {
tracing::warn!("[dc_pool] serializable IO error on DC{dc_id}, evicting and retrying");
self.evict(dc_id);
let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
retry_slot.in_flight.fetch_add(1, Ordering::Relaxed);
let r = retry_slot
.conn
.lock()
.await
.rpc_call_serializable(req)
.await;
retry_slot.in_flight.fetch_sub(1, Ordering::Relaxed);
return r;
}
result
}
pub fn update_addrs(&mut self, entries: &[DcEntry]) {
for e in entries {
self.addrs.insert(e.dc_id, e.addr.clone());
}
}
pub fn collect_keys(&self, entries: &mut [DcEntry]) {
for e in entries.iter_mut() {
if let Some(slots) = self.conns.get(&e.dc_id)
&& let Some(slot) = slots.first()
&& let Ok(conn) = slot.conn.try_lock()
{
e.auth_key = Some(conn.auth_key_bytes());
e.first_salt = conn.first_salt();
e.time_offset = conn.time_offset();
}
}
}
}
pub(crate) fn build_msgs_ack_body(msg_ids: &[i64]) -> Vec<u8> {
let mut out = Vec::with_capacity(4 + 4 + 4 + msg_ids.len() * 8);
out.extend_from_slice(&0x62d6b459_u32.to_le_bytes()); out.extend_from_slice(&0x1cb5c415_u32.to_le_bytes()); out.extend_from_slice(&(msg_ids.len() as u32).to_le_bytes());
for &id in msg_ids {
out.extend_from_slice(&id.to_le_bytes());
}
out
}
pub(crate) fn build_msgs_ack_ping_body(ping_id: i64) -> Vec<u8> {
let mut out = Vec::with_capacity(4 + 8 + 4);
out.extend_from_slice(&0xf3427b8c_u32.to_le_bytes()); out.extend_from_slice(&ping_id.to_le_bytes());
out.extend_from_slice(&75_i32.to_le_bytes()); out
}