use std::{io::ErrorKind, os::fd::AsFd, sync::Arc, time::Duration};
use async_trait::async_trait;
use cdns_rs::
{
a_sync::
{
interface::{AsyncMutex, AsyncMutexGuard, MutexedCaches, UnifiedFs},
network::NetworkTapType,
request,
CacheInstance,
CachesController,
NetworkTap,
QDns,
SocketTap,
SocketTaps
},
cfg_resolv_parser::ResolveConfEntry,
internal_error,
internal_error_map,
CDnsErrorType,
CDnsResult,
HostConfig,
QType,
QuerySetup,
ResolveConfig,
SocketTapCommon
};
use tokio::
{
fs::File,
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpSocket, TcpStream, UdpSocket},
sync::{Mutex, MutexGuard},
time::{timeout, Instant}
};
#[derive(Clone, Debug)]
pub struct SocketCustomBase;
impl SocketTaps<SocketCustomBase> for SocketCustomBase
{
type TcpSock = LocaltcpStrem;
type UdpSock = LocalUdpSocket;
#[inline]
fn new_tcp_socket(resolver: Arc<ResolveConfEntry>, timeout: Duration) -> CDnsResult<Box<NetworkTapType<SocketCustomBase>>>
{
return NetworkTap::<Self::TcpSock, SocketCustomBase>::new(resolver, timeout)
}
#[inline]
fn new_udp_socket(resolver: Arc<ResolveConfEntry>, timeout: Duration) -> CDnsResult<Box<NetworkTapType<SocketCustomBase>>>
{
return NetworkTap::<Self::UdpSock, SocketCustomBase>::new(resolver, timeout)
}
}
#[repr(transparent)]
#[derive(Debug)]
pub struct LocalUdpSocket(UdpSocket);
impl AsFd for LocalUdpSocket
{
fn as_fd(&self) -> std::os::unix::prelude::BorrowedFd<'_>
{
return self.0.as_fd();
}
}
#[derive(Debug)]
pub struct LocaltcpStrem(TcpStream);
impl AsFd for LocaltcpStrem
{
fn as_fd(&self) -> std::os::unix::prelude::BorrowedFd<'_>
{
return self.0.as_fd();
}
}
async
fn new_tcp_stream(cfg: &ResolveConfEntry, conn_timeout: Option<Duration>) -> CDnsResult<TcpStream>
{
let socket =
if cfg.get_resolver_ip().is_ipv4() == true
{
TcpSocket::new_v4()
}
else
{
TcpSocket::new_v6()
}
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket.bind(*cfg.get_adapter_ip()).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket.set_keepalive(false).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket.set_nodelay(true).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
let tcpstream =
if let Some(c_timeout) = conn_timeout
{
timeout(c_timeout, socket.connect(*cfg.get_resolver_sa()))
.await
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
}
else
{
socket
.connect(*cfg.get_resolver_sa())
.await
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
};
return Ok(tcpstream);
}
#[async_trait]
impl SocketTap<SocketCustomBase> for NetworkTap<LocaltcpStrem, SocketCustomBase>
{
async
fn connect(&mut self, conn_timeout: Option<Duration>) -> CDnsResult<()>
{
if self.sock.is_some() == true
{
return Ok(());
}
let tcpstream = new_tcp_stream(&self.cfg, conn_timeout).await?;
self.sock = Some(LocaltcpStrem(tcpstream));
return Ok(());
}
fn is_encrypted(&self) -> bool
{
return false;
}
fn is_tcp(&self) -> bool
{
return true;
}
fn should_append_len(&self) -> bool
{
return true;
}
async
fn poll_read(&self) -> CDnsResult<()>
{
timeout(self.timeout, self.sock.as_ref().unwrap().0.readable())
.await
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "Timeout {}", e)
)?
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "socket poll error {}", e)
)
}
async
fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>
{
return
self
.sock
.as_mut()
.unwrap()
.0
.write(sndbuf)
.await
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
}
async
fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
async
fn sub_recv(this: &mut NetworkTap<LocaltcpStrem, SocketCustomBase>, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
loop
{
match this.sock.as_mut().unwrap().0.read(rcvbuf).await
{
Ok(n) =>
{
return Ok(n);
},
Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
{
continue;
},
Err(ref e) if e.kind() == ErrorKind::Interrupted =>
{
continue;
},
Err(e) =>
{
internal_error!(CDnsErrorType::IoError, "{}", e);
}
} } }
match timeout(self.timeout, sub_recv(self, rcvbuf)).await
{
Ok(r) => return r,
Err(e) => internal_error!(CDnsErrorType::RequestTimeout, "{}", e)
}
}
}
#[async_trait]
impl SocketTap<SocketCustomBase> for NetworkTap<LocalUdpSocket, SocketCustomBase>
{
async
fn connect(&mut self, _conn_timeout: Option<Duration>) -> CDnsResult<()>
{
if self.sock.is_some() == true
{
return Ok(());
}
let socket =
UdpSocket::bind(self.cfg.get_adapter_ip())
.await
.map_err(|e| internal_error_map!(CDnsErrorType::InternalError, "{}", e))?;
socket.connect(self.cfg.get_resolver_sa())
.await
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
self.sock = Some(LocalUdpSocket(socket));
return Ok(());
}
fn is_encrypted(&self) -> bool
{
return false;
}
fn is_tcp(&self) -> bool
{
return false;
}
fn should_append_len(&self) -> bool
{
return false;
}
async
fn poll_read(&self) -> CDnsResult<()>
{
timeout(self.timeout, self.sock.as_ref().unwrap().0.readable())
.await
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "Timeout {}", e)
)?
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "socket poll error {}", e)
)
}
async
fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>
{
return
self.sock.as_mut()
.unwrap()
.0
.send(sndbuf)
.await
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
}
async
fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
async
fn sub_recv(this: &mut NetworkTap<LocalUdpSocket, SocketCustomBase>, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
loop
{
match this.sock.as_mut().unwrap().0.recv_from(rcvbuf).await
{
Ok((rcv_len, rcv_src)) =>
{
if &rcv_src != this.get_remote_addr()
{
internal_error!(
CDnsErrorType::DnsResponse,
"received answer from unknown host: '{}' exp: '{}'",
this.get_remote_addr(),
rcv_src
);
}
return Ok(rcv_len);
},
Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
{
continue;
},
Err(ref e) if e.kind() == ErrorKind::Interrupted =>
{
continue;
},
Err(e) =>
{
internal_error!(CDnsErrorType::IoError, "{}", e);
}
} }
}
match timeout(self.timeout, sub_recv(self, rcvbuf)).await
{
Ok(r) =>
return r,
Err(e) =>
internal_error!(CDnsErrorType::RequestTimeout, "{}", e)
}
}
}
#[derive(Debug)]
pub struct LocalTokioInterf;
impl MutexedCaches for LocalTokioInterf
{
type MetadataFs = LocalFile;
type ResolveCache = LocalMutex<CacheInstance<ResolveConfig, Self::MetadataFs>>;
type HostCahae = LocalMutex<CacheInstance<HostConfig, Self::MetadataFs>>;
}
#[derive(Debug)]
pub struct LocalFile;
impl UnifiedFs for LocalFile
{
type ErrRes = std::io::Error;
type FileOp = File;
async
fn metadata(path: &std::path::Path) -> Result<std::fs::Metadata, Self::ErrRes>
{
return tokio::fs::metadata(path).await;
}
async
fn open<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Self::FileOp>
{
return tokio::fs::File::open(path).await;
}
async
fn read_to_string(file: &mut Self::FileOp, buf: &mut String) -> std::io::Result<usize>
{
return file.read_to_string(buf).await;
}
}
#[repr(transparent)]
#[derive(Debug)]
pub struct LocalMutex<DS: Sized>(Mutex<DS>);
impl<DS: Sized> AsyncMutex<DS> for LocalMutex<DS>
{
type MutxGuard<'mux> = LocalMutexGuard<'mux, DS> where DS: 'mux;
fn a_new(v: DS) -> Self
{
return LocalMutex(Mutex::new(v));
}
async
fn a_lock<'mux>(&'mux self) -> Self::MutxGuard<'mux>
{
return LocalMutexGuard(self.0.lock().await);
}
}
#[repr(transparent)]
#[derive(Debug)]
pub struct LocalMutexGuard<'mux, DS: Sized>(MutexGuard<'mux, DS>);
impl<'mux, DS: Sized> AsyncMutexGuard<'mux, DS>for LocalMutexGuard<'mux, DS>
{
fn guard(&self) -> &DS
{
return &self.0;
}
fn guard_mut(&mut self) -> &mut DS
{
return &mut self.0;
}
}
#[tokio::main]
async fn main()
{
let cache = Arc::new(CachesController::<LocalTokioInterf>::new_custom().await.unwrap());
let cfg = "nameserver 1.1.1.1";
let cust = Arc::new(ResolveConfig::async_custom_config(cfg).await.unwrap());
let now = Instant::now();
let res_a = request::resolve_fqdn::<_, SocketCustomBase, SocketCustomBase, LocalTokioInterf>("protonmail.com", Some(cust.clone()), cache.clone()).await.unwrap();
let res_mx = request::resolve_mx::<_, SocketCustomBase, SocketCustomBase, LocalTokioInterf>("protonmail.com", Some(cust.clone()), cache.clone()).await.unwrap();
let res_ptr_local = request::resolve_reverse::<_, SocketCustomBase, SocketCustomBase, LocalTokioInterf>("::1", Some(cust.clone()), cache.clone()).await.unwrap();
let res_ptr = request::resolve_reverse::<_, SocketCustomBase, SocketCustomBase, LocalTokioInterf>("8.8.8.8", Some(cust.clone()), cache.clone()).await.unwrap();
let mut dns_req =
QDns::<SocketCustomBase, SocketCustomBase, LocalTokioInterf>::make_empty(Some(cust.clone()), QuerySetup::default(), cache.clone())
.await
.unwrap();
dns_req.add_request(QType::SOA, "protonmail.com");
let res = dns_req.query().await;
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
println!("A/AAAA:");
for a in res_a
{
println!("\t{}", a);
}
println!("MX:");
for mx in res_mx
{
println!("\t{}", mx);
}
println!("PTR local:");
for ptr in res_ptr_local
{
println!("\t{}", ptr);
}
println!("PTR:");
for ptr in res_ptr
{
println!("\t{}", ptr);
}
println!("SOA:");
let soa_res = res.get_result();
if soa_res.is_err() == true
{
println!("error: {}", soa_res.err().unwrap());
}
else
{
let soa = soa_res.unwrap();
if soa.is_empty() == true
{
println!("\tNo SOA found!")
}
else
{
for s in soa
{
for i in s.resp
{
println!("\t{}", i)
}
}
}
}
return;
}