#![cfg(all(target_os = "macos", target_arch = "aarch64"))]
use std::fmt;
use std::fs::File;
use std::io::{BufReader, ErrorKind, Read, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::Arc;
use std::time::Duration;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::{ServerConfig, ServerConnection};
use crate::devices::virtio::vsock::device::Vsock;
pub struct TlsConfig {
pub listen_addr: String,
pub vm_port: Option<u32>,
pub cert_path: String,
pub key_path: String,
}
#[derive(Debug)]
pub enum StartError {
Bind {
addr: String,
source: std::io::Error,
},
Config(String),
LocalAddr {
addr: String,
source: std::io::Error,
},
ThreadSpawn {
name: String,
source: std::io::Error,
},
}
impl fmt::Display for StartError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StartError::Bind { addr, source } => write!(f, "TLS: bind {addr}: {source}"),
StartError::Config(e) => write!(f, "TLS: load cert/key: {e}"),
StartError::LocalAddr { addr, source } => {
write!(f, "TLS: local_addr {addr}: {source}")
}
StartError::ThreadSpawn { name, source } => {
write!(f, "spawn thread {name}: {source}")
}
}
}
}
impl std::error::Error for StartError {}
pub fn start(cfg: TlsConfig, vsock: Arc<Vsock>) -> Result<(), StartError> {
let server_config =
build_server_config(&cfg.cert_path, &cfg.key_path).map_err(StartError::Config)?;
let listener = TcpListener::bind(&cfg.listen_addr).map_err(|source| StartError::Bind {
addr: cfg.listen_addr.clone(),
source,
})?;
let local = listener
.local_addr()
.map_err(|source| StartError::LocalAddr {
addr: cfg.listen_addr.clone(),
source,
})?;
eprintln!(
" TLS terminator on {local} -> guest vm_port={:?}",
cfg.vm_port
);
let server_config = Arc::new(server_config);
let vm_port_opt = cfg.vm_port;
let name = "tls-acceptor".to_string();
std::thread::Builder::new()
.name(name.clone())
.spawn(move || {
for stream in listener.incoming() {
let stream = match stream {
Ok(s) => s,
Err(e) => {
eprintln!("[tls] accept err: {e}");
continue;
}
};
let _ = stream.set_nodelay(true);
let cfg = server_config.clone();
let vsock_c = vsock.clone();
std::thread::Builder::new()
.name("tls-conn".to_string())
.spawn(move || handle_conn(stream, cfg, vsock_c, vm_port_opt))
.ok();
}
})
.map_err(|source| StartError::ThreadSpawn { name, source })?;
Ok(())
}
fn build_server_config(cert_path: &str, key_path: &str) -> Result<ServerConfig, String> {
let _ = rustls::crypto::ring::default_provider().install_default();
let cert_file = File::open(cert_path).map_err(|e| format!("open cert {cert_path}: {e}"))?;
let mut cr = BufReader::new(cert_file);
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cr)
.collect::<Result<_, _>>()
.map_err(|e| format!("parse certs: {e}"))?;
if certs.is_empty() {
return Err(format!("no certs in {cert_path}"));
}
let key_file = File::open(key_path).map_err(|e| format!("open key {key_path}: {e}"))?;
let mut kr = BufReader::new(key_file);
let key = rustls_pemfile::private_key(&mut kr)
.map_err(|e| format!("parse key: {e}"))?
.ok_or_else(|| format!("no private key in {key_path}"))?;
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, PrivateKeyDer::from(key))
.map_err(|e| format!("ServerConfig: {e}"))
}
fn handle_conn(
tls_sock: TcpStream,
cfg: Arc<ServerConfig>,
vsock: Arc<Vsock>,
vm_port: Option<u32>,
) {
let host_port = match wait_for_host_port(&vsock, vm_port) {
Some(p) => p,
None => {
eprintln!("[tls] no host port (vm_port={vm_port:?}) after wait");
return;
}
};
let plain_sock = match TcpStream::connect(("127.0.0.1", host_port)) {
Ok(s) => s,
Err(e) => {
eprintln!("[tls] connect 127.0.0.1:{host_port}: {e}");
return;
}
};
let _ = plain_sock.set_nodelay(true);
let mut conn = match ServerConnection::new(cfg) {
Ok(c) => c,
Err(e) => {
eprintln!("[tls] ServerConnection: {e}");
return;
}
};
let _ = tls_sock.set_nonblocking(true);
let _ = plain_sock.set_nonblocking(true);
if let Err(e) = pump(&mut conn, tls_sock, plain_sock) {
if std::env::var_os("SUPERMACHINE_TLS_TRACE").is_some() {
eprintln!("[tls] pump end: {e}");
}
}
}
fn pump(
conn: &mut ServerConnection,
mut tls_sock: TcpStream,
mut plain_sock: TcpStream,
) -> std::io::Result<()> {
let mut buf = [0u8; 16 * 1024];
let mut closed_tls = false;
let mut closed_plain = false;
loop {
let mut did_work = false;
if conn.wants_read() && !closed_tls {
match conn.read_tls(&mut tls_sock) {
Ok(0) => {
closed_tls = true;
}
Ok(_) => {
did_work = true;
if let Err(e) = conn.process_new_packets() {
return Err(std::io::Error::new(ErrorKind::Other, format!("{e}")));
}
}
Err(e) if e.kind() == ErrorKind::WouldBlock => {}
Err(e) => return Err(e),
}
}
loop {
match conn.reader().read(&mut buf) {
Ok(0) => break,
Ok(n) => {
did_work = true;
if plain_sock.write_all(&buf[..n]).is_err() {
closed_plain = true;
break;
}
}
Err(e) if e.kind() == ErrorKind::WouldBlock => break,
Err(_) => break,
}
}
if !closed_plain {
match plain_sock.read(&mut buf) {
Ok(0) => {
closed_plain = true;
conn.send_close_notify();
}
Ok(n) => {
did_work = true;
conn.writer().write_all(&buf[..n])?;
}
Err(e) if e.kind() == ErrorKind::WouldBlock => {}
Err(_) => {
closed_plain = true;
conn.send_close_notify();
}
}
}
if conn.wants_write() {
match conn.write_tls(&mut tls_sock) {
Ok(_) => {
did_work = true;
}
Err(e) if e.kind() == ErrorKind::WouldBlock => {}
Err(e) => return Err(e),
}
}
if closed_tls && closed_plain && !conn.wants_write() {
break;
}
if closed_tls && !conn.wants_write() && !closed_plain {
}
if !did_work {
std::thread::sleep(Duration::from_micros(200));
}
}
let _ = tls_sock.shutdown(std::net::Shutdown::Both);
let _ = plain_sock.shutdown(std::net::Shutdown::Both);
Ok(())
}
fn wait_for_host_port(vsock: &Vsock, vm_port: Option<u32>) -> Option<u16> {
let lookup = || match vm_port {
Some(p) => vsock.muxer().host_port_for_vm_port(p),
None => vsock.muxer().first_host_port(),
};
for _ in 0..50 {
if let Some(p) = lookup() {
return Some(p);
}
std::thread::sleep(Duration::from_millis(20));
}
lookup()
}