use crate::Body;
use crate::HttpError;
use http::HeaderName;
use hyper::Request;
use semver::Version;
use slog::Logger;
use std::str::FromStr;
#[derive(Debug)]
pub enum VersionPolicy {
Unversioned,
Dynamic(Box<dyn DynamicVersionPolicy>),
}
impl VersionPolicy {
pub(crate) fn request_version(
&self,
request: &Request<Body>,
request_log: &Logger,
) -> Result<Option<Version>, HttpError> {
match self {
VersionPolicy::Unversioned => Ok(None),
VersionPolicy::Dynamic(vers_impl) => {
let result =
vers_impl.request_extract_version(request, request_log);
match &result {
Ok(version) => {
debug!(request_log, "determined request API version";
"version" => %version,
);
}
Err(error) => {
error!(
request_log,
"failed to determine request API version";
"error" => ?error,
);
}
}
result.map(Some)
}
}
}
}
pub trait DynamicVersionPolicy: std::fmt::Debug + Send + Sync {
fn request_extract_version(
&self,
request: &Request<Body>,
log: &Logger,
) -> Result<Version, HttpError>;
}
#[derive(Debug)]
pub struct ClientSpecifiesVersionInHeader {
name: HeaderName,
max_version: Version,
}
impl ClientSpecifiesVersionInHeader {
pub fn new(
name: HeaderName,
max_version: Version,
) -> ClientSpecifiesVersionInHeader {
ClientSpecifiesVersionInHeader { name, max_version }
}
}
impl DynamicVersionPolicy for ClientSpecifiesVersionInHeader {
fn request_extract_version(
&self,
request: &Request<Body>,
_log: &Logger,
) -> Result<Version, HttpError> {
let v = parse_header(request.headers(), &self.name)?;
if v <= self.max_version {
Ok(v)
} else {
Err(HttpError::for_bad_request(
None,
format!("server does not support this API version: {}", v),
))
}
}
}
fn parse_header<T>(
headers: &http::HeaderMap,
header_name: &HeaderName,
) -> Result<T, HttpError>
where
T: FromStr,
<T as FromStr>::Err: std::fmt::Display,
{
let v_value = headers.get(header_name).ok_or_else(|| {
HttpError::for_bad_request(
None,
format!("missing expected header {:?}", header_name),
)
})?;
let v_str = v_value.to_str().map_err(|_| {
HttpError::for_bad_request(
None,
format!(
"bad value for header {:?}: not ASCII: {:?}",
header_name, v_value
),
)
})?;
v_str.parse::<T>().map_err(|e| {
HttpError::for_bad_request(
None,
format!("bad value for header {:?}: {}: {}", header_name, e, v_str),
)
})
}