pub mod async_intrf;
use std::io::ErrorKind;
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use tokio::net::{TcpStream, UdpSocket};
use tokio::io::{AsyncWriteExt, AsyncReadExt};
use tokio::time::timeout;
use tokio::net::{TcpSocket};
use crate::a_sync::network::SocketTap;
use crate::network_common::SocketTapCommon;
use crate::{internal_error, internal_error_map, CDnsErrorType};
use crate::{a_sync::{network::{NetworkTap, NetworkTapType}, SocketTaps}, cfg_resolv_parser::ResolveConfEntry, CDnsResult};
#[derive(Clone, Debug)]
pub struct TokioSocketBase;
impl SocketTaps<TokioSocketBase> for TokioSocketBase
{
type TcpSock = TcpStream;
type UdpSock = UdpSocket;
#[cfg(feature = "use_async_tokio_tls")]
type TlsSock = self::with_tls::TcpTlsConnection;
#[inline]
fn new_tcp_socket(resolver: Arc<ResolveConfEntry>, timeout: Duration) -> CDnsResult<Box<NetworkTapType<TokioSocketBase>>>
{
return NetworkTap::<Self::TcpSock, TokioSocketBase>::new(resolver, timeout)
}
#[inline]
fn new_udp_socket(resolver: Arc<ResolveConfEntry>, timeout: Duration) -> CDnsResult<Box<NetworkTapType<TokioSocketBase>>>
{
return NetworkTap::<Self::UdpSock, TokioSocketBase>::new(resolver, timeout)
}
#[cfg(feature = "use_async_tokio_tls")]
#[inline]
fn new_tls_socket(resolver: Arc<ResolveConfEntry>, timeout: Duration) -> CDnsResult<Box<NetworkTapType<TokioSocketBase>>>
{
return NetworkTap::<Self::TlsSock, TokioSocketBase>::new(resolver, timeout)
}
}
#[cfg(feature = "use_async_tokio_tls")]
pub mod with_tls
{
use std::io::ErrorKind;
use std::os::fd::{AsFd, BorrowedFd};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use rustls::pki_types::ServerName;
use rustls::RootCertStore;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_rustls::client::TlsStream;
use crate::a_sync::network::{NetworkTap, SocketTap};
use crate::a_sync::tokio_exc::new_tcp_stream;
use crate::a_sync::TokioSocketBase;
use crate::cfg_resolv_parser::ResolveConfEntry;
use crate::network_common::SocketTapCommon;
use crate::{internal_error, internal_error_map, CDnsErrorType, CDnsResult};
#[derive(Debug)]
pub struct TcpTlsConnection
{
stream: TlsStream<TcpStream>,
}
impl AsFd for TcpTlsConnection
{
fn as_fd(&self) -> BorrowedFd<'_>
{
return self.stream.get_ref().0.as_fd();
}
}
impl TcpTlsConnection
{ async
fn connect(cfg: &ResolveConfEntry, conn_timeout: Option<Duration>) -> CDnsResult<Self>
{
let domain_name =
if let Some(domainname) = cfg.get_tls_domain()
{
ServerName::try_from(domainname.clone())
.map_err(|e|
internal_error_map!(CDnsErrorType::InternalError, "{}", e)
)?
}
else
{
internal_error!(CDnsErrorType::InternalError, "no domain is set for TLS conncection");
};
let config =
rustls
::ClientConfig
::builder_with_protocol_versions(&[&rustls::version::TLS12])
.with_root_certificates(RootCertStore{roots: webpki_roots::TLS_SERVER_ROOTS.into()})
.with_no_client_auth();
let conn =
tokio_rustls::TlsConnector::from(Arc::new(config));
let socket = new_tcp_stream(&cfg, conn_timeout).await?;
let mut stream_tls =
conn
.connect(domain_name, socket)
.await
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "{}", e)
)?;
stream_tls.flush().await.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
return Ok( Self{ stream: stream_tls } );
}
async
fn internal_poll_read(&self, timeout_dur: Duration) -> CDnsResult<()>
{
timeout(timeout_dur, self.stream.get_ref().0.readable())
.await
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "Timeout {}", e)
)?
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "socket poll error {}", e)
)
}
}
#[async_trait]
impl SocketTap<TokioSocketBase> for NetworkTap<TcpTlsConnection, TokioSocketBase>
{
async
fn connect(&mut self, conn_timeout: Option<Duration>) -> CDnsResult<()>
{
if self.sock.is_some() == true
{
return Ok(());
}
let socket=
TcpTlsConnection::connect(self.cfg.as_ref(), conn_timeout).await?;
self.sock = Some(socket);
return Ok(());
}
fn is_encrypted(&self) -> bool
{
return true;
}
fn is_tcp(&self) -> bool
{
return true;
}
fn should_append_len(&self) -> bool
{
return true;
}
async
fn poll_read(&self) -> CDnsResult<()>
{
return self.sock.as_ref().unwrap().internal_poll_read(self.timeout).await;
}
async
fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>
{
return
self
.sock
.as_mut()
.unwrap()
.stream
.write_all(sndbuf)
.await
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))
.map(|_| sndbuf.len());
}
async
fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
loop
{
match self.sock.as_mut().unwrap().stream.read(rcvbuf).await
{
Ok(n) =>
{
return Ok(n);
},
Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
{
internal_error!(CDnsErrorType::RequestTimeout, "request timeout from: '{}'", self.get_remote_addr());
},
Err(ref e) if e.kind() == ErrorKind::Interrupted =>
{
continue;
},
Err(e) =>
{
internal_error!(CDnsErrorType::IoError, "{}", e);
}
}
}
}
}
}
#[async_trait]
impl SocketTap<TokioSocketBase> for NetworkTap<UdpSocket, TokioSocketBase>
{
async
fn connect(&mut self, _conn_timeout: Option<Duration>) -> CDnsResult<()>
{
if self.sock.is_some() == true
{
return Ok(());
}
let socket =
UdpSocket::bind(self.cfg.get_adapter_ip())
.await
.map_err(|e| internal_error_map!(CDnsErrorType::InternalError, "{}", e))?;
socket.connect(self.cfg.get_resolver_sa())
.await
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
self.sock = Some(socket);
return Ok(());
}
fn is_encrypted(&self) -> bool
{
return false;
}
fn is_tcp(&self) -> bool
{
return false;
}
fn should_append_len(&self) -> bool
{
return false;
}
async
fn poll_read(&self) -> CDnsResult<()>
{
timeout(self.timeout, self.sock.as_ref().unwrap().readable())
.await
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "Timeout {}", e)
)?
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "socket poll error {}", e)
)
}
async
fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>
{
return
self.sock.as_mut()
.unwrap()
.send(sndbuf)
.await
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
}
async
fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
async
fn sub_recv(this: &mut NetworkTap<UdpSocket, TokioSocketBase>, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
loop
{
match this.sock.as_mut().unwrap().recv_from(rcvbuf).await
{
Ok((rcv_len, rcv_src)) =>
{
if &rcv_src != this.get_remote_addr()
{
internal_error!(
CDnsErrorType::DnsResponse,
"received answer from unknown host: '{}' exp: '{}'",
this.get_remote_addr(),
rcv_src
);
}
return Ok(rcv_len);
},
Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
{
continue;
},
Err(ref e) if e.kind() == ErrorKind::Interrupted =>
{
continue;
},
Err(e) =>
{
internal_error!(CDnsErrorType::IoError, "{}", e);
}
} }
}
match timeout(self.timeout, sub_recv(self, rcvbuf)).await
{
Ok(r) =>
return r,
Err(e) =>
internal_error!(CDnsErrorType::RequestTimeout, "{}", e)
}
}
}
async
fn new_tcp_stream(cfg: &ResolveConfEntry, conn_timeout: Option<Duration>) -> CDnsResult<TcpStream>
{
let socket =
if cfg.get_resolver_ip().is_ipv4() == true
{
TcpSocket::new_v4()
}
else
{
TcpSocket::new_v6()
}
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket.bind(*cfg.get_adapter_ip()).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket.set_keepalive(false).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket.set_nodelay(true).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
let tcpstream =
if let Some(c_timeout) = conn_timeout
{
timeout(c_timeout, socket.connect(*cfg.get_resolver_sa()))
.await
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
}
else
{
socket
.connect(*cfg.get_resolver_sa())
.await
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
};
return Ok(tcpstream);
}
#[async_trait]
impl SocketTap<TokioSocketBase> for NetworkTap<TcpStream, TokioSocketBase>
{
async
fn connect(&mut self, conn_timeout: Option<Duration>) -> CDnsResult<()>
{
if self.sock.is_some() == true
{
return Ok(());
}
let tcpstream = new_tcp_stream(&self.cfg, conn_timeout).await?;
self.sock = Some(tcpstream);
return Ok(());
}
fn is_encrypted(&self) -> bool
{
return false;
}
fn is_tcp(&self) -> bool
{
return true;
}
fn should_append_len(&self) -> bool
{
return true;
}
async
fn poll_read(&self) -> CDnsResult<()>
{
timeout(self.timeout, self.sock.as_ref().unwrap().readable())
.await
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "Timeout {}", e)
)?
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "socket poll error {}", e)
)
}
async
fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>
{
return
self
.sock
.as_mut()
.unwrap()
.write(sndbuf)
.await
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
}
async
fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
async
fn sub_recv(this: &mut NetworkTap<TcpStream, TokioSocketBase>, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
loop
{
match this.sock.as_mut().unwrap().read(rcvbuf).await
{
Ok(n) =>
{
return Ok(n);
},
Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
{
continue;
},
Err(ref e) if e.kind() == ErrorKind::Interrupted =>
{
continue;
},
Err(e) =>
{
internal_error!(CDnsErrorType::IoError, "{}", e);
}
} } }
match timeout(self.timeout, sub_recv(self, rcvbuf)).await
{
Ok(r) => return r,
Err(e) => internal_error!(CDnsErrorType::RequestTimeout, "{}", e)
}
}
}
#[cfg(test)]
mod tests
{
use std::{net::{IpAddr, SocketAddr}, sync::Arc, time::Duration};
use tokio::net::UdpSocket;
use crate::{a_sync::{network::NetworkTap, TokioSocketBase}, cfg_resolv_parser::ResolveConfEntry, common::IPV4_BIND_ALL};
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_struct()
{
let ip0: IpAddr = "127.0.0.1".parse().unwrap();
let bind = SocketAddr::from((IPV4_BIND_ALL, 0));
let v = Arc::new(ResolveConfEntry::new(SocketAddr::new(ip0, 53), None, bind).unwrap());
let res = NetworkTap::<UdpSocket, TokioSocketBase>::new(v, Duration::from_secs(5));
assert_eq!(res.is_ok(), true, "{}", res.err().unwrap());
let _res = res.unwrap();
}
}