use std::sync::{Arc, OnceLock};
use nu_protocol::ShellError;
use nu_protocol::shell_error::generic::GenericError;
use rustls::crypto::CryptoProvider;
use ureq::tls::{RootCerts, TlsConfig};
#[derive(Debug)]
pub struct NuCryptoProvider(OnceLock<Result<Arc<CryptoProvider>, ShellError>>);
pub static CRYPTO_PROVIDER: NuCryptoProvider = NuCryptoProvider(OnceLock::new());
impl NuCryptoProvider {
pub fn get(&self) -> Result<Arc<CryptoProvider>, ShellError> {
match self.0.get() {
Some(val) => val.clone(),
None => Err(ShellError::Generic(
GenericError::new_internal(
"tls crypto provider not found",
"no crypto provider for rustls was defined",
)
.with_help("ensure that nu_command::tls::CRYPTO_PROVIDER is set"),
)),
}
}
pub fn set(&self, f: impl FnOnce() -> Result<CryptoProvider, ShellError>) -> bool {
let value = f().map(Arc::new);
self.0.set(value).is_ok()
}
pub fn default(&self) -> bool {
self.set(|| Ok(rustls::crypto::ring::default_provider()))
}
}
#[doc = include_str!("./tls_config.rustdoc.md")]
pub fn tls_config(allow_insecure: bool) -> Result<TlsConfig, ShellError> {
let crypto_provider = CRYPTO_PROVIDER.get()?;
let config = match allow_insecure {
false => {
#[cfg(all(feature = "os", not(target_os = "android")))]
let certs = RootCerts::PlatformVerifier;
#[cfg(all(feature = "os", target_os = "android"))]
let certs = native_certs();
#[cfg(not(feature = "os"))]
let certs = RootCerts::WebPki;
TlsConfig::builder()
.unversioned_rustls_crypto_provider(crypto_provider)
.root_certs(certs)
.build()
}
true => TlsConfig::builder().disable_verification(true).build(),
};
Ok(config)
}
#[cfg(feature = "os")]
pub fn native_certs() -> RootCerts {
use rustls_native_certs::CertificateResult;
use ureq::tls::Certificate;
let CertificateResult { certs, errors, .. } = rustls_native_certs::load_native_certs();
debug_assert!(
errors.is_empty(),
"encountered errors while loading tls certificates"
);
let certs: Vec<_> = certs
.into_iter()
.map(|cert| Certificate::from_der(&cert).to_owned())
.collect();
let certs = Arc::new(certs);
RootCerts::Specific(certs)
}