use boring::asn1::Asn1Time;
use boring::hash::MessageDigest;
use boring::pkey::PKey;
use boring::rsa::Rsa;
use boring::ssl::{SslAcceptor, SslMethod};
use boring::x509::X509;
use bytes::Bytes;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
pub fn generate_self_signed_cert() -> (X509, PKey<boring::pkey::Private>) {
let rsa = Rsa::generate(2048).unwrap();
let pkey = PKey::from_rsa(rsa).unwrap();
let mut name = boring::x509::X509Name::builder().unwrap();
name.append_entry_by_text("CN", "127.0.0.1").unwrap();
let name = name.build();
let mut builder = X509::builder().unwrap();
builder.set_version(2).unwrap();
builder.set_subject_name(&name).unwrap();
builder.set_issuer_name(&name).unwrap();
builder.set_pubkey(&pkey).unwrap();
let not_before = Asn1Time::days_from_now(0).unwrap();
let not_after = Asn1Time::days_from_now(365).unwrap();
builder.set_not_before(¬_before).unwrap();
builder.set_not_after(¬_after).unwrap();
builder.sign(&pkey, MessageDigest::sha256()).unwrap();
let cert = builder.build();
(cert, pkey)
}
pub struct TlsMockServer {
pub addr: SocketAddr,
acceptor: Arc<SslAcceptor>,
listener_v4: TcpListener,
listener_v6: Option<TcpListener>,
}
#[allow(dead_code)]
impl TlsMockServer {
pub async fn start() -> Self {
let listener_v4 = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener_v4.local_addr().unwrap();
let port = addr.port();
let listener_v6 = TcpListener::bind(format!("[::1]:{}", port)).await.ok();
let (cert, pkey) = generate_self_signed_cert();
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
acceptor.set_private_key(&pkey).unwrap();
acceptor.set_certificate(&cert).unwrap();
acceptor.set_alpn_select_callback(|_, _| Ok(b"h2"));
let acceptor = Arc::new(acceptor.build());
Self {
addr,
acceptor,
listener_v4,
listener_v6,
}
}
pub async fn handle_next_h2<F, Fut>(&self, handler: F)
where
F: FnOnce(http::Request<http2::RecvStream>, http2::server::SendResponse<Bytes>) -> Fut
+ Send
+ 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let socket = match &self.listener_v6 {
Some(lv6) => {
tokio::select! {
res = self.listener_v4.accept() => res.unwrap().0,
res = lv6.accept() => res.unwrap().0,
}
}
None => self.listener_v4.accept().await.unwrap().0,
};
let ssl_stream = tokio_boring::accept(&self.acceptor, socket).await.unwrap();
let mut h2_conn = http2::server::handshake(ssl_stream).await.unwrap();
if let Some(result) = h2_conn.accept().await {
let (req, resp) = result.unwrap();
struct PollClose(
http2::server::Connection<
tokio_boring::SslStream<tokio::net::TcpStream>,
bytes::Bytes,
>,
);
impl std::future::Future for PollClose {
type Output = Result<(), http2::Error>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.0.poll_closed(cx)
}
}
tokio::spawn(async move {
let _ = PollClose(h2_conn).await;
});
handler(req, resp).await;
}
}
pub async fn handle_next_h2_multi<F, Fut>(&self, num_streams: usize, mut handler: F)
where
F: FnMut(http::Request<http2::RecvStream>, http2::server::SendResponse<Bytes>) -> Fut
+ Send
+ 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let socket = match &self.listener_v6 {
Some(lv6) => {
tokio::select! {
res = self.listener_v4.accept() => res.unwrap().0,
res = lv6.accept() => res.unwrap().0,
}
}
None => self.listener_v4.accept().await.unwrap().0,
};
let ssl_stream = tokio_boring::accept(&self.acceptor, socket).await.unwrap();
let mut h2_conn = http2::server::handshake(ssl_stream).await.unwrap();
for _ in 0..num_streams {
if let Some(result) = h2_conn.accept().await {
let (req, resp) = result.unwrap();
handler(req, resp).await;
}
}
struct PollClose(
http2::server::Connection<tokio_boring::SslStream<tokio::net::TcpStream>, bytes::Bytes>,
);
impl std::future::Future for PollClose {
type Output = Result<(), http2::Error>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.0.poll_closed(cx)
}
}
tokio::spawn(async move {
let _ = PollClose(h2_conn).await;
});
}
}