use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tracing::{debug, warn};
use futures_util::{FutureExt, StreamExt};
use crate::config::Config;
use crate::ws_client::{connect_ws_for_dc, TgWsStream};
const MAX_AGE: Duration = Duration::from_secs(55);
struct PoolEntry {
ws: TgWsStream,
created: Instant,
}
type Bucket = Vec<PoolEntry>;
type PoolMap = HashMap<(u32, bool), Bucket>;
pub struct WsPool {
pool_size: usize,
idle: Mutex<PoolMap>,
refilling: StdMutex<HashSet<(u32, bool)>>,
}
struct RefillGuard<'a> {
set: &'a StdMutex<HashSet<(u32, bool)>>,
key: (u32, bool),
}
impl Drop for RefillGuard<'_> {
fn drop(&mut self) {
self.set.lock().unwrap().remove(&self.key);
}
}
impl WsPool {
pub fn new(pool_size: usize) -> Self {
Self {
pool_size,
idle: Mutex::new(HashMap::new()),
refilling: StdMutex::new(HashSet::new()),
}
}
pub async fn get(
self: &Arc<Self>,
dc: u32,
is_media: bool,
target_ip: String,
skip_tls_verify: bool,
) -> Option<TgWsStream> {
let now = Instant::now();
let mut lock = self.idle.lock().await;
let bucket = lock.entry((dc, is_media)).or_default();
while let Some(mut entry) = bucket.pop() {
if now.saturating_duration_since(entry.created) > MAX_AGE {
continue;
}
if entry.ws.next().now_or_never().is_some() {
debug!(
"pool: discarding stale DC{}{} connection",
dc,
if is_media { "m" } else { "" }
);
continue;
}
let remaining = bucket.len();
drop(lock);
debug!(
"pool hit DC{}{} ({} left)",
dc,
if is_media { "m" } else { "" },
remaining
);
let pool = Arc::clone(self);
tokio::spawn(async move {
pool.refill(dc, is_media, target_ip, skip_tls_verify).await;
});
return Some(entry.ws);
}
drop(lock);
let pool = Arc::clone(self);
tokio::spawn(async move {
pool.refill(dc, is_media, target_ip, skip_tls_verify).await;
});
None
}
pub async fn warmup(&self, config: &Config) {
let dc_redirects = config.dc_redirects();
let skip_tls = config.skip_tls_verify;
let pool_size = self.pool_size;
for (dc, ip) in dc_redirects {
for is_media in [false, true] {
let new_conns =
Self::connect_batch(&ip, dc, is_media, skip_tls, pool_size).await;
let mut lock = self.idle.lock().await;
let bucket = lock.entry((dc, is_media)).or_default();
for ws in new_conns {
bucket.push(PoolEntry {
ws,
created: Instant::now(),
});
}
}
}
debug!("WS pool warmup complete");
}
async fn refill(&self, dc: u32, is_media: bool, target_ip: String, skip_tls: bool) {
let registered = self.refilling.lock().unwrap().insert((dc, is_media));
if !registered {
return; }
let _guard = RefillGuard { set: &self.refilling, key: (dc, is_media) };
let needed = {
let lock = self.idle.lock().await;
let current = lock.get(&(dc, is_media)).map_or(0, |b| b.len());
if current >= self.pool_size {
return;
}
self.pool_size - current
};
let new_conns = Self::connect_batch(&target_ip, dc, is_media, skip_tls, needed).await;
if !new_conns.is_empty() {
let mut lock = self.idle.lock().await;
let bucket = lock.entry((dc, is_media)).or_default();
let can_add = self.pool_size.saturating_sub(bucket.len());
for ws in new_conns.into_iter().take(can_add) {
bucket.push(PoolEntry {
ws,
created: Instant::now(),
});
}
debug!(
"pool refilled DC{}{}: {} ready",
dc,
if is_media { "m" } else { "" },
lock.get(&(dc, is_media)).map_or(0, |b| b.len())
);
}
}
async fn connect_batch(
ip: &str,
dc: u32,
is_media: bool,
skip_tls: bool,
count: usize,
) -> Vec<TgWsStream> {
let mut results = Vec::new();
let timeout = Duration::from_secs(8);
for _ in 0..count {
match connect_ws_for_dc(ip, dc, is_media, skip_tls, timeout).await {
(Some(ws), _) => results.push(ws),
(None, _) => {
warn!("pool: failed to pre-connect DC{}{}", dc, if is_media { "m" } else { "" });
break;
}
}
}
results
}
}