use std::fs::File;
use std::io::BufReader;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use futures_util::{FutureExt, SinkExt, StreamExt};
use log::{debug, error, info, warn};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tokio_rustls::TlsAcceptor;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::accept_async;
#[derive(Debug, Clone)]
pub enum WsMessage {
Text(String),
Binary(Vec<u8>),
Close,
}
impl From<Message> for WsMessage {
fn from(msg: Message) -> Self {
match msg {
Message::Text(text) => WsMessage::Text(text),
Message::Binary(data) => WsMessage::Binary(data),
Message::Close(_) => WsMessage::Close,
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {
WsMessage::Text("".to_string())
}
}
}
}
impl From<WsMessage> for Message {
fn from(msg: WsMessage) -> Self {
match msg {
WsMessage::Text(text) => Message::Text(text),
WsMessage::Binary(data) => Message::Binary(data),
WsMessage::Close => Message::Close(None),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ClientId(pub String);
#[derive(Debug, Clone)]
pub struct WsServerConfig {
pub addr: String,
pub cert_path: PathBuf,
pub key_path: PathBuf,
pub ca_cert_path: PathBuf,
pub max_connections: usize,
pub connection_timeout: u64,
pub client_cert_required: bool,
}
impl Default for WsServerConfig {
fn default() -> Self {
Self {
addr: "127.0.0.1:9000".to_string(),
cert_path: PathBuf::from("./crate_cert/a_cert.pem"),
key_path: PathBuf::from("./crate_cert/a_key.pem"),
ca_cert_path: PathBuf::from("./crate_cert/ca_cert.pem"),
max_connections: 1000,
connection_timeout: 30,
client_cert_required: true,
}
}
}
pub trait ServerHandler: Send + Sync + 'static {
fn on_connect(&self, client_id: ClientId, addr: SocketAddr);
fn on_disconnect(&self, client_id: ClientId);
fn on_message(&self, client_id: ClientId, message: WsMessage) -> Option<WsMessage>;
fn on_error(&self, client_id: Option<ClientId>, error: String);
}
struct ClientConnection {
client_id: ClientId,
tx: mpsc::Sender<WsMessage>,
}
pub struct WsServer {
config: WsServerConfig,
handler: Arc<dyn ServerHandler>,
tls_acceptor: TlsAcceptor,
clients: Arc<tokio::sync::Mutex<Vec<ClientConnection>>>,
}
impl WsServer {
pub fn new(config: WsServerConfig, handler: impl ServerHandler) -> Result<Self, String> {
let tls_acceptor = Self::create_tls_acceptor(&config)
.map_err(|e| format!("Failed to create TLS acceptor: {}", e))?;
Ok(Self {
config,
handler: Arc::new(handler),
tls_acceptor,
clients: Arc::new(tokio::sync::Mutex::new(Vec::new())),
})
}
pub async fn start(&self) -> Result<(), String> {
let listener = TcpListener::bind(&self.config.addr)
.await
.map_err(|e| format!("Failed to bind to address {}: {}", self.config.addr, e))?;
info!("WebSocket server started on {}", self.config.addr);
loop {
match listener.accept().await {
Ok((stream, addr)) => {
debug!("New TCP connection from: {}", addr);
let acceptor = self.tls_acceptor.clone();
let handler = self.handler.clone();
let clients = self.clients.clone();
let connection_timeout = Duration::from_secs(self.config.connection_timeout);
let client_id = ClientId(format!("client-{}", uuid_simple()));
let client_id_clone = client_id.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(
stream,
addr,
acceptor,
handler,
clients,
client_id_clone,
connection_timeout
).await {
error!("Connection error for {}: {}", addr, e);
}
});
}
Err(e) => {
error!("Failed to accept connection: {}", e);
}
}
let client_count = self.clients.lock().await.len();
if client_count >= self.config.max_connections {
warn!("Maximum connections reached: {}", client_count);
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
pub async fn broadcast(&self, message: WsMessage) -> Result<usize, String> {
let clients = self.clients.lock().await;
let mut sent_count = 0;
for client in clients.iter() {
if client.tx.send(message.clone()).await.is_ok() {
sent_count += 1;
}
}
Ok(sent_count)
}
pub async fn send_to_client(&self, client_id: &ClientId, message: WsMessage) -> Result<(), String> {
let clients = self.clients.lock().await;
for client in clients.iter() {
if client.client_id == *client_id {
return client.tx.send(message)
.await
.map_err(|_| format!("Failed to send message to client {}", client_id.0));
}
}
Err(format!("Client not found: {}", client_id.0))
}
pub async fn client_count(&self) -> usize {
self.clients.lock().await.len()
}
pub async fn client_list(&self) -> Vec<ClientId> {
let clients = self.clients.lock().await;
clients.iter().map(|c| c.client_id.clone()).collect()
}
async fn handle_connection(
stream: TcpStream,
addr: SocketAddr,
acceptor: TlsAcceptor,
handler: Arc<dyn ServerHandler>,
clients: Arc<tokio::sync::Mutex<Vec<ClientConnection>>>,
client_id: ClientId,
connection_timeout: Duration,
) -> Result<(), String> {
let tls_handshake = tokio::time::timeout(
connection_timeout,
acceptor.accept(stream),
).await
.map_err(|_| format!("TLS handshake timed out after {} seconds", connection_timeout.as_secs()))?
.map_err(|e| format!("TLS handshake failed: {}", e))?;
debug!("TLS handshake successful for {}", addr);
let ws_stream = tokio::time::timeout(
connection_timeout,
accept_async(tls_handshake),
).await
.map_err(|_| format!("WebSocket handshake timed out after {} seconds", connection_timeout.as_secs()))?
.map_err(|e| format!("WebSocket handshake failed: {}", e))?;
debug!("WebSocket handshake successful for {}", addr);
let (tx, mut rx) = mpsc::channel::<WsMessage>(100);
{
let mut clients_lock = clients.lock().await;
clients_lock.push(ClientConnection {
client_id: client_id.clone(),
tx: tx.clone(),
});
info!("Client connected: {} from {}", client_id.0, addr);
}
handler.on_connect(client_id.clone(), addr);
let (ws_sender, ws_receiver) = ws_stream.split();
let mut send_task = {
let mut ws_sender = ws_sender;
let client_id_for_send = client_id.clone();
let handler_for_send = handler.clone();
async move {
while let Some(msg) = rx.recv().await {
match ws_sender.send(msg.into()).await {
Ok(_) => {
debug!("Message sent to client {}", client_id_for_send.0);
}
Err(e) => {
let error_msg = format!("Failed to send message: {}", e);
handler_for_send.on_error(Some(client_id_for_send.clone()), error_msg);
break;
}
}
}
let _ = ws_sender.close().await;
debug!("Send task completed for client {}", client_id_for_send.0);
}.boxed()
};
let mut receive_task = {
let mut ws_receiver = ws_receiver;
let handler_for_recv = handler.clone();
let client_id_for_recv = client_id.clone();
let tx_for_recv = tx.clone();
async move {
while let Some(result) = ws_receiver.next().await {
match result {
Ok(msg) => {
if msg.is_close() {
debug!("Client {} requested close", client_id_for_recv.0);
break;
}
let ws_msg = WsMessage::from(msg);
if let Some(response) = handler_for_recv.on_message(client_id_for_recv.clone(), ws_msg) {
if tx_for_recv.send(response).await.is_err() {
break;
}
}
}
Err(e) => {
let error_msg = format!("Error receiving message: {}", e);
handler_for_recv.on_error(Some(client_id_for_recv.clone()), error_msg);
break;
}
}
}
debug!("Receive task completed for client {}", client_id_for_recv.0);
}.boxed()
};
tokio::select! {
_ = &mut send_task => {},
_ = &mut receive_task => {},
}
Self::remove_client(clients, client_id.clone()).await;
handler.on_disconnect(client_id.clone());
info!("Client disconnected: {} from {}", client_id.0, addr);
Ok(())
}
async fn remove_client(
clients: Arc<tokio::sync::Mutex<Vec<ClientConnection>>>,
client_id: ClientId,
) {
let mut clients_lock = clients.lock().await;
if let Some(pos) = clients_lock.iter().position(|c| c.client_id == client_id) {
clients_lock.remove(pos);
}
}
fn create_tls_acceptor(config: &WsServerConfig) -> Result<TlsAcceptor, Box<dyn std::error::Error>> {
info!("Loading certificates and keys...");
let certs = load_certs(&config.cert_path)?;
let key = load_private_key(&config.key_path)?;
let ca_certs = load_certs(&config.ca_cert_path)?;
let mut root_cert_store = rustls::RootCertStore::empty();
for cert in ca_certs {
root_cert_store.add(cert)?;
}
let server_config = if config.client_cert_required {
let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_cert_store))
.build()?;
rustls::ServerConfig::builder()
.with_client_cert_verifier(client_verifier)
.with_single_cert(certs, key)?
} else {
rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?
};
Ok(TlsAcceptor::from(Arc::new(server_config)))
}
}
fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut certs = Vec::new();
for cert_result in rustls_pemfile::certs(&mut reader) {
let cert = cert_result?;
certs.push(cert);
}
if certs.is_empty() {
return Err(format!("No certificates found in {}", path.display()).into());
}
Ok(certs)
}
fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut pkcs8_keys = Vec::new();
for key_result in rustls_pemfile::pkcs8_private_keys(&mut reader) {
pkcs8_keys.push(key_result?);
}
if !pkcs8_keys.is_empty() {
return Ok(PrivateKeyDer::Pkcs8(pkcs8_keys.remove(0)));
}
reader = BufReader::new(File::open(path)?);
let mut rsa_keys = Vec::new();
for key_result in rustls_pemfile::rsa_private_keys(&mut reader) {
rsa_keys.push(key_result?);
}
if !rsa_keys.is_empty() {
return Ok(PrivateKeyDer::Pkcs1(rsa_keys.remove(0)));
}
reader = BufReader::new(File::open(path)?);
let mut ec_keys = Vec::new();
for key_result in rustls_pemfile::ec_private_keys(&mut reader) {
ec_keys.push(key_result?);
}
if !ec_keys.is_empty() {
return Ok(PrivateKeyDer::Sec1(ec_keys.remove(0)));
}
Err(format!("No private keys found in {}", path.display()).into())
}
fn uuid_simple() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
format!(
"{:x}{:x}",
now.as_secs(),
now.subsec_nanos()
)
}