use std::array::IntoIter;
use std::convert::TryInto;
use std::error::Error;
use std::future::Future;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::sync::Arc;
use async_channel::Sender;
use async_trait::async_trait;
#[cfg(test)]
use mockall::automock;
use tokio::io::{duplex, split};
use tokio::net::UdpSocket;
use tokio::time::{timeout, Duration, Instant};
use trust_dns_client::op::{Header, MessageType, OpCode};
use trust_dns_client::proto::rr::rdata::TXT;
use trust_dns_client::proto::rr::{Name, RData, Record, RecordType};
use trust_dns_server::authority::{MessageRequest, MessageResponse, MessageResponseBuilder};
use trust_dns_server::server::{Request, RequestHandler, ResponseHandler};
use trust_dns_server::ServerFuture;
use common::cache::{StreamCreator, StreamsCache};
use common::dns::{
AppendSuffixDecoder, ClientId, ClientIdSuffixDecoder, Decoder, Encoder, HexDecoder, HexEncoder,
CLIENT_ID_SIZE_IN_BYTES,
};
use common::io::Stream;
use crate::tunnel::Untunneler;
pub(crate) struct DnsUntunneler {
listener_address: SocketAddr,
read_timeout: Duration,
idle_client_timeout: Duration,
client_suffix: String,
}
impl DnsUntunneler {
pub(crate) async fn new(
local_address: IpAddr,
local_port: u16,
read_timeout: Duration,
idle_client_timeout: Duration,
client_suffix: String,
) -> Result<Self, Box<dyn Error>> {
let listener_address = SocketAddr::new(local_address, local_port);
Ok(Self {
listener_address,
read_timeout,
idle_client_timeout,
client_suffix,
})
}
}
#[async_trait(? Send)]
impl Untunneler for DnsUntunneler {
async fn untunnel(&mut self, new_clients: Sender<Stream>) -> Result<(), Box<dyn Error>> {
let cache = StreamsCache::with_default_cleanup_duration(
move || {
let (local, remote) = duplex(4096);
let (remote_reader, remote_writer) = split(remote);
new_clients.try_send(Stream::new(remote_reader, remote_writer))?;
let (local_reader, local_writer) = split(local);
Ok(Stream::new(local_reader, local_writer))
},
self.idle_client_timeout,
);
let udp_socket = UdpSocket::bind(self.listener_address).await?;
let mut server = ServerFuture::new(UntunnelRequestHandler::new(
cache,
self.read_timeout,
self.client_suffix.clone(),
));
server.register_socket(udp_socket);
match server.block_until_done().await {
Ok(_) => Ok(()),
Err(e) => Err(e.into()),
}
}
}
#[cfg_attr(test, automock)]
pub(crate) trait DnsResponseHandler {
#[allow(clippy::needless_lifetimes)]
fn send_response<'a, 'b>(&mut self, response: MessageResponse<'a, 'b>) -> io::Result<()>;
}
struct ResponseHandlerWrapper<R: ResponseHandler> {
handler: R,
}
impl<R: ResponseHandler> DnsResponseHandler for ResponseHandlerWrapper<R> {
fn send_response(&mut self, response: MessageResponse) -> io::Result<()> {
self.handler.send_response(response)
}
}
const MAXIMUM_TXT_RECORD_SIZE: usize = 54;
pub struct UntunnelRequestHandler<F: StreamCreator> {
untunneled_clients: Arc<StreamsCache<F, ClientId>>,
read_timeout: Duration,
client_suffix: String,
}
impl<F: StreamCreator> UntunnelRequestHandler<F> {
pub(crate) fn new(
cache: StreamsCache<F, ClientId>,
read_timeout: Duration,
client_suffix: String,
) -> Self {
Self {
untunneled_clients: Arc::new(cache),
read_timeout,
client_suffix,
}
}
}
impl<F: StreamCreator> RequestHandler for UntunnelRequestHandler<F> {
type ResponseFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
fn handle_request<R: ResponseHandler>(
&self,
request: Request,
response_handle: R,
) -> Self::ResponseFuture {
Box::pin(untunnel_request(
request,
ResponseHandlerWrapper {
handler: response_handle,
},
self.untunneled_clients.clone(),
AppendSuffixDecoder::new(
ClientIdSuffixDecoder::new(HexDecoder {}),
self.client_suffix.clone(),
),
HexEncoder {},
self.read_timeout,
))
}
}
async fn untunnel_request<R: DnsResponseHandler, F: StreamCreator, D: Decoder, E: Encoder>(
request: Request,
mut response_handle: R,
untunneled_clients: Arc<StreamsCache<F, ClientId>>,
decoder: D,
encoder: E,
read_timeout: Duration,
) {
let message = &request.message;
let encoded_data_from_tunnel = match get_data_from_tunnel(message) {
Some(x) => x,
None => return,
};
log::debug!(
"received from tunnel {}",
encoded_data_from_tunnel.to_string()
);
let mut non_fqdn_encoded_data = encoded_data_from_tunnel.clone();
non_fqdn_encoded_data.set_fqdn(false);
let data_from_tunnel = match decoder.decode(&non_fqdn_encoded_data.to_string()) {
Ok(x) => x,
Err(e) => {
log::error!("{}: failed to decode: {}", message.id(), e);
return;
}
};
let data_to_tunnel = match untunnel_data(
untunneled_clients,
data_from_tunnel,
message.id(),
encoder.calculate_max_decoded_size(MAXIMUM_TXT_RECORD_SIZE),
read_timeout,
)
.await
{
Some(x) => x,
None => return,
};
let encoded_data_to_tunnel = match encoder.encode(&data_to_tunnel) {
Ok(x) => x,
Err(e) => {
log::error!("{}: failed to encode: {}", message.id(), e);
return;
}
};
log::debug!("sending to tunnel {:?}", encoded_data_to_tunnel);
let answer = Record::from_rdata(
encoded_data_from_tunnel,
0,
RData::TXT(TXT::new(vec![encoded_data_to_tunnel])),
);
let response = create_response(message, &answer);
response_handle.send_response(response).unwrap_or_else(|e| {
log::error!("{}: failed to send response: {}", message.id(), e);
});
}
fn get_data_from_tunnel(message: &MessageRequest) -> Option<Name> {
let first_query = match message.queries().len() {
1 => (&message.queries()[0]).original(),
x => {
log::error!("{}: unexpected number of queries {}", message.id(), x);
return None;
}
};
let data_from_tunnel = match first_query.query_type() {
RecordType::TXT => first_query.name().clone(),
x => {
log::error!("{}: unexpected type of query {}", message.id(), x);
return None;
}
};
Some(data_from_tunnel)
}
async fn untunnel_data<F: StreamCreator>(
clients: Arc<StreamsCache<F, ClientId>>,
data: Vec<u8>,
message_id: u16,
data_to_tunnel_max_size: usize,
read_timeout: Duration,
) -> Option<Vec<u8>> {
if data.len() < CLIENT_ID_SIZE_IN_BYTES {
log::error!(
"{}: decode return value too small: {}",
message_id,
data.len()
);
return None;
}
let (data_to_write, client_id) = data.split_at(data.len() - CLIENT_ID_SIZE_IN_BYTES);
let client = match clients.get(client_id.try_into().unwrap(), Instant::now()) {
Ok(x) => x,
Err(e) => {
log::error!("{}: failed to get client: {}", message_id, e);
return None;
}
};
if let Err(e) = client.lock().await.writer.write(data_to_write).await {
log::error!("failed to write to client: {}", e);
return None;
};
let mut data_to_tunnel = vec![0; data_to_tunnel_max_size];
let read_result = match timeout(
read_timeout,
client.lock().await.reader.read(&mut data_to_tunnel),
)
.await
{
Ok(x) => x,
Err(_) => {
data_to_tunnel.truncate(0);
return Some(data_to_tunnel);
}
};
let size = match read_result {
Ok(x) => x,
Err(e) => {
log::error!("failed to read from client: {}", e);
return None;
}
};
data_to_tunnel.truncate(size);
Some(data_to_tunnel)
}
fn create_response<'a>(message: &'a MessageRequest, answer: &'a Record) -> MessageResponse<'a, 'a> {
let builder = MessageResponseBuilder::new(Option::from(message.raw_queries()));
let mut response_header = Header::new();
response_header.set_id(message.id());
response_header.set_op_code(OpCode::Query);
response_header.set_message_type(MessageType::Response);
response_header.set_authoritative(true);
builder.build(
response_header,
new_iterator(Some(answer)),
new_iterator(None),
new_iterator(None),
new_iterator(None),
)
}
fn new_iterator<'a>(
record: Option<&'a Record>,
) -> Box<dyn Iterator<Item = &'a Record> + Send + 'a> {
match record {
None => Box::new(IntoIter::new([])),
Some(r) => Box::new(IntoIter::new([r])),
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use mockall::mock;
use tokio::io::ErrorKind;
use tokio_test::io::Builder;
use trust_dns_client::op::Message;
use trust_dns_client::proto::serialize::binary::BinDecodable;
use trust_dns_server::authority::MessageRequest;
use trust_dns_server::proto::op::Query;
use super::*;
mock! {
Encoder{}
impl Encoder for Encoder {
fn calculate_max_decoded_size(&self, max_encoded_size: usize) -> usize;
fn encode(&self, data: &[u8]) -> Result<String, Box<dyn Error>>;
}
}
mock! {
Decoder{}
impl Decoder for Decoder {
fn decode(&self, data: &str) -> Result<Vec<u8>, Box<dyn Error>>;
}
}
#[tokio::test]
async fn dns_empty_message() -> Result<(), Box<dyn Error>> {
let message = Message::default();
let message_bytes = message.to_vec().unwrap();
let request = MessageRequest::from_bytes(&message_bytes).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let request = Request {
message: request,
src: socket,
};
let handler_mock = MockDnsResponseHandler::new();
let cache = StreamsCache::with_default_cleanup_duration(
|| Ok(Stream::new(Builder::new().build(), Builder::new().build())),
Duration::from_secs(3 * 60),
);
let decoder_mock = MockDecoder::new();
let encoder_mock = MockEncoder::new();
untunnel_request(
request,
handler_mock,
Arc::new(cache),
decoder_mock,
encoder_mock,
Duration::from_millis(100),
)
.await;
Ok(())
}
#[tokio::test]
async fn dns_multiple_queries_message() -> Result<(), Box<dyn Error>> {
let mut message = Message::default();
let queries = message.queries_mut();
queries.push(Query::default());
queries.push(Query::default());
let message_bytes = message.to_vec().unwrap();
let request = MessageRequest::from_bytes(&message_bytes).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let request = Request {
message: request,
src: socket,
};
let handler_mock = MockDnsResponseHandler::new();
let cache = StreamsCache::with_default_cleanup_duration(
|| Ok(Stream::new(Builder::new().build(), Builder::new().build())),
Duration::from_secs(3 * 60),
);
let decoder_mock = MockDecoder::new();
let encoder_mock = MockEncoder::new();
untunnel_request(
request,
handler_mock,
Arc::new(cache),
decoder_mock,
encoder_mock,
Duration::from_millis(100),
)
.await;
Ok(())
}
#[tokio::test]
async fn dns_non_txt_query_message() -> Result<(), Box<dyn Error>> {
let mut message = Message::default();
message.queries_mut().push(Query::default());
let message_bytes = message.to_vec().unwrap();
let request = MessageRequest::from_bytes(&message_bytes).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let request = Request {
message: request,
src: socket,
};
let handler_mock = MockDnsResponseHandler::new();
let cache = StreamsCache::with_default_cleanup_duration(
|| Ok(Stream::new(Builder::new().build(), Builder::new().build())),
Duration::from_secs(3 * 60),
);
let decoder_mock = MockDecoder::new();
let encoder_mock = MockEncoder::new();
untunnel_request(
request,
handler_mock,
Arc::new(cache),
decoder_mock,
encoder_mock,
Duration::from_millis(100),
)
.await;
Ok(())
}
#[tokio::test]
async fn dns_failed_to_decode() -> Result<(), Box<dyn Error>> {
let mut message = Message::default();
let mut query = Query::default();
query.set_query_type(RecordType::TXT);
message.queries_mut().push(query);
let message_bytes = message.to_vec().unwrap();
let request = MessageRequest::from_bytes(&message_bytes).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let request = Request {
message: request,
src: socket,
};
let handler_mock = MockDnsResponseHandler::new();
let cache = StreamsCache::with_default_cleanup_duration(
|| Ok(Stream::new(Builder::new().build(), Builder::new().build())),
Duration::from_secs(3 * 60),
);
let mut decoder_mock = MockDecoder::new();
decoder_mock
.expect_decode()
.returning(|_| Err(String::from("bla").into()));
let encoder_mock = MockEncoder::new();
untunnel_request(
request,
handler_mock,
Arc::new(cache),
decoder_mock,
encoder_mock,
Duration::from_millis(100),
)
.await;
Ok(())
}
#[tokio::test]
async fn dns_decode_return_value_too_small() -> Result<(), Box<dyn Error>> {
let mut message = Message::default();
let mut query = Query::default();
query.set_query_type(RecordType::TXT);
message.queries_mut().push(query);
let message_bytes = message.to_vec().unwrap();
let request = MessageRequest::from_bytes(&message_bytes).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let request = Request {
message: request,
src: socket,
};
let handler_mock = MockDnsResponseHandler::new();
let cache = StreamsCache::with_default_cleanup_duration(
|| Ok(Stream::new(Builder::new().build(), Builder::new().build())),
Duration::from_secs(3 * 60),
);
let mut decoder_mock = MockDecoder::new();
decoder_mock
.expect_decode()
.returning(|_| Ok(String::from("12").into_bytes()));
let mut encoder_mock = MockEncoder::new();
encoder_mock
.expect_calculate_max_decoded_size()
.return_const(17 as usize);
untunnel_request(
request,
handler_mock,
Arc::new(cache),
decoder_mock,
encoder_mock,
Duration::from_millis(100),
)
.await;
Ok(())
}
#[tokio::test]
async fn dns_failed_to_get_client() -> Result<(), Box<dyn Error>> {
let mut message = Message::default();
let mut query = Query::default();
query.set_query_type(RecordType::TXT);
message.queries_mut().push(query);
let message_bytes = message.to_vec().unwrap();
let request = MessageRequest::from_bytes(&message_bytes).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let request = Request {
message: request,
src: socket,
};
let handler_mock = MockDnsResponseHandler::new();
let cache = StreamsCache::with_default_cleanup_duration(
|| Err(String::from("bla").into()),
Duration::from_secs(3 * 60),
);
let mut decoder_mock = MockDecoder::new();
decoder_mock
.expect_decode()
.returning(|_| Ok(String::from("bla1234").into_bytes()));
let mut encoder_mock = MockEncoder::new();
encoder_mock
.expect_calculate_max_decoded_size()
.return_const(17 as usize);
untunnel_request(
request,
handler_mock,
Arc::new(cache),
decoder_mock,
encoder_mock,
Duration::from_millis(100),
)
.await;
Ok(())
}
#[tokio::test]
async fn dns_failed_to_write_to_client() -> Result<(), Box<dyn Error>> {
let mut message = Message::default();
let mut query = Query::default();
query.set_query_type(RecordType::TXT);
message.queries_mut().push(query);
let message_bytes = message.to_vec().unwrap();
let request = MessageRequest::from_bytes(&message_bytes).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let request = Request {
message: request,
src: socket,
};
let handler_mock = MockDnsResponseHandler::new();
let cache = StreamsCache::with_default_cleanup_duration(
|| {
Ok(Stream::new(
Builder::new().build(),
Builder::new()
.write_error(io::Error::new(ErrorKind::Other, "oh no!"))
.build(),
))
},
Duration::from_secs(3 * 60),
);
let mut decoder_mock = MockDecoder::new();
decoder_mock
.expect_decode()
.returning(|_| Ok(String::from("bla1234").into_bytes()));
let mut encoder_mock = MockEncoder::new();
encoder_mock
.expect_calculate_max_decoded_size()
.return_const(17 as usize);
untunnel_request(
request,
handler_mock,
Arc::new(cache),
decoder_mock,
encoder_mock,
Duration::from_millis(100),
)
.await;
Ok(())
}
#[tokio::test]
async fn dns_failed_to_read_from_client() -> Result<(), Box<dyn Error>> {
let mut message = Message::default();
let mut query = Query::default();
query.set_query_type(RecordType::TXT);
message.queries_mut().push(query);
let message_bytes = message.to_vec().unwrap();
let request = MessageRequest::from_bytes(&message_bytes).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let request = Request {
message: request,
src: socket,
};
let handler_mock = MockDnsResponseHandler::new();
let cache = StreamsCache::with_default_cleanup_duration(
|| {
Ok(Stream::new(
Builder::new()
.read_error(io::Error::new(ErrorKind::Other, "oh no!"))
.build(),
Builder::new().write(b"bla").build(),
))
},
Duration::from_secs(3 * 60),
);
let mut decoder_mock = MockDecoder::new();
decoder_mock
.expect_decode()
.returning(|_| Ok(String::from("bla1234").into_bytes()));
let mut encoder_mock = MockEncoder::new();
encoder_mock
.expect_calculate_max_decoded_size()
.return_const(17 as usize);
untunnel_request(
request,
handler_mock,
Arc::new(cache),
decoder_mock,
encoder_mock,
Duration::from_millis(100),
)
.await;
Ok(())
}
#[tokio::test]
async fn dns_failed_to_encode() -> Result<(), Box<dyn Error>> {
let mut message = Message::default();
let mut query = Query::default();
query.set_query_type(RecordType::TXT);
message.queries_mut().push(query);
let message_bytes = message.to_vec().unwrap();
let request = MessageRequest::from_bytes(&message_bytes).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let request = Request {
message: request,
src: socket,
};
let handler_mock = MockDnsResponseHandler::new();
let cache = StreamsCache::with_default_cleanup_duration(
|| {
Ok(Stream::new(
Builder::new().read(b"bli").build(),
Builder::new().write(b"bla").build(),
))
},
Duration::from_secs(3 * 60),
);
let mut decoder_mock = MockDecoder::new();
decoder_mock
.expect_decode()
.returning(|_| Ok(String::from("bla1234").into_bytes()));
let mut encoder_mock = MockEncoder::new();
encoder_mock
.expect_calculate_max_decoded_size()
.return_const(17 as usize);
encoder_mock
.expect_encode()
.returning(|_| Err(String::from("bla").into()));
untunnel_request(
request,
handler_mock,
Arc::new(cache),
decoder_mock,
encoder_mock,
Duration::from_millis(100),
)
.await;
Ok(())
}
#[tokio::test]
async fn dns_failed_to_send_response() -> Result<(), Box<dyn Error>> {
let mut message = Message::default();
let mut query = Query::default();
query.set_query_type(RecordType::TXT);
message.queries_mut().push(query);
let message_bytes = message.to_vec().unwrap();
let request = MessageRequest::from_bytes(&message_bytes).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let request = Request {
message: request,
src: socket,
};
let mut handler_mock = MockDnsResponseHandler::new();
handler_mock
.expect_send_response()
.returning(|_| Err(io::Error::new(ErrorKind::Other, "oh no!")));
let cache = StreamsCache::with_default_cleanup_duration(
|| {
Ok(Stream::new(
Builder::new().read(b"bli").build(),
Builder::new().write(b"bla").build(),
))
},
Duration::from_secs(3 * 60),
);
let mut decoder_mock = MockDecoder::new();
decoder_mock
.expect_decode()
.returning(|_| Ok(String::from("bla1234").into_bytes()));
let mut encoder_mock = MockEncoder::new();
encoder_mock
.expect_calculate_max_decoded_size()
.return_const(17 as usize);
encoder_mock
.expect_encode()
.returning(|_| Ok(String::from("encoded")));
untunnel_request(
request,
handler_mock,
Arc::new(cache),
decoder_mock,
encoder_mock,
Duration::from_millis(100),
)
.await;
Ok(())
}
#[tokio::test]
async fn dns_success() -> Result<(), Box<dyn Error>> {
let mut message = Message::default();
let mut query = Query::default();
query.set_query_type(RecordType::TXT);
message.queries_mut().push(query);
let message_bytes = message.to_vec().unwrap();
let request = MessageRequest::from_bytes(&message_bytes).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let request = Request {
message: request,
src: socket,
};
let mut handler_mock = MockDnsResponseHandler::new();
handler_mock.expect_send_response().returning(|_| Ok(()));
let cache = StreamsCache::with_default_cleanup_duration(
|| {
Ok(Stream::new(
Builder::new().read(b"bli").build(),
Builder::new().write(b"bla").build(),
))
},
Duration::from_secs(3 * 60),
);
let mut decoder_mock = MockDecoder::new();
decoder_mock
.expect_decode()
.returning(|_| Ok(String::from("bla1234").into_bytes()));
let mut encoder_mock = MockEncoder::new();
encoder_mock
.expect_calculate_max_decoded_size()
.return_const(17 as usize);
encoder_mock
.expect_encode()
.returning(|_| Ok(String::from("encoded")));
untunnel_request(
request,
handler_mock,
Arc::new(cache),
decoder_mock,
encoder_mock,
Duration::from_millis(100),
)
.await;
Ok(())
}
}