axum 0.6.4

Web framework that focuses on ergonomics and modularity
Documentation
use http::{Request, Uri};
use std::{
    sync::Arc,
    task::{Context, Poll},
};
use tower_service::Service;

#[derive(Clone)]
pub(super) struct StripPrefix<S> {
    inner: S,
    prefix: Arc<str>,
}

impl<S> StripPrefix<S> {
    pub(super) fn new(inner: S, prefix: &str) -> Self {
        Self {
            inner,
            prefix: prefix.into(),
        }
    }
}

impl<S, B> Service<Request<B>> for StripPrefix<S>
where
    S: Service<Request<B>>,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = S::Future;

    #[inline]
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, mut req: Request<B>) -> Self::Future {
        if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) {
            *req.uri_mut() = new_uri;
        }
        self.inner.call(req)
    }
}

fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
    let path_and_query = uri.path_and_query()?;

    // Check whether the prefix matches the path and if so how long the matching prefix is.
    //
    // For example:
    //
    // prefix = /api
    // path   = /api/users
    //          ^^^^ this much is matched and the length is 4. Thus if we chop off the first 4
    //          characters we get the remainder
    //
    // prefix = /api/:version
    // path   = /api/v0/users
    //          ^^^^^^^ this much is matched and the length is 7.
    let mut matching_prefix_length = Some(0);
    for item in zip_longest(segments(path_and_query.path()), segments(prefix)) {
        // count the `/`
        *matching_prefix_length.as_mut().unwrap() += 1;

        match item {
            Item::Both(path_segment, prefix_segment) => {
                if prefix_segment.starts_with(':') || path_segment == prefix_segment {
                    // the prefix segment is either a param, which matches anything, or
                    // it actually matches the path segment
                    *matching_prefix_length.as_mut().unwrap() += path_segment.len();
                } else if prefix_segment.is_empty() {
                    // the prefix ended in a `/` so we got a match.
                    //
                    // For example:
                    //
                    // prefix = /foo/
                    // path   = /foo/bar
                    //
                    // The prefix matches and the new path should be `/bar`
                    break;
                } else {
                    // the prefix segment didn't match so there is no match
                    matching_prefix_length = None;
                    break;
                }
            }
            // the path had more segments than the prefix but we got a match.
            //
            // For example:
            //
            // prefix = /foo
            // path   = /foo/bar
            Item::First(_) => {
                break;
            }
            // the prefix had more segments than the path so there is no match
            Item::Second(_) => {
                matching_prefix_length = None;
                break;
            }
        }
    }

    // if the prefix matches it will always do so up until a `/`, it cannot match only
    // part of a segment. Therefore this will always be at a char boundary and `split_at` wont
    // panic
    let after_prefix = uri.path().split_at(matching_prefix_length?).1;

    let new_path_and_query = match (after_prefix.starts_with('/'), path_and_query.query()) {
        (true, None) => after_prefix.parse().unwrap(),
        (true, Some(query)) => format!("{after_prefix}?{query}").parse().unwrap(),
        (false, None) => format!("/{after_prefix}").parse().unwrap(),
        (false, Some(query)) => format!("/{after_prefix}?{query}").parse().unwrap(),
    };

    let mut parts = uri.clone().into_parts();
    parts.path_and_query = Some(new_path_and_query);

    Some(Uri::from_parts(parts).unwrap())
}

fn segments(s: &str) -> impl Iterator<Item = &str> {
    assert!(
        s.starts_with('/'),
        "path didn't start with '/'. axum should have caught this higher up."
    );

    s.split('/')
        // skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"]
        // otherwise
        .skip(1)
}

fn zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>>
where
    I: Iterator,
    I2: Iterator<Item = I::Item>,
{
    let a = a.map(Some).chain(std::iter::repeat_with(|| None));
    let b = b.map(Some).chain(std::iter::repeat_with(|| None));
    a.zip(b).map_while(|(a, b)| match (a, b) {
        (Some(a), Some(b)) => Some(Item::Both(a, b)),
        (Some(a), None) => Some(Item::First(a)),
        (None, Some(b)) => Some(Item::Second(b)),
        (None, None) => None,
    })
}

#[derive(Debug)]
enum Item<T> {
    Both(T, T),
    First(T),
    Second(T),
}

#[cfg(test)]
mod tests {
    #[allow(unused_imports)]
    use super::*;
    use quickcheck::Arbitrary;
    use quickcheck_macros::quickcheck;

