use std::borrow::Cow;
use reqwest::Url;
const SQLX_UNSUPPORTED_POSTGRES_CONNECT_PARAMS: &[&str] = &["channel_binding"];
pub fn sanitize_sqlx_postgres_connect_uri(uri: &str) -> Cow<'_, str> {
let normalized = crate::parser::resolve_compatible_postgres_uri(uri);
let changed_scheme = normalized != uri;
let Ok(mut parsed) = Url::parse(&normalized) else {
return if changed_scheme {
Cow::Owned(normalized)
} else {
Cow::Borrowed(uri)
};
};
if !matches!(parsed.scheme(), "postgresql") {
return if changed_scheme {
Cow::Owned(normalized)
} else {
Cow::Borrowed(uri)
};
}
let original_pairs: Vec<(String, String)> = parsed.query_pairs().into_owned().collect();
let original_pair_count = original_pairs.len();
if original_pair_count == 0 {
return if changed_scheme {
Cow::Owned(normalized)
} else {
Cow::Borrowed(uri)
};
}
let filtered_pairs: Vec<(String, String)> = original_pairs
.into_iter()
.filter(|(key, _)| {
!SQLX_UNSUPPORTED_POSTGRES_CONNECT_PARAMS
.iter()
.any(|candidate| key.eq_ignore_ascii_case(candidate))
})
.collect();
if filtered_pairs.len() == original_pair_count {
return if changed_scheme {
Cow::Owned(normalized)
} else {
Cow::Borrowed(uri)
};
}
parsed.set_query(None);
if !filtered_pairs.is_empty() {
let mut query_pairs = parsed.query_pairs_mut();
for (key, value) in &filtered_pairs {
query_pairs.append_pair(key, value);
}
}
Cow::Owned(parsed.into())
}
#[cfg(test)]
mod tests {
use super::sanitize_sqlx_postgres_connect_uri;
#[test]
fn strips_channel_binding_from_postgres_uri() {
let sanitized = sanitize_sqlx_postgres_connect_uri(
"postgres://user:pass@localhost:5432/app?sslmode=require&channel_binding=require&application_name=athena",
);
assert_eq!(
sanitized.as_ref(),
"postgresql://user:pass@localhost:5432/app?sslmode=require&application_name=athena"
);
}
#[test]
fn strips_only_unsupported_parameter() {
let sanitized = sanitize_sqlx_postgres_connect_uri(
"postgresql://localhost/app?channel_binding=require",
);
assert_eq!(sanitized.as_ref(), "postgresql://localhost/app");
}
#[test]
fn leaves_unrelated_uris_unchanged() {
let sanitized =
sanitize_sqlx_postgres_connect_uri("https://example.com?channel_binding=require");
assert_eq!(
sanitized.as_ref(),
"https://example.com?channel_binding=require"
);
}
#[test]
fn rewrites_athena_scheme_to_postgresql() {
let sanitized = sanitize_sqlx_postgres_connect_uri("athena://user:pass@localhost/db");
assert_eq!(sanitized.as_ref(), "postgresql://user:pass@localhost/db");
}
}