use std::io;
use std::net::SocketAddr;
use async_trait::async_trait;
use tokio::net::TcpStream;
use super::mse::{EncryptedStream, PeerStream};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransportType {
Tcp,
Utp,
}
impl std::fmt::Display for TransportType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tcp => write!(f, "TCP"),
Self::Utp => write!(f, "uTP"),
}
}
}
#[async_trait]
pub trait PeerTransport: Send + Sync {
fn peer_addr(&self) -> io::Result<SocketAddr>;
fn local_addr(&self) -> io::Result<SocketAddr>;
fn transport_type(&self) -> TransportType;
fn is_encrypted(&self) -> bool;
async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()>;
async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize>;
async fn write_all(&mut self, buf: &[u8]) -> io::Result<()>;
async fn flush(&mut self) -> io::Result<()>;
async fn shutdown(&mut self) -> io::Result<()>;
}
pub struct TcpTransport {
stream: PeerStream,
}
impl TcpTransport {
pub fn new(stream: TcpStream) -> Self {
Self {
stream: PeerStream::Plain(stream),
}
}
pub fn encrypted(stream: EncryptedStream) -> Self {
Self {
stream: PeerStream::Encrypted(Box::new(stream)),
}
}
pub fn from_peer_stream(stream: PeerStream) -> Self {
Self { stream }
}
}
#[async_trait]
impl PeerTransport for TcpTransport {
fn peer_addr(&self) -> io::Result<SocketAddr> {
self.stream.peer_addr()
}
fn local_addr(&self) -> io::Result<SocketAddr> {
match &self.stream {
PeerStream::Plain(s) => s.local_addr(),
PeerStream::Encrypted(s) => s.local_addr(),
PeerStream::Utp(s) => s.peer_addr(), }
}
fn transport_type(&self) -> TransportType {
TransportType::Tcp
}
fn is_encrypted(&self) -> bool {
self.stream.is_encrypted()
}
async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
self.stream.read_exact(buf).await
}
async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.stream.read(buf).await
}
async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
self.stream.write_all(buf).await
}
async fn flush(&mut self) -> io::Result<()> {
self.stream.flush().await
}
async fn shutdown(&mut self) -> io::Result<()> {
self.stream.shutdown().await
}
}
pub struct UtpTransport {
socket: super::utp::UtpSocket,
local_addr: SocketAddr,
}
impl UtpTransport {
pub fn new(socket: super::utp::UtpSocket, local_addr: SocketAddr) -> Self {
Self { socket, local_addr }
}
}
#[async_trait]
impl PeerTransport for UtpTransport {
fn peer_addr(&self) -> io::Result<SocketAddr> {
self.socket.peer_addr()
}
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.local_addr)
}
fn transport_type(&self) -> TransportType {
TransportType::Utp
}
fn is_encrypted(&self) -> bool {
false
}
async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
self.socket.read_exact(buf).await
}
async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.socket.read(buf).await
}
async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
self.socket.write_all(buf).await
}
async fn flush(&mut self) -> io::Result<()> {
self.socket.flush().await
}
async fn shutdown(&mut self) -> io::Result<()> {
self.socket.shutdown().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_type_display() {
assert_eq!(format!("{}", TransportType::Tcp), "TCP");
assert_eq!(format!("{}", TransportType::Utp), "uTP");
}
}