s3util-rs 1.3.0

Tools for managing Amazon S3 objects and buckets
Documentation
use percent_encoding::percent_decode_str;
use regex::Regex;
use url::{ParseError, Url};

use crate::types::StoragePath;

const INVALID_SCHEME: &str = "scheme must be s3:// .";
const INVALID_PATH: &str = "path must be a valid URL or a local path.";
const NO_BUCKET_NAME_SPECIFIED: &str = "bucket name must be specified.";
const NO_PATH_SPECIFIED: &str = "path must be specified.";
const MULTI_REGION_ARN_REGEX: &str = r"^s3://arn:aws:s3::?.+:accesspoint/.+";

const STDIO_PATH: &str = "-";

pub fn check_storage_path(path: &str) -> Result<String, String> {
    if path == STDIO_PATH {
        return Ok(path.to_string());
    }

    if is_multi_region_arn(path) {
        return Ok(path.to_string());
    }

    let result = Url::parse(path);
    if result == Err(ParseError::RelativeUrlWithoutBase) {
        if path.is_empty() {
            return Err(NO_PATH_SPECIFIED.to_string());
        }

        return Ok(path.to_string());
    }

    if result.is_err() {
        return Err(INVALID_PATH.to_string());
    }

    let parsed = result.unwrap();
    match parsed.scheme() {
        "s3" => {
            if parsed.host_str().is_none() {
                return Err(NO_BUCKET_NAME_SPECIFIED.to_string());
            }
        }
        _ => {
            if !is_windows_absolute_path(path) {
                return Err(INVALID_SCHEME.to_string());
            }
        }
    }

    Ok(path.to_string())
}

fn is_multi_region_arn(path: &str) -> bool {
    Regex::new(MULTI_REGION_ARN_REGEX).unwrap().is_match(path)
}

pub fn parse_storage_path(path: &str) -> StoragePath {
    check_storage_path(path).unwrap();

    if path == STDIO_PATH {
        return StoragePath::Stdio;
    }

    if is_multi_region_arn(path) {
        return StoragePath::S3 {
            bucket: extract_multi_region_arn(path),
            prefix: extract_prefix(path),
        };
    }

    let result = Url::parse(path);
    if result == Err(ParseError::RelativeUrlWithoutBase) {
        return parse_local_path(path);
    }
    if is_windows_absolute_path(path) {
        return parse_local_path(path);
    }

    parse_s3_path(path)
}

fn extract_multi_region_arn(path: &str) -> String {
    let mut iter = path.match_indices('/');
    let third_slash = iter.nth(3);
    match third_slash {
        Some((idx, _)) => {
            let arn = &path[..=idx].to_string();
            arn.replace("s3://", "")
                .to_string()
                .strip_suffix('/')
                .map_or(arn.to_string(), |s| s.to_string())
                .to_string()
        }
        None => path.replace("s3://", "").to_string(),
    }
}

fn extract_prefix(path: &str) -> String {
    // skipcq: RS-W1201
    path.char_indices()
        .filter(|&(_, c)| c == '/')
        .nth(3)
        .map(|(i, _)| &path[i + 1..])
        .map_or("", |s| s)
        .to_string()
}

pub fn is_both_storage_local(source: &StoragePath, target: &StoragePath) -> bool {
    let source_local = matches!(source, StoragePath::Local(_));
    let target_local = matches!(target, StoragePath::Local(_));
    source_local && target_local
}

pub fn is_both_storage_s3(source: &StoragePath, target: &StoragePath) -> bool {
    let source_s3 = matches!(source, StoragePath::S3 { .. });
    let target_s3 = matches!(target, StoragePath::S3 { .. });
    source_s3 && target_s3
}

fn parse_local_path(path: &str) -> StoragePath {
    StoragePath::Local(path.into())
}

fn parse_s3_path(path: &str) -> StoragePath {
    let bucket = Url::parse(path).unwrap().host_str().unwrap().to_string();
    let mut prefix = Url::parse(path).unwrap().path().to_string();

    // remove first '/'
    if !prefix.is_empty() {
        prefix.remove(0);
    }

    prefix = percent_decode_str(&prefix)
        .decode_utf8()
        .unwrap()
        .to_string();

    StoragePath::S3 { bucket, prefix }
}

fn is_windows_absolute_path(path: &str) -> bool {
    if !cfg!(windows) {
        return false;
    }

    let re = Regex::new(r"^[a-zA-Z]:\\").unwrap();
    re.is_match(path)
}

#[cfg(test)]
mod tests {
    use std::path::PathBuf;

    use super::*;

    #[test]
    fn check_valid_url() {
        check_storage_path("/etc/").unwrap();
        check_storage_path("etc/dir").unwrap();
        check_storage_path("s3://my-bucket").unwrap();
        check_storage_path("s3://my-bucket/").unwrap();
        check_storage_path("s3://my-bucket/xyz.dat").unwrap();
        check_storage_path("-").unwrap();
    }

    #[test]
    fn check_valid_url_error() {
        assert!(check_storage_path("s3://arn:aws").is_err());
    }

