use std::sync::Arc;
use std::time::Duration;
use cipher::StreamCipher;
use futures_util::SinkExt;
use futures_util::StreamExt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, info, warn};
use tungstenite::Message;
use crate::config::{default_dc_ips, default_dc_overrides, Config};
use crate::crypto::{
build_connection_ciphers, generate_client_handshake, generate_relay_init, parse_handshake,
AesCtr256, ConnectionCiphers,
};
use crate::pool::WsPool;
use crate::splitter::MsgSplitter;
use crate::ws_client::{connect_cf_ws_for_dc, connect_ws_for_dc, ws_send, TgWsStream};
use std::collections::HashMap;
use std::sync::Mutex as StdMutex;
use std::time::Instant;
static DC_FAIL_UNTIL: StdMutex<Option<HashMap<(u32, bool), Instant>>> = StdMutex::new(None);
static UPSTREAM_FAIL_UNTIL: StdMutex<Option<HashMap<String, Instant>>> = StdMutex::new(None);
fn upstream_key(host: &str, port: u16) -> String {
format!("{}:{}", host, port)
}
fn set_upstream_cooldown(host: &str, port: u16, cooldown: Duration) {
let key = upstream_key(host, port);
let mut lock = UPSTREAM_FAIL_UNTIL.lock().unwrap();
lock.get_or_insert_with(HashMap::new)
.insert(key, Instant::now() + cooldown);
}
fn clear_upstream_cooldown(host: &str, port: u16) {
let key = upstream_key(host, port);
let mut lock = UPSTREAM_FAIL_UNTIL.lock().unwrap();
if let Some(map) = lock.as_mut() {
map.remove(&key);
}
}
fn upstream_in_cooldown(host: &str, port: u16) -> bool {
let key = upstream_key(host, port);
let lock = UPSTREAM_FAIL_UNTIL.lock().unwrap();
if let Some(map) = lock.as_ref() {
if let Some(&until) = map.get(&key) {
return Instant::now() < until;
}
}
false
}
static CF_FAIL_UNTIL: StdMutex<Option<HashMap<(u32, bool), Instant>>> = StdMutex::new(None);
fn set_cf_cooldown(dc: u32, is_media: bool, cooldown: Duration) {
let mut lock = CF_FAIL_UNTIL.lock().unwrap();
lock.get_or_insert_with(HashMap::new)
.insert((dc, is_media), Instant::now() + cooldown);
}
fn clear_cf_cooldown(dc: u32, is_media: bool) {
let mut lock = CF_FAIL_UNTIL.lock().unwrap();
if let Some(map) = lock.as_mut() {
map.remove(&(dc, is_media));
}
}
fn cf_in_cooldown(dc: u32, is_media: bool) -> bool {
let lock = CF_FAIL_UNTIL.lock().unwrap();
if let Some(map) = lock.as_ref() {
if let Some(&until) = map.get(&(dc, is_media)) {
return Instant::now() < until;
}
}
false
}
fn blacklist_ws(dc: u32, is_media: bool, cooldown: Duration) {
let mut lock = DC_FAIL_UNTIL.lock().unwrap();
lock.get_or_insert_with(HashMap::new)
.insert((dc, is_media), Instant::now() + cooldown);
}
fn set_dc_cooldown(dc: u32, is_media: bool, cooldown: Duration) {
let mut lock = DC_FAIL_UNTIL.lock().unwrap();
lock.get_or_insert_with(HashMap::new)
.insert((dc, is_media), Instant::now() + cooldown);
}
fn clear_dc_cooldown(dc: u32, is_media: bool) {
let mut lock = DC_FAIL_UNTIL.lock().unwrap();
if let Some(map) = lock.as_mut() {
map.remove(&(dc, is_media));
}
}
fn ws_timeout_for(dc: u32, is_media: bool, normal_timeout: Duration, fail_probe_timeout: Duration) -> Duration {
let lock = DC_FAIL_UNTIL.lock().unwrap();
if let Some(map) = lock.as_ref() {
if let Some(&until) = map.get(&(dc, is_media)) {
if Instant::now() < until {
return fail_probe_timeout; }
}
}
normal_timeout
}
pub async fn handle_client(
stream: TcpStream,
peer: std::net::SocketAddr,
config: Config,
pool: Arc<WsPool>,
) {
let label = peer.to_string();
let _ = stream.set_nodelay(true);
let secret = config.secret_bytes();
let dc_redirects = config.dc_redirects();
let dc_overrides = default_dc_overrides();
let dc_fallback_ips = default_dc_ips();
let skip_tls = config.skip_tls_verify;
let ws_connect_timeout = Duration::from_secs(config.ws_connect_timeout);
let ws_fail_probe_timeout = Duration::from_secs(config.ws_fail_probe_timeout);
let ws_fail_cooldown = Duration::from_secs(config.ws_fail_cooldown);
let ws_redirect_cooldown = Duration::from_secs(config.ws_redirect_cooldown);
let handshake_timeout = Duration::from_secs(config.handshake_timeout);
let tcp_fallback_timeout = Duration::from_secs(config.tcp_fallback_timeout);
let upstream_connect_timeout = Duration::from_secs(config.upstream_connect_timeout);
let upstream_fail_cooldown = Duration::from_secs(config.upstream_fail_cooldown);
let cf_connect_timeout = Duration::from_secs(config.cf_connect_timeout);
let cf_fail_cooldown = Duration::from_secs(config.cf_fail_cooldown);
let (mut reader, writer) = tokio::io::split(stream);
let mut handshake_buf = [0u8; 64];
match tokio::time::timeout(
handshake_timeout,
reader.read_exact(&mut handshake_buf),
)
.await
{
Ok(Ok(_)) => {}
Ok(Err(e)) => {
debug!("[{}] read handshake: {}", label, e);
return;
}
Err(_) => {
warn!("[{}] handshake timeout", label);
return;
}
}
let info = match parse_handshake(&handshake_buf, &secret) {
Some(i) => i,
None => {
debug!(
"[{}] bad handshake (wrong secret or reserved prefix)",
label
);
let _ = tokio::io::copy(&mut reader, &mut tokio::io::sink()).await;
return;
}
};
let dc_id = info.dc_id;
let is_media = info.is_media;
let proto = info.proto;
let ws_dc = *dc_overrides.get(&dc_id).unwrap_or(&dc_id);
let dc_idx: i16 = if is_media {
-(dc_id as i16)
} else {
dc_id as i16
};
debug!(
"[{}] handshake ok: DC{}{} proto={:?}",
label,
dc_id,
if is_media { " media" } else { "" },
proto
);
let relay_init = generate_relay_init(proto, dc_idx);
let ciphers = build_connection_ciphers(&info.prekey_and_iv, &secret, &relay_init);
let target_ip = dc_redirects.get(&dc_id).cloned();
let media_tag = if is_media { "m" } else { "" };
if target_ip.is_none() {
let reason = format!("DC{} not in --dc-ip config", dc_id);
let fallback = match dc_fallback_ips.get(&dc_id) {
Some(ip) => ip.clone(),
None => {
warn!("[{}] {} — no fallback IP available", label, reason);
return;
}
};
if !config.cf_domains.is_empty() {
if !cf_in_cooldown(dc_id, is_media) {
debug!(
"[{}] DC{}{} {} → trying CF proxy via {:?}",
label, dc_id, media_tag, reason, config.cf_domains
);
let (cf_ws_opt, _all_redirects) =
connect_cf_ws_for_dc(dc_id, &config.cf_domains, is_media, skip_tls, cf_connect_timeout)
.await;
if let Some(ws) = cf_ws_opt {
clear_cf_cooldown(dc_id, is_media);
info!(
"[{}] DC{}{} {} → CF proxy connected",
label, dc_id, media_tag, reason
);
bridge_ws(
&label, reader, writer, ws, relay_init, ciphers, proto, dc_id, is_media,
)
.await;
return;
} else {
set_cf_cooldown(dc_id, is_media, cf_fail_cooldown);
warn!(
"[{}] DC{}{} CF proxy failed, cooldown {}s",
label, dc_id, media_tag, cf_fail_cooldown.as_secs()
);
}
} else {
debug!(
"[{}] DC{}{} CF proxy in cooldown, skipping",
label, dc_id, media_tag
);
}
}
for upstream in &config.mtproto_proxies {
if upstream_in_cooldown(&upstream.host, upstream.port) {
debug!(
"[{}] upstream {}:{} in cooldown, skipping",
label, upstream.host, upstream.port
);
continue;
}
match connect_mtproto_upstream(
&upstream.host,
upstream.port,
&upstream.secret,
dc_idx,
proto,
upstream_connect_timeout,
)
.await
{
Some((rem_reader, rem_writer, up_enc, up_dec)) => {
clear_upstream_cooldown(&upstream.host, upstream.port);
info!(
"[{}] DC{}{} {} → upstream MTProto {}:{}",
label, dc_id, media_tag, reason, upstream.host, upstream.port
);
let ConnectionCiphers { clt_dec, clt_enc, .. } = ciphers;
let up_ciphers = ConnectionCiphers {
clt_dec,
clt_enc,
tg_enc: up_enc,
tg_dec: up_dec,
};
bridge_mtproto_relay(
&label, reader, writer, rem_reader, rem_writer, up_ciphers, dc_id,
is_media,
)
.await;
return;
}
None => {
set_upstream_cooldown(&upstream.host, upstream.port, upstream_fail_cooldown);
warn!(
"[{}] upstream {}:{} failed, cooldown {}s",
label,
upstream.host,
upstream.port,
upstream_fail_cooldown.as_secs()
);
}
}
}
info!("[{}] {} → TCP fallback {}:443", label, reason, fallback);
bridge_tcp(
&label,
reader,
writer,
&fallback,
&relay_init,
ciphers,
dc_id,
is_media,
tcp_fallback_timeout,
)
.await;
return;
}
let target_ip = target_ip.unwrap();
let ws_timeout = ws_timeout_for(dc_id, is_media, ws_connect_timeout, ws_fail_probe_timeout);
if config.cf_priority && !config.cf_domains.is_empty() {
if !cf_in_cooldown(dc_id, is_media) {
debug!(
"[{}] DC{}{} cf-priority → trying CF proxy first",
label, dc_id, media_tag
);
let (cf_ws_opt, _all_redirects) =
connect_cf_ws_for_dc(dc_id, &config.cf_domains, is_media, skip_tls, cf_connect_timeout)
.await;
if let Some(ws) = cf_ws_opt {
clear_cf_cooldown(dc_id, is_media);
info!(
"[{}] DC{}{} → CF proxy connected (priority)",
label, dc_id, media_tag
);
bridge_ws(
&label, reader, writer, ws, relay_init, ciphers, proto, dc_id, is_media,
)
.await;
return;
} else {
set_cf_cooldown(dc_id, is_media, cf_fail_cooldown);
warn!(
"[{}] DC{}{} CF proxy failed (priority), cooldown {}s — falling back to WS",
label, dc_id, media_tag, cf_fail_cooldown.as_secs()
);
}
} else {
debug!(
"[{}] DC{}{} CF proxy in cooldown (priority), trying WS",
label, dc_id, media_tag
);
}
}
let ws_opt = pool.get(dc_id, is_media, target_ip.clone(), skip_tls).await;
let ws = if let Some(ws) = ws_opt {
info!(
"[{}] DC{}{} → pool hit via {}",
label, dc_id, media_tag, target_ip
);
ws
} else {
let (ws_opt, all_redirects) =
connect_ws_for_dc(&target_ip, ws_dc, is_media, skip_tls, ws_timeout).await;
match ws_opt {
Some(ws) => {
clear_dc_cooldown(dc_id, is_media);
info!(
"[{}] DC{}{} → WS connected via {}",
label, dc_id, media_tag, target_ip
);
ws
}
None => {
if all_redirects {
blacklist_ws(dc_id, is_media, ws_redirect_cooldown);
warn!(
"[{}] DC{}{} WS cooldown {}s (all domains returned redirect)",
label,
dc_id,
media_tag,
ws_redirect_cooldown.as_secs()
);
} else {
set_dc_cooldown(dc_id, is_media, ws_fail_cooldown);
info!(
"[{}] DC{}{} WS cooldown {}s",
label,
dc_id,
media_tag,
ws_fail_cooldown.as_secs()
);
}
if !config.cf_priority && !config.cf_domains.is_empty() {
if !cf_in_cooldown(dc_id, is_media) {
debug!(
"[{}] DC{}{} WS failed → trying CF proxy",
label, dc_id, media_tag
);
let (cf_ws_opt, _all_redirects) = connect_cf_ws_for_dc(
dc_id,
&config.cf_domains,
is_media,
skip_tls,
cf_connect_timeout,
)
.await;
if let Some(ws) = cf_ws_opt {
clear_cf_cooldown(dc_id, is_media);
info!(
"[{}] DC{}{} → CF proxy connected",
label, dc_id, media_tag
);
bridge_ws(
&label, reader, writer, ws, relay_init, ciphers, proto, dc_id,
is_media,
)
.await;
return;
} else {
set_cf_cooldown(dc_id, is_media, cf_fail_cooldown);
warn!(
"[{}] DC{}{} CF proxy failed, cooldown {}s",
label, dc_id, media_tag, cf_fail_cooldown.as_secs()
);
}
} else {
debug!(
"[{}] DC{}{} CF proxy in cooldown, skipping",
label, dc_id, media_tag
);
}
}
for upstream in &config.mtproto_proxies {
if upstream_in_cooldown(&upstream.host, upstream.port) {
debug!(
"[{}] upstream {}:{} in cooldown, skipping",
label, upstream.host, upstream.port
);
continue;
}
match connect_mtproto_upstream(
&upstream.host,
upstream.port,
&upstream.secret,
dc_idx,
proto,
upstream_connect_timeout,
)
.await
{
Some((rem_reader, rem_writer, up_enc, up_dec)) => {
clear_upstream_cooldown(&upstream.host, upstream.port);
info!(
"[{}] DC{}{} → upstream MTProto {}:{}",
label, dc_id, media_tag, upstream.host, upstream.port
);
let ConnectionCiphers { clt_dec, clt_enc, .. } = ciphers;
let up_ciphers = ConnectionCiphers {
clt_dec,
clt_enc,
tg_enc: up_enc,
tg_dec: up_dec,
};
bridge_mtproto_relay(
&label, reader, writer, rem_reader, rem_writer, up_ciphers,
dc_id, is_media,
)
.await;
return;
}
None => {
set_upstream_cooldown(&upstream.host, upstream.port, upstream_fail_cooldown);
warn!(
"[{}] upstream {}:{} failed, cooldown {}s",
label,
upstream.host,
upstream.port,
upstream_fail_cooldown.as_secs()
);
}
}
}
let fallback = dc_fallback_ips
.get(&dc_id)
.cloned()
.unwrap_or(target_ip.clone());
info!(
"[{}] DC{}{} → TCP fallback {}:443",
label, dc_id, media_tag, fallback
);
bridge_tcp(
&label,
reader,
writer,
&fallback,
&relay_init,
ciphers,
dc_id,
is_media,
tcp_fallback_timeout,
)
.await;
return;
}
}
};
bridge_ws(
&label, reader, writer, ws, relay_init, ciphers, proto, dc_id, is_media,
)
.await;
}
async fn bridge_ws(
label: &str,
reader: tokio::io::ReadHalf<TcpStream>,
writer: tokio::io::WriteHalf<TcpStream>,
mut ws: TgWsStream,
relay_init: [u8; 64],
ciphers: crate::crypto::ConnectionCiphers,
proto: crate::crypto::ProtoTag,
dc: u32,
is_media: bool,
) {
if let Err(e) = ws_send(&mut ws, relay_init.to_vec()).await {
warn!("[{}] failed to send relay init: {}", label, e);
return;
}
let ConnectionCiphers {
mut clt_dec,
mut clt_enc,
mut tg_enc,
mut tg_dec,
} = ciphers;
let splitter = MsgSplitter::new(&relay_init, proto);
let (mut ws_sink, mut ws_source) = ws.split();
let start = std::time::Instant::now();
let mut upload = tokio::spawn({
let mut splitter = splitter;
async move {
let mut reader = reader;
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match reader.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
let chunk = &mut buf[..n];
clt_dec.apply_keystream(chunk);
tg_enc.apply_keystream(chunk);
let parts = splitter.split(chunk);
for part in parts {
if ws_sink.send(Message::Binary(part)).await.is_err() {
return total;
}
}
total += n as u64;
}
for part in splitter.flush() {
let _ = ws_sink.send(Message::Binary(part)).await;
}
let _ = ws_sink.close().await;
total
}
});
let mut download = tokio::spawn(async move {
let mut writer = writer;
let mut total = 0u64;
loop {
let data = match ws_source.next().await {
Some(Ok(Message::Binary(b))) => b,
Some(Ok(Message::Text(t))) => t.into_bytes(),
Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => continue,
_ => break,
};
let mut data = data;
tg_dec.apply_keystream(&mut data);
clt_enc.apply_keystream(&mut data);
if writer.write_all(&data).await.is_err() {
break;
}
total += data.len() as u64;
}
total
});
let (bytes_up, bytes_down) = tokio::select! {
result = &mut upload => {
let up = result.unwrap_or_else(|_| 0);
download.abort();
let down = download.await.unwrap_or_else(|_| 0);
(up, down)
}
result = &mut download => {
let down = result.unwrap_or_else(|_| 0);
upload.abort();
let up = upload.await.unwrap_or_else(|_| 0);
(up, down)
}
};
let elapsed = start.elapsed().as_secs_f32();
info!(
"[{}] DC{}{} WS session closed: ↑{} ↓{} {:.1}s",
label,
dc,
if is_media { "m" } else { "" },
human_bytes(bytes_up),
human_bytes(bytes_down),
elapsed
);
}
async fn connect_mtproto_upstream(
host: &str,
port: u16,
secret_hex: &str,
dc_idx: i16,
proto: crate::crypto::ProtoTag,
timeout: Duration,
) -> Option<(
tokio::io::ReadHalf<TcpStream>,
tokio::io::WriteHalf<TcpStream>,
AesCtr256,
AesCtr256,
)> {
let secret = match hex::decode(secret_hex) {
Ok(b) => b,
Err(e) => {
warn!(
"[upstream] {}:{} invalid hex secret: {}",
host, port, e
);
return None;
}
};
let key_bytes: &[u8] = if secret.len() == 17 && matches!(secret[0], 0xdd | 0xee) {
&secret[1..]
} else {
&secret
};
let stream = match tokio::time::timeout(
timeout,
TcpStream::connect(format!("{}:{}", host, port)),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!("[upstream] {}:{} connect error: {}", host, port, e);
return None;
}
Err(_) => {
warn!("[upstream] {}:{} connect timed out", host, port);
return None;
}
};
let _ = stream.set_nodelay(true);
let (handshake, enc, dec) = generate_client_handshake(key_bytes, dc_idx, proto);
let (reader, mut writer) = tokio::io::split(stream);
if let Err(e) = writer.write_all(&handshake).await {
warn!("[upstream] {}:{} send handshake error: {}", host, port, e);
return None;
}
Some((reader, writer, enc, dec))
}
async fn bridge_mtproto_relay(
label: &str,
reader: tokio::io::ReadHalf<TcpStream>,
writer: tokio::io::WriteHalf<TcpStream>,
rem_reader: tokio::io::ReadHalf<TcpStream>,
mut rem_writer: tokio::io::WriteHalf<TcpStream>,
ciphers: ConnectionCiphers,
dc: u32,
is_media: bool,
) {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let ConnectionCiphers {
mut clt_dec,
mut clt_enc,
mut tg_enc,
mut tg_dec,
} = ciphers;
let start = std::time::Instant::now();
let mut upload = tokio::spawn(async move {
let mut reader = reader;
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match reader.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
let chunk = &mut buf[..n];
clt_dec.apply_keystream(chunk);
tg_enc.apply_keystream(chunk);
if rem_writer.write_all(chunk).await.is_err() {
break;
}
total += n as u64;
}
total
});
let mut download = tokio::spawn(async move {
let mut rem_reader = rem_reader;
let mut writer = writer;
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match rem_reader.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
let chunk = &mut buf[..n];
tg_dec.apply_keystream(chunk);
clt_enc.apply_keystream(chunk);
if writer.write_all(chunk).await.is_err() {
break;
}
total += n as u64;
}
total
});
let (bytes_up, bytes_down) = tokio::select! {
result = &mut upload => {
let up = result.unwrap_or(0);
download.abort();
let down = download.await.unwrap_or(0);
(up, down)
}
result = &mut download => {
let down = result.unwrap_or(0);
upload.abort();
let up = upload.await.unwrap_or(0);
(up, down)
}
};
let elapsed = start.elapsed().as_secs_f32();
info!(
"[{}] DC{}{} upstream session closed: ↑{} ↓{} {:.1}s",
label,
dc,
if is_media { "m" } else { "" },
human_bytes(bytes_up),
human_bytes(bytes_down),
elapsed
);
}
async fn bridge_tcp(
label: &str,
mut reader: tokio::io::ReadHalf<TcpStream>,
mut writer: tokio::io::WriteHalf<TcpStream>,
dst: &str,
relay_init: &[u8; 64],
ciphers: crate::crypto::ConnectionCiphers,
dc: u32,
is_media: bool,
connect_timeout: Duration,
) {
let remote = match tokio::time::timeout(
connect_timeout,
TcpStream::connect(format!("{}:443", dst)),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!("[{}] TCP fallback connect failed: {}", label, e);
return;
}
Err(_) => {
warn!("[{}] TCP fallback connect timed out", label);
return;
}
};
let _ = remote.set_nodelay(true);
let (mut rem_reader, mut rem_writer) = tokio::io::split(remote);
if let Err(e) = rem_writer.write_all(relay_init).await {
warn!("[{}] TCP fallback: send relay init failed: {}", label, e);
return;
}
let crate::crypto::ConnectionCiphers {
mut clt_dec,
mut clt_enc,
mut tg_enc,
mut tg_dec,
} = ciphers;
let start = std::time::Instant::now();
let mut upload = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match reader.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
let chunk = &mut buf[..n];
clt_dec.apply_keystream(chunk);
tg_enc.apply_keystream(chunk);
if rem_writer.write_all(chunk).await.is_err() {
break;
}
total += n as u64;
}
total
});
let mut download = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match rem_reader.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
let chunk = &mut buf[..n];
tg_dec.apply_keystream(chunk);
clt_enc.apply_keystream(chunk);
if writer.write_all(chunk).await.is_err() {
break;
}
total += n as u64;
}
total
});
let (bytes_up, bytes_down) = tokio::select! {
result = &mut upload => {
let up = result.unwrap_or_else(|_| 0);
download.abort();
let down = download.await.unwrap_or_else(|_| 0);
(up, down)
}
result = &mut download => {
let down = result.unwrap_or_else(|_| 0);
upload.abort();
let up = upload.await.unwrap_or_else(|_| 0);
(up, down)
}
};
let elapsed = start.elapsed().as_secs_f32();
info!(
"[{}] DC{}{} TCP session closed: ↑{} ↓{} {:.1}s",
label,
dc,
if is_media { "m" } else { "" },
human_bytes(bytes_up),
human_bytes(bytes_down),
elapsed
);
}
fn human_bytes(n: u64) -> String {
const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
let mut v = n as f64;
for unit in UNITS {
if v < 1024.0 {
return format!("{:.1}{}", v, unit);
}
v /= 1024.0;
}
format!("{:.1}PB", v)
}