athena_rs 3.26.2

Hyper performant polyglot Database driver
Documentation
use std::borrow::Cow;

use reqwest::Url;

const SQLX_UNSUPPORTED_POSTGRES_CONNECT_PARAMS: &[&str] = &["channel_binding"];

pub fn sanitize_sqlx_postgres_connect_uri(uri: &str) -> Cow<'_, str> {
    let normalized = crate::parser::resolve_compatible_postgres_uri(uri);
    let changed_scheme = normalized != uri;

    let Ok(mut parsed) = Url::parse(&normalized) else {
        return if changed_scheme {
            Cow::Owned(normalized)
        } else {
            Cow::Borrowed(uri)
        };
    };

    if !matches!(parsed.scheme(), "postgresql") {
        return if changed_scheme {
            Cow::Owned(normalized)
        } else {
            Cow::Borrowed(uri)
        };
    }

    let original_pairs: Vec<(String, String)> = parsed.query_pairs().into_owned().collect();
    let original_pair_count = original_pairs.len();
    if original_pair_count == 0 {
        return if changed_scheme {
            Cow::Owned(normalized)
        } else {
            Cow::Borrowed(uri)
        };
    }

    let filtered_pairs: Vec<(String, String)> = original_pairs
        .into_iter()
        .filter(|(key, _)| {
            !SQLX_UNSUPPORTED_POSTGRES_CONNECT_PARAMS
                .iter()
                .any(|candidate| key.eq_ignore_ascii_case(candidate))
        })
        .collect();

    if filtered_pairs.len() == original_pair_count {
        return if changed_scheme {
            Cow::Owned(normalized)
        } else {
            Cow::Borrowed(uri)
        };
    }

    parsed.set_query(None);
    if !filtered_pairs.is_empty() {
        let mut query_pairs = parsed.query_pairs_mut();
        for (key, value) in &filtered_pairs {
            query_pairs.append_pair(key, value);
        }
    }

    Cow::Owned(parsed.into())
}

#[cfg(test)]
mod tests {
    use super::sanitize_sqlx_postgres_connect_uri;

    #[test]
    fn strips_channel_binding_from_postgres_uri() {
        let sanitized = sanitize_sqlx_postgres_connect_uri(
            "postgres://user:pass@localhost:5432/app?sslmode=require&channel_binding=require&application_name=athena",
        );

        assert_eq!(
            sanitized.as_ref(),
            "postgresql://user:pass@localhost:5432/app?sslmode=require&application_name=athena"
        );
    }

    #[test]
    fn strips_only_unsupported_parameter() {
        let sanitized = sanitize_sqlx_postgres_connect_uri(
            "postgresql://localhost/app?channel_binding=require",
        );

        assert_eq!(sanitized.as_ref(), "postgresql://localhost/app");
    }

    #[test]
    fn leaves_unrelated_uris_unchanged() {
        let sanitized =
            sanitize_sqlx_postgres_connect_uri("https://example.com?channel_binding=require");

        assert_eq!(
            sanitized.as_ref(),
            "https://example.com?channel_binding=require"
        );
    }

    #[test]
    fn rewrites_athena_scheme_to_postgresql() {
        let sanitized = sanitize_sqlx_postgres_connect_uri("athena://user:pass@localhost/db");

        assert_eq!(sanitized.as_ref(), "postgresql://user:pass@localhost/db");
    }
}