use core::str::FromStr;
use std::sync::Arc;
use http::header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE};
use http::{HeaderMap, HeaderValue, Request, Response, StatusCode, Uri, header, uri};
use tracing::debug;
use crate::error::NetError;
pub(crate) struct RequestContext {
pub(crate) version: Version,
pub(crate) server_name: Arc<str>,
pub(crate) query_path: Arc<str>,
pub(crate) set_headers: Option<Arc<dyn SetHeaders>>,
}
impl RequestContext {
pub(crate) fn build(&self, message_len: usize) -> Result<Request<()>, NetError> {
let mut parts = uri::Parts::default();
parts.path_and_query = Some(
uri::PathAndQuery::try_from(&*self.query_path)
.map_err(|e| NetError::from(format!("invalid DoH path: {e}")))?,
);
parts.scheme = Some(uri::Scheme::HTTPS);
parts.authority = Some(
uri::Authority::from_str(&self.server_name)
.map_err(|e| NetError::from(format!("invalid authority: {e}")))?,
);
let url =
Uri::from_parts(parts).map_err(|e| NetError::from(format!("uri parse error: {e}")))?;
let mut request = Request::builder()
.method("POST")
.uri(url)
.version(self.version.to_http())
.header(CONTENT_TYPE, MIME_APPLICATION_DNS)
.header(ACCEPT, MIME_APPLICATION_DNS)
.header(CONTENT_LENGTH, message_len);
if let Some(headers) = &self.set_headers {
if let Some(map) = request.headers_mut() {
headers.set_headers(map)?;
}
}
request
.body(())
.map_err(|e| NetError::from(format!("http stream errored: {e}")))
}
}
pub fn verify<T>(
version: Version,
name_server: Option<&str>,
query_path: &str,
request: &Request<T>,
) -> Result<(), NetError> {
let uri = request.uri();
if uri.path() != query_path {
return Err(format!("bad path: {}, expected: {}", uri.path(), query_path).into());
}
if Some(&uri::Scheme::HTTPS) != uri.scheme() {
return Err("must be HTTPS scheme".into());
}
if let Some(name_server) = name_server {
if let Some(authority) = uri.authority() {
if authority.host() != name_server {
return Err("incorrect authority".into());
}
} else {
return Err("no authority in HTTPS request".into());
}
}
match request.headers().get(CONTENT_TYPE).map(|v| v.to_str()) {
Some(Ok(ctype)) if ctype == MIME_APPLICATION_DNS => {}
_ => return Err("unsupported content type".into()),
};
match request.headers().get(ACCEPT).map(|v| v.to_str()) {
Some(Ok(ctype)) => {
let mut found = false;
for mime_and_quality in ctype.split(',') {
let mut parts = mime_and_quality.splitn(2, ';');
match parts.next() {
Some(mime) if mime.trim() == MIME_APPLICATION_DNS => {
found = true;
break;
}
Some(mime) if mime.trim() == "application/*" => {
found = true;
break;
}
_ => continue,
}
}
if !found {
return Err("does not accept content type".into());
}
}
Some(Err(e)) => return Err(e.into()),
None => return Err("Accept is unspecified".into()),
};
if request.version() != version.to_http() {
let message = match version {
#[cfg(feature = "__https")]
Version::Http2 => "only HTTP/2 supported",
#[cfg(feature = "__h3")]
Version::Http3 => "only HTTP/3 supported",
};
return Err(message.into());
}
debug!(
"verified request from: {}",
request
.headers()
.get(header::USER_AGENT)
.map(|h| h.to_str().unwrap_or("bad user agent"))
.unwrap_or("unknown user agent")
);
Ok(())
}
pub fn response(version: Version, message_len: usize) -> Result<Response<()>, NetError> {
Response::builder()
.status(StatusCode::OK)
.version(version.to_http())
.header(CONTENT_TYPE, MIME_APPLICATION_DNS)
.header(CONTENT_LENGTH, message_len)
.body(())
.map_err(|e| NetError::from(format!("invalid response: {e}")))
}
#[derive(Clone, Copy, Debug)]
pub enum Version {
#[cfg(feature = "__https")]
Http2,
#[cfg(feature = "__h3")]
Http3,
}
impl Version {
fn to_http(self) -> http::Version {
match self {
#[cfg(feature = "__https")]
Self::Http2 => http::Version::HTTP_2,
#[cfg(feature = "__h3")]
Self::Http3 => http::Version::HTTP_3,
}
}
}
pub trait SetHeaders: Send + Sync + 'static {
fn set_headers(&self, headers: &mut HeaderMap<HeaderValue>) -> Result<(), NetError>;
}
pub(crate) const MIME_APPLICATION_DNS: &str = "application/dns-message";
pub const DEFAULT_DNS_QUERY_PATH: &str = "/dns-query";
#[cfg(test)]
mod tests {
use http::{
HeaderMap,
header::{HeaderName, HeaderValue},
};
use super::*;
#[test]
#[cfg(feature = "__https")]
fn test_new_verify_h2() {
let cx = RequestContext {
version: Version::Http2,
server_name: Arc::from("ns.example.com"),
query_path: Arc::from("/dns-query"),
set_headers: None,
};
let request = cx.build(512).expect("error converting to http");
assert!(
verify(
Version::Http2,
Some("ns.example.com"),
"/dns-query",
&request
)
.is_ok()
);
}
#[test]
#[cfg(feature = "__https")]
fn test_additional_headers() {
let cx = RequestContext {
version: Version::Http2,
server_name: Arc::from("ns.example.com"),
query_path: Arc::from("/dns-query"),
set_headers: Some(Arc::new(vec![(
HeaderName::from_static("test-header"),
HeaderValue::from_static("test-header-value"),
)]) as Arc<dyn SetHeaders>),
};
let request = cx.build(512).expect("error converting to http");
assert!(
verify(
Version::Http2,
Some("ns.example.com"),
"/dns-query",
&request
)
.is_ok()
);
assert_eq!(
request
.headers()
.get(HeaderName::from_static("test-header"))
.expect("header to be set"),
HeaderValue::from_static("test-header-value")
)
}
#[test]
#[cfg(feature = "__h3")]
fn test_new_verify_h3() {
let cx = RequestContext {
version: Version::Http3,
server_name: Arc::from("ns.example.com"),
query_path: Arc::from("/dns-query"),
set_headers: None,
};
let request = cx.build(512).expect("error converting to http");
assert!(
verify(
Version::Http3,
Some("ns.example.com"),
"/dns-query",
&request
)
.is_ok()
);
}
impl SetHeaders for Vec<(HeaderName, HeaderValue)> {
fn set_headers(&self, map: &mut HeaderMap<HeaderValue>) -> Result<(), NetError> {
for (name, value) in self.iter() {
map.insert(name.clone(), value.clone());
}
Ok(())
}
}
}