use std::cell::RefCell;
use std::fmt;
use std::io::{ErrorKind, Write};
use std::io::prelude::*;
use std::net::SocketAddr;
use std::net::UdpSocket;
use std::os::fd::AsFd;
use std::sync::Arc;
use std::time::Duration;
use std::fmt::Debug;
use std::net::TcpStream;
use nix::poll::{PollFd, PollFlags};
use socket2::{Socket, Domain, Type, Protocol, SockAddr};
use crate::cfg_resolv_parser::ResolveConfEntry;
use crate::common::DnsRequestAnswer;
use crate::network_common::SocketTapCommon;
use crate::{internal_error, internal_error_map, write_error};
use crate::error::*;
pub trait SocketTap: SocketTapCommon
{
fn is_tcp(&self) -> bool;
fn should_append_len(&self) -> bool;
fn is_encrypted(&self) -> bool;
fn poll_read(&self) -> CDnsResult<bool>;
fn send(&self, sndbuf: &[u8]) -> CDnsResult<usize> ;
fn recv(&self) -> CDnsResult<DnsRequestAnswer>;
}
impl Debug for dyn SocketTap
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
{
write!(f, "{:?}", self)
}
}
#[cfg(feature = "use_sync_tls")]
pub mod with_tls
{
use std::cell::RefCell;
use std::io::{ErrorKind, Read, Write};
use std::net::TcpStream;
use std::os::fd::{AsFd, BorrowedFd};
use std::sync::Arc;
use std::time::Duration;
use rustls::pki_types::ServerName;
use rustls::{ClientConnection, RootCertStore, StreamOwned};
use crate::cfg_resolv_parser::ResolveConfEntry;
use crate::common::{DnsRequestAnswer, DEF_USERAGENT};
use crate::network_common::SocketTapCommon;
use crate::sync::network::{new_tcp_stream, NetworkTap, SocketTap};
use crate::{internal_error, internal_error_map, CDnsErrorType, CDnsResult};
#[derive(Debug)]
pub struct TcpHttpsConnection
{
stream: TcpTlsConnection,
}
impl AsFd for TcpHttpsConnection
{
fn as_fd(&self) -> BorrowedFd<'_>
{
return self.stream.as_fd();
}
}
impl TcpHttpsConnection
{
fn connect(cfg: &ResolveConfEntry,
conn_timeout: Option<Duration>, timeout: Option<Duration>, blk_flag: bool) -> CDnsResult<Self>
{
let tls_conn = TcpTlsConnection::connect(cfg, conn_timeout, timeout, blk_flag)?;
return Ok(TcpHttpsConnection { stream: tls_conn });
}
}
#[derive(Debug)]
pub struct TcpTlsConnection
{
stream: StreamOwned<ClientConnection, TcpStream>,
}
impl AsFd for TcpTlsConnection
{
fn as_fd(&self) -> BorrowedFd<'_>
{
return self.stream.sock.as_fd();
}
}
impl TcpTlsConnection
{
fn connect(cfg: &ResolveConfEntry,
conn_timeout: Option<Duration>, timeout: Option<Duration>, blk_flag: bool) -> CDnsResult<Self>
{
let domain_name =
if let Some(domainname) = cfg.get_tls_domain()
{
ServerName::try_from(domainname.clone())
.map_err(|e|
internal_error_map!(CDnsErrorType::InternalError, "{}", e)
)?
}
else
{
internal_error!(CDnsErrorType::InternalError, "no domain is set for TLS conncection");
};
let config =
rustls
::ClientConfig
::builder_with_protocol_versions(&[&rustls::version::TLS12])
.with_root_certificates(RootCertStore{roots: webpki_roots::TLS_SERVER_ROOTS.into()})
.with_no_client_auth();
let conn =
rustls
::ClientConnection
::new(Arc::new(config), domain_name)
.map_err(|e| internal_error_map!(CDnsErrorType::InternalError, "{}", e))?;
let socket = new_tcp_stream(cfg, conn_timeout, timeout, blk_flag)?;
let mut tlssock = rustls::StreamOwned::new(conn, socket);
tlssock.flush().map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
return Ok( Self{ stream: tlssock } );
}
}
impl NetworkTap<TcpHttpsConnection>
{
fn sub_recv(&self, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
loop
{
match self.sock.borrow_mut().stream.stream.read(rcvbuf)
{
Ok(n) =>
{
return Ok(n);
},
Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
{
internal_error!(CDnsErrorType::RequestTimeout, "request timeout from: '{}'", self.get_remote_addr());
},
Err(ref e) if e.kind() == ErrorKind::Interrupted =>
{
continue;
},
Err(e) =>
{
internal_error!(CDnsErrorType::IoError, "{}", e);
}
}
}
}
}
impl SocketTap for NetworkTap<TcpHttpsConnection>
{
fn is_encrypted(&self) -> bool
{
return true;
}
fn is_tcp(&self) -> bool
{
return true;
}
fn should_append_len(&self) -> bool
{
return false;
}
fn poll_read(&self) -> CDnsResult<bool>
{
return self.internal_poll_read();
}
fn send(&self, sndbuf: &[u8]) -> CDnsResult<usize>
{
let url_path =
self.cfg.get_tls_path().map_or("", |f| f.as_str());
let host =
self.cfg.get_tls_domain().map_or(self.cfg.get_resolver_ip().to_string(), |f| f.clone());
let mut http_req =
[
"POST /", url_path," HTTP/1.1\r\n",
"Host: ", host.as_str(), "\r\n",
"Content-Type: application/dns-message\r\n",
"Accept: application/dns-message\r\n",
"User-Agent: ", DEF_USERAGENT, "\r\n",
"Content-Length: ", sndbuf.len().to_string().as_str(), "\r\n\r\n",
]
.concat()
.into_bytes();
println!("{}", http_req.len());
http_req.extend(sndbuf);
println!("{}", http_req.len());
return
self
.sock
.borrow_mut()
.stream
.stream
.write_all(&http_req)
.map_err(|e|
internal_error_map!(CDnsErrorType::IoError, "{}", e)
)
.map(|_| http_req.len());
}
fn recv(&self) -> CDnsResult<DnsRequestAnswer>
{
let mut rcvbuf = vec![0_u8; 4096];
let mut total_rcv_len = 0;
let index;
let mut full_cont_len = 0;
loop
{
let rcv_len = self.sub_recv(&mut rcvbuf[total_rcv_len..])?;
total_rcv_len += rcv_len;
if rcv_len == 0
{
internal_error!(CDnsErrorType::IoError, "received EOF");
}
let mut headers = [httparse::EMPTY_HEADER; 18];
let mut response = httparse::Response::new(&mut headers);
let status =
response
.parse(&rcvbuf)
.map_err(|e|
internal_error_map!(CDnsErrorType::DnsResponse, "can not parse response - {}", e)
)?;
if status.is_partial() == true
{
if total_rcv_len == rcvbuf.capacity()
{
internal_error!(CDnsErrorType::IoError, "received invalid data");
}
}
else
{
index = status.unwrap();
if response.code != Some(200)
{
internal_error!(CDnsErrorType::HttpError, "http error code {}, reason {}",
response.code.unwrap(), response.reason.unwrap_or("N/A"));
}
for header in response.headers
{
let header_value = String::from_utf8_lossy(header.value);
if header.name == "Content-Length"
{
let content_length: usize = header_value.parse().unwrap();
full_cont_len = index + content_length;
break;
}
}
break;
}
}
if full_cont_len == 0
{
internal_error!(CDnsErrorType::HttpError, "http Content-Length header field was not found in resp");
}
if rcvbuf.len() < full_cont_len
{
rcvbuf.reserve_exact(full_cont_len - rcvbuf.len() + 1);
}
let mut i = 0;
while total_rcv_len < full_cont_len
{
let rcv_len = self.sub_recv(&mut rcvbuf[total_rcv_len..])?;
total_rcv_len += rcv_len;
if rcv_len == 0
{
internal_error!(CDnsErrorType::IoError, "received EOF")
}
i += 1;
}
return DnsRequestAnswer::parse(&rcvbuf[index..total_rcv_len]);
}
}
impl NetworkTap<TcpTlsConnection>
{
fn sub_read(&self, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
loop
{
match self.sock.borrow_mut().stream.read(rcvbuf)
{
Ok(n) =>
{
return Ok(n);
},
Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
{
internal_error!(CDnsErrorType::RequestTimeout, "request timeout from: '{}'", self.get_remote_addr());
},
Err(ref e) if e.kind() == ErrorKind::Interrupted =>
{
continue;
},
Err(e) =>
{
internal_error!(CDnsErrorType::IoError, "{}", e);
}
}
}
}
}
impl SocketTap for NetworkTap<TcpTlsConnection>
{
fn is_encrypted(&self) -> bool
{
return true;
}
fn is_tcp(&self) -> bool
{
return true;
}
fn should_append_len(&self) -> bool
{
return true;
}
fn poll_read(&self) -> CDnsResult<bool>
{
return self.internal_poll_read();
}
fn send(&self, sndbuf: &[u8]) -> CDnsResult<usize>
{
return
self
.sock
.borrow_mut()
.stream
.write_all(sndbuf)
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))
.map(|_| sndbuf.len());
}
fn recv(&self) -> CDnsResult<DnsRequestAnswer>
{
let mut pkg_pen: [u8; 2] = [0, 0];
let n = self.sub_read(&mut pkg_pen)?;
if n == 0
{
internal_error!(CDnsErrorType::IoError, "tcp received zero len message!");
}
else if n != 2
{
internal_error!(CDnsErrorType::IoError, "tcp expected 2 bytes to be read!");
}
let ln = u16::from_be_bytes(pkg_pen);
let mut rcvbuf = vec![0_u8; ln as usize];
let mut n = self.sub_read(rcvbuf.as_mut_slice())?;
if n == 0
{
internal_error!(CDnsErrorType::IoError, "tcp received zero len message!");
}
else if n == 1
{
n = self.sub_read(&mut rcvbuf[1..])?;
if n == 0
{
internal_error!(CDnsErrorType::IoError, "tcp received zero len message again!");
}
n += 1;
}
return DnsRequestAnswer::parse(&rcvbuf);
}
}
impl NetworkTap<TcpTlsConnection>
{
pub
fn new_tls(
resolver: Arc<ResolveConfEntry>,
timeout: Duration,
nonblk_flag: bool,
conn_timeout: Option<Duration>
) -> CDnsResult<Box<dyn SocketTap>>
where
NetworkTap<TcpTlsConnection>: SocketTap + 'static
{
let socket=
TcpTlsConnection
::connect(
resolver.as_ref(),
conn_timeout,
Some(timeout),
nonblk_flag
)?;
let ret =
NetworkTap::<TcpTlsConnection>
{
sock: RefCell::new(socket),
timeout: timeout,
cfg: resolver,
};
return Ok(Box::new(ret));
}
}
impl NetworkTap<TcpHttpsConnection>
{
pub
fn new_https(
resolver: Arc<ResolveConfEntry>,
timeout: Duration,
nonblk_flag: bool,
conn_timeout: Option<Duration>
) -> CDnsResult<Box<dyn SocketTap>>
where
NetworkTap<TcpHttpsConnection>: SocketTap + 'static
{
let socket=
TcpHttpsConnection
::connect(
resolver.as_ref(),
conn_timeout,
Some(timeout),
nonblk_flag
)?;
let ret =
NetworkTap::<TcpHttpsConnection>
{
sock: RefCell::new(socket),
timeout: timeout,
cfg: resolver,
};
return Ok(Box::new(ret));
}
}
}
pub struct NetworkTap<T: AsFd>
{
sock: RefCell<T>,
timeout: Duration,
cfg: Arc<ResolveConfEntry>,
}
impl<T: AsFd + Debug> fmt::Debug for NetworkTap<T>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
f.debug_struct("NetworkTap").field("sock", &self.sock).field("timeout", &self.timeout).field("cfg", &self.cfg).finish()
}
}
impl NetworkTap<TcpStream>
{
pub
fn new_tcp(
resolver: Arc<ResolveConfEntry>,
nonblk_flag: bool,
timeout: Duration,
conn_timeout: Option<Duration>
) -> CDnsResult<Box<dyn SocketTap>>
where
NetworkTap<TcpStream>: SocketTap + 'static
{
let socket =
new_tcp_stream(&resolver, conn_timeout, Some(timeout), nonblk_flag)?;
let ret =
NetworkTap::<TcpStream>
{
sock: RefCell::new(socket),
timeout: timeout,
cfg: resolver,
};
return Ok(Box::new(ret));
}
}
impl NetworkTap<UdpSocket>
{
pub
fn new_udp(
resolver: Arc<ResolveConfEntry>,
nonblk_flag: bool,
timeout: Duration
) -> CDnsResult<Box<dyn SocketTap>>
where
NetworkTap<UdpSocket>: SocketTap + 'static
{
let socket =
UdpSocket::bind(resolver.get_adapter_ip())
.map_err(|e| internal_error_map!(CDnsErrorType::InternalError, "{}", e))?;
socket
.set_nonblocking(nonblk_flag)
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket
.set_read_timeout(Some(timeout))
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket
.set_write_timeout(Some(timeout)) .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket
.connect(&resolver.get_resolver_sa())
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
let ret =
NetworkTap::<UdpSocket>
{
sock: RefCell::new(socket),
timeout: timeout,
cfg: resolver,
};
return Ok(Box::new(ret));
}
}
impl<T: AsFd> NetworkTap<T>
{
fn internal_poll_read(&self) -> CDnsResult<bool>
{
let sock = self.sock.borrow();
let mut poll_tap = [PollFd::new(sock.as_fd(), PollFlags::from(PollFlags::POLLIN))];
let poll_res =
nix::poll::poll(&mut poll_tap, self.timeout.as_millis() as u16)
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
if poll_res == 0
{
internal_error!(CDnsErrorType::IoError, "Timeout");
}
let Some(revents) = poll_tap[0].revents()
else
{
internal_error!(CDnsErrorType::IoError, "Crate NIX returned None due to unknown data returned by kernel!\
Try update nix crate! Probably it is outdated!");
};
if revents.intersects(PollFlags::POLLERR | PollFlags::POLLHUP | PollFlags::POLLNVAL ) == true
{
write_error!(internal_error_map!(CDnsErrorType::IoError, "socket poll error {:?}", revents));
return Ok(false);
}
else if revents.intersects(PollFlags::POLLIN) == true
{
return Ok(true);
}
return Ok(false);
}
}
impl<T: AsFd> SocketTapCommon for NetworkTap<T>
{
fn get_remote_addr(&self) -> &SocketAddr
{
return &self.cfg.get_resolver_sa();
}
}
impl SocketTap for NetworkTap<UdpSocket>
{
fn is_tcp(&self) -> bool
{
return false;
}
fn is_encrypted(&self) -> bool
{
return false;
}
fn should_append_len(&self) -> bool
{
return false;
}
fn poll_read(&self) -> CDnsResult<bool>
{
return self.internal_poll_read();
}
fn send(&self, sndbuf: &[u8]) -> CDnsResult<usize>
{
return
self
.sock
.borrow()
.send(sndbuf)
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
}
fn recv(&self) -> CDnsResult<DnsRequestAnswer>
{
let mut rcvbuf = vec![0_u8; 1457];
let _n =
loop
{
match self.sock.borrow().recv_from(&mut rcvbuf)
{
Ok((rcv_len, rcv_src)) =>
{
if &rcv_src != self.get_remote_addr()
{
internal_error!(
CDnsErrorType::DnsResponse,
"received answer from unknown host: '{}' exp: '{}'",
self.get_remote_addr(),
rcv_src
);
}
break Ok(rcv_len);
},
Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
{
internal_error!(CDnsErrorType::RequestTimeout, "request timeout from: '{}'", self.get_remote_addr());
},
Err(ref e) if e.kind() == ErrorKind::Interrupted =>
{
continue;
},
Err(e) =>
{
internal_error!(CDnsErrorType::IoError, "{}", e);
}
} }?;
return DnsRequestAnswer::parse(&rcvbuf);
}
}
fn new_tcp_stream(cfg: &ResolveConfEntry, conn_timeout: Option<Duration>, timeout: Option<Duration>, blk_flag: bool) -> CDnsResult<TcpStream>
{
let socket =
Socket::new(Domain::for_address(*cfg.get_resolver_sa()), Type::STREAM, Some(Protocol::TCP))
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket
.bind(&SockAddr::from(*cfg.get_adapter_ip()))
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket.set_tcp_nodelay(true).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket.set_keepalive(false).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
if let Some(c_timeout) = conn_timeout
{
socket
.connect_timeout(&SockAddr::from(*cfg.get_resolver_sa()), c_timeout)
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
}
else
{
socket
.connect(&SockAddr::from(*cfg.get_resolver_sa()))
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
}
let socket: TcpStream = socket.into();
socket
.set_nonblocking(blk_flag)
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket
.set_read_timeout(timeout) .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
socket
.set_write_timeout(timeout) .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
return Ok(socket);
}
impl NetworkTap<TcpStream>
{
fn sub_read(&self, rcvbuf: &mut [u8]) -> CDnsResult<usize>
{
loop
{
match self.sock.borrow_mut().read(rcvbuf)
{
Ok(n) =>
{
return Ok(n);
},
Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
{
internal_error!(CDnsErrorType::RequestTimeout, "request timeout from: '{}'", self.get_remote_addr());
},
Err(ref e) if e.kind() == ErrorKind::Interrupted =>
{
continue;
},
Err(e) =>
{
internal_error!(CDnsErrorType::IoError, "{}", e);
}
}
}
}
}
impl SocketTap for NetworkTap<TcpStream>
{
fn is_tcp(&self) -> bool
{
return true;
}
fn is_encrypted(&self) -> bool
{
return false;
}
fn should_append_len(&self) -> bool
{
return true;
}
fn poll_read(&self) -> CDnsResult<bool>
{
return self.internal_poll_read();
}
fn send(&self, sndbuf: &[u8]) -> CDnsResult<usize>
{
return
self
.sock
.borrow_mut()
.write(sndbuf)
.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
}
fn recv(&self) -> CDnsResult<DnsRequestAnswer>
{
let mut pkg_pen: [u8; 2] = [0, 0];
let n = self.sub_read(&mut pkg_pen)?;
if n == 0
{
internal_error!(CDnsErrorType::IoError, "tcp received zero len message!");
}
else if n != 2
{
internal_error!(CDnsErrorType::IoError, "tcp expected 2 bytes to be read!");
}
let ln = u16::from_be_bytes(pkg_pen);
let mut rcvbuf = vec![0_u8; ln as usize];
let mut n = self.sub_read(rcvbuf.as_mut_slice())?;
if n == 0
{
internal_error!(CDnsErrorType::IoError, "tcp received zero len message!");
}
else if n == 1
{
n = self.sub_read(&mut rcvbuf[1..])?;
if n == 0
{
internal_error!(CDnsErrorType::IoError, "tcp received zero len message again!");
}
n += 1;
}
return DnsRequestAnswer::parse(&rcvbuf);
}
}
#[cfg(test)]
mod tests
{
use std::{net::{IpAddr, SocketAddr, UdpSocket}, sync::Arc, time::Duration};
use crate::{cfg_resolv_parser::ResolveConfEntry, sync::network::NetworkTap};
#[test]
fn test_struct()
{
use crate::common::IPV4_BIND_ALL;
let ip0: IpAddr = "127.0.0.1".parse().unwrap();
let bind = SocketAddr::from((IPV4_BIND_ALL, 0));
let v = Arc::new(ResolveConfEntry::new(SocketAddr::new(ip0, 53), None, bind).unwrap());
let res =
NetworkTap::<UdpSocket>::new_udp(v, true, Duration::from_secs(5))
.unwrap();
}
}