#![cfg(not(target_arch = "wasm32"))]
use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::collections::HashMap;
use std::fmt::Debug;
use std::net::SocketAddr;
use futures::{SinkExt, StreamExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::select;
use tokio::task::spawn_local;
use tokio_stream::StreamMap;
use tokio_util::codec::{
BytesCodec, Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec, LinesCodec,
};
use super::unsync::mpsc::{Receiver, Sender};
use super::unsync_channel;
pub fn tcp_framed<Codec>(
stream: TcpStream,
codec: Codec,
) -> (
FramedWrite<OwnedWriteHalf, Codec>,
FramedRead<OwnedReadHalf, Codec>,
)
where
Codec: Clone + Decoder,
{
let (recv, send) = stream.into_split();
let send = FramedWrite::new(send, codec.clone());
let recv = FramedRead::new(recv, codec);
(send, recv)
}
pub fn tcp_bytes(
stream: TcpStream,
) -> (
FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
) {
tcp_framed(stream, LengthDelimitedCodec::new())
}
pub fn tcp_bytestream(
stream: TcpStream,
) -> (
FramedWrite<OwnedWriteHalf, BytesCodec>,
FramedRead<OwnedReadHalf, BytesCodec>,
) {
tcp_framed(stream, BytesCodec::new())
}
pub fn tcp_lines(
stream: TcpStream,
) -> (
FramedWrite<OwnedWriteHalf, LinesCodec>,
FramedRead<OwnedReadHalf, LinesCodec>,
) {
tcp_framed(stream, LinesCodec::new())
}
pub type TcpFramedSink<T> = Sender<(T, SocketAddr)>;
#[expect(type_alias_bounds, reason = "code readability")]
pub type TcpFramedStream<Codec: Decoder> =
Receiver<Result<(<Codec as Decoder>::Item, SocketAddr), <Codec as Decoder>::Error>>;
pub async fn bind_tcp<Item, Codec>(
endpoint: SocketAddr,
codec: Codec,
) -> Result<(TcpFramedSink<Item>, TcpFramedStream<Codec>, SocketAddr), std::io::Error>
where
Item: 'static,
Codec: 'static + Clone + Decoder + Encoder<Item>,
<Codec as Encoder<Item>>::Error: Debug,
{
let listener = TcpListener::bind(endpoint).await?;
let bound_endpoint = listener.local_addr()?;
let (send_egress, mut recv_egress) = unsync_channel::<(Item, SocketAddr)>(None);
let (send_ingres, recv_ingres) = unsync_channel(None);
spawn_local(async move {
let send_ingress = send_ingres;
let mut peers_send = HashMap::new();
let mut peers_recv = StreamMap::<SocketAddr, FramedRead<OwnedReadHalf, Codec>>::new();
loop {
select! {
biased;
msg_send = recv_egress.next() => {
let Some((payload, peer_addr)) = msg_send else {
continue;
};
let Some(stream) = peers_send.get_mut(&peer_addr) else {
tracing::warn!("Dropping message to non-connected peer: {}", peer_addr);
continue;
};
if let Err(err) = SinkExt::send(stream, payload).await {
tracing::error!("IO or codec error sending message to peer {}, disconnecting: {:?}", peer_addr, err);
peers_send.remove(&peer_addr); };
}
msg_recv = peers_recv.next(), if !peers_recv.is_empty() => {
let Some((peer_addr, payload_result)) = msg_recv else {
continue; };
if let Err(err) = send_ingress.send(payload_result.map(|payload| (payload, peer_addr))).await {
tracing::error!("Error passing along received message: {:?}", err);
}
}
new_peer = listener.accept() => {
let Ok((stream, _addr)) = new_peer else {
continue;
};
let Ok(peer_addr) = stream.peer_addr() else {
continue;
};
let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());
peers_send.insert(peer_addr, peer_send);
peers_recv.insert(peer_addr, peer_recv);
}
}
}
});
Ok((send_egress, recv_ingres, bound_endpoint))
}
pub fn connect_tcp<Item, Codec>(codec: Codec) -> (TcpFramedSink<Item>, TcpFramedStream<Codec>)
where
Item: 'static,
Codec: 'static + Clone + Decoder + Encoder<Item>,
<Codec as Encoder<Item>>::Error: Debug,
{
let (send_egress, mut recv_egress) = unsync_channel(None);
let (send_ingres, recv_ingres) = unsync_channel(None);
spawn_local(async move {
let send_ingres = send_ingres;
let mut peers_send = HashMap::new();
let mut peers_recv = StreamMap::new();
loop {
select! {
biased;
msg_send = recv_egress.next() => {
let Some((payload, peer_addr)) = msg_send else {
continue;
};
let stream = match peers_send.entry(peer_addr) {
Occupied(entry) => entry.into_mut(),
Vacant(entry) => {
let socket = TcpSocket::new_v4().unwrap();
let stream = socket.connect(peer_addr).await.unwrap();
let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());
peers_recv.insert(peer_addr, peer_recv);
entry.insert(peer_send)
}
};
if let Err(err) = stream.send(payload).await {
tracing::error!("IO or codec error sending message to peer {}, disconnecting: {:?}", peer_addr, err);
peers_send.remove(&peer_addr); }
}
msg_recv = peers_recv.next(), if !peers_recv.is_empty() => {
let Some((peer_addr, payload_result)) = msg_recv else {
continue; };
if let Err(err) = send_ingres.send(payload_result.map(|payload| (payload, peer_addr))).await {
tracing::error!("Error passing along received message: {:?}", err);
}
}
}
}
});
(send_egress, recv_ingres)
}