use super::Method;
use actix_web::{Error as ActixError, HttpRequest, HttpResponse};
use futures::{Future, IntoFuture};
pub enum Continuation {
Continue,
Terminate(Box<dyn Future<Item = HttpResponse, Error = ActixError>>),
}
impl Continuation {
pub fn terminate<F>(fut: F) -> Continuation
where
F: Future<Item = HttpResponse, Error = ActixError> + 'static,
{
Continuation::Terminate(Box::new(fut))
}
}
pub trait RequestGuard: Send + Sync {
fn evaluate(&self, req: &HttpRequest) -> Continuation;
}
impl<F> RequestGuard for F
where
F: Fn(&HttpRequest) -> Continuation + Sync + Send,
{
fn evaluate(&self, req: &HttpRequest) -> Continuation {
(*self)(req)
}
}
impl RequestGuard for Box<dyn RequestGuard> {
fn evaluate(&self, req: &HttpRequest) -> Continuation {
(**self).evaluate(req)
}
}
#[derive(Clone)]
pub struct ProtocolVersionRangeGuard {
min: u32,
max: u32,
method: Option<Method>,
}
impl ProtocolVersionRangeGuard {
pub fn new(min: u32, max: u32) -> Self {
Self {
min,
max,
method: None,
}
}
pub fn with_method(mut self, method: Method) -> Self {
self.method = Some(method);
self
}
}
impl RequestGuard for ProtocolVersionRangeGuard {
fn evaluate(&self, req: &HttpRequest) -> Continuation {
if let Some(method) = &self.method {
if method != req.method() {
return Continuation::Continue;
}
}
if let Some(header_value) = req.headers().get("SplinterProtocolVersion") {
let parsed_header = header_value
.to_str()
.map_err(|err| {
format!(
"Invalid characters in SplinterProtocolVersion header: {}",
err
)
})
.and_then(|val_str| {
val_str.parse::<u32>().map_err(|_| {
"SplinterProtocolVersion must be a valid positive integer".to_string()
})
});
match parsed_header {
Err(msg) => Continuation::terminate(
HttpResponse::BadRequest()
.json(json!({
"message": msg,
}))
.into_future(),
),
Ok(version) if version < self.min => Continuation::terminate(
HttpResponse::BadRequest()
.json(json!({
"message": format!(
"Client must support protocol version {} or greater.",
self.min,
),
"requested_protocol": version,
"splinter_protocol": self.max,
"libsplinter_version": format!(
"{}.{}.{}",
env!("CARGO_PKG_VERSION_MAJOR"),
env!("CARGO_PKG_VERSION_MINOR"),
env!("CARGO_PKG_VERSION_PATCH")
)
}))
.into_future(),
),
Ok(version) if version > self.max => Continuation::terminate(
HttpResponse::BadRequest()
.json(json!({
"message": format!(
"Client requires a newer protocol than can be provided: {} > {}",
version,
self.max,
),
"requested_protocol": version,
"splinter_protocol": self.max,
"libsplinter_version": format!(
"{}.{}.{}",
env!("CARGO_PKG_VERSION_MAJOR"),
env!("CARGO_PKG_VERSION_MINOR"),
env!("CARGO_PKG_VERSION_PATCH")
)
}))
.into_future(),
),
Ok(_) => Continuation::Continue,
}
} else {
Continuation::Continue
}
}
}