use std::sync::Arc;
use std::net::SocketAddr;
use std::fs::File;
use std::io::{BufReader, Error, ErrorKind};
use std::thread::sleep;
use std::time::Duration;
use data_encoding::BASE64URL_NOPAD;
use tokio_rustls::{Connect, TlsConnector, TlsStream};
use tokio::timer::Timeout;
use tokio::net::TcpStream;
use tokio::net::tcp::ConnectFuture;
use tokio::prelude::FutureExt;
use rustls::{ClientSession, ClientConfig};
use webpki::DNSNameRef;
use futures_locks::{Mutex, MutexFut, MutexGuard};
use futures::{Async, Future, Stream};
use h2::client::{SendRequest, Handshake, ResponseFuture, Connection, handshake};
use h2::RecvStream;
use http::Request;
use bytes::Bytes;
use cache::Cache;
use dns::DnsPacket;
use ::Context;
pub fn create_config(cafile: &str) -> Result<ClientConfig, Error> {
let certfile = File::open(&cafile)?;
let mut config = ClientConfig::new();
if let Err(()) = config.root_store.add_pem_file(&mut BufReader::new(certfile)) {
return Err(Error::new(ErrorKind::Other, "Cannot parse pem file"));
}
config.alpn_protocols.push(vec![104, 50]); // h2
Ok(config)
}
enum Http2RequestState {
GetMutexCache(MutexFut<Cache<Bytes, Bytes>>),
GetMutexSendRequest(MutexFut<(Option<SendRequest<Bytes>>, u32)>),
GetConnection(MutexGuard<(Option<SendRequest<Bytes>>, u32)>, Http2ConnectionFuture, u32),
GetResponse(Timeout<Http2ResponseFuture>, u32),
CloseConnection(MutexFut<(Option<SendRequest<Bytes>>, u32)>, u32),
GetMutexCacheFallback(MutexFut<Cache<Bytes, Bytes>>),
SaveInCache(MutexFut<Cache<Bytes, Bytes>>, Bytes, Duration),
}
pub struct Http2RequestFuture {
mutex_send_request: Mutex<(Option<SendRequest<Bytes>>, u32)>,
mutex_cache: Mutex<Cache<Bytes, Bytes>>,
state: Http2RequestState,
context: &'static Context,
msg: DnsPacket,
addr: SocketAddr,
}
impl Http2RequestFuture {
pub fn new(mutex_send_request: Mutex<(Option<SendRequest<Bytes>>, u32)>, mutex_cache: Mutex<Cache<Bytes, Bytes>>, msg: DnsPacket, addr: SocketAddr, context: &'static Context) -> Http2RequestFuture {
use self::Http2RequestState::{GetMutexCache, GetMutexSendRequest};
debug!("Received UDP packet from {} {:#?}", addr, msg.get_tid());
let state = if context.config.cache_size == 0 {
GetMutexSendRequest(mutex_send_request.lock())
} else {
GetMutexCache(mutex_cache.lock())
};
Http2RequestFuture { mutex_send_request, mutex_cache, state, msg, addr, context }
}
}
macro_rules! send_request {
($a:ident, $b:ident) => {
{
let config = &$a.context.config;
let post = config.post;
let msg = &$a.msg;
let request = if post {
Request::builder()
.method("POST")
.uri(config.uri.clone())
.header("accept", "application/dns-message")
.header("content-type", "application/dns-message")
.header("content-length", msg.len().to_string())
.body(())
.unwrap()
} else {
Request::builder()
.method("GET")
.uri(format!("{}?dns={}", config.uri, BASE64URL_NOPAD.encode(&msg.get_without_tid())))
.header("accept", "application/dns-message")
.body(())
.unwrap()
};
let id = (*$b).1;
match (*$b).0 {
Some(ref mut send_request) => {
match send_request.send_request(request, false) {
Ok((response, mut request)) => {
if post {
match request.send_data(msg.get_without_tid(), true) {
Ok(()) => GetResponse(Http2ResponseFuture::new(response).timeout(Duration::from_secs(config.timeout)), id),
Err(e) => {
error!("send_data: {}", e);
CloseConnection($a.mutex_send_request.lock(), id)
}
}
} else {
GetResponse(Http2ResponseFuture::new(response).timeout(Duration::from_secs(config.timeout)), id)
}
},
Err(e) => {
error!("send_request: {}", e);
CloseConnection($a.mutex_send_request.lock(), id)
}
}
},
None => return Err(())
}
}
}
}
impl Future for Http2RequestFuture {
type Item = ();
type Error = ();
fn poll(&mut self) -> Result<Async<()>, ()> {
use self::Http2RequestState::*;
use self::Async::*;
loop {
self.state = match self.state {
GetMutexCache(ref mut mutex_fut) => {
match mutex_fut.poll() {
Ok(async_) => {
match async_ {
Ready(mut guard) => {
let result = if self.context.config.cache_fallback {
(*guard).get_expired(&self.msg.get_without_tid())
} else {
(*guard).get(&self.msg.get_without_tid())
};
match result {
Some(buffer) => {
debug!("GetMutexCache: found in cache");
match DnsPacket::from_tid((*buffer).clone(), self.msg.get_tid()) {
Ok(dns) => {
match self.context.sender.unbounded_send((dns, self.addr)) {
Ok(()) => return Ok(Ready(())),
Err(e) => {
error!("GetMutexCache: unbounded_send: {}", e);
return Err(());
}
}
}
Err(e) => {
error!("GetMutexCache: parse error: {}", e);
GetMutexSendRequest(self.mutex_send_request.lock())
}
}
}
None => {
debug!("GetMutexCache: missing in cache");
GetMutexSendRequest(self.mutex_send_request.lock())
}
}
}
NotReady => return Ok(NotReady)
}
}
Err(_e) => {
error!("GetMutexCache: could not get mutex");
Http2RequestState::GetMutexSendRequest(self.mutex_send_request.lock())
}
}
}
GetMutexSendRequest(ref mut mutex_fut) => {
let config = &self.context.config;
match mutex_fut.poll() {
Ok(async_) => {
match async_ {
Ready(mut guard) => {
if (*guard).0.is_some() {
send_request!(self, guard)
} else {
GetConnection(guard, Http2ConnectionFuture::new(config.remote_addr, config.client_config.clone(), config.domain.clone()), 1)
}
}
NotReady => return Ok(NotReady)
}
}
Err(_e) => {
error!("GetMutexSendRequest: could not get mutex");
return Err(());
}
}
}
GetConnection(ref mut guard, ref mut http2_connection_future, ref mut tries) => {
let config = &self.context.config;
match http2_connection_future.poll() {
Ok(async_) => {
match async_ {
Ready((mut send_request, connection)) => {
tokio::spawn(connection.map_err(|e| {
error!("GetConnection: H2 connection error: {}", e)
}));
info!("GetConnection: connection was successfully established to remote server {} ({})", config.remote_addr, config.domain);
(*guard).0.replace(send_request);
(*guard).1 += 1;
send_request!(self, guard)
}
NotReady => return Ok(NotReady)
}
}
Err(e) => {
error!("GetConnection: connection to remote server {} ({}) failed: {}: retry: {}", config.remote_addr, config.domain, e, *tries);
sleep(Duration::from_secs(1));
if config.retries > *tries {
*tries += 1;
*http2_connection_future = Http2ConnectionFuture::new(config.remote_addr, config.client_config.clone(), config.domain.clone());
continue;
} else {
error!("GetConnection: too many connection attempts to remote server {} ({})", config.remote_addr, config.domain);
if self.context.config.cache_fallback {
GetMutexCacheFallback(self.mutex_cache.lock())
} else {
return Err(());
}
}
}
}
}
GetResponse(ref mut http2_response_future, ref id) => {
match http2_response_future.poll() {
Ok(async_) => {
match async_ {
Ready(result) => {
let (buffer, duration) = result;
match DnsPacket::from_tid(buffer, self.msg.get_tid()) {
Ok(dns) => {
if dns.is_response() {
let context = &self.context;
match context.sender.unbounded_send((dns.clone(), self.addr)) {
Ok(()) => {
if context.config.cache_size == 0 {
return Ok(Ready(()));
} else {
if let Some(duration) = duration {
SaveInCache(self.mutex_cache.lock(), dns.get_without_tid(), duration)
} else {
return Ok(Ready(()));
}
}
}
Err(e) => {
error!("GetResponse: unbounded_send: {}", e);
return Err(());
}
}
} else {
error!("GetResponse: get a non DNS response");
return Err(());
}
}
Err(e) => {
error!("GetResponse: DNS parser error: {}", e);
return Err(());
}
}
}
NotReady => return Ok(NotReady)
}
}
Err(_e) => {
error!("GetResponse: timeout");
CloseConnection(self.mutex_send_request.lock(), *id)
}
}
}
CloseConnection(ref mut mutex_fut, ref id) => {
match mutex_fut.poll() {
Ok(async_) => {
match async_ {
Ready(mut guard) => {
if (*guard).1 == *id {
(*guard).0.take();
}
if self.context.config.cache_fallback {
GetMutexCacheFallback(self.mutex_cache.lock())
} else {
return Err(());
}
}
NotReady => return Ok(NotReady)
}
}
Err(_e) => {
error!("CloseConnection: could not get mutex");
return Err(());
}
}
}
GetMutexCacheFallback(ref mut mutex_fut) => {
match mutex_fut.poll() {
Ok(async_) => {
match async_ {
Ready(mut guard) => {
match (*guard).get_expired_fallback(&self.msg.get_without_tid()) {
Some(buffer) => {
debug!("GetMutexCacheFallback: found in cache");
match DnsPacket::from_tid((*buffer).clone(), self.msg.get_tid()) {
Ok(dns) => {
match self.context.sender.unbounded_send((dns, self.addr)) {
Ok(()) => return Ok(Ready(())),
Err(e) => {
error!("GetMutexCache: unbounded_send: {}", e);
return Err(());
}
}
}
Err(e) => {
error!("GetMutexCacheFallback: parse error: {}", e);
return Err(())
}
}
}
None => {
debug!("GetMutexCacheFallback: missing in cache");
return Err(())
}
}
}
NotReady => return Ok(NotReady)
}
}
Err(_e) => {
error!("GetMutexCacheFallback: could not get mutex");
return Err(());
}
}
}
SaveInCache(ref mut mutex_fut, ref buffer, ref duration) => {
match mutex_fut.poll() {
Ok(async_) => {
match async_ {
Ready(mut guard) => {
(*guard).put(self.msg.get_without_tid(), buffer.clone(), duration.clone());
return Ok(Ready(()));
}
NotReady => return Ok(NotReady)
}
}
Err(_e) => {
error!("SaveInCache: could not get mutex");
return Err(());
}
}
}
}
}
}
}
enum Http2ConnectionState {
GetTcpConnection(ConnectFuture),
GetTlsConnection(Connect<TcpStream>),
GetHttp2Connection(Handshake<TlsStream<TcpStream, ClientSession>, Bytes>),
}
pub struct Http2ConnectionFuture {
state: Http2ConnectionState,
tls_connector: TlsConnector,
domain: String,
}
impl Http2ConnectionFuture {
pub fn new(remote_addr: SocketAddr, config: ClientConfig, domain: String) -> Http2ConnectionFuture {
Http2ConnectionFuture { state: Http2ConnectionState::GetTcpConnection(TcpStream::connect(&remote_addr)), tls_connector: TlsConnector::from(Arc::new(config)), domain }
}
}
impl Future for Http2ConnectionFuture {
type Item = (SendRequest<Bytes>, Connection<TlsStream<TcpStream, ClientSession>>);
type Error = Error;
fn poll(&mut self) -> Result<Async<(SendRequest<Bytes>, Connection<TlsStream<TcpStream, ClientSession>>)>, Error> {
use self::Http2ConnectionState::*;
use self::Async::*;
loop {
self.state = match self.state {
GetTcpConnection(ref mut future) => {
match future.poll() {
Ok(async_) => {
match async_ {
Ready(tcp) => {
if let Err(e) = tcp.set_keepalive(Some(Duration::from_secs(1))) {
error!("GetTcpConnection: could not set keepalive on TCP: {}", e);
}
if let Err(e) = tcp.set_nodelay(true) {
error!("GetTcpConnection: could not set nodelay on TCP: {}", e);
}
GetTlsConnection(self.tls_connector.connect(DNSNameRef::try_from_ascii_str(&self.domain).unwrap(), tcp))
}
NotReady => return Ok(NotReady),
}
}
Err(e) => return Err(e)
}
}
GetTlsConnection(ref mut connect) => {
match connect.poll() {
Ok(async_) => {
match async_ {
Ready(tls) => GetHttp2Connection(handshake(tls)),
NotReady => return Ok(NotReady),
}
}
Err(e) => return Err(e)
}
}
GetHttp2Connection(ref mut handshake) => {
match handshake.poll() {
Ok(async_) => return Ok(async_),
Err(e) => return Err(Error::new(ErrorKind::Other, e))
}
}
}
}
}
}
enum Http2ResponseState {
GetResponse(ResponseFuture),
GetBody(RecvStream),
}
pub struct Http2ResponseFuture {
state: Http2ResponseState,
buffer: Bytes,
duration: Option<Duration>,
}
impl Http2ResponseFuture {
pub fn new(response_future: ResponseFuture) -> Http2ResponseFuture {
Http2ResponseFuture { state: Http2ResponseState::GetResponse(response_future), buffer: Bytes::new(), duration: None }
}
}
impl Future for Http2ResponseFuture {
type Item = (Bytes, Option<Duration>);
type Error = ();
fn poll(&mut self) -> Result<Async<(Bytes, Option<Duration>)>, ()> {
use self::Http2ResponseState::*;
use self::Async::*;
loop {
self.state = match self.state {
GetResponse(ref mut future) => {
match future.poll() {
Ok(async_) => {
match async_ {
Ready(response) => {
let (header, body) = response.into_parts();
if header.status != 200 {
error!("GetResponse: header.status != 200");
return Err(());
}
let headers = &header.headers;
match headers.get("content-type") {
Some(value) => {
if value != "application/dns-message" {
error!("GetResponse: content-type != application/dns-message");
return Err(());
}
}
None => {
error!("GetResponse: content-type is None");
return Err(());
}
}
if let Some(value) = headers.get("cache-control") {
for i in value.to_str().unwrap().split(",") {
let key_value: Vec<&str> = i.splitn(2, "=").map(|s| s.trim()).collect();
if key_value.len() == 2 && key_value[0] == "max-age" {
if let Ok(value) = key_value[1].parse::<u64>() {
self.duration.replace(Duration::from_secs(value));
}
}
}
}
GetBody(body)
}
NotReady => return Ok(NotReady),
}
}
Err(e) => {
error!("GetResponse: {}", e);
return Err(());
}
}
}
GetBody(ref mut stream) => {
loop {
match stream.poll() {
Ok(async_) => {
match async_ {
Ready(mut body) => {
if let Some(b) = body {
let buffer_len = self.buffer.len();
let b_len = b.len();
if buffer_len < 1024 {
if buffer_len + b_len < 1024 {
self.buffer.extend(b);
} else {
self.buffer.extend(b.slice_to(1024 - buffer_len));
}
}
match stream.release_capacity().release_capacity(b_len) {
Ok(()) => {}
Err(e) => error!("GetBody: release_capacity: {}", e)
}
} else {
return Ok(Ready((self.buffer.clone(), self.duration)));
}
}
NotReady => return Ok(NotReady),
}
}
Err(e) => {
error!("GetBody: {}", e);
return Err(());
}
}
}
}
}
}
}
}