use std::sync::Arc;
use http::{HeaderValue, header::HeaderName};
#[cfg(feature = "openssl-tls")] use hyper::rt::{Read, Write};
use hyper_util::client::legacy::connect::HttpConnector;
use jiff::Timestamp;
use secrecy::ExposeSecret;
use tower::{filter::AsyncFilterLayer, util::Either};
#[cfg(any(feature = "rustls-tls", feature = "openssl-tls"))] use super::tls;
use super::{
auth::Auth,
middleware::{AddAuthorizationLayer, AuthLayer, BaseUriLayer, ExtraHeadersLayer},
};
use crate::{Config, Error, Result};
pub trait ConfigExt: private::Sealed {
fn base_uri_layer(&self) -> BaseUriLayer;
fn auth_layer(&self) -> Result<Option<AuthLayer>>;
fn extra_headers_layer(&self) -> Result<ExtraHeadersLayer>;
#[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))]
#[cfg(feature = "rustls-tls")]
fn rustls_https_connector(&self) -> Result<hyper_rustls::HttpsConnector<HttpConnector>>;
#[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))]
#[cfg(feature = "rustls-tls")]
fn rustls_https_connector_with_connector<H>(
&self,
connector: H,
) -> Result<hyper_rustls::HttpsConnector<H>>;
#[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))]
#[cfg(feature = "rustls-tls")]
fn rustls_client_config(&self) -> Result<rustls::ClientConfig>;
#[cfg_attr(docsrs, doc(cfg(feature = "openssl-tls")))]
#[cfg(feature = "openssl-tls")]
fn openssl_https_connector(&self)
-> Result<hyper_openssl::client::legacy::HttpsConnector<HttpConnector>>;
#[cfg_attr(docsrs, doc(cfg(feature = "openssl-tls")))]
#[cfg(feature = "openssl-tls")]
fn openssl_https_connector_with_connector<H>(
&self,
connector: H,
) -> Result<hyper_openssl::client::legacy::HttpsConnector<H>>
where
H: tower::Service<http::Uri> + Send,
H::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
H::Future: Send + 'static,
H::Response: Read + Write + hyper_util::client::legacy::connect::Connection + Unpin;
#[cfg_attr(docsrs, doc(cfg(feature = "openssl-tls")))]
#[cfg(feature = "openssl-tls")]
fn openssl_ssl_connector_builder(&self) -> Result<openssl::ssl::SslConnectorBuilder>;
}
#[cfg(all(test, feature = "openssl-tls"))]
mod openssl_tls_server_name_tests {
use std::{
net::TcpListener,
sync::{Arc, Mutex},
};
use openssl::{
asn1::Asn1Time,
hash::MessageDigest,
pkey::{PKey, Private},
rsa::Rsa,
ssl::{NameType, SslAcceptor, SslMethod},
x509::{
extension::{BasicConstraints, SubjectAlternativeName},
X509NameBuilder, X509,
},
};
use tower::ServiceExt as _;
use super::*;
enum San<'a> {
Dns(&'a str),
Ip(&'a str),
}
fn self_signed_cert(san: San) -> (X509, PKey<Private>) {
let cn = match san {
San::Dns(s) | San::Ip(s) => s,
};
let pkey = PKey::from_rsa(Rsa::generate(2048).unwrap()).unwrap();
let mut name = X509NameBuilder::new().unwrap();
name.append_entry_by_text("CN", cn).unwrap();
let name = name.build();
let mut builder = X509::builder().unwrap();
builder.set_version(2).unwrap();
builder.set_subject_name(&name).unwrap();
builder.set_issuer_name(&name).unwrap();
builder.set_pubkey(&pkey).unwrap();
builder.set_not_before(&Asn1Time::days_from_now(0).unwrap()).unwrap();
builder.set_not_after(&Asn1Time::days_from_now(1).unwrap()).unwrap();
builder
.append_extension(BasicConstraints::new().critical().ca().build().unwrap())
.unwrap();
let mut san_ext = SubjectAlternativeName::new();
match san {
San::Dns(s) => san_ext.dns(s),
San::Ip(s) => san_ext.ip(s),
};
let san_ext = san_ext.build(&builder.x509v3_context(None, None)).unwrap();
builder.append_extension(san_ext).unwrap();
builder.sign(&pkey, MessageDigest::sha256()).unwrap();
(builder.build(), pkey)
}
fn spawn_tls_server(cert: X509, key: PKey<Private>) -> (u16, Arc<Mutex<Option<String>>>) {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
let captured: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let captured_in_cb = captured.clone();
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
acceptor.set_private_key(&key).unwrap();
acceptor.set_certificate(&cert).unwrap();
acceptor.set_servername_callback(move |ssl, _alert| {
*captured_in_cb.lock().unwrap() = ssl.servername(NameType::HOST_NAME).map(str::to_owned);
Ok(())
});
let acceptor = acceptor.build();
std::thread::spawn(move || {
if let Ok((stream, _)) = listener.accept() {
let _ = acceptor.accept(stream);
}
});
(port, captured)
}
fn config_for(port: u16, ca: &X509, tls_server_name: Option<&str>) -> Config {
let mut config = Config::new(format!("https://127.0.0.1:{port}").parse().unwrap());
config.root_cert = Some(vec![ca.to_der().unwrap()]);
config.tls_server_name = tls_server_name.map(str::to_owned);
config
}
fn connector_for(config: &Config) -> hyper_openssl::client::legacy::HttpsConnector<HttpConnector> {
let mut http = HttpConnector::new();
http.enforce_http(false);
config.openssl_https_connector_with_connector(http).unwrap()
}
#[tokio::test]
async fn tls_server_name_drives_sni_and_verification() {
let server_name = "kubernetes.example.com";
let (cert, key) = self_signed_cert(San::Dns(server_name));
let (port, captured_sni) = spawn_tls_server(cert.clone(), key);
let config = config_for(port, &cert, Some(server_name));
let uri: http::Uri = config.cluster_url.clone();
connector_for(&config)
.oneshot(uri)
.await
.expect("handshake should succeed when verification targets tls_server_name");
assert_eq!(
captured_sni.lock().unwrap().as_deref(),
Some(server_name),
"ClientHello SNI must equal tls_server_name, not the connection host"
);
}
#[tokio::test]
async fn without_tls_server_name_verification_uses_connection_host() {
let server_name = "kubernetes.example.com";
let (cert, key) = self_signed_cert(San::Dns(server_name));
let (port, _captured_sni) = spawn_tls_server(cert.clone(), key);
let config = config_for(port, &cert, None);
let uri: http::Uri = config.cluster_url.clone();
let result = connector_for(&config).oneshot(uri).await;
assert!(
result.is_err(),
"handshake must fail when the cert does not match the connection host"
);
}
#[tokio::test]
async fn tls_server_name_as_ip_verifies_without_sni() {
let (cert, key) = self_signed_cert(San::Ip("127.0.0.1"));
let (port, captured_sni) = spawn_tls_server(cert.clone(), key);
let config = config_for(port, &cert, Some("127.0.0.1"));
let uri: http::Uri = config.cluster_url.clone();
connector_for(&config)
.oneshot(uri)
.await
.expect("handshake should succeed when the IP tls_server_name matches the cert");
assert_eq!(
*captured_sni.lock().unwrap(),
None,
"SNI must not be sent for an IP tls_server_name"
);
}
#[tokio::test]
async fn accept_invalid_certs_skips_verification() {
let (cert, key) = self_signed_cert(San::Dns("kubernetes.example.com"));
let (port, _captured_sni) = spawn_tls_server(cert.clone(), key);
let mut config = config_for(port, &cert, None);
config.accept_invalid_certs = true;
let uri: http::Uri = config.cluster_url.clone();
connector_for(&config)
.oneshot(uri)
.await
.expect("handshake should succeed when accept_invalid_certs disables verification");
}
}
mod private {
pub trait Sealed {}
impl Sealed for super::Config {}
}
impl ConfigExt for Config {
fn base_uri_layer(&self) -> BaseUriLayer {
BaseUriLayer::new(self.cluster_url.clone())
}
fn auth_layer(&self) -> Result<Option<AuthLayer>> {
Ok(match Auth::try_from(&self.auth_info).map_err(Error::Auth)? {
Auth::None => None,
Auth::Basic(user, pass) => Some(AuthLayer(Either::Left(
AddAuthorizationLayer::basic(&user, pass.expose_secret()).as_sensitive(true),
))),
Auth::Bearer(token) => Some(AuthLayer(Either::Left(
AddAuthorizationLayer::bearer(token.expose_secret()).as_sensitive(true),
))),
Auth::RefreshableToken(refreshable) => {
Some(AuthLayer(Either::Right(AsyncFilterLayer::new(refreshable))))
}
Auth::Certificate(_client_certificate_data, _client_key_data, _) => None,
})
}
fn extra_headers_layer(&self) -> Result<ExtraHeadersLayer> {
let mut headers = self.headers.clone();
if let Some(impersonate_user) = &self.auth_info.impersonate {
headers.push((
HeaderName::from_static("impersonate-user"),
HeaderValue::from_str(impersonate_user)
.map_err(http::Error::from)
.map_err(Error::HttpError)?,
));
}
if let Some(impersonate_groups) = &self.auth_info.impersonate_groups {
for group in impersonate_groups {
headers.push((
HeaderName::from_static("impersonate-group"),
HeaderValue::from_str(group)
.map_err(http::Error::from)
.map_err(Error::HttpError)?,
));
}
}
Ok(ExtraHeadersLayer {
headers: Arc::new(headers),
})
}
#[cfg(feature = "rustls-tls")]
fn rustls_client_config(&self) -> Result<rustls::ClientConfig> {
let identity = match self.exec_identity_pem().0 {
Some(identity) => Some(identity),
None => self.identity_pem()?,
};
let mut config = tls::rustls_tls::rustls_client_config(
identity.as_deref(),
self.root_cert.as_deref(),
self.accept_invalid_certs,
)
.map_err(Error::RustlsTls)?;
if !self.accept_invalid_certs
&& let Some(path) = &self.root_cert_file
{
let verifier =
tls::rustls_tls::ReloadingVerifier::new(path.clone()).map_err(Error::RustlsTls)?;
config
.dangerous()
.set_certificate_verifier(Arc::new(verifier));
}
Ok(config)
}
#[cfg(feature = "rustls-tls")]
fn rustls_https_connector(&self) -> Result<hyper_rustls::HttpsConnector<HttpConnector>> {
let mut connector = HttpConnector::new();
connector.enforce_http(false);
self.rustls_https_connector_with_connector(connector)
}
#[cfg(feature = "rustls-tls")]
fn rustls_https_connector_with_connector<H>(
&self,
connector: H,
) -> Result<hyper_rustls::HttpsConnector<H>> {
use hyper_rustls::FixedServerNameResolver;
use crate::client::tls::rustls_tls;
let rustls_config = self.rustls_client_config()?;
let mut builder = hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(rustls_config)
.https_or_http();
if let Some(tsn) = self.tls_server_name.as_ref() {
builder = builder.with_server_name_resolver(FixedServerNameResolver::new(
tsn.clone()
.try_into()
.map_err(rustls_tls::Error::InvalidServerName)
.map_err(Error::RustlsTls)?,
));
}
Ok(builder.enable_http1().wrap_connector(connector))
}
#[cfg(feature = "openssl-tls")]
fn openssl_ssl_connector_builder(&self) -> Result<openssl::ssl::SslConnectorBuilder> {
let identity = match self.exec_identity_pem().0 {
Some(identity) => Some(identity),
None => self.identity_pem()?,
};
tls::openssl_tls::ssl_connector_builder(identity.as_ref(), self.root_cert.as_ref())
.map_err(|e| Error::OpensslTls(tls::openssl_tls::Error::CreateSslConnector(e)))
}
#[cfg(feature = "openssl-tls")]
fn openssl_https_connector(
&self,
) -> Result<hyper_openssl::client::legacy::HttpsConnector<HttpConnector>> {
let mut connector = HttpConnector::new();
connector.enforce_http(false);
self.openssl_https_connector_with_connector(connector)
}
#[cfg(feature = "openssl-tls")]
fn openssl_https_connector_with_connector<H>(
&self,
connector: H,
) -> Result<hyper_openssl::client::legacy::HttpsConnector<H>>
where
H: tower::Service<http::Uri> + Send,
H::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
H::Future: Send + 'static,
H::Response: Read + Write + hyper_util::client::legacy::connect::Connection + Unpin,
{
let mut https = hyper_openssl::client::legacy::HttpsConnector::with_connector(
connector,
self.openssl_ssl_connector_builder()?,
)
.map_err(|e| Error::OpensslTls(tls::openssl_tls::Error::CreateHttpsConnector(e)))?;
let accept_invalid_certs = self.accept_invalid_certs;
let tls_server_name = self.tls_server_name.clone();
if accept_invalid_certs || tls_server_name.is_some() {
https.set_callback(move |ssl, _uri| {
if accept_invalid_certs {
ssl.set_verify(openssl::ssl::SslVerifyMode::NONE);
}
if let Some(name) = &tls_server_name {
use std::net::IpAddr;
use openssl::x509::verify::X509CheckFlags;
ssl.set_use_server_name_indication(false);
ssl.set_verify_hostname(false);
if name.parse::<IpAddr>().is_err() {
ssl.set_hostname(name)?;
}
let param = ssl.param_mut();
param.set_hostflags(X509CheckFlags::NO_PARTIAL_WILDCARDS);
match name.parse::<IpAddr>() {
Ok(ip) => param.set_ip(ip)?,
Err(_) => param.set_host(name)?,
}
}
Ok(())
});
}
Ok(https)
}
}
impl Config {
pub(crate) fn exec_identity_pem(&self) -> (Option<Vec<u8>>, Option<Timestamp>) {
match Auth::try_from(&self.auth_info) {
Ok(Auth::Certificate(client_certificate_data, client_key_data, expiration)) => {
const NEW_LINE: u8 = b'\n';
let mut buffer = client_key_data.expose_secret().as_bytes().to_vec();
buffer.push(NEW_LINE);
buffer.extend_from_slice(client_certificate_data.as_bytes());
buffer.push(NEW_LINE);
(Some(buffer), expiration)
}
_ => (None, None),
}
}
}