use crate::Body;
use crate::HttpError;
use http::HeaderName;
use hyper::Request;
use semver::Version;
use slog::Logger;
use std::str::FromStr;
#[derive(Debug)]
pub enum VersionPolicy {
Unversioned,
Dynamic(Box<dyn DynamicVersionPolicy>),
}
impl VersionPolicy {
pub(crate) fn request_version(
&self,
request: &Request<Body>,
request_log: &Logger,
) -> Result<Option<Version>, HttpError> {
match self {
VersionPolicy::Unversioned => Ok(None),
VersionPolicy::Dynamic(vers_impl) => {
let result =
vers_impl.request_extract_version(request, request_log);
match &result {
Ok(version) => {
debug!(request_log, "determined request API version";
"version" => %version,
);
}
Err(error) => {
error!(
request_log,
"failed to determine request API version";
"error" => ?error,
);
}
}
result.map(Some)
}
}
}
}
pub trait DynamicVersionPolicy: std::fmt::Debug + Send + Sync {
fn request_extract_version(
&self,
request: &Request<Body>,
log: &Logger,
) -> Result<Version, HttpError>;
}
#[derive(Debug)]
pub struct ClientSpecifiesVersionInHeader {
name: HeaderName,
max_version: Version,
on_missing: Option<Version>,
}
impl ClientSpecifiesVersionInHeader {
pub fn new(
name: HeaderName,
max_version: Version,
) -> ClientSpecifiesVersionInHeader {
ClientSpecifiesVersionInHeader { name, max_version, on_missing: None }
}
pub fn on_missing(mut self, version: Version) -> Self {
self.on_missing = Some(version);
self
}
}
impl DynamicVersionPolicy for ClientSpecifiesVersionInHeader {
fn request_extract_version(
&self,
request: &Request<Body>,
_log: &Logger,
) -> Result<Version, HttpError> {
let v = parse_header(request.headers(), &self.name)?;
match (v, &self.on_missing) {
(Some(v), _) => {
if v <= self.max_version {
Ok(v)
} else {
Err(HttpError::for_bad_request(
None,
format!(
"server does not support this API version: {}",
v
),
))
}
}
(None, Some(on_missing)) => Ok(on_missing.clone()),
(None, None) => Err(HttpError::for_bad_request(
None,
format!("missing expected header {:?}", self.name),
)),
}
}
}
fn parse_header<T>(
headers: &http::HeaderMap,
header_name: &HeaderName,
) -> Result<Option<T>, HttpError>
where
T: FromStr,
<T as FromStr>::Err: std::fmt::Display,
{
let Some(v_value) = headers.get(header_name) else { return Ok(None) };
let v_str = v_value.to_str().map_err(|_| {
HttpError::for_bad_request(
None,
format!(
"bad value for header {:?}: not ASCII: {:?}",
header_name, v_value
),
)
})?;
let v = v_str.parse::<T>().map_err(|e| {
HttpError::for_bad_request(
None,
format!("bad value for header {:?}: {}: {}", header_name, e, v_str),
)
})?;
Ok(Some(v))
}