use crate::bail;
use crate::clients::mime::content_type_equal;
use crate::clients::AsyncExchanger;
use crate::clients::ToUrls;
use crate::Message;
use crate::clients::stats::StatsBuilder;
use async_trait::async_trait;
use http::header::*;
use http::{Method, Request};
use hyper::client::connect::HttpInfo;
use hyper::{Body, Client as HyperClient};
use hyper_alpn::AlpnConnector;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::SocketAddr;
use std::time::Duration;
use url::Url;
pub const GOOGLE: &str = "https://dns.google/dns-query";
const CONTENT_TYPE_APPLICATION_DNS_MESSAGE: &str = "application/dns-message";
const DNS_QUERY_PARAM: &str = "dns";
pub struct Client {
servers: Vec<Url>,
method: Method, }
impl Default for Client {
fn default() -> Self {
Client {
servers: Vec::default(),
method: Method::GET,
}
}
}
impl Client {
pub fn new<A: ToUrls>(servers: A, method: Method) -> Result<Self, crate::Error> {
match method {
Method::GET | Method::POST => (), _ => bail!(InvalidInput, "only GET and POST allowed"),
}
Ok(Self {
servers: servers.to_urls()?.collect(),
method,
})
}
}
#[async_trait]
impl AsyncExchanger for Client {
async fn exchange(&self, query: &Message) -> Result<Message, crate::Error> {
let mut query = query.clone();
query.id = 0;
let p = query.to_vec()?;
let alpn = AlpnConnector::new();
let client = HyperClient::builder()
.pool_idle_timeout(Duration::from_secs(30))
.http2_only(true) .build::<_, hyper::Body>(alpn);
let req = Request::builder()
.method(&self.method)
.header(ACCEPT, CONTENT_TYPE_APPLICATION_DNS_MESSAGE);
let req = match self.method {
Method::GET => {
let mut buf = String::new();
base64::encode_config_buf(p, base64::URL_SAFE_NO_PAD, &mut buf);
let mut url = self.servers[0].clone(); url.query_pairs_mut().append_pair(DNS_QUERY_PARAM, &buf);
let uri: hyper::Uri = url.as_str().parse()?;
req.uri(uri).body(Body::empty())
}
Method::POST => {
req.uri(self.servers[0].as_str()) .header(CONTENT_TYPE, CONTENT_TYPE_APPLICATION_DNS_MESSAGE)
.body(Body::from(p)) }
_ => bail!(InvalidInput, "only GET and POST allowed"),
};
let stats = StatsBuilder::start(0);
let resp = client.request(req.unwrap()).await?;
if let Some(content_type) = resp.headers().get(CONTENT_TYPE) {
if !content_type_equal(content_type, CONTENT_TYPE_APPLICATION_DNS_MESSAGE) {
bail!(
InvalidData,
"recevied invalid content-type: {:?} expected {}",
content_type,
CONTENT_TYPE_APPLICATION_DNS_MESSAGE,
);
}
}
if resp.status().is_success() {
let remote_addr = match resp.extensions().get::<HttpInfo>() {
Some(http_info) => http_info.remote_addr(),
None => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0), };
let body = hyper::body::to_bytes(resp.into_body()).await?;
let mut m = Message::from_slice(&body)?;
m.stats = Some(stats.end(remote_addr, body.len()));
return Ok(m);
}
bail!(
InvalidInput,
"recevied unexpected HTTP status code: {:}",
resp.status()
);
}
}