1use super::*;
2use std::sync::atomic::Ordering;
3
4pub(crate) enum ConnCmd {
5 SigRecv(tx5_signal::SignalMessage),
6 WebrtcRecv(webrtc::WebrtcEvt),
7 SendMessage(Vec<u8>),
8 WebrtcTimeoutCheck,
9 WebrtcClosed,
10}
11
12pub struct ConnRecv(CloseRecv<Vec<u8>>);
14
15impl ConnRecv {
16 pub async fn recv(&mut self) -> Option<Vec<u8>> {
18 self.0.recv().await
19 }
20}
21
22pub struct Conn {
24 ready: Arc<tokio::sync::Semaphore>,
25 pub_key: PubKey,
26 cmd_send: CloseSend<ConnCmd>,
27 conn_task: tokio::task::JoinHandle<()>,
28 keepalive_task: tokio::task::JoinHandle<()>,
29 is_webrtc: Arc<std::sync::atomic::AtomicBool>,
30 send_msg_count: Arc<std::sync::atomic::AtomicU64>,
31 send_byte_count: Arc<std::sync::atomic::AtomicU64>,
32 recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
33 recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
34 hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
35}
36
37macro_rules! netaudit {
38 ($lvl:ident, $($all:tt)*) => {
39 ::tracing::event!(
40 target: "NETAUDIT",
41 ::tracing::Level::$lvl,
42 m = "tx5-connection",
43 $($all)*
44 );
45 };
46}
47
48impl Drop for Conn {
49 fn drop(&mut self) {
50 netaudit!(DEBUG, pub_key = ?self.pub_key, a = "drop");
51
52 self.conn_task.abort();
53 self.keepalive_task.abort();
54
55 let hub_cmd_send = self.hub_cmd_send.clone();
56 let pub_key = self.pub_key.clone();
57 tokio::task::spawn(async move {
58 let _ = hub_cmd_send.send(HubCmd::Disconnect(pub_key)).await;
59 });
60 }
61}
62
63impl Conn {
64 #[cfg(test)]
65 pub(crate) fn test_kill_keepalive_task(&self) {
66 self.keepalive_task.abort();
67 }
68
69 pub(crate) fn priv_new(
70 webrtc_config: Vec<u8>,
71 is_polite: bool,
72 pub_key: PubKey,
73 client: Weak<tx5_signal::SignalConnection>,
74 config: Arc<HubConfig>,
75 hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
76 ) -> (Arc<Self>, ConnRecv, CloseSend<ConnCmd>) {
77 netaudit!(
78 DEBUG,
79 webrtc_config = String::from_utf8_lossy(&webrtc_config).to_string(),
80 ?pub_key,
81 ?is_polite,
82 a = "open",
83 );
84
85 let is_webrtc = Arc::new(std::sync::atomic::AtomicBool::new(false));
86 let send_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
87 let send_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
88 let recv_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
89 let recv_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
90
91 let ready = Arc::new(tokio::sync::Semaphore::new(0));
93
94 let (mut msg_send, msg_recv) = CloseSend::sized_channel(1024);
95 let (cmd_send, cmd_recv) = CloseSend::sized_channel(1024);
96
97 let keepalive_dur = config.signal_config.max_idle / 2;
98 let client2 = client.clone();
99 let pub_key2 = pub_key.clone();
100 let keepalive_task = tokio::task::spawn(async move {
101 loop {
102 tokio::time::sleep(keepalive_dur).await;
103
104 if let Some(client) = client2.upgrade() {
105 if client.send_keepalive(&pub_key2).await.is_err() {
106 break;
107 }
108 } else {
109 break;
110 }
111 }
112 });
113
114 msg_send.set_close_on_drop(true);
115
116 let con_task_fut = con_task(
117 is_polite,
118 webrtc_config,
119 TaskCore {
120 client,
121 config,
122 pub_key: pub_key.clone(),
123 cmd_send: cmd_send.clone(),
124 cmd_recv,
125 send_msg_count: send_msg_count.clone(),
126 send_byte_count: send_byte_count.clone(),
127 recv_msg_count: recv_msg_count.clone(),
128 recv_byte_count: recv_byte_count.clone(),
129 msg_send,
130 ready: ready.clone(),
131 is_webrtc: is_webrtc.clone(),
132 },
133 );
134 let conn_task = tokio::task::spawn(con_task_fut);
135
136 let mut cmd_send2 = cmd_send.clone();
137 cmd_send2.set_close_on_drop(true);
138 let this = Self {
139 ready,
140 pub_key,
141 cmd_send: cmd_send2,
142 conn_task,
143 keepalive_task,
144 is_webrtc,
145 send_msg_count,
146 send_byte_count,
147 recv_msg_count,
148 recv_byte_count,
149 hub_cmd_send,
150 };
151
152 (Arc::new(this), ConnRecv(msg_recv), cmd_send)
153 }
154
155 pub async fn ready(&self) {
157 let _ = self.ready.acquire().await;
159 }
160
161 pub fn is_using_webrtc(&self) -> bool {
163 self.is_webrtc.load(Ordering::SeqCst)
164 }
165
166 pub fn pub_key(&self) -> &PubKey {
168 &self.pub_key
169 }
170
171 pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
173 self.cmd_send.send(ConnCmd::SendMessage(msg)).await
174 }
175
176 pub fn get_stats(&self) -> ConnStats {
178 ConnStats {
179 send_msg_count: self.send_msg_count.load(Ordering::Relaxed),
180 send_byte_count: self.send_byte_count.load(Ordering::Relaxed),
181 recv_msg_count: self.recv_msg_count.load(Ordering::Relaxed),
182 recv_byte_count: self.recv_byte_count.load(Ordering::Relaxed),
183 }
184 }
185}
186
187#[derive(Default)]
189pub struct ConnStats {
190 pub send_msg_count: u64,
192
193 pub send_byte_count: u64,
195
196 pub recv_msg_count: u64,
198
199 pub recv_byte_count: u64,
201}
202
203struct TaskCore {
204 config: Arc<HubConfig>,
205 client: Weak<tx5_signal::SignalConnection>,
206 pub_key: PubKey,
207 cmd_send: CloseSend<ConnCmd>,
208 cmd_recv: CloseRecv<ConnCmd>,
209 msg_send: CloseSend<Vec<u8>>,
210 ready: Arc<tokio::sync::Semaphore>,
211 is_webrtc: Arc<std::sync::atomic::AtomicBool>,
212 send_msg_count: Arc<std::sync::atomic::AtomicU64>,
213 send_byte_count: Arc<std::sync::atomic::AtomicU64>,
214 recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
215 recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
216}
217
218impl TaskCore {
219 async fn handle_recv_msg(
220 &self,
221 msg: Vec<u8>,
222 ) -> std::result::Result<(), ()> {
223 self.recv_msg_count.fetch_add(1, Ordering::Relaxed);
224 self.recv_byte_count
225 .fetch_add(msg.len() as u64, Ordering::Relaxed);
226 if self.msg_send.send(msg).await.is_err() {
227 netaudit!(
228 DEBUG,
229 pub_key = ?self.pub_key,
230 a = "close: msg_send closed",
231 );
232 Err(())
233 } else {
234 Ok(())
235 }
236 }
237
238 fn track_send_msg(&self, len: usize) {
239 self.send_msg_count.fetch_add(1, Ordering::Relaxed);
240 self.send_byte_count
241 .fetch_add(len as u64, Ordering::Relaxed);
242 }
243}
244
245async fn con_task(
246 is_polite: bool,
247 webrtc_config: Vec<u8>,
248 mut task_core: TaskCore,
249) {
250 if let Some(client) = task_core.client.upgrade() {
252 let handshake_fut = async {
253 let nonce = client.send_handshake_req(&task_core.pub_key).await?;
254
255 let mut got_peer_res = false;
256 let mut sent_our_res = false;
257
258 while let Some(cmd) = task_core.cmd_recv.recv().await {
259 match cmd {
260 ConnCmd::SigRecv(sig) => {
261 use tx5_signal::SignalMessage::*;
262 match sig {
263 HandshakeReq(oth_nonce) => {
264 client
265 .send_handshake_res(
266 &task_core.pub_key,
267 oth_nonce,
268 )
269 .await?;
270 sent_our_res = true;
271 }
272 HandshakeRes(res_nonce) => {
273 if res_nonce != nonce {
274 return Err(Error::other("nonce mismatch"));
275 }
276 got_peer_res = true;
277 }
278 _ => (),
281 }
282 }
283 ConnCmd::SendMessage(_) => {
284 return Err(Error::other("send before ready"));
285 }
286 ConnCmd::WebrtcTimeoutCheck
287 | ConnCmd::WebrtcRecv(_)
288 | ConnCmd::WebrtcClosed => {
289 unreachable!()
292 }
293 }
294 if got_peer_res && sent_our_res {
295 break;
296 }
297 }
298
299 Result::Ok(())
300 };
301
302 match tokio::time::timeout(
303 task_core.config.signal_config.max_idle,
304 handshake_fut,
305 )
306 .await
307 {
308 Err(_) | Ok(Err(_)) => {
309 client.close_peer(&task_core.pub_key).await;
310 return;
311 }
312 Ok(Ok(_)) => (),
313 }
314 } else {
315 return;
316 }
317
318 let task_core = match con_task_attempt_webrtc(
320 is_polite,
321 webrtc_config,
322 task_core,
323 )
324 .await
325 {
326 AttemptWebrtcResult::Abort => return,
327 AttemptWebrtcResult::Fallback(task_core) => task_core,
328 };
329
330 task_core.is_webrtc.store(false, Ordering::SeqCst);
331
332 con_task_fallback_use_signal(task_core).await;
335}
336
337async fn recv_cmd(task_core: &mut TaskCore) -> Option<ConnCmd> {
338 match tokio::time::timeout(
339 task_core.config.signal_config.max_idle,
340 task_core.cmd_recv.recv(),
341 )
342 .await
343 {
344 Err(_) => {
345 netaudit!(
346 DEBUG,
347 pub_key = ?task_core.pub_key,
348 a = "close: connection idle",
349 );
350 None
351 }
352 Ok(None) => {
353 netaudit!(
354 DEBUG,
355 pub_key = ?task_core.pub_key,
356 a = "close: cmd_recv stream complete",
357 );
358 None
359 }
360 Ok(Some(cmd)) => Some(cmd),
361 }
362}
363
364async fn webrtc_task(
365 mut webrtc_recv: CloseRecv<webrtc::WebrtcEvt>,
366 cmd_send: CloseSend<ConnCmd>,
367) {
368 while let Some(evt) = webrtc_recv.recv().await {
369 if cmd_send.send(ConnCmd::WebrtcRecv(evt)).await.is_err() {
370 break;
371 }
372 }
373 let _ = cmd_send.send(ConnCmd::WebrtcClosed).await;
374}
375
376enum AttemptWebrtcResult {
377 Abort,
378 Fallback(TaskCore),
379}
380
381async fn con_task_attempt_webrtc(
382 is_polite: bool,
383 webrtc_config: Vec<u8>,
384 mut task_core: TaskCore,
385) -> AttemptWebrtcResult {
386 use AttemptWebrtcResult::*;
387
388 let timeout_dur = task_core.config.signal_config.max_idle;
389 let timeout_cmd_send = task_core.cmd_send.clone();
390 tokio::task::spawn(async move {
391 tokio::time::sleep(timeout_dur).await;
392 let _ = timeout_cmd_send.send(ConnCmd::WebrtcTimeoutCheck).await;
393 });
394
395 let (webrtc, webrtc_recv) = webrtc::new_backend_module(
396 task_core.config.backend_module,
397 is_polite,
398 webrtc_config.clone(),
399 4096,
401 );
402
403 struct AbortWebrtc(tokio::task::AbortHandle);
404
405 impl Drop for AbortWebrtc {
406 fn drop(&mut self) {
407 self.0.abort();
408 }
409 }
410
411 let _abort_webrtc = AbortWebrtc(
413 tokio::task::spawn(webrtc_task(
414 webrtc_recv,
415 task_core.cmd_send.clone(),
416 ))
417 .abort_handle(),
418 );
419
420 let mut is_ready = false;
421
422 #[cfg(test)]
423 if task_core.config.test_fail_webrtc {
424 netaudit!(
425 WARN,
426 pub_key = ?task_core.pub_key,
427 a = "webrtc fallback: test",
428 );
429 return Fallback(task_core);
430 }
431
432 while let Some(cmd) = recv_cmd(&mut task_core).await {
433 use tx5_signal::SignalMessage::*;
434 use webrtc::WebrtcEvt::*;
435 use ConnCmd::*;
436 match cmd {
437 SigRecv(HandshakeReq(_)) | SigRecv(HandshakeRes(_)) => {
438 netaudit!(
439 DEBUG,
440 pub_key = ?task_core.pub_key,
441 a = "close: unexpected handshake msg",
442 );
443 return Abort;
444 }
445 SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
446 if task_core.handle_recv_msg(msg).await.is_err() {
447 return Abort;
448 }
449 netaudit!(
450 WARN,
451 pub_key = ?task_core.pub_key,
452 a = "webrtc fallback: remote sent us an sbd message",
453 );
454 return Fallback(task_core);
458 }
459 SigRecv(Offer(offer)) => {
460 netaudit!(
461 TRACE,
462 pub_key = ?task_core.pub_key,
463 offer = String::from_utf8_lossy(&offer).to_string(),
464 a = "recv_offer",
465 );
466 if let Err(err) = webrtc.in_offer(offer).await {
467 netaudit!(
468 WARN,
469 pub_key = ?task_core.pub_key,
470 ?err,
471 a = "webrtc fallback: failed to parse received offer",
472 );
473 return Fallback(task_core);
474 }
475 }
476 SigRecv(Answer(answer)) => {
477 netaudit!(
478 TRACE,
479 pub_key = ?task_core.pub_key,
480 offer = String::from_utf8_lossy(&answer).to_string(),
481 a = "recv_answer",
482 );
483 if let Err(err) = webrtc.in_answer(answer).await {
484 netaudit!(
485 WARN,
486 pub_key = ?task_core.pub_key,
487 ?err,
488 a = "webrtc fallback: failed to parse received answer",
489 );
490 return Fallback(task_core);
491 }
492 }
493 SigRecv(Ice(ice)) => {
494 netaudit!(
495 TRACE,
496 pub_key = ?task_core.pub_key,
497 offer = String::from_utf8_lossy(&ice).to_string(),
498 a = "recv_ice",
499 );
500 if let Err(err) = webrtc.in_ice(ice).await {
501 netaudit!(
502 DEBUG,
503 pub_key = ?task_core.pub_key,
504 ?err,
505 a = "ignoring webrtc in_ice error",
506 );
507 }
509 }
510 SigRecv(Keepalive) | SigRecv(Unknown) => {
511 }
513 WebrtcRecv(GeneratedOffer(offer)) => {
514 netaudit!(
515 TRACE,
516 pub_key = ?task_core.pub_key,
517 offer = String::from_utf8_lossy(&offer).to_string(),
518 a = "send_offer",
519 );
520 if let Some(client) = task_core.client.upgrade() {
521 if let Err(err) =
522 client.send_offer(&task_core.pub_key, offer).await
523 {
524 netaudit!(
525 DEBUG,
526 pub_key = ?task_core.pub_key,
527 ?err,
528 a = "webrtc send_offer error",
529 );
530 return Abort;
531 }
532 } else {
533 return Abort;
534 }
535 }
536 WebrtcRecv(GeneratedAnswer(answer)) => {
537 netaudit!(
538 TRACE,
539 pub_key = ?task_core.pub_key,
540 offer = String::from_utf8_lossy(&answer).to_string(),
541 a = "send_answer",
542 );
543 if let Some(client) = task_core.client.upgrade() {
544 if let Err(err) =
545 client.send_answer(&task_core.pub_key, answer).await
546 {
547 netaudit!(
548 DEBUG,
549 pub_key = ?task_core.pub_key,
550 ?err,
551 a = "webrtc send_answer error",
552 );
553 return Abort;
554 }
555 } else {
556 return Abort;
557 }
558 }
559 WebrtcRecv(GeneratedIce(ice)) => {
560 netaudit!(
561 TRACE,
562 pub_key = ?task_core.pub_key,
563 offer = String::from_utf8_lossy(&ice).to_string(),
564 a = "send_ice",
565 );
566 if let Some(client) = task_core.client.upgrade() {
567 if let Err(err) =
568 client.send_ice(&task_core.pub_key, ice).await
569 {
570 netaudit!(
571 DEBUG,
572 pub_key = ?task_core.pub_key,
573 ?err,
574 a = "webrtc send_ice error",
575 );
576 return Abort;
577 }
578 } else {
579 return Abort;
580 }
581 }
582 WebrtcRecv(webrtc::WebrtcEvt::Message(msg)) => {
583 if task_core.handle_recv_msg(msg).await.is_err() {
584 return Abort;
585 }
586 }
587 WebrtcRecv(Ready) => {
588 is_ready = true;
589 task_core.is_webrtc.store(true, Ordering::SeqCst);
590 task_core.ready.close();
591 }
592 SendMessage(msg) => {
593 let len = msg.len();
594
595 netaudit!(
596 TRACE,
597 pub_key = ?task_core.pub_key,
598 byte_len = len,
599 a = "queue msg for backend send",
600 );
601 if let Err(err) = webrtc.message(msg).await {
602 netaudit!(
603 WARN,
604 pub_key = ?task_core.pub_key,
605 ?err,
606 a = "webrtc fallback: failed to send message",
607 );
608 return Fallback(task_core);
609 }
610
611 task_core.track_send_msg(len);
612 }
613 WebrtcTimeoutCheck => {
614 if !is_ready {
615 netaudit!(
616 WARN,
617 pub_key = ?task_core.pub_key,
618 a = "webrtc fallback: failed to ready within timeout",
619 );
620 return Fallback(task_core);
621 }
622 }
623 WebrtcClosed => {
624 netaudit!(
625 WARN,
626 pub_key = ?task_core.pub_key,
627 a = "webrtc fallback: webrtc processing task closed",
628 );
629 return Fallback(task_core);
630 }
631 }
632 }
633
634 Abort
635}
636
637async fn con_task_fallback_use_signal(mut task_core: TaskCore) {
638 task_core.ready.close();
640
641 while let Some(cmd) = recv_cmd(&mut task_core).await {
642 match cmd {
643 ConnCmd::SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
644 if task_core.handle_recv_msg(msg).await.is_err() {
645 break;
646 }
647 }
648 ConnCmd::SendMessage(msg) => match task_core.client.upgrade() {
649 Some(client) => {
650 let len = msg.len();
651 if let Err(err) =
652 client.send_message(&task_core.pub_key, msg).await
653 {
654 netaudit!(
655 DEBUG,
656 pub_key = ?task_core.pub_key,
657 ?err,
658 a = "close: sbd client send error",
659 );
660 break;
661 }
662 task_core.track_send_msg(len);
663 }
664 None => {
665 netaudit!(
666 DEBUG,
667 pub_key = ?task_core.pub_key,
668 a = "close: sbd client closed",
669 );
670 break;
671 }
672 },
673 _ => (),
674 }
675 }
676}