sark-client 0.2.1

Simple Asynchronous Rust webKit - Client
use std::collections::HashSet;

use http::{Method, Uri};

use crate::connector::error::Error;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(super) enum OriginRelation {
    Same,
    Cross,
}

pub(super) struct Resolve;

impl Resolve {
    pub(super) fn origin(a: &Uri, b: &Uri) -> OriginRelation {
        if a.scheme() == b.scheme() && a.host() == b.host() && a.port_u16() == b.port_u16() {
            OriginRelation::Same
        } else {
            OriginRelation::Cross
        }
    }

    pub(super) fn redirect(base: &Uri, location: &str) -> Result<Uri, Error> {
        if location.starts_with("http://") || location.starts_with("https://") {
            return location
                .parse()
                .map_err(|e| Error::Http(format!("invalid redirect URL: {e}")));
        }

        let scheme = base.scheme_str().unwrap_or("http");
        let authority = base
            .authority()
            .ok_or_else(|| Error::Http("missing authority in base URI".into()))?;

        let (raw_path, suffix) = Self::split_location(location);
        let base_prefix = if raw_path.starts_with('/') {
            String::new()
        } else {
            let base_path = base.path();
            match base_path.rfind('/') {
                Some(i) => base_path[..=i].to_string(),
                None => "/".to_string(),
            }
        };
        let normalized_path = Self::normalize_path(&(base_prefix + raw_path));
        let path = format!("{normalized_path}{suffix}");

        format!("{scheme}://{authority}{path}")
            .parse()
            .map_err(|e| Error::Http(format!("invalid redirect URL: {e}")))
    }

    fn split_location(location: &str) -> (&str, &str) {
        match location.find(['?', '#']) {
            Some(i) => (&location[..i], &location[i..]),
            None => (location, ""),
        }
    }

    fn normalize_path(input: &str) -> String {
        let absolute = input.starts_with('/');
        let mut segments: Vec<&str> = Vec::new();
        for seg in input.split('/') {
            if seg.is_empty() || seg == "." {
                continue;
            }
            if seg == ".." {
                segments.pop();
                continue;
            }
            segments.push(seg);
        }

        let mut out = String::new();
        if absolute {
            out.push('/');
        }
        out.push_str(&segments.join("/"));
        if out.is_empty() {
            out.push('/');
        }
        out
    }
}

pub(super) struct RedirectState {
    remaining: u32,
    base: Uri,
    current: Uri,
    visited: HashSet<Uri>,
}

impl RedirectState {
    pub(super) fn new(max_redirects: u32, base: Uri, first_path: &str) -> Self {
        let current = Self::join_path(&base, first_path);
        let mut visited = HashSet::new();
        visited.insert(current.clone());
        Self {
            remaining: max_redirects,
            base,
            current,
            visited,
        }
    }

    fn join_path(base: &Uri, path: &str) -> Uri {
        let scheme = base.scheme_str().unwrap_or("http");
        let authority = match base.authority() {
            Some(a) => a.as_str().to_string(),
            None => String::new(),
        };
        let p = if path.starts_with('/') {
            path.to_string()
        } else {
            format!("/{path}")
        };
        format!("{scheme}://{authority}{p}")
            .parse()
            .unwrap_or_else(|_| base.clone())
    }

    pub(super) fn path_and_query(&self) -> String {
        match self.current.path_and_query() {
            Some(pq) => pq.as_str().to_string(),
            None => self.current.path().to_string(),
        }
    }

    pub(super) fn advance(
        &mut self,
        status: u16,
        location: &str,
        method: &Method,
    ) -> Result<Method, Error> {
        if self.remaining == 0 {
            return Err(Error::Http("redirect limit exceeded".into()));
        }
        let new_uri = Resolve::redirect(&self.current, location)?;
        if Resolve::origin(&self.base, &new_uri) == OriginRelation::Cross {
            return Err(Error::Http(format!(
                "cross-host redirect to {new_uri} is not supported"
            )));
        }
        if !self.visited.insert(new_uri.clone()) {
            return Err(Error::Http("redirect loop detected".into()));
        }
        self.current = new_uri;
        self.remaining -= 1;
        let next_method = if matches!(status, 307 | 308) {
            method.clone()
        } else {
            Method::GET
        };
        Ok(next_method)
    }
}

