use crate::{AtomPtr, IoPair, LinkType, LockedStream, Packet, PacketBuilder};
use async_std::{future::timeout, io::prelude::WriteExt, net::TcpStream, sync::Arc, task};
use std::sync::atomic::{AtomicBool, Ordering};
use std::{net::SocketAddr, time::Duration};
mod id {
use std::sync::atomic::{AtomicUsize, Ordering};
static ID_CTR: AtomicUsize = AtomicUsize::new(0);
pub fn next() -> usize {
ID_CTR.fetch_add(1, Ordering::Relaxed)
}
}
pub(crate) type SourceAddr = SocketAddr;
pub(crate) type DstAddr = SocketAddr;
#[derive(Debug)]
pub(crate) enum PeerState {
RxOnly,
TxOnly,
Duplex,
#[allow(unused)]
Invalid,
}
#[derive(Clone, Debug, Default)]
pub(crate) struct Peer {
pub(crate) id: usize,
src: AtomPtr<Option<SourceAddr>>,
dst: Option<DstAddr>,
sender: AtomPtr<LockedStream>,
_type: LinkType,
#[doc(hidden)]
_run: Arc<AtomicBool>,
io: Arc<IoPair<Packet>>,
}
impl Peer {
pub(crate) fn from_src(src: SourceAddr) -> Arc<Self> {
Arc::new(Self {
id: id::next(),
src: AtomPtr::new(Some(src)),
_run: Arc::new(true.into()),
..Default::default()
})
}
#[tracing::instrument(level = "trace")]
pub(crate) fn open(dst: DstAddr, port: u16, _type: LinkType) -> Arc<Self> {
let p = Arc::new(Self {
id: id::next(),
dst: Some(dst),
_run: Arc::new(true.into()),
_type,
..Default::default()
});
Arc::clone(&p).run_io_sender(port, _type);
task::block_on(async { Arc::clone(&p).send(Packet::Hello { port, _type }).await });
return p;
}
pub(crate) fn set_src<O: Into<Option<SourceAddr>>>(&self, src: O) {
self.src.swap(src.into());
}
pub(crate) async fn set_stream(&self, s: LockedStream) {
self.sender.swap(s);
}
pub(crate) fn stop(&self) {
self._run.fetch_and(false, Ordering::Relaxed);
}
pub(crate) fn state(&self) -> PeerState {
match (self.get_src(), self.dst) {
(Some(_), Some(_)) => PeerState::Duplex,
(Some(_), None) => PeerState::RxOnly,
(None, Some(_)) => PeerState::TxOnly,
(None, None) => unreachable!(),
}
}
pub(crate) fn link_type(&self) -> LinkType {
self._type
}
pub(crate) fn alive(&self) -> bool {
self._run.load(Ordering::Relaxed)
}
async fn send_packet(self: &Arc<Self>, p: &Packet) -> Option<()> {
let r = self.sender.get_ref();
let mut s = r.write().await;
match *s {
Some(ref mut stream) => {
let addr = match stream.peer_addr() {
Ok(addr) => addr.to_string(),
Err(_) => {
std::mem::swap(&mut *s, &mut None);
return None;
}
};
let buf = p.serialize();
if let Err(e) = stream.write_all(&buf).await {
error!("Failed to send message: {}!", e.to_string());
std::mem::swap(&mut *s, &mut None);
return None;
}
match p {
Packet::Hello { .. } => {
trace!("Sending HELLO to {}", addr);
if self._type == LinkType::Bidirect {
let _self = Arc::clone(self);
task::spawn(_self.wait_for_ack());
}
}
_ => {}
}
Some(())
}
None => unreachable!(),
}
}
async fn wait_for_ack(self: Arc<Self>) {
let t = timeout(Duration::from_secs(10), async {
loop {
let r = self.sender.get_ref();
let mut s = r.write().await;
if s.is_none() {
break;
}
if timeout(Duration::from_millis(1), async {
let mut pb = PacketBuilder::new((*s).as_mut().unwrap());
match pb.parse().await {
Ok(_) => match pb.build() {
Some(Packet::Ack) => trace!("Received an ACK."),
_ => error!("Invalid data (only ACKs)!"),
},
_ => {
std::mem::swap(&mut *s, &mut None);
error!("Failed to read ACK from sender stream");
}
}
})
.await
.is_ok()
{
break;
}
drop(s);
task::sleep(Duration::from_millis(50)).await;
}
});
match t.await {
Ok(_) => {}
Err(_) => {
let _ref = self.sender.get_ref();
let mut s = _ref.write().await;
std::mem::swap(&mut *s, &mut None);
}
}
}
async fn send_or_introduce(self: &Arc<Self>, p: Packet, port: u16, _type: LinkType) {
loop {
if self.sender.get_ref().read().await.is_some() {
match self.send_packet(&p).await {
Some(_) => break,
None => continue, }
} else {
if _type == LinkType::Bidirect {
trace!("Sender is None, opening a connection first...");
Arc::clone(&self).introduce_blocking(port).await;
}
}
}
}
pub(crate) fn run_io_sender(self: Arc<Self>, port: u16, _type: LinkType) {
trace!("Running IO sender");
task::spawn(async move {
while let Ok(p) = self.io.rx.recv().await {
trace!("Queued packet {:?}", p);
self.send_or_introduce(p, port, _type).await;
if !self.alive() {
break;
}
}
trace!("Shutting down packet sender for peer {}", self.id);
});
}
async fn introduce_blocking(self: Arc<Self>, _port: u16) {
let id = self.id.clone();
let dst = self.dst.clone().unwrap();
let run = Arc::clone(&self._run);
let sender = Arc::clone(&self.sender.get_ref());
let mut ctr = 0;
while run.load(Ordering::Relaxed) {
let pre = match ctr {
0 => "".into(),
n => format!("[retry #{}]", n),
};
if sender.read().await.is_some() {
trace!(
"Peer `{}` (ID: {}) is already connected!",
dst.to_string(),
id
);
break;
}
trace!(
"{}: Attempting to connect to peer `{}`",
pre,
dst.to_string()
);
let s = match TcpStream::connect(dst).await {
Ok(s) => s,
Err(_) => {
error!(
"Failed to connect to peer `{}`. Starting timeout...",
dst.to_string()
);
task::sleep(Duration::from_secs(5)).await;
ctr += 1;
continue;
}
};
s.set_nodelay(true).unwrap();
trace!("Successfully connected to peer `{}`", &dst);
let mut sender = sender.write().await;
*sender = Some(s);
break;
}
}
pub(crate) async fn send(&self, packet: Packet) {
self.io.tx.send(packet).await.unwrap();
}
pub(crate) fn get_src(&self) -> Option<SourceAddr> {
*self.src.get_ref().clone()
}
pub(crate) fn get_dst(&self) -> Option<DstAddr> {
self.dst.clone()
}
}