use crate::config::{Tls, TlsIdentity};
use anyhow::{Context, Result};
use fxhash::FxHashMap;
use log::{debug, info, warn};
use parking_lot::Mutex;
use rustls_pki_types::{
pem::{PemObject, SectionKind},
PrivateKeyDer,
};
use smallvec::SmallVec;
use std::{
collections::{BTreeMap, HashMap},
fmt, mem,
sync::{Arc, LazyLock},
};
use x509_parser::prelude::GeneralName;
pub(crate) fn load_certs(
path: &str,
) -> Result<Vec<rustls_pki_types::CertificateDer<'static>>> {
use std::{fs, io::BufReader};
Ok(rustls_pemfile::certs(&mut BufReader::new(fs::File::open(path)?))
.map(|r| r.map_err(anyhow::Error::from))
.collect::<Result<Vec<_>>>()?)
}
#[derive(Debug)]
pub struct Names {
pub cn: String,
pub alt_name: String,
}
pub(crate) fn get_names(cert: &[u8]) -> Result<Option<Names>> {
let (_, cert) = x509_parser::parse_x509_certificate(&cert)?;
let cn = cert
.subject()
.iter_common_name()
.next()
.and_then(|cn| cn.as_str().ok().map(String::from));
let mut alt_names: SmallVec<[String; 4]> = cert
.subject_alternative_name()?
.map(|alt_name| &alt_name.value.general_names)
.unwrap_or(&vec![])
.into_iter()
.filter_map(|gn| match gn {
GeneralName::DNSName(d) => Some(String::from(*d)),
GeneralName::DirectoryName(_)
| GeneralName::EDIPartyName(_)
| GeneralName::IPAddress(_)
| GeneralName::OtherName(_, _)
| GeneralName::RFC822Name(_)
| GeneralName::RegisteredID(_)
| GeneralName::URI(_)
| GeneralName::X400Address(_)
| GeneralName::Invalid(_, _) => None,
})
.collect();
let alt_name = if alt_names.len() == 0 {
bail!("certificate is missing subjectAltName")
} else if alt_names.len() > 1 {
bail!("certificate must have exactly 1 subjectAltName for use with netidx")
} else {
alt_names.pop().unwrap()
};
Ok(cn.map(|cn| Names { cn, alt_name }))
}
static CACHED: LazyLock<Mutex<FxHashMap<String, String>>> =
LazyLock::new(|| Mutex::new(HashMap::default()));
pub fn pre_cache_password(path: &str, password: &str) {
CACHED.lock().insert(path.into(), password.into());
}
pub fn clear_cached_passwords() {
for (_, p) in CACHED.lock().drain() {
let mut v = p.into_bytes();
for i in 0..v.len() {
v[i] = 0;
}
}
}
pub fn load_key_password(askpass: Option<&str>, path: &str) -> Result<String> {
use keyring::Entry;
use std::process::Command;
info!("loading password for {} from the system keyring", path);
let mut cache = CACHED.lock();
match cache.get(path) {
Some(pass) => Ok(pass.into()),
None => {
let entry = Entry::new("netidx", path)?;
match entry.get_password() {
Ok(password) => {
cache.insert(path.into(), password.clone());
Ok(password)
}
Err(e) => match askpass {
None => {
bail!("password isn't in the keychain and no askpass specified")
}
Some(askpass) => {
info!(
"failed to find password entry for netidx {}, error {}",
path, e
);
let res = Command::new(askpass).arg(path).output()?;
let password = String::from_utf8_lossy(&res.stdout);
let password = password.trim_matches(|c| c == '\r' || c == '\n');
if let Err(e) = entry.set_password(password) {
warn!(
"failed to set password entry for netidx {}, error {}",
path, e
);
}
let password = String::from(password);
cache.insert(path.into(), password.clone());
Ok(password)
}
},
}
}
}
}
pub fn save_password_for_key(path: &str, password: &str) -> Result<()> {
use keyring::Entry;
let entry = Entry::new("netidx", path)?;
Ok(entry.set_password(password)?)
}
pub fn load_private_key(
askpass: Option<&str>,
path: &str,
) -> Result<PrivateKeyDer<'static>> {
use pkcs8::{
der::{pem::PemLabel, zeroize::Zeroize},
EncryptedPrivateKeyInfo, PrivateKeyInfo, SecretDocument,
};
debug!("reading key from {}", path);
let doc = std::fs::read_to_string(path)?;
let (label, doc) = match SecretDocument::from_pem(&doc) {
Ok((label, doc)) => (label, doc),
Err(e) => bail!("failed to load pem {}, error: {}", path, e),
};
debug!("key label is {}", label);
if label == EncryptedPrivateKeyInfo::PEM_LABEL {
let doc = match EncryptedPrivateKeyInfo::try_from(doc.as_bytes()) {
Ok(doc) => doc,
Err(e) => bail!("failed to parse encrypted key {}", e),
};
debug!("decrypting key");
let mut password = load_key_password(askpass, path)?;
let doc = match doc.decrypt(&password) {
Ok(doc) => doc,
Err(e) => bail!("failed to decrypt key {}", e),
};
password.zeroize();
let key =
PrivateKeyDer::from_pem(SectionKind::PrivateKey, Vec::from(doc.as_bytes()))
.ok_or_else(|| anyhow!("invalid key"))?;
Ok(key)
} else if label == PrivateKeyInfo::PEM_LABEL {
let key =
PrivateKeyDer::from_pem(SectionKind::PrivateKey, Vec::from(doc.as_bytes()))
.ok_or_else(|| anyhow!("invalid key"))?;
Ok(key)
} else {
bail!("expected a key in pem format")
}
}
pub(crate) fn create_tls_connector(
askpass: Option<&str>,
root_certificates: &str,
certificate: &str,
private_key: &str,
) -> Result<tokio_rustls::TlsConnector> {
let mut root_store = rustls::RootCertStore::empty();
for cert in load_certs(root_certificates).context("loading root certs")? {
root_store.add(cert).context("adding root cert")?;
}
let certs = load_certs(certificate).context("loading user cert")?;
let private_key =
load_private_key(askpass, private_key).context("loading user private key")?;
let mut config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_client_auth_cert(certs, private_key)
.context("building rustls client config")?;
config.resumption = rustls::client::Resumption::in_memory_sessions(256);
Ok(tokio_rustls::TlsConnector::from(Arc::new(config)))
}
pub(crate) fn create_tls_acceptor(
askpass: Option<&str>,
root_certificates: &str,
certificate: &str,
private_key: &str,
) -> Result<tokio_rustls::TlsAcceptor> {
let client_auth = {
debug!("creating tls client auth trust store");
let mut root_store = rustls::RootCertStore::empty();
debug!("loading CA certificates");
for cert in load_certs(root_certificates)? {
root_store.add(cert)?;
}
rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store)).build()?
};
debug!("loading server certificate");
let certs = load_certs(certificate)?;
debug!("loading server private key");
let private_key = load_private_key(askpass, private_key)?;
debug!("creating tls acceptor");
let mut config = rustls::ServerConfig::builder()
.with_client_cert_verifier(client_auth)
.with_single_cert(certs, private_key)?;
config.session_storage = rustls::server::ServerSessionMemoryCache::new(1024);
Ok(tokio_rustls::TlsAcceptor::from(Arc::new(config)))
}
pub(crate) fn get_match<'a: 'b, 'b, U>(
m: &'a BTreeMap<String, U>,
identity: &'b str,
) -> Option<&'a U> {
m.iter().find_map(|(k, v)| {
if k == identity || identity.starts_with(k) {
Some(v)
} else {
None
}
})
}
struct CachedInnerLocked<T> {
tmp: String,
cached: BTreeMap<String, T>,
}
struct CachedInner<T> {
tls: Tls,
t: Mutex<CachedInnerLocked<T>>,
}
#[derive(Clone)]
struct Cached<T>(Arc<CachedInner<T>>);
impl<T> fmt::Debug for Cached<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CachedTls")
}
}
impl<T: Clone + 'static> Cached<T> {
fn new(tls: Tls) -> Self {
Self(Arc::new(CachedInner {
tls,
t: Mutex::new(CachedInnerLocked {
tmp: String::with_capacity(256),
cached: BTreeMap::new(),
}),
}))
}
fn default_identity(&self) -> &TlsIdentity {
self.0.tls.default_identity()
}
fn get_identity(&self, id: &str) -> Option<&TlsIdentity> {
self.0.tls.identities.get(id)
}
fn load(
&self,
identity: &str,
f: fn(Option<&str>, &str, &str, &str) -> Result<T>,
) -> Result<T> {
let rev_identity = {
let mut inner = self.0.t.lock();
inner.tmp.clear();
inner.tmp.push_str(&identity);
Tls::reverse_domain_name(&mut inner.tmp);
if let Some(v) = get_match(&inner.cached, &inner.tmp) {
return Ok(v.clone());
}
mem::replace(&mut inner.tmp, String::new())
};
match get_match(&self.0.tls.identities, &rev_identity) {
None => {
self.0.t.lock().tmp = rev_identity;
bail!("no plausible identity matches {}", identity)
}
Some(TlsIdentity { name: _, trusted, certificate, private_key }) => {
let askpass = self.0.tls.askpass.as_ref().map(|s| s.as_str());
let con = f(askpass, trusted, certificate, private_key)?;
self.0.t.lock().cached.insert(rev_identity, con.clone());
Ok(con)
}
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct CachedConnector(Cached<tokio_rustls::TlsConnector>);
impl CachedConnector {
pub(crate) fn new(cfg: Tls) -> Self {
Self(Cached::new(cfg))
}
pub(crate) fn load(&self, identity: &str) -> Result<tokio_rustls::TlsConnector> {
self.0.load(identity, create_tls_connector)
}
pub(crate) fn default_identity(&self) -> &TlsIdentity {
self.0.default_identity()
}
pub(crate) fn get_identity(&self, name: &str) -> Option<&TlsIdentity> {
self.0.get_identity(name)
}
}
#[derive(Debug, Clone)]
pub(crate) struct CachedAcceptor(Cached<tokio_rustls::TlsAcceptor>);
impl CachedAcceptor {
pub(crate) fn new(tls: Tls) -> Self {
Self(Cached::new(tls))
}
pub(crate) fn load(
&self,
identity: Option<&str>,
) -> Result<tokio_rustls::TlsAcceptor> {
let identity = identity.unwrap_or_else(|| &self.0.default_identity().name);
self.0.load(identity, create_tls_acceptor)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_get_match() {
let mut m = BTreeMap::new();
m.insert("com.mydomain.".to_string(), 1);
m.insert("com.mydomain.foo.".to_string(), 2);
let r = get_match(&m, "com.mydomain.bar.").copied();
assert_eq!(r, Some(1));
let r = get_match(&m, "com.mydomain.qux.").copied();
assert_eq!(r, Some(1));
}
}