    #[test]
    fn parse_stdio_path() {
        assert!(matches!(parse_storage_path("-"), StoragePath::Stdio));
    }

    #[test]
    fn parse_local() {
        let local_path = "dir1/";
        if let StoragePath::Local(path) = parse_storage_path(local_path) {
            assert_eq!(path, PathBuf::from("dir1/"));
        } else {
            panic!("local path not found");
        }
    }

    #[test]
    fn parse_s3_url_with_key() {
        let s3_url = "s3://test-bucket/my_key";
        if let StoragePath::S3 { bucket, prefix } = parse_storage_path(s3_url) {
            assert_eq!(bucket, "test-bucket");
            assert_eq!(prefix, "my_key");
        } else {
            panic!("s3 url not found");
        }
    }

    #[test]
    fn empty_local_path() {
        let result = check_storage_path("");
        assert!(result.is_err());
    }

    #[test]
    fn invalid_scheme() {
        let result = check_storage_path("https://my-bucket");
        assert!(result.is_err());
    }

    #[test]
    fn no_bucket_name() {
        let result = check_storage_path("s3://");
        assert!(result.is_err());
    }

    #[test]
    fn both_storage_local() {
        assert!(is_both_storage_local(
            &StoragePath::Local("/source".into()),
            &StoragePath::Local("/target".into()),
        ));
    }

    #[test]
    fn both_storage_s3() {
        assert!(is_both_storage_s3(
            &StoragePath::S3 {
                bucket: "source-bucket".to_string(),
                prefix: "source-prefix".to_string(),
            },
            &StoragePath::S3 {
                bucket: "target-bucket".to_string(),
                prefix: "target-prefix".to_string(),
            },
        ));
    }

    #[test]
    fn is_both_storage_s3_false_for_mixed() {
        let s3 = StoragePath::S3 {
            bucket: "b".to_string(),
            prefix: String::new(),
        };
        let local = StoragePath::Local("/x".into());
        assert!(!is_both_storage_s3(&s3, &local));
        assert!(!is_both_storage_s3(&local, &s3));
        assert!(!is_both_storage_s3(&local, &local));
    }

    #[test]
    fn is_both_storage_local_false_for_mixed() {
        let s3 = StoragePath::S3 {
            bucket: "b".to_string(),
            prefix: String::new(),
        };
        let local = StoragePath::Local("/x".into());
        assert!(!is_both_storage_local(&s3, &local));
        assert!(!is_both_storage_local(&local, &s3));
        assert!(!is_both_storage_local(&s3, &s3));
    }

    #[test]
    fn extract_prefix_returns_segment_after_fourth_slash() {
        // extract_prefix is only called from the multi-region-ARN branch of
        // parse_storage_path. The function takes the 4th '/' (nth(3)) and
        // returns everything after it. For ARN-style input
        // `s3://arn:.../accesspoint/<ap-name>/<key>`, that lands on the key.
        let path = "s3://arn:aws:s3::123456789012:accesspoint/my-ap/some/key";
        assert_eq!(extract_prefix(path), "some/key");
    }

    #[test]
    fn extract_prefix_returns_empty_when_fewer_than_four_slashes() {
        // No 4th slash → empty prefix.
        assert_eq!(extract_prefix("s3://bucket"), "");
        assert_eq!(extract_prefix("s3://bucket/"), "");
        assert_eq!(extract_prefix("s3://arn:aws:s3::123:accesspoint/ap"), "");
    }

    #[test]
    fn extract_multi_region_arn_strips_path_after_arn() {
        // Returns the ARN portion only, without trailing key/path.
        let path = "s3://arn:aws:s3::123456789012:accesspoint/my-ap/some/key";
        let arn = extract_multi_region_arn(path);
        // Trailing key is dropped; the ARN itself is preserved minus the s3:// scheme.
        assert!(!arn.contains("/some"));
        assert!(!arn.contains("/key"));
        assert!(arn.starts_with("arn:aws:s3"));
    }

    #[test]
    fn extract_multi_region_arn_no_path_returns_full_minus_scheme() {
        // No slash after s3://… → returns input minus s3:// prefix.
        let path = "s3://bucket-only";
        assert_eq!(extract_multi_region_arn(path), "bucket-only");
    }

    #[test]
    fn check_storage_path_accepts_multi_region_arn() {
        // Multi-region ARN is recognized by the dedicated regex branch and
        // returned verbatim; it must not fall into the Url::parse path where
        // "arn:aws" would look like a scheme and fail validation.
        let path = "s3://arn:aws:s3::123456789012:accesspoint/my-ap/some/key";
        assert_eq!(check_storage_path(path).unwrap(), path);
    }

    #[test]
    fn parse_storage_path_returns_s3_for_multi_region_arn() {
        let path = "s3://arn:aws:s3::123456789012:accesspoint/my-ap/some/key";
        match parse_storage_path(path) {
            StoragePath::S3 { bucket, prefix } => {
                assert!(bucket.starts_with("arn:aws:s3"));
                assert!(bucket.contains("accesspoint/my-ap"));
                assert_eq!(prefix, "some/key");
            }
            _ => panic!("expected multi-region ARN to parse as S3"),
        }
    }
}