ws-markdown-server 1.0.2

A simple Websocket (as well as a Unix socket) server that serves on-the-fly compilation of Markdown to HTML.
use clap::Parser;
use comrak::{markdown_to_html, ComrakOptions};
use futures_util::{SinkExt, StreamExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufStream};
use tokio_tungstenite::WebSocketStream;

use std::net::ToSocketAddrs;
use std::{
    net::SocketAddr,
    sync::{Arc, Mutex},
};
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::tungstenite::Result;
use tracing::{debug, error, info, trace};

#[derive(Parser, Debug, Clone)]
#[clap(
    author = "Aalekh Patel <aalekh.gwpeck.7998@icloud.com>",
    about = "A simple markdown-to-html websocket server backed by comrak and tokio-tungstenite.",
    version = "1.0.0"
)]

pub struct Opts {
    #[clap(short, long, default_value = "0.0.0.0", env = "MD_SERVER_HOST")]
    pub host: String,
    #[clap(short, long, default_value_t = 9003, env = "MD_SERVER_WS_PORT")]
    pub ws_port: u16,
    #[clap(short, long, default_value = None, env = "MD_SERVER_TCP_PORT")]
    pub port: Option<u16>,
    #[clap(short, long, default_value_t = 9005, env = "MD_HEALTHCHECK_PORT")]
    pub healthcheck_port: u16,
}

#[derive(Debug, Clone, Default)]
pub struct MarkdownEngine {
    pub options: ComrakOptions,
}

impl MarkdownEngine {
    pub fn new() -> Self {
        Self::default()
    }

    #[inline(always)]
    #[tracing::instrument(level = "trace", skip(self))]
    pub fn render(&self, md: &str) -> String {
        debug!(text_len = md.len(), "Rendering markdown...");
        markdown_to_html(md, &self.options)
    }
}

pub type SharedState<T> = Arc<Mutex<T>>;
pub type Engine = SharedState<MarkdownEngine>;

pub async fn accept_ws_connection(
    stream: TcpStream,
    engine: Engine,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    let addr = stream.peer_addr()?;
    let ws_stream = tokio_tungstenite::accept_async(stream).await?;
    debug!("Accepted WS connection from {}", addr);

    if let Err(e) = handle_ws_connection(addr, ws_stream, engine).await {
        match e {
            tokio_tungstenite::tungstenite::Error::ConnectionClosed
            | tokio_tungstenite::tungstenite::Error::AlreadyClosed
            | tokio_tungstenite::tungstenite::Error::Protocol(_) => {
                info!("Connection from {} closed", addr);
            }
            err => {
                error!("Error handling connection from {}: {}", addr, err);
            }
        }
    }
    Ok(())
}

pub async fn accept_socket_connection(
    stream: TcpStream,
    engine: Engine,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    let addr = stream.peer_addr()?;
    debug!("Accepted WS connection from {}", addr);

    if let Err(e) = handle_socket_connection(addr, stream, engine).await {
        error!("Error handling connection from {}: {}", addr, e);
    }
    Ok(())
}

pub async fn handle_ws_connection(
    peer: SocketAddr,
    mut ws_stream: WebSocketStream<TcpStream>,
    engine: Engine,
) -> Result<()> {
    while let Some(msg) = ws_stream.next().await {
        let msg = msg?;
        if msg.is_text() || msg.is_binary() {
            trace!("Received text/binary message from {}: {:?}", peer, msg);
            let msg = msg.to_text()?;
            let outbound_msg = engine.lock().unwrap().render(msg);
            ws_stream.send(outbound_msg.into()).await?;
        }
    }
    Ok(())
}

pub async fn handle_socket_connection(
    peer: SocketAddr,
    mut stream: TcpStream,
    engine: Engine,
) -> Result<()> {
    debug!("Accepted TcpStream from {}", peer);
    let (mut read_half, mut write_half) = stream.split();
    loop {
        let mut buf = String::new();
        let n = read_half.read_to_string(&mut buf).await?;
        if n == 0 {
            return Ok(());
        }
        let outbound_msg = engine.lock().unwrap().render(&buf);
        write_half.write_all(outbound_msg.as_bytes()).await?;
    }
}

pub async fn create_ws_server<A: ToSocketAddrs>(
    addr: A,
    engine: Engine,
) -> Result<(), Box<dyn std::error::Error>> {
    let addr = addr.to_socket_addrs()?.next().unwrap();
    let listener = TcpListener::bind(&addr).await?;
    info!("Listening for WebSocket connections on {}", addr);

    while let Ok((stream, _)) = listener.accept().await {
        tokio::spawn(accept_ws_connection(stream, engine.clone()));
    }
    Ok(())
}

pub async fn create_tcp_server<A: ToSocketAddrs>(
    addr: A,
    engine: Engine,
) -> Result<(), Box<dyn std::error::Error>> {
    let addr = addr.to_socket_addrs()?.next().unwrap();
    let listener = TcpListener::bind(&addr).await?;
    info!("Listening for TCP (Unix) socket connections on {}", addr);

    while let Ok((stream, _)) = listener.accept().await {
        tokio::spawn(accept_socket_connection(stream, engine.clone()));
    }
    Ok(())
}

pub async fn create_healthcheck_http_server(port: u16) -> Result<(), Box<dyn std::error::Error>> {
    let listener = TcpListener::bind(format!("0.0.0.0:{}", port).parse::<SocketAddr>()?).await?;
    info!("Listening for healthcheck requests on port {}", port);
    

    while let Ok((mut stream, _)) = listener.accept().await {
        stream.writable().await?;
        tokio::spawn(async move {
            let buf = b"HTTP/1.1 200 OK\r\nContent-Type: application/json; charset=UTF-8\r\n\r\n{\"status\":\"ok\"}\r\n";
            let _ = stream.write(buf).await;
            stream.shutdown().await.expect("Faield to shutdown stream.");
        });
    }

    Ok(())
}

#[tokio::main]
pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let opts = Opts::parse();
    tracing_subscriber::fmt::init();
    let engine = Arc::new(Mutex::new(MarkdownEngine::new()));
    let ws_addr = format!("{}:{}", opts.host, opts.ws_port);

    let task1 = create_ws_server(ws_addr, engine.clone());

    let task2 = opts.port.map(|port| {
        let addr = format!("{}:{}", opts.host, port);
        create_tcp_server(addr, engine.clone())
    });
    let task3 = create_healthcheck_http_server(opts.healthcheck_port);

    if let Some(task2) = task2 {
        _ = tokio::join!(task1, task2, task3);
    } else {
        _ = tokio::join!(task1, task3);
    }
    Ok(())
}