use std::{
fs::File,
io::{self, BufReader},
sync::Arc,
};
use async_tls::TlsAcceptor;
use async_tungstenite::accept_async;
use rustls::{Certificate as RustlsCertificate, PrivateKey, ServerConfig};
use rustls_pemfile::{certs, read_one, Item};
use smol::net::TcpListener;
use tokio::{sync::Mutex, task::spawn_local};
use tungstenite::Error;
use crate::{
protocol::{
errors::{Error as ShdpError, ErrorKind},
prelude::common::utils::{Certificate, Listener, DEVICES},
},
server::ws::handle_connection,
};
pub async fn listen(port: String, cert: Certificate) -> Result<(), ShdpError> {
let acceptor = match load_acceptor(cert) {
Ok(acceptor) => acceptor,
Err(e) => {
return Err(ShdpError {
code: 500,
message: format!("Error loading acceptor: {}", e),
kind: ErrorKind::InternalServerError,
})
}
};
let listener = match TcpListener::bind(format!("127.0.0.1:{}", port)).await {
Ok(listener) => listener,
Err(e) => {
return Err(ShdpError {
code: 0b1111,
message: format!("Error binding to port: {}", e),
kind: ErrorKind::Conflict,
})
}
};
DEVICES.lock().unwrap().insert(
("127.0.0.1".to_string(), port.clone()),
Listener::StdServer(listener),
);
println!("[SHDP:WS] Listening on port {}", port);
while let Ok((stream, _)) = DEVICES
.lock()
.unwrap()
.get(&("127.0.0.1".to_string(), port.clone()))
.unwrap()
.get_std_server()
.accept()
.await
{
let acceptor = acceptor.clone();
let handle = accept_async(match acceptor.accept(stream.clone()).await {
Ok(stream) => stream,
Err(e) => {
println!("[SHDP:WS] Error accepting TLS connection: {}", e);
continue;
}
})
.await;
match handle {
Ok(ws_stream) => {
spawn_local(async move {
handle_connection(Arc::new(Mutex::new(ws_stream))).await;
});
}
Err(e) => {
println!("[SHDP:WS] Error accepting WebSocket connection: {}", e);
}
}
if stream.peer_addr().is_ok() {
println!(
"[SHDP:WS] New connection from {}",
stream.peer_addr().unwrap()
);
}
}
Ok(())
}
fn load_certs(path: &str) -> io::Result<Vec<RustlsCertificate>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let certs = certs(&mut reader)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "could not read certs"))?
.into_iter()
.map(RustlsCertificate)
.collect();
Ok(certs)
}
fn load_private_key(path: &str) -> io::Result<PrivateKey> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let keys = read_one(&mut reader)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "could not read private key"))?
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "no keys found in file"))?;
match keys {
Item::RSAKey(data) | Item::PKCS8Key(data) => Ok(PrivateKey(data)),
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"unexpected key format",
)),
}
}
fn load_acceptor(cert: Certificate) -> Result<TlsAcceptor, Error> {
let certs = load_certs(&cert.cert_path)?;
let key = load_private_key(&cert.key_path)?;
let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
Ok(TlsAcceptor::from(Arc::new(config)))
}