#[cfg(feature = "acme")]
use acme_lib::{create_rsa_key, Directory, DirectoryUrl, Error, persist::MemoryPersist};
use chrono::{DateTime, Utc};
use x509_certificate::X509Certificate;
use std::collections::HashMap;
use std::io::Read;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use std::{
fs::File,
io::{self, BufReader},
sync::Arc,
};
use lazy_static::lazy_static;
use rustls::{
pki_types::{CertificateDer, PrivateKeyDer},
ServerConnection,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{Accept, TlsAcceptor};
lazy_static! {
static ref CACHE_REQUEST: Mutex<HashMap<String, Instant>> = Mutex::new(HashMap::new());
static ref REQUEST_NEW_DELAY: Duration = Duration::from_secs(30);
static ref REQUEST_UPDATE_DELAY: Duration = Duration::from_secs(300);
}
#[derive(Clone)]
pub struct WrapTlsAccepter {
pub last: Instant,
pub domain: Option<String>,
pub accepter: Option<TlsAcceptor>,
pub expired: Option<DateTime<Utc>>,
pub is_acme: bool,
}
impl WrapTlsAccepter {
fn load_certs(path: &Option<String>) -> io::Result<(DateTime<Utc>, Vec<CertificateDer<'static>>)> {
if let Some(path) = path {
match File::open(&path) {
Ok(mut file) => {
let mut content = String::new();
file.read_to_string(&mut content)?;
let mut reader = BufReader::new(content.as_bytes());
let certs = rustls_pemfile::certs(&mut reader);
let cert = X509Certificate::from_pem(content.as_bytes()).map_err(|_| io::Error::new(io::ErrorKind::Other, "cert error"))?;
Ok((cert.validity_not_after(), certs.into_iter().collect::<Result<Vec<_>, _>>()?))
}
Err(e) => {
log::warn!("加载公钥{}出错,错误内容:{:?}", path, e);
return Err(e);
}
}
} else {
Err(io::Error::new(io::ErrorKind::Other, "unknow certs"))
}
}
fn load_keys(path: &Option<String>) -> io::Result<PrivateKeyDer<'static>> {
if let Some(path) = path {
match File::open(&path) {
Ok(mut file) => {
let mut content = String::new();
file.read_to_string(&mut content)?;
{
let mut reader = BufReader::new(content.as_bytes());
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader)
.collect::<Result<Vec<_>, _>>()?;
if keys.len() == 1 {
return Ok(PrivateKeyDer::from(keys.remove(0)));
}
}
{
let mut reader = BufReader::new(content.as_bytes());
let mut keys = rustls_pemfile::rsa_private_keys(&mut reader)
.collect::<Result<Vec<_>, _>>()?;
if keys.len() == 1 {
return Ok(PrivateKeyDer::from(keys.remove(0)));
}
}
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("No pkcs8 or rsa private key found"),
));
}
Err(e) => {
log::warn!("加载私钥{}出错,错误内容:{:?}", path, e);
return Err(e);
}
}
} else {
return Err(io::Error::new(io::ErrorKind::Other, "unknow keys"));
};
}
pub fn new(domain: String) -> WrapTlsAccepter {
let mut wrap = WrapTlsAccepter {
last: Instant::now(),
domain: Some(domain),
accepter: None,
expired: None,
is_acme: true,
};
wrap.try_load_cert();
wrap
}
pub fn try_load_cert(&mut self) -> bool {
match Self::load_ssl(&self.get_cert_path(), &self.get_key_path()) {
Ok((expired, accepter)) => {
self.accepter = Some(accepter);
self.expired = Some(expired);
true
}
Err(e) => {
println!("load ssl error ={:?}", e);
false
}
}
}
pub fn get_cert_path(&self) -> Option<String> {
if let Some(domain) = &self.domain {
Some(format!(".well-known/{}.pem", domain))
} else {
None
}
}
pub fn get_key_path(&self) -> Option<String> {
if let Some(domain) = &self.domain {
Some(format!(".well-known/{}.key", domain))
} else {
None
}
}
pub fn update_last(&mut self) {
if self.last.elapsed() > Duration::from_secs(5) {
self.try_load_cert();
self.last = Instant::now();
}
}
pub fn is_wait_acme(&self) -> bool {
self.accepter.is_none()
}
pub fn load_ssl(cert: &Option<String>, key: &Option<String>) -> io::Result<(DateTime<Utc>, TlsAcceptor)> {
let (expired, one_cert) = Self::load_certs(&cert)?;
let one_key = Self::load_keys(&key)?;
let config = rustls::ServerConfig::builder();
let mut config = config
.with_no_client_auth()
.with_single_cert(one_cert, one_key)
.map_err(|e| {
log::warn!("添加证书时失败:{:?}", e);
io::Error::new(io::ErrorKind::Other, "key error")
})?;
config.alpn_protocols.push("h2".as_bytes().to_vec());
config.alpn_protocols.push("http/1.1".as_bytes().to_vec());
Ok((expired, TlsAcceptor::from(Arc::new(config))))
}
pub fn new_cert(cert: &Option<String>, key: &Option<String>) -> io::Result<WrapTlsAccepter> {
let (expired, config) = Self::load_ssl(cert, key)?;
Ok(WrapTlsAccepter {
last: Instant::now(),
domain: None,
accepter: Some(config),
expired: Some(expired),
is_acme: false,
})
}
#[inline]
pub fn accept<IO>(&self, stream: IO) -> io::Result<Accept<IO>>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
self.accept_with(stream, |_| ())
}
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> io::Result<Accept<IO>>
where
IO: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(&mut ServerConnection),
{
if let Some(a) = &self.accepter {
if self.is_acme && self.is_tls_will_expired() {
let _ = self.check_and_request_cert();
}
Ok(a.accept_with(stream, f))
} else {
self.check_and_request_cert()
.map_err(|_| io::Error::new(io::ErrorKind::Other, "load https error"))?;
Err(io::Error::new(io::ErrorKind::Other, "try next https error"))
}
}
fn is_tls_will_expired(&self) -> bool {
if let Some(expire) = &self.expired {
let now = Utc::now();
if now.timestamp() > expire.timestamp() - 86400 {
return true;
}
}
false
}
#[cfg(feature = "acme")]
fn get_delay_time(&self) -> Duration {
if self.accepter.is_some() {
*REQUEST_UPDATE_DELAY
} else {
*REQUEST_NEW_DELAY
}
}
#[cfg(not (feature = "acme"))]
fn check_and_request_cert(&self) -> Result<(), io::Error> {
Ok(())
}
#[cfg(feature = "acme")]
fn check_and_request_cert(&self) -> Result<(), Error> {
if self.domain.is_none() {
return Err(io::Error::new(io::ErrorKind::Other, "未知域名").into());
}
{
let mut map = CACHE_REQUEST
.lock()
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Fail get Lock"))?;
if let Some(last) = map.get(self.domain.as_ref().unwrap()) {
if last.elapsed() < self.get_delay_time() {
return Err(io::Error::new(io::ErrorKind::Other, "等待上次请求结束").into());
}
}
map.insert(self.domain.clone().unwrap(), Instant::now());
};
let obj = self.clone();
std::thread::spawn(move || {
let _ = obj.request_cert();
});
Ok(())
}
#[cfg(feature = "acme")]
fn request_cert(&self) -> Result<(), Error> {
let url = DirectoryUrl::LetsEncrypt;
let path = std::path::Path::new(".well-known/acme-challenge");
if !path.exists() {
let _ = std::fs::create_dir_all(path);
}
let persist = MemoryPersist::new();
let dir = Directory::from_url(persist, url)?;
let acc = dir.account("wmproxy@wmproxy.net")?;
let mut ord_new = acc.new_order(&self.domain.clone().unwrap_or_default(), &[])?;
let start = Instant::now();
let ord_csr = loop {
if let Some(ord_csr) = ord_new.confirm_validations() {
break ord_csr;
}
if start.elapsed() > self.get_delay_time() {
println!("获取证书超时");
return Ok(());
}
let auths = ord_new.authorizations()?;
let chall = auths[0].http_challenge();
let token = chall.http_token();
let path = format!(".well-known/acme-challenge/{}", token);
let proof = chall.http_proof();
crate::Helper::write_to_file(&path, proof.as_bytes())?;
chall.validate(5000)?;
ord_new.refresh()?;
};
let pkey_pri = create_rsa_key(2048);
let ord_cert = ord_csr.finalize_pkey(pkey_pri, 5000)?;
let cert = ord_cert.download_and_save_cert()?;
crate::Helper::write_to_file(
&self.get_cert_path().unwrap(),
cert.certificate().as_bytes(),
)?;
crate::Helper::write_to_file(&self.get_key_path().unwrap(), cert.private_key().as_bytes())?;
Ok(())
}
}