use super::*;
use std::sync::atomic::Ordering;
pub(crate) enum ConnCmd {
SigRecv(tx5_signal::SignalMessage),
WebrtcRecv(webrtc::WebrtcEvt),
SendMessage(Vec<u8>),
WebrtcTimeoutCheck,
WebrtcClosed,
}
pub struct ConnRecv(CloseRecv<Vec<u8>>);
impl ConnRecv {
pub async fn recv(&mut self) -> Option<Vec<u8>> {
self.0.recv().await
}
}
pub struct Conn {
ready: Arc<tokio::sync::Semaphore>,
pub_key: PubKey,
cmd_send: CloseSend<ConnCmd>,
conn_task: tokio::task::JoinHandle<()>,
keepalive_task: tokio::task::JoinHandle<()>,
is_webrtc: Arc<std::sync::atomic::AtomicBool>,
send_msg_count: Arc<std::sync::atomic::AtomicU64>,
send_byte_count: Arc<std::sync::atomic::AtomicU64>,
recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
}
macro_rules! netaudit {
($lvl:ident, $($all:tt)*) => {
::tracing::event!(
target: "NETAUDIT",
::tracing::Level::$lvl,
m = "tx5-connection",
$($all)*
);
};
}
impl Drop for Conn {
fn drop(&mut self) {
netaudit!(DEBUG, pub_key = ?self.pub_key, a = "drop");
self.conn_task.abort();
self.keepalive_task.abort();
let hub_cmd_send = self.hub_cmd_send.clone();
let pub_key = self.pub_key.clone();
tokio::task::spawn(async move {
let _ = hub_cmd_send.send(HubCmd::Disconnect(pub_key)).await;
});
}
}
impl Conn {
#[cfg(test)]
pub(crate) fn test_kill_keepalive_task(&self) {
self.keepalive_task.abort();
}
pub(crate) fn priv_new(
webrtc_config: WebRtcConfig,
is_polite: bool,
pub_key: PubKey,
client: Weak<tx5_signal::SignalConnection>,
config: Arc<HubConfig>,
hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
) -> (Arc<Self>, ConnRecv, CloseSend<ConnCmd>) {
netaudit!(DEBUG, ?webrtc_config, ?pub_key, ?is_polite, a = "open",);
let is_webrtc = Arc::new(std::sync::atomic::AtomicBool::new(false));
let send_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
let send_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
let recv_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
let recv_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
let ready = Arc::new(tokio::sync::Semaphore::new(0));
let (mut msg_send, msg_recv) = CloseSend::sized_channel(1024);
let (cmd_send, cmd_recv) = CloseSend::sized_channel(1024);
let keepalive_dur = config.signal_config.max_idle / 2;
let client2 = client.clone();
let pub_key2 = pub_key.clone();
let keepalive_task = tokio::task::spawn(async move {
loop {
tokio::time::sleep(keepalive_dur).await;
if let Some(client) = client2.upgrade() {
if client.send_keepalive(&pub_key2).await.is_err() {
break;
}
} else {
break;
}
}
});
msg_send.set_close_on_drop(true);
let con_task_fut = con_task(
is_polite,
webrtc_config,
TaskCore {
client,
config,
pub_key: pub_key.clone(),
cmd_send: cmd_send.clone(),
cmd_recv,
send_msg_count: send_msg_count.clone(),
send_byte_count: send_byte_count.clone(),
recv_msg_count: recv_msg_count.clone(),
recv_byte_count: recv_byte_count.clone(),
msg_send,
ready: ready.clone(),
is_webrtc: is_webrtc.clone(),
},
);
let conn_task = tokio::task::spawn(con_task_fut);
let mut cmd_send2 = cmd_send.clone();
cmd_send2.set_close_on_drop(true);
let this = Self {
ready,
pub_key,
cmd_send: cmd_send2,
conn_task,
keepalive_task,
is_webrtc,
send_msg_count,
send_byte_count,
recv_msg_count,
recv_byte_count,
hub_cmd_send,
};
(Arc::new(this), ConnRecv(msg_recv), cmd_send)
}
pub async fn ready(&self) {
let _ = self.ready.acquire().await;
}
pub fn is_using_webrtc(&self) -> bool {
self.is_webrtc.load(Ordering::SeqCst)
}
pub fn pub_key(&self) -> &PubKey {
&self.pub_key
}
pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
self.cmd_send.send(ConnCmd::SendMessage(msg)).await
}
pub fn get_stats(&self) -> ConnStats {
ConnStats {
send_msg_count: self.send_msg_count.load(Ordering::Relaxed),
send_byte_count: self.send_byte_count.load(Ordering::Relaxed),
recv_msg_count: self.recv_msg_count.load(Ordering::Relaxed),
recv_byte_count: self.recv_byte_count.load(Ordering::Relaxed),
}
}
}
#[derive(Default)]
pub struct ConnStats {
pub send_msg_count: u64,
pub send_byte_count: u64,
pub recv_msg_count: u64,
pub recv_byte_count: u64,
}
struct TaskCore {
config: Arc<HubConfig>,
client: Weak<tx5_signal::SignalConnection>,
pub_key: PubKey,
cmd_send: CloseSend<ConnCmd>,
cmd_recv: CloseRecv<ConnCmd>,
msg_send: CloseSend<Vec<u8>>,
ready: Arc<tokio::sync::Semaphore>,
is_webrtc: Arc<std::sync::atomic::AtomicBool>,
send_msg_count: Arc<std::sync::atomic::AtomicU64>,
send_byte_count: Arc<std::sync::atomic::AtomicU64>,
recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
}
impl TaskCore {
async fn handle_recv_msg(
&self,
msg: Vec<u8>,
) -> std::result::Result<(), ()> {
self.recv_msg_count.fetch_add(1, Ordering::Relaxed);
self.recv_byte_count
.fetch_add(msg.len() as u64, Ordering::Relaxed);
if self.msg_send.send(msg).await.is_err() {
netaudit!(
DEBUG,
pub_key = ?self.pub_key,
a = "close: msg_send closed",
);
Err(())
} else {
Ok(())
}
}
fn track_send_msg(&self, len: usize) {
self.send_msg_count.fetch_add(1, Ordering::Relaxed);
self.send_byte_count
.fetch_add(len as u64, Ordering::Relaxed);
}
}
async fn con_task(
is_polite: bool,
webrtc_config: WebRtcConfig,
mut task_core: TaskCore,
) {
if let Some(client) = task_core.client.upgrade() {
let handshake_fut = async {
let nonce = client.send_handshake_req(&task_core.pub_key).await?;
let mut got_peer_res = false;
let mut sent_our_res = false;
while let Some(cmd) = task_core.cmd_recv.recv().await {
match cmd {
ConnCmd::SigRecv(sig) => {
use tx5_signal::SignalMessage::*;
match sig {
HandshakeReq(oth_nonce) => {
client
.send_handshake_res(
&task_core.pub_key,
oth_nonce,
)
.await?;
sent_our_res = true;
}
HandshakeRes(res_nonce) => {
if res_nonce != nonce {
return Err(Error::other("nonce mismatch"));
}
got_peer_res = true;
}
_ => (),
}
}
ConnCmd::SendMessage(_) => {
return Err(Error::other("send before ready"));
}
ConnCmd::WebrtcTimeoutCheck
| ConnCmd::WebrtcRecv(_)
| ConnCmd::WebrtcClosed => {
unreachable!()
}
}
if got_peer_res && sent_our_res {
break;
}
}
Result::Ok(())
};
match tokio::time::timeout(
task_core.config.signal_config.max_idle,
handshake_fut,
)
.await
{
Err(_) | Ok(Err(_)) => {
client.close_peer(&task_core.pub_key).await;
return;
}
Ok(Ok(_)) => (),
}
} else {
return;
}
let task_core = match con_task_attempt_webrtc(
is_polite,
webrtc_config,
task_core,
)
.await
{
AttemptWebrtcResult::Abort => return,
AttemptWebrtcResult::Fallback(task_core) => {
if task_core.config.danger_deny_signal_relay {
netaudit!(
INFO,
pub_key = ?task_core.pub_key,
a = "webrtc fallback: denied signal relay",
);
return;
}
task_core
}
};
task_core.is_webrtc.store(false, Ordering::SeqCst);
con_task_fallback_use_signal(task_core).await;
}
async fn recv_cmd(task_core: &mut TaskCore) -> Option<ConnCmd> {
match tokio::time::timeout(
task_core.config.signal_config.max_idle,
task_core.cmd_recv.recv(),
)
.await
{
Err(_) => {
netaudit!(
DEBUG,
pub_key = ?task_core.pub_key,
a = "close: connection idle",
);
None
}
Ok(None) => {
netaudit!(
DEBUG,
pub_key = ?task_core.pub_key,
a = "close: cmd_recv stream complete",
);
None
}
Ok(Some(cmd)) => Some(cmd),
}
}
async fn webrtc_task(
mut webrtc_recv: CloseRecv<webrtc::WebrtcEvt>,
cmd_send: CloseSend<ConnCmd>,
) {
while let Some(evt) = webrtc_recv.recv().await {
if cmd_send.send(ConnCmd::WebrtcRecv(evt)).await.is_err() {
break;
}
}
netaudit!(DEBUG, a = "webrtc task closed, sending WebrtcClosed",);
let _ = cmd_send.send(ConnCmd::WebrtcClosed).await;
}
enum AttemptWebrtcResult {
Abort,
Fallback(TaskCore),
}
async fn con_task_attempt_webrtc(
is_polite: bool,
webrtc_config: WebRtcConfig,
mut task_core: TaskCore,
) -> AttemptWebrtcResult {
use AttemptWebrtcResult::*;
let timeout_dur = task_core.config.webrtc_connect_timeout;
let timeout_cmd_send = task_core.cmd_send.clone();
tokio::task::spawn(async move {
tokio::time::sleep(timeout_dur).await;
let _ = timeout_cmd_send.send(ConnCmd::WebrtcTimeoutCheck).await;
});
let (webrtc, webrtc_recv) = webrtc::new_backend_module(
task_core.config.backend_module,
is_polite,
webrtc_config,
4096,
);
struct AbortWebrtc(tokio::task::AbortHandle);
impl Drop for AbortWebrtc {
fn drop(&mut self) {
self.0.abort();
}
}
let _abort_webrtc = AbortWebrtc(
tokio::task::spawn(webrtc_task(
webrtc_recv,
task_core.cmd_send.clone(),
))
.abort_handle(),
);
let mut is_ready = false;
if task_core.config.danger_force_signal_relay {
netaudit!(
WARN,
pub_key = ?task_core.pub_key,
a = "webrtc fallback: test",
);
return Fallback(task_core);
}
while let Some(cmd) = recv_cmd(&mut task_core).await {
use tx5_signal::SignalMessage::*;
use webrtc::WebrtcEvt::*;
use ConnCmd::*;
match cmd {
SigRecv(HandshakeReq(_)) | SigRecv(HandshakeRes(_)) => {
netaudit!(
DEBUG,
pub_key = ?task_core.pub_key,
a = "close: unexpected handshake msg",
);
break;
}
SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
if task_core.handle_recv_msg(msg).await.is_err() {
break;
}
netaudit!(
WARN,
pub_key = ?task_core.pub_key,
a = "webrtc fallback: remote sent us an sbd message",
);
return Fallback(task_core);
}
SigRecv(Offer(offer)) => {
netaudit!(
TRACE,
pub_key = ?task_core.pub_key,
offer = String::from_utf8_lossy(&offer).to_string(),
a = "recv_offer",
);
if let Err(err) = webrtc.in_offer(offer).await {
netaudit!(
WARN,
pub_key = ?task_core.pub_key,
?err,
a = "webrtc fallback: failed to parse received offer",
);
return Fallback(task_core);
}
}
SigRecv(Answer(answer)) => {
netaudit!(
TRACE,
pub_key = ?task_core.pub_key,
offer = String::from_utf8_lossy(&answer).to_string(),
a = "recv_answer",
);
if let Err(err) = webrtc.in_answer(answer).await {
netaudit!(
WARN,
pub_key = ?task_core.pub_key,
?err,
a = "webrtc fallback: failed to parse received answer",
);
return Fallback(task_core);
}
}
SigRecv(Ice(ice)) => {
netaudit!(
TRACE,
pub_key = ?task_core.pub_key,
offer = String::from_utf8_lossy(&ice).to_string(),
a = "recv_ice",
);
if let Err(err) = webrtc.in_ice(ice).await {
netaudit!(
DEBUG,
pub_key = ?task_core.pub_key,
?err,
a = "ignoring webrtc in_ice error",
);
}
}
SigRecv(Keepalive) | SigRecv(Unknown) => {
}
WebrtcRecv(GeneratedOffer(offer)) => {
netaudit!(
TRACE,
pub_key = ?task_core.pub_key,
offer = String::from_utf8_lossy(&offer).to_string(),
a = "send_offer",
);
if let Some(client) = task_core.client.upgrade() {
if let Err(err) =
client.send_offer(&task_core.pub_key, offer).await
{
netaudit!(
DEBUG,
pub_key = ?task_core.pub_key,
?err,
a = "webrtc send_offer error",
);
break;
}
} else {
break;
}
}
WebrtcRecv(GeneratedAnswer(answer)) => {
netaudit!(
TRACE,
pub_key = ?task_core.pub_key,
offer = String::from_utf8_lossy(&answer).to_string(),
a = "send_answer",
);
if let Some(client) = task_core.client.upgrade() {
if let Err(err) =
client.send_answer(&task_core.pub_key, answer).await
{
netaudit!(
DEBUG,
pub_key = ?task_core.pub_key,
?err,
a = "webrtc send_answer error",
);
break;
}
} else {
break;
}
}
WebrtcRecv(GeneratedIce(ice)) => {
netaudit!(
TRACE,
pub_key = ?task_core.pub_key,
offer = String::from_utf8_lossy(&ice).to_string(),
a = "send_ice",
);
if let Some(client) = task_core.client.upgrade() {
if let Err(err) =
client.send_ice(&task_core.pub_key, ice).await
{
netaudit!(
DEBUG,
pub_key = ?task_core.pub_key,
?err,
a = "webrtc send_ice error",
);
break;
}
} else {
break;
}
}
WebrtcRecv(webrtc::WebrtcEvt::Message(msg)) => {
if task_core.handle_recv_msg(msg).await.is_err() {
break;
}
}
WebrtcRecv(Ready) => {
is_ready = true;
task_core.is_webrtc.store(true, Ordering::SeqCst);
task_core.ready.close();
}
SendMessage(msg) => {
let len = msg.len();
netaudit!(
TRACE,
pub_key = ?task_core.pub_key,
byte_len = len,
a = "queue msg for backend send",
);
if let Err(err) = webrtc.message(msg).await {
netaudit!(
WARN,
pub_key = ?task_core.pub_key,
?err,
a = "webrtc fallback: failed to send message",
);
return Fallback(task_core);
}
task_core.track_send_msg(len);
}
WebrtcTimeoutCheck => {
if !is_ready {
netaudit!(
WARN,
pub_key = ?task_core.pub_key,
a = "webrtc fallback: failed to ready within timeout",
);
return Fallback(task_core);
}
}
WebrtcClosed => {
netaudit!(
WARN,
pub_key = ?task_core.pub_key,
a = "webrtc processing task closed",
);
break;
}
}
}
Abort
}
async fn con_task_fallback_use_signal(mut task_core: TaskCore) {
task_core.ready.close();
while let Some(cmd) = recv_cmd(&mut task_core).await {
match cmd {
ConnCmd::SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
if task_core.handle_recv_msg(msg).await.is_err() {
break;
}
}
ConnCmd::SendMessage(msg) => match task_core.client.upgrade() {
Some(client) => {
let len = msg.len();
if let Err(err) =
client.send_message(&task_core.pub_key, msg).await
{
netaudit!(
DEBUG,
pub_key = ?task_core.pub_key,
?err,
a = "close: sbd client send error",
);
break;
}
task_core.track_send_msg(len);
}
None => {
netaudit!(
DEBUG,
pub_key = ?task_core.pub_key,
a = "close: sbd client closed",
);
break;
}
},
_ => (),
}
}
}