    macro_rules! test {
        (
            $name:ident,
            uri = $uri:literal,
            prefix = $prefix:literal,
            expected = $expected:expr,
        ) => {
            #[test]
            fn $name() {
                let uri = $uri.parse().unwrap();
                let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string());
                assert_eq!(new_uri.as_deref(), $expected);
            }
        };
    }

    test!(empty, uri = "/", prefix = "/", expected = Some("/"),);

    test!(
        single_segment,
        uri = "/a",
        prefix = "/a",
        expected = Some("/"),
    );

    test!(
        single_segment_root_uri,
        uri = "/",
        prefix = "/a",
        expected = None,
    );

    // the prefix is empty, so removing it should have no effect
    test!(
        single_segment_root_prefix,
        uri = "/a",
        prefix = "/",
        expected = Some("/a"),
    );

    test!(
        single_segment_no_match,
        uri = "/a",
        prefix = "/b",
        expected = None,
    );

    test!(
        single_segment_trailing_slash,
        uri = "/a/",
        prefix = "/a/",
        expected = Some("/"),
    );

    test!(
        single_segment_trailing_slash_2,
        uri = "/a",
        prefix = "/a/",
        expected = None,
    );

    test!(
        single_segment_trailing_slash_3,
        uri = "/a/",
        prefix = "/a",
        expected = Some("/"),
    );

    test!(
        multi_segment,
        uri = "/a/b",
        prefix = "/a",
        expected = Some("/b"),
    );

    test!(
        multi_segment_2,
        uri = "/b/a",
        prefix = "/a",
        expected = None,
    );

    test!(
        multi_segment_3,
        uri = "/a",
        prefix = "/a/b",
        expected = None,
    );

    test!(
        multi_segment_4,
        uri = "/a/b",
        prefix = "/b",
        expected = None,
    );

    test!(
        multi_segment_trailing_slash,
        uri = "/a/b/",
        prefix = "/a/b/",
        expected = Some("/"),
    );

    test!(
        multi_segment_trailing_slash_2,
        uri = "/a/b",
        prefix = "/a/b/",
        expected = None,
    );

    test!(
        multi_segment_trailing_slash_3,
        uri = "/a/b/",
        prefix = "/a/b",
        expected = Some("/"),
    );

    test!(param_0, uri = "/", prefix = "/:param", expected = Some("/"),);

    test!(
        param_1,
        uri = "/a",
        prefix = "/:param",
        expected = Some("/"),
    );

    test!(
        param_2,
        uri = "/a/b",
        prefix = "/:param",
        expected = Some("/b"),
    );

    test!(
        param_3,
        uri = "/b/a",
        prefix = "/:param",
        expected = Some("/a"),
    );

    test!(
        param_4,
        uri = "/a/b",
        prefix = "/a/:param",
        expected = Some("/"),
    );

    test!(param_5, uri = "/b/a", prefix = "/a/:param", expected = None,);

    test!(param_6, uri = "/a/b", prefix = "/:param/a", expected = None,);

    test!(
        param_7,
        uri = "/b/a",
        prefix = "/:param/a",
        expected = Some("/"),
    );

    test!(
        param_8,
        uri = "/a/b/c",
        prefix = "/a/:param/c",
        expected = Some("/"),
    );

    test!(
        param_9,
        uri = "/c/b/a",
        prefix = "/a/:param/c",
        expected = None,
    );

    test!(
        param_10,
        uri = "/a/",
        prefix = "/:param",
        expected = Some("/"),
    );

    test!(param_11, uri = "/a", prefix = "/:param/", expected = None,);

    test!(
        param_12,
        uri = "/a/",
        prefix = "/:param/",
        expected = Some("/"),
    );

    test!(
        param_13,
        uri = "/a/a",
        prefix = "/a/",
        expected = Some("/a"),
    );

    #[quickcheck]
    fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool {
        let UriAndPrefix { uri, prefix } = uri_and_prefix;
        strip_prefix(&uri, &prefix);
        true
    }

    #[derive(Clone, Debug)]
    struct UriAndPrefix {
        uri: Uri,
        prefix: String,
    }

    impl Arbitrary for UriAndPrefix {
        fn arbitrary(g: &mut quickcheck::Gen) -> Self {
            let mut uri = String::new();
            let mut prefix = String::new();

            let size = u8_between(1, 20, g);

            for _ in 0..size {
                let segment = ascii_alphanumeric(g);

                uri.push('/');
                uri.push_str(&segment);

                prefix.push('/');

                let make_matching_segment = bool::arbitrary(g);
                let make_capture = bool::arbitrary(g);

                match (make_matching_segment, make_capture) {
                    (_, true) => {
                        prefix.push_str(":a");
                    }
                    (true, false) => {
                        prefix.push_str(&segment);
                    }
                    (false, false) => {
                        prefix.push_str(&ascii_alphanumeric(g));
                    }
                }
            }

            if bool::arbitrary(g) {
                uri.push('/');
            }

            if bool::arbitrary(g) {
                prefix.push('/');
            }

            Self {
                uri: uri.parse().unwrap(),
                prefix,
            }
        }
    }

    fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String {
        #[derive(Clone)]
        struct AsciiAlphanumeric(String);

        impl Arbitrary for AsciiAlphanumeric {
            fn arbitrary(g: &mut quickcheck::Gen) -> Self {
                let mut out = String::new();

                let size = u8_between(1, 20, g) as usize;

                while out.len() < size {
                    let c = char::arbitrary(g);
                    if c.is_ascii_alphanumeric() {
                        out.push(c);
                    }
                }
                Self(out)
            }
        }

        let out = AsciiAlphanumeric::arbitrary(g).0;
        assert!(!out.is_empty());
        out
    }

    fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 {
        loop {
            let size = u8::arbitrary(g);
            if size > lower && size <= upper {
                break size;
            }
        }
    }
}