scalesocket 0.2.4

A collaborative websocket server and autoscaler
use {
    clap::builder::ArgPredicate,
    clap::{ArgAction, Parser},
    std::net::SocketAddr,
    std::ops::Range,
    std::path::PathBuf,
};

use crate::types::{Cache, Frame, Log};

const CACHE_SIZES: &[usize; 3] = &[1, 8, 64];

/// Server configuration
#[derive(Parser, Debug, Clone)]
#[clap(author, version, about, long_about = None)]
pub struct Config {
    /// Interface to bind to
    #[clap(long, default_value = "0.0.0.0:9000")]
    pub addr: SocketAddr,

    /// Set scalesocket to experimental binary mode
    #[clap(short, long, action)]
    pub binary: bool,

    /// Cache server message history for room and replay it to new clients
    ///
    /// The cache buffer retains the last <SIZE> chunks, determined by <TYPE>:
    /// When <TYPE> is `all`, all server messages are cached.
    /// When <TYPE> is `tagged`, only server messages with `_cache: true` are cached.
    #[clap(long, value_parser = parse_cache, value_name = "[TYPE:]SIZE", verbatim_doc_comment)]
    pub cache: Option<Cache>,

    #[clap(long = "cachepersist", action)]
    /// Preserve server message history for room even after last client disconnects
    pub cache_persist: bool,

