use std::convert::Infallible;
use tower::Layer;
use tower::Service;
use crate::body::BoxBody;
use crate::routing::tiny_map::TinyMap;
use crate::routing::Route;
use crate::routing::Router;
use http::header::ToStrError;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error {
#[error("relative URI is not \"/\"")]
NotRootUrl,
#[error("method not POST")]
MethodNotAllowed,
#[error("missing the \"x-amz-target\" header")]
MissingHeader,
#[error("failed to parse header: {0}")]
InvalidHeader(ToStrError),
#[error("operation not found")]
NotFound,
}
pub(crate) const ROUTE_CUTOFF: usize = 15;
#[derive(Debug, Clone)]
pub struct AwsJsonRouter<S> {
routes: TinyMap<&'static str, S, ROUTE_CUTOFF>,
}
impl<S> AwsJsonRouter<S> {
pub fn layer<L>(self, layer: L) -> AwsJsonRouter<L::Service>
where
L: Layer<S>,
{
AwsJsonRouter {
routes: self
.routes
.into_iter()
.map(|(key, route)| (key, layer.layer(route)))
.collect(),
}
}
pub fn boxed<B>(self) -> AwsJsonRouter<Route<B>>
where
S: Service<http::Request<B>, Response = http::Response<BoxBody>, Error = Infallible>,
S: Send + Clone + 'static,
S::Future: Send + 'static,
{
AwsJsonRouter {
routes: self.routes.into_iter().map(|(key, s)| (key, Route::new(s))).collect(),
}
}
}
impl<B, S> Router<B> for AwsJsonRouter<S>
where
S: Clone,
{
type Service = S;
type Error = Error;
fn match_route(&self, request: &http::Request<B>) -> Result<S, Self::Error> {
if request.uri() != "/" {
return Err(Error::NotRootUrl);
}
if request.method() != http::Method::POST {
return Err(Error::MethodNotAllowed);
}
let target = request.headers().get("x-amz-target").ok_or(Error::MissingHeader)?;
let target = target.to_str().map_err(Error::InvalidHeader)?;
let route = self.routes.get(target).ok_or(Error::NotFound)?;
Ok(route.clone())
}
}
impl<S> FromIterator<(&'static str, S)> for AwsJsonRouter<S> {
#[inline]
fn from_iter<T: IntoIterator<Item = (&'static str, S)>>(iter: T) -> Self {
Self {
routes: iter.into_iter().collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{protocol::test_helpers::req, routing::Router};
use http::{HeaderMap, HeaderValue, Method};
use pretty_assertions::assert_eq;
#[tokio::test]
async fn simple_routing() {
let routes = vec![("Service.Operation")];
let router: AwsJsonRouter<_> = routes.clone().into_iter().map(|operation| (operation, ())).collect();
let mut headers = HeaderMap::new();
headers.insert("x-amz-target", HeaderValue::from_static("Service.Operation"));
router
.match_route(&req(&Method::POST, "/", Some(headers.clone())))
.unwrap();
let res = router.match_route(&req(&Method::POST, "/", None));
assert_eq!(res.unwrap_err().to_string(), Error::MissingHeader.to_string());
let res = router.match_route(&req(&Method::GET, "/", Some(headers.clone())));
assert_eq!(res.unwrap_err().to_string(), Error::MethodNotAllowed.to_string());
let res = router.match_route(&req(&Method::POST, "/something", Some(headers)));
assert_eq!(res.unwrap_err().to_string(), Error::NotRootUrl.to_string());
}
}