#[cfg(test)]
mod tests {
    use super::{OriginRelation, RedirectState, Resolve};

    #[test]
    fn redirect_absolute_url() {
        let base: http::Uri = "http://example.com/old".parse().unwrap();
        let result = Resolve::redirect(&base, "http://other.com/new").unwrap();
        assert_eq!(result.to_string(), "http://other.com/new");
    }

    #[test]
    fn redirect_absolute_path() {
        let base: http::Uri = "http://example.com/old/page".parse().unwrap();
        let result = Resolve::redirect(&base, "/new/page").unwrap();
        assert_eq!(result.to_string(), "http://example.com/new/page");
    }

    #[test]
    fn redirect_relative_path() {
        let base: http::Uri = "http://example.com/dir/page".parse().unwrap();
        let result = Resolve::redirect(&base, "other").unwrap();
        assert_eq!(result.to_string(), "http://example.com/dir/other");
    }

    #[test]
    fn redirect_preserves_https() {
        let base: http::Uri = "https://example.com/page".parse().unwrap();
        let result = Resolve::redirect(&base, "/new").unwrap();
        assert_eq!(result.to_string(), "https://example.com/new");
    }

    #[test]
    fn redirect_preserves_port() {
        let base: http::Uri = "http://example.com:8080/page".parse().unwrap();
        let result = Resolve::redirect(&base, "/new").unwrap();
        assert_eq!(result.to_string(), "http://example.com:8080/new");
    }

    #[test]
    fn redirect_normalizes_dot_segments() {
        let base: http::Uri = "http://example.com/dir/sub/page".parse().unwrap();
        let result = Resolve::redirect(&base, "../other/./file").unwrap();
        assert_eq!(result.to_string(), "http://example.com/dir/other/file");
    }

    #[test]
    fn redirect_preserves_query_and_fragment() {
        let base: http::Uri = "http://example.com/dir/page".parse().unwrap();
        let result = Resolve::redirect(&base, "next?x=1#frag").unwrap();
        assert_eq!(result.to_string(), "http://example.com/dir/next?x=1");
    }

    #[test]
    fn same_origin_detection() {
        let a: http::Uri = "http://example.com/a".parse().unwrap();
        let b: http::Uri = "http://example.com/b".parse().unwrap();
        let c: http::Uri = "http://other.com/a".parse().unwrap();
        assert_eq!(Resolve::origin(&a, &b), OriginRelation::Same);
        assert_eq!(Resolve::origin(&a, &c), OriginRelation::Cross);
    }

    #[test]
    fn same_origin_redirect_chain_advances_path() {
        let base: http::Uri = "http://h.example/".parse().unwrap();
        let mut st = RedirectState::new(3, base, "/start");
        let m = st.advance(302, "/next", &http::Method::POST).unwrap();
        assert_eq!(st.path_and_query(), "/next");
        assert_eq!(m, http::Method::GET);
        let m = st.advance(307, "final?q=1", &http::Method::POST).unwrap();
        assert_eq!(st.path_and_query(), "/final?q=1");
        assert_eq!(m, http::Method::POST);
    }

    #[test]
    fn redirect_loop_detected() {
        let base: http::Uri = "http://h.example/".parse().unwrap();
        let mut st = RedirectState::new(5, base, "/a");
        st.advance(302, "/b", &http::Method::GET).unwrap();
        let err = st.advance(302, "/a", &http::Method::GET).unwrap_err();
        assert!(err.to_string().contains("loop"));
    }

    #[test]
    fn cross_host_redirect_rejected() {
        let base: http::Uri = "http://h.example/".parse().unwrap();
        let mut st = RedirectState::new(5, base, "/a");
        let err = st
            .advance(302, "http://other.example/x", &http::Method::GET)
            .unwrap_err();
        assert!(err.to_string().contains("cross-host"));
    }

    #[test]
    fn redirect_limit_enforced() {
        let base: http::Uri = "http://h.example/".parse().unwrap();
        let mut st = RedirectState::new(1, base, "/a");
        st.advance(302, "/b", &http::Method::GET).unwrap();
        let err = st.advance(302, "/c", &http::Method::GET).unwrap_err();
        assert!(err.to_string().contains("limit"));
    }
}