use crate::errors::InvocationError;
use crate::sender::DcConnection;
use crate::sender_task::{FrameEvent, RpcEnqueue, spawn_sender_task};
use ferogram_connect::util::maybe_gz_pack;
use ferogram_connect::{Socks5Config, TransportKind};
use ferogram_session::DcEntry;
use ferogram_tl_types::{RemoteCall, Serializable};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use tokio::sync::{mpsc, oneshot};
const MAX_CONNS_PER_DC: usize = 3;
pub struct ConnSlot {
rpc_tx: mpsc::Sender<RpcEnqueue>,
pub in_flight: AtomicUsize,
alive: Arc<AtomicBool>,
auth_key: [u8; 256],
first_salt: i64,
time_offset: i32,
}
pub struct DcPool {
pub conns: HashMap<i32, Vec<Arc<ConnSlot>>>,
addrs: HashMap<i32, String>,
#[allow(dead_code)]
home_dc_id: i32,
socks5: Option<Socks5Config>,
transport: TransportKind,
init_done: std::collections::HashSet<i32>,
}
impl DcPool {
pub fn new(
home_dc_id: i32,
dc_entries: &[DcEntry],
socks5: Option<Socks5Config>,
transport: TransportKind,
) -> Self {
let addrs = dc_entries
.iter()
.map(|e| (e.dc_id, e.addr.clone()))
.collect();
Self {
conns: HashMap::new(),
addrs,
home_dc_id,
socks5,
transport,
init_done: std::collections::HashSet::new(),
}
}
pub fn has_connection(&self, dc_id: i32) -> bool {
self.conns.get(&dc_id).is_some_and(|v| !v.is_empty())
}
fn spawn_slot(conn: DcConnection) -> Arc<ConnSlot> {
let auth_key = conn.auth_key_bytes();
let first_salt = conn.first_salt();
let time_offset = conn.time_offset();
let (stream, frame_kind, enc) = conn.into_parts();
let (handle, mut frame_rx) = spawn_sender_task(stream, enc, frame_kind, None);
drop(handle.reconnect_tx);
let alive = Arc::new(AtomicBool::new(true));
let alive_for_drain = alive.clone();
tokio::spawn(async move {
while let Some(event) = frame_rx.recv().await {
if let FrameEvent::Error(e) = event {
tracing::warn!("[ferogram::pool] worker connection dropped: {e}");
alive_for_drain.store(false, Ordering::Release);
break;
}
}
});
Arc::new(ConnSlot {
rpc_tx: handle.rpc_tx,
in_flight: AtomicUsize::new(0),
alive,
auth_key,
first_salt,
time_offset,
})
}
pub fn insert(&mut self, dc_id: i32, conn: DcConnection) {
let slot = Self::spawn_slot(conn);
self.conns.entry(dc_id).or_default().push(slot);
let total: usize = self.conns.values().map(|v| v.len()).sum();
metrics::gauge!("ferogram.connections_active").set(total as f64);
}
pub(crate) async fn get_or_create_slot(
&mut self,
dc_id: i32,
pfs: bool,
auth_key: Option<([u8; 256], i64, i32)>,
) -> Result<Arc<ConnSlot>, InvocationError> {
let addr = self.addrs.get(&dc_id).cloned().ok_or_else(|| {
InvocationError::Deserialize(format!("dc_pool: no address for DC{dc_id}"))
})?;
if !self.conns.contains_key(&dc_id) || self.conns[&dc_id].is_empty() {
tracing::debug!("[ferogram::pool] opening first connection to DC{dc_id} at {addr}");
let conn = if let Some((key, salt, offset)) = auth_key {
DcConnection::connect_with_key(
&addr,
key,
salt,
offset,
self.socks5.as_ref(),
None,
&self.transport,
dc_id as i16,
pfs,
)
.await?
} else {
DcConnection::connect_raw(
&addr,
self.socks5.as_ref(),
&self.transport,
dc_id as i16,
)
.await?
};
let slot = Self::spawn_slot(conn);
self.conns.entry(dc_id).or_default().push(slot);
self.init_done.remove(&dc_id);
let total: usize = self.conns.values().map(|v| v.len()).sum();
metrics::gauge!("ferogram.connections_active").set(total as f64);
}
let slots = self
.conns
.get(&dc_id)
.expect("dc_id must be registered before use");
let best = slots
.iter()
.min_by_key(|s| s.in_flight.load(Ordering::Relaxed))
.expect("slots vec is non-empty")
.clone();
let min_inflight = best.in_flight.load(Ordering::Relaxed);
if min_inflight > 0 && slots.len() < MAX_CONNS_PER_DC {
tracing::debug!(
"[ferogram::pool] DC{dc_id}: all {} slots busy (min_inflight={min_inflight}), opening extra connection",
slots.len()
);
let conn = if let Some((key, salt, offset)) = auth_key {
DcConnection::connect_with_key(
&addr,
key,
salt,
offset,
self.socks5.as_ref(),
None,
&self.transport,
dc_id as i16,
pfs,
)
.await?
} else {
DcConnection::connect_raw(
&addr,
self.socks5.as_ref(),
&self.transport,
dc_id as i16,
)
.await?
};
let new_slot = Self::spawn_slot(conn);
let arc = new_slot.clone();
self.conns
.get_mut(&dc_id)
.expect("dc_id must be registered")
.push(new_slot);
let total: usize = self.conns.values().map(|v| v.len()).sum();
metrics::gauge!("ferogram.connections_active").set(total as f64);
return Ok(arc);
}
Ok(best)
}
pub fn evict(&mut self, dc_id: i32) {
self.conns.remove(&dc_id);
self.init_done.remove(&dc_id);
let total: usize = self.conns.values().map(|v| v.len()).sum();
metrics::gauge!("ferogram.connections_active").set(total as f64);
tracing::debug!("[ferogram::pool] evicted all connections for DC{dc_id}");
}
async fn send_via_slot(
slot: &Arc<ConnSlot>,
body: Vec<u8>,
) -> Result<Vec<u8>, InvocationError> {
slot.in_flight.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel();
let send_result = slot.rpc_tx.send(RpcEnqueue { body, tx }).await;
let result = if send_result.is_err() {
slot.alive.store(false, Ordering::Release);
Err(InvocationError::Deserialize(
"worker sender task shut down".into(),
))
} else {
match rx.await {
Ok(r) => r,
Err(_) => {
slot.alive.store(false, Ordering::Release);
Err(InvocationError::Deserialize(
"worker rpc channel closed".into(),
))
}
}
};
slot.in_flight.fetch_sub(1, Ordering::Relaxed);
result
}
pub async fn invoke_on_dc<R: RemoteCall>(
&mut self,
dc_id: i32,
_dc_entries: &[DcEntry],
req: &R,
) -> Result<Vec<u8>, InvocationError> {
let slot = self.get_or_create_slot(dc_id, false, None).await?;
let body = maybe_gz_pack(&req.to_bytes());
let result = Self::send_via_slot(&slot, body.clone()).await;
if let Err(ref e) = result {
let kind = match e {
InvocationError::Rpc(_) => "rpc",
InvocationError::Io(_) => "io",
_ => "other",
};
metrics::counter!("ferogram.rpc_errors_total", "kind" => kind).increment(1);
}
if result.is_err() && !slot.alive.load(Ordering::Acquire) {
tracing::warn!(
"[ferogram::pool] DC{dc_id} connection died mid-request; evicting and retrying on a fresh connection"
);
self.evict(dc_id);
let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
return Self::send_via_slot(&retry_slot, body).await;
}
result
}
pub fn mark_init_done(&mut self, dc_id: i32) {
self.init_done.insert(dc_id);
}
pub fn is_init_done(&self, dc_id: i32) -> bool {
self.init_done.contains(&dc_id)
}
pub async fn invoke_on_dc_serializable<S: Serializable>(
&mut self,
dc_id: i32,
req: &S,
) -> Result<Vec<u8>, InvocationError> {
let slot = self
.get_or_create_slot(dc_id, false, None)
.await
.map_err(|_| InvocationError::Deserialize(format!("no connection for DC{dc_id}")))?;
let body = maybe_gz_pack(&req.to_bytes());
let result = Self::send_via_slot(&slot, body.clone()).await;
if result.is_err() && !slot.alive.load(Ordering::Acquire) {
tracing::warn!(
"[ferogram::pool] DC{dc_id} connection died mid-request (serializable path); evicting and retrying"
);
self.evict(dc_id);
let retry_slot = self.get_or_create_slot(dc_id, false, None).await?;
return Self::send_via_slot(&retry_slot, body).await;
}
result
}
pub fn update_addrs(&mut self, entries: &[DcEntry]) {
for e in entries {
self.addrs.insert(e.dc_id, e.addr.clone());
}
}
pub fn collect_keys(&self, entries: &mut [DcEntry]) {
for e in entries.iter_mut() {
if let Some(slots) = self.conns.get(&e.dc_id)
&& let Some(slot) = slots.first()
{
e.auth_key = Some(slot.auth_key);
e.first_salt = slot.first_salt;
e.time_offset = slot.time_offset;
}
}
}
}
pub(crate) fn build_msgs_ack_body(msg_ids: &[i64]) -> Vec<u8> {
let mut out = Vec::with_capacity(4 + 4 + 4 + msg_ids.len() * 8);
out.extend_from_slice(&0x62d6b459_u32.to_le_bytes()); out.extend_from_slice(&0x1cb5c415_u32.to_le_bytes()); out.extend_from_slice(&(msg_ids.len() as u32).to_le_bytes());
for &id in msg_ids {
out.extend_from_slice(&id.to_le_bytes());
}
out
}
pub(crate) fn build_msgs_ack_ping_body(ping_id: i64) -> Vec<u8> {
let mut out = Vec::with_capacity(4 + 8 + 4);
out.extend_from_slice(&0xf3427b8c_u32.to_le_bytes()); out.extend_from_slice(&ping_id.to_le_bytes());
out.extend_from_slice(&75_i32.to_le_bytes()); out
}