use firewall::builder::{AcceptDenyOverride, Firewall, TlsAccept};
use firewall::openssl::init;
use firewall::{Accept, ClientHello};
use openssl::asn1::Asn1Time;
use openssl::bn::BigNum;
use openssl::error::ErrorStack;
use openssl::hash::MessageDigest;
use openssl::pkey::{PKey, Private};
use openssl::rsa::Rsa;
use openssl::ssl::{HandshakeError, SslAcceptor, SslMethod, SslVerifyMode};
use openssl::x509::extension::{ExtendedKeyUsage, KeyUsage, SubjectAlternativeName};
use openssl::x509::{X509NameBuilder, X509};
use std::fmt::{Debug, Display, Formatter};
use std::io::{copy, sink, Error as IOError, Write};
use std::net::{TcpListener, TcpStream};
use std::thread;
#[derive(Debug)]
enum Error {
OpenSSL(ErrorStack),
Handshake(HandshakeError<TcpStream>),
IO(IOError),
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Error::OpenSSL(err) => write!(f, "{:?}", err),
Error::Handshake(err) => write!(f, "{:?}", err),
Error::IO(err) => write!(f, "{:?}", err),
}
}
}
impl From<ErrorStack> for Error {
fn from(value: ErrorStack) -> Self {
Self::OpenSSL(value)
}
}
impl From<HandshakeError<TcpStream>> for Error {
fn from(value: HandshakeError<TcpStream>) -> Self {
Self::Handshake(value)
}
}
impl From<IOError> for Error {
fn from(value: IOError) -> Self {
Self::IO(value)
}
}
fn generate_self_signed_certificates() -> Result<(PKey<Private>, X509), Error> {
let pkey: PKey<Private> = Rsa::generate(2048)?.try_into()?;
let mut builder = X509::builder()?;
builder.set_version(2)?;
builder.set_serial_number(BigNum::from_u32(1)?.to_asn1_integer()?.as_ref())?;
let mut name = X509NameBuilder::new()?;
name.append_entry_by_text("CN", "localhost")?;
let name = name.build();
builder.set_issuer_name(&name)?;
builder.set_subject_name(&name)?;
builder.set_pubkey(&pkey)?;
builder.set_not_before(Asn1Time::days_from_now(0)?.as_ref())?;
builder.set_not_after(Asn1Time::days_from_now(365)?.as_ref())?;
let mut san1 = SubjectAlternativeName::new();
san1.dns("localhost");
let extension1 = san1.build(&builder.x509v3_context(None, None))?;
builder.append_extension(extension1)?;
builder.append_extension(KeyUsage::new().digital_signature().build()?)?;
builder.append_extension(ExtendedKeyUsage::new().server_auth().build()?)?;
builder.sign(&pkey, MessageDigest::sha512())?;
let cert = builder.build();
Ok((pkey, cert))
}
struct RequireHttp1 {}
impl TlsAccept for RequireHttp1 {
fn accept(&self, client_hello: impl ClientHello) -> AcceptDenyOverride {
if client_hello.has_alpn(b"http/1.1") || client_hello.has_alpn(b"http/1.0") {
AcceptDenyOverride::Accept
} else {
AcceptDenyOverride::Deny
}
}
}
fn main() -> Result<(), Error> {
let firewall = Firewall::default()
.require_sni()
.allow_server_name("localhost")
.try_allow_ip("127.0.0.1")
.unwrap()
.try_allow_ip("::1")
.unwrap()
.with_exception(RequireHttp1 {});
start_server(firewall)
}
fn start_server(firewall: Firewall<RequireHttp1>) -> Result<(), Error> {
let mut acceptor = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
let (pkey, cert) = generate_self_signed_certificates()?;
acceptor.set_private_key(&pkey)?;
acceptor.set_certificate(&cert)?;
acceptor.set_verify(SslVerifyMode::NONE);
let index = init(&mut acceptor, vec![b"http/1.1"])?;
let acceptor = acceptor.build();
let listener = TcpListener::bind("127.0.0.1:443")?;
println!("https://localhost");
loop {
let (tcp_stream, remote_addr) = listener.accept()?;
if let Ok(mut tls_stream) = acceptor.accept(tcp_stream) {
let ssl = tls_stream.ssl();
if firewall.accept(remote_addr.ip(), Some((ssl, index))) {
thread::spawn(move || {
let _ = tls_stream
.write(
b"\
HTTP/1.1 200 OK\r\n\
Cache-Control: no-cache\r\n\
Connection: close\r\n\
Content-Type: text/plain;charset=UTF-8\r\n\
Content-Length: 2\r\n\
\r\n\
OK\
",
)
.unwrap();
let _ = copy(&mut tls_stream, &mut sink()).unwrap();
let _ = tls_stream.shutdown();
});
}
}
}
}