use crate::commons::TlsPeerVerificationMode;
use std::collections::HashMap;
use url::Url;
#[derive(Debug, Clone)]
pub struct UriBuilder {
url: Url,
cached_params: Option<HashMap<String, String>>,
has_pending_changes: bool,
}
impl UriBuilder {
pub const PEER_VERIFICATION_MODE_KEY: &'static str = "verify";
pub const CA_CERTIFICATE_BUNDLE_PATH_KEY: &'static str = "cacertfile";
pub const CLIENT_CERTIFICATE_PATH_KEY: &'static str = "certfile";
pub const CLIENT_PRIVATE_KEY_PATH_KEY: &'static str = "keyfile";
pub const SERVER_NAME_INDICATION_KEY: &'static str = "server_name_indication";
const TLS_PARAMS: &'static [&'static str] = &[
Self::PEER_VERIFICATION_MODE_KEY,
Self::CA_CERTIFICATE_BUNDLE_PATH_KEY,
Self::CLIENT_CERTIFICATE_PATH_KEY,
Self::CLIENT_PRIVATE_KEY_PATH_KEY,
Self::SERVER_NAME_INDICATION_KEY,
];
pub fn new(base_uri: &str) -> Result<Self, url::ParseError> {
let url = Url::parse(base_uri)?;
Ok(Self {
url,
cached_params: None,
has_pending_changes: false,
})
}
pub fn with_tls_peer_verification(mut self, mode: TlsPeerVerificationMode) -> Self {
self.set_query_param(Self::PEER_VERIFICATION_MODE_KEY, mode.as_ref());
self
}
pub fn with_ca_cert_file<S: AsRef<str>>(mut self, path: S) -> Self {
self.set_query_param(Self::CA_CERTIFICATE_BUNDLE_PATH_KEY, path.as_ref());
self
}
pub fn with_client_cert_file<S: AsRef<str>>(mut self, path: S) -> Self {
self.set_query_param(Self::CLIENT_CERTIFICATE_PATH_KEY, path.as_ref());
self
}
pub fn with_client_key_file<S: AsRef<str>>(mut self, path: S) -> Self {
self.set_query_param(Self::CLIENT_PRIVATE_KEY_PATH_KEY, path.as_ref());
self
}
pub fn with_server_name_indication<S: AsRef<str>>(mut self, hostname: S) -> Self {
self.set_query_param(Self::SERVER_NAME_INDICATION_KEY, hostname.as_ref());
self
}
pub fn with_query_param<K: AsRef<str>, V: AsRef<str>>(mut self, key: K, value: V) -> Self {
self.set_query_param(key.as_ref(), value.as_ref());
self
}
pub fn without_query_param<K: AsRef<str>>(mut self, key: K) -> Self {
self.remove_query_param(key.as_ref());
self
}
pub fn replace(mut self, config: TlsClientSettings) -> Self {
self.ensure_params_cached();
if let Some(ref mut params) = self.cached_params {
let mut any_removed = false;
for &key in Self::TLS_PARAMS {
if params.remove(key).is_some() {
any_removed = true;
}
}
if any_removed {
self.has_pending_changes = true;
}
}
self.apply_tls_settings(&config);
self
}
pub fn merge(mut self, settings: TlsClientSettings) -> Self {
self.ensure_params_cached();
self.apply_tls_settings(&settings);
self
}
pub fn build(mut self) -> Result<String, url::ParseError> {
self.apply_cached_params_to_url();
Ok(self.url.to_string())
}
pub fn as_url(&mut self) -> &Url {
self.apply_cached_params_to_url();
&self.url
}
pub fn query_params(&mut self) -> HashMap<String, String> {
self.apply_cached_params_to_url();
self.url
.query_pairs()
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect()
}
fn set_query_param(&mut self, key: &str, value: &str) {
self.ensure_params_cached();
let Some(ref mut params) = self.cached_params else {
return;
};
params.insert(key.to_string(), value.to_string());
self.has_pending_changes = true;
}
fn remove_query_param(&mut self, key: &str) {
self.ensure_params_cached();
let Some(ref mut params) = self.cached_params else {
return;
};
if params.remove(key).is_some() {
self.has_pending_changes = true;
}
}
fn recalculate_query_params(&self) -> HashMap<String, String> {
self.url
.query_pairs()
.map(|(k, v)| {
let decoded_key = urlencoding::decode(&k).unwrap_or_else(|_| k.clone());
let decoded_value = urlencoding::decode(&v).unwrap_or_else(|_| v.clone());
(decoded_key.into_owned(), decoded_value.into_owned())
})
.collect()
}
fn rebuild_query_string_from_map(&mut self, params: &HashMap<String, String>) {
if params.is_empty() {
self.url.set_query(None);
return;
}
let query_string = params
.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join("&");
self.url.set_query(Some(&query_string));
}
fn ensure_params_cached(&mut self) {
if self.cached_params.is_none() {
let params = self.recalculate_query_params();
self.cached_params = Some(params);
}
}
fn apply_cached_params_to_url(&mut self) {
if !self.has_pending_changes {
return;
}
if let Some(params) = self.cached_params.clone() {
self.rebuild_query_string_from_map(¶ms);
}
self.has_pending_changes = false;
}
fn apply_tls_settings(&mut self, settings: &TlsClientSettings) {
if let Some(verification) = &settings.peer_verification {
self.set_query_param(Self::PEER_VERIFICATION_MODE_KEY, verification.as_ref());
}
if let Some(ca_cert) = &settings.ca_cert_file {
self.set_query_param(Self::CA_CERTIFICATE_BUNDLE_PATH_KEY, ca_cert);
}
if let Some(client_cert) = &settings.client_cert_file {
self.set_query_param(Self::CLIENT_CERTIFICATE_PATH_KEY, client_cert);
}
if let Some(client_key) = &settings.client_key_file {
self.set_query_param(Self::CLIENT_PRIVATE_KEY_PATH_KEY, client_key);
}
if let Some(sni) = &settings.server_name_indication {
self.set_query_param(Self::SERVER_NAME_INDICATION_KEY, sni);
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TlsClientSettings {
pub peer_verification: Option<TlsPeerVerificationMode>,
pub ca_cert_file: Option<String>,
pub client_cert_file: Option<String>,
pub client_key_file: Option<String>,
pub server_name_indication: Option<String>,
}
impl TlsClientSettings {
pub fn new() -> Self {
Self::default()
}
pub fn with_verification() -> Self {
Self {
peer_verification: Some(TlsPeerVerificationMode::Enabled),
..Default::default()
}
}
pub fn without_verification() -> Self {
Self {
peer_verification: Some(TlsPeerVerificationMode::Disabled),
..Default::default()
}
}
pub fn peer_verification(mut self, mode: TlsPeerVerificationMode) -> Self {
self.peer_verification = Some(mode);
self
}
pub fn ca_cert_file<S: Into<String>>(mut self, path: S) -> Self {
self.ca_cert_file = Some(path.into());
self
}
pub fn client_cert_file<S: Into<String>>(mut self, path: S) -> Self {
self.client_cert_file = Some(path.into());
self
}
pub fn client_key_file<S: Into<String>>(mut self, path: S) -> Self {
self.client_key_file = Some(path.into());
self
}
pub fn server_name_indication<S: Into<String>>(mut self, hostname: S) -> Self {
self.server_name_indication = Some(hostname.into());
self
}
}