    /// Delay before attaching to child
    ///
    /// [default: 1 with --tcp]
    #[clap(
        long = "delay",
        value_name = "SECONDS",
        default_value_if("tcp",  ArgPredicate::Equals("true".into()), Some("1"))
    )]
    pub delay: Option<u64>,

    /// Process output items are terminated by given characters
    ///
    /// See --null for null termination.
    #[clap(
        long,
        value_parser,
        default_value = "\n",
        default_value_if("binary", ArgPredicate::Equals("true".into()), Some("")),
        default_value_if("null", ArgPredicate::Equals("true".into()), Some("")),
        require_equals = true,
        conflicts_with = "binary",
    )]
    pub delimiters: Option<String>,

    /// Emit message to child on client connect (use #ID for id)
    #[clap(
        long = "joinmsg",
        value_name = "MSG",
        default_value_if("json",  ArgPredicate::Equals("true".into()), Some(r#"{"t":"Join","_from":#ID}"#))
    )]
    pub join_msg: Option<String>,

    /// Enable JSON framing with default join and leave messages
    ///
    /// This option is equivalent to
    /// --frame=json --joinmsg '{"t":"Join","_from":#ID}' --leavemsg '{"t":"Leave","_from":#ID}'
    #[clap(
        long,
        action,
        conflicts_with = "client_frame",
        conflicts_with = "server_frame",
        conflicts_with = "frame"
    )]
    pub json: bool,

    /// Emit message to child on client disconnect (use #ID for id)
    #[clap(
        long = "leavemsg",
        value_name = "MSG",
        default_value_if("json",  ArgPredicate::Equals("true".into()), Some(r#"{"t":"Leave","_from":#ID}"#))
    )]
    pub leave_msg: Option<String>,

    /// Log format
    ///
    /// [default: text, possible values: text, json]
    #[clap(
        long,
        action,
        value_parser,
        value_name = "FMT",
        default_value = "text",
        hide_possible_values = true,
        hide_default_value = true
    )]
    pub log: Log,

    /// Expose OpenMetrics endpoint at /metrics
    #[clap(long, action)]
    pub metrics: bool,

    /// Process output items are terminated by a null character
    #[clap(long, action)]
    pub null: bool,

    /// Serve only once
    #[clap(long)]
    pub oneshot: bool,

    /// List of envvars to pass to child
    #[clap(
        long,
        value_name = "LIST",
        value_delimiter = ',',
        default_value = "PATH,DYLD_LIBRARY_PATH"
    )]
    pub passenv: Vec<String>,

    /// List of valid rooms
    ///
    /// When set, websocket connections are only accepted on the specified paths `/<ROOM>`.
    #[clap(long, value_name = "LIST", value_delimiter = ',')]
    pub rooms: Option<Vec<String>>,

    /// Maximum number of rooms
    ///
    /// When set, websocket connections are accepted on up to <NUM> rooms.
    /// Since a child process is spawned for each room, this is equivalent to limiting the maximum number of processes.
    #[clap(
        long = "maxrooms",
        alias = "maxforks",
        value_name = "NUM",
        default_value_if("oneshot", ArgPredicate::Equals("true".into()), Some("1")),
        conflicts_with = "tcp_ports"
    )]
    pub max_rooms: Option<usize>,

    /// Enable framing and routing for all messages
    ///
    /// Client messages are tagged with an ID header (u32). Server messages with optional client ID are routed to clients.
    ///
    /// When set to `json`, messages are parsed as JSON.
    /// Client messages are amended with an "_from" field.
    /// Server messages are routed to clients based an optional "_to" field.
    ///
    /// Server messages with `_meta: true` will be dropped, and stored as room metadata accessible via the API.
    ///
    /// When set to `gwsocket`, messages are parsed according to gwsocket's strict mode.
    /// Unparseable messages may be dropped.
    ///
    /// See --serverframe and --clientframe for specifying framing independently.
    ///
    /// [default: json with --json, possible values: gwsocket, json]
    #[clap(
        long,
        value_parser,
        value_name = "MODE",
        default_missing_value = "json",
        default_value_if("json",  ArgPredicate::Equals("true".into()), Some("json")),
        num_args = 0..,
        require_equals = true,
        hide_possible_values = true
    )]
    pub frame: Option<Frame>,

    /// Enable framing and routing for client originated messages
    ///
    /// See --frame for options.
    #[clap(
        long = "clientframe",
        value_parser,
        value_name = "MODE",
        conflicts_with = "frame",
        require_equals = true,
        hide_possible_values = true
    )]
    pub client_frame: Option<Frame>,

    /// Enable framing and routing for server originated messages
    ///
    /// See --frame for options.
    #[clap(
        long = "serverframe",
        value_parser,
        value_name = "MODE",
        conflicts_with = "frame",
        require_equals = true,
        hide_possible_values = true
    )]
    pub server_frame: Option<Frame>,

    /// Serve static files from directory over HTTP
    #[clap(long = "staticdir", value_parser, value_name = "DIR")]
    pub static_dir: Option<PathBuf>,

    /// Expose room metadata API under /api/
    ///
    /// The exposed endpoints are:
    /// * /api/rooms/          - list rooms
    /// * /api/<ROOM>/         - get room metadata
    /// * /api/<ROOM>/<METRIC> - get room individual metric
    #[clap(long, action, verbatim_doc_comment)]
    pub api: bool,

    /// Connect to child using TCP instead of stdio. Use PORT to bind
    #[clap(long, action)]
    pub tcp: bool,

    /// Port range for TCP
    ///
    /// [default: 9001:9999 with --tcp]
    #[clap(
        long = "tcpports",
        value_parser = parse_ports,
        value_name = "START:END",
        requires = "tcp",
        default_value_if("tcp",  ArgPredicate::Equals("true".into()), Some("9001:9999"))
    )]
    pub tcp_ports: Option<Range<u16>>,

    /// Increase level of verbosity
    #[clap(short, action = ArgAction::Count)]
    pub verbosity: u8,

    /// Command to wrap
    #[clap(required = true)]
    pub cmd: String,

    /// Arguments to command
    #[clap(last = true)]
    pub args: Vec<String>,
}

fn parse_ports(arg: &str) -> Result<Range<u16>, &'static str> {
    if let Some((start, end)) = arg.split_once(':') {
        let range: (Option<u16>, Option<u16>) = (start.parse().ok(), end.parse().ok());
        if let (Some(start), Some(end)) = range {
            return Ok(start..end);
        }
    };
    Err("Could not parse port range")
}

fn parse_cache(arg: &str) -> Result<Cache, &'static str> {
    let params = arg
        .split_once(':')
        .map(|(t, size)| (t, size.parse()))
        .or_else(|| Some(("messages", arg.parse())));

    match params {
        Some(("all", Ok(n))) if CACHE_SIZES.contains(&n) => Ok(Cache::All(n)),
        Some(("tagged", Ok(n))) if CACHE_SIZES.contains(&n) => Ok(Cache::Tagged(n)),
        _ => Err("Expected <TYPE>:<SIZE> or <SIZE> where SIZE is 1, 8 or 64"),
    }
}