t2_bus 0.1.0

An inter- or intra-process message bus supporting publish/subscribe and request/response.
Documentation
use std::collections::HashSet;
use std::{fmt::Display, net::SocketAddr, path::PathBuf};
use clap::{command, Parser, Subcommand};
use serde::{Deserialize, Serialize};
use t2_bus::prelude::*;
use regex::Regex;
use std::net::AddrParseError;

pub const DEFAULT_BUS_ADDR: &str = ".t2";
pub const DEFAULT_BUS_PORT: u16 = 4242;
const BUS_ADDR_NAME_RGX: &str = r"^[a-z_]+$";
const BUS_ADDR_RGX: &str = r"^(tcp|unix|name):(.+)$";
const BUS_ADDR_CONFIG_RGX: &str = r"(.+?) (tcp|unix):(.*?)\n";

#[derive(Debug)]
enum ResolvedBusAddr{
    Tcp(String),
    Unix(PathBuf)
}

impl Display for ResolvedBusAddr{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self{
            ResolvedBusAddr::Tcp(addr) => write!(f, "tcp:{addr}"),
            ResolvedBusAddr::Unix(addr) => write!(f, "unix:{}", addr.display()),
        }
    }
}

fn validate_bus_addr_name(s: &str) -> Result<String, String> {
    let valid = Regex::new(BUS_ADDR_NAME_RGX).unwrap().is_match(s);
    match valid {
        true => Ok(s.to_string()),
        false => Err("Invalid bus address name, must contain only lowercase and underscore.".to_string()),
    }
}

fn validate_bus_addr(s: &str) -> Result<String, String> {
    let valid = Regex::new(BUS_ADDR_RGX).unwrap().is_match(s);
    match valid {
        true => Ok(s.to_string()),
        false => Err("Invalid bus address. Must be in the format of (tcp|unix|name):<address or name>".to_string()),
    }
}

#[derive(Parser)]
#[command(version = "1.0", author = "Felix Watts", about = "Utilities related to the t2 bus.")]
struct Cli {
    #[command(subcommand)]
    command: Commands,
}

#[derive(Subcommand)]
enum Commands {
    /// Start a bus server
    Serve {
        /// One or more addresses to serve on e.g. `tcp:127.0.0.1` or `unix:my_bus` or `name:bus`
        #[arg(value_parser = validate_bus_addr)]
        addr: Vec<String>
    },
    /// Connect to a bus server, subscribe to a topic and print received messages to the console
    Sub{
        /// The topic to subscribe to
        topic: String,
        /// The address of the bus server to connect to e.g. `tcp:127.0.0.1` or `unix:my_bus` or `name:bus`. If none is provided the default will be used
        #[arg(value_parser = validate_bus_addr)]
        addr: Option<String>
    },
    /// Connect to a bus server and publish a StringProtocol message to a topic. This utility only supports `f32` and `String` type messages.
    Pub{
        /// The topic to publish on. It must start with `f32/` or `string/` to denote which protocol it should be published on.
        topic: String,
        /// The message to publish
        value: String,
        /// The address of the bus server to connect to e.g. `tcp:127.0.0.1` or `unix:my_bus` or `name:bus`. If none is provided the default will be used
        #[arg(value_parser = validate_bus_addr)]
        addr: Option<String>
    },
    /// Connect to a bus, subscribe to the given topic and list the topics of any messages received
    Ls{
        /// The topic to subscribe to
        topic: String,
        /// The address of the bus server to connect to e.g. `tcp:127.0.0.1` or `unix:my_bus` or `name:bus`. If none is provided the default will be used
        #[arg(value_parser = validate_bus_addr)]
        addr: Option<String>
    },
    /// Register a bus server by a name for easy connection in future
    Register{
        /// The name to give the server. You can then use this name as a bus address argument to this program
        #[arg(value_parser = validate_bus_addr_name)]
        name: String,
        /// The address of the bus server e.g. `tcp:127.0.0.1` or `unix:my_bus`
        #[arg(value_parser = validate_bus_addr)]
        addr: String,
        /// True if this should be your default bus server. The default server will be used for any commands
        /// where the bus address is not specified
        #[arg(long)]
        default: bool
    },
    /// Unregister a registered bus server
    Unregister{
        /// The name of the server to unregister
        #[arg(value_parser = validate_bus_addr_name)]
        name: String
    }
}

impl Commands{
    fn validate(&self) -> Result<(), Error> {
        match self{
            Commands::Serve { .. } => Ok(()),
            Commands::Sub { ..} => Ok(()),
            Commands::Ls { ..} => Ok(()),
            Commands::Pub { topic, value, .. } => {
                if !(topic.starts_with("f32/") || topic.starts_with("string/")) {
                    return Err(Error("Unknown protocol".into()))
                }

                if topic.starts_with("f32/") && value.parse::<f32>().is_err() {
                    return Err(Error("When the topic starts with f32/ then the value must be a valid f32".into()))
                }

                Ok(())
            },
            Commands::Register { .. } => Ok(()),
            Commands::Unregister { .. } => Ok(()),
        }
    }
}

struct Error(String);

impl From<std::io::Error> for Error{
    fn from(value: std::io::Error) -> Self {
        Self(value.to_string())
    }
}

impl From<BusError> for Error{
    fn from(value: BusError) -> Self {
        Self(value.to_string())
    }
}

impl From<AddrParseError> for Error{
    fn from(value: AddrParseError) -> Self {
        Self(value.to_string())
    }
}

impl Display for Error{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", &self.0)
    }
}

#[tokio::main]
async fn main() {
    if let Err(e) = run().await {
        println!("{}", &e.to_string());
    }
}

async fn run() -> Result<(), Error> {
    let cli = Cli::parse();
    cli.command.validate()?;
    match cli.command {
        Commands::Serve { addr } => {
            let mut builder = t2_bus::prelude::ServerBuilder::new();

            for addr in addr.into_iter() {
                let resolved_addr = resolve_addr(&Some(addr))?;
                match resolved_addr {
                    ResolvedBusAddr::Tcp(addr) => {
                        builder = builder.serve_tcp(addr.parse::<SocketAddr>()?);
                    },
                    ResolvedBusAddr::Unix(addr) => {
                        builder = builder.serve_unix_socket(addr);
                    },
                }
            }

            let (stopper, _) = builder.build().await?;

            stopper.join().await?;
        },
        Commands::Sub { addr, topic } => {
            let client = build_client(&addr).await?;

            let mut sub = client.subscribe_bytes(&topic).await?;
            while let Some(msg) = sub.recv().await {
                let val_str = if msg.topic.starts_with("f32/") {
                    let bytes: Vec<u8> = msg.payload.into();
                    let payload: F32Protocol = t2_bus::transport::cbor_codec::deser(&bytes[..])?;
                    payload.0.to_string()
                } else if msg.topic.starts_with("string/") {
                    let bytes: Vec<u8> = msg.payload.into();
                    let payload: StringProtocol = t2_bus::transport::cbor_codec::deser(&bytes[..])?;
                    payload.0
                } else {
                    let bytes: Vec<u8> = msg.payload.into();
                    format!("0x{}", &hex::encode(bytes))
                };

                println!("{}: {val_str}", msg.topic)
            }
        },
        Commands::Ls { addr, topic } => {
            let client = build_client(&addr).await?;
            let mut encountered_topics = HashSet::new();

            let mut sub = client.subscribe_bytes(&topic).await?;
            while let Some(PubMsg{ topic, .. }) = sub.recv().await {
                if !encountered_topics.contains(&topic) {
                    println!("{topic}");
                    encountered_topics.insert(topic);
                }
            }
        },
        Commands::Pub { addr, topic, value } => {
            let client = build_client(&addr).await?;
            let payload = if topic.starts_with("f32/") {
                t2_bus::transport::cbor_codec::ser(&F32Protocol(value.parse().unwrap()))?

            } else {
                t2_bus::transport::cbor_codec::ser(&StringProtocol(value.parse().unwrap()))?
            };

            client.publish_bytes(&topic, payload).await?;
        },
        Commands::Register { name, addr, default } => {
            let home = match std::env::var("HOME") {
                Ok(home) => home,
                Err(_) => return Err(Error("HOME environment variable not set".into()))
            };
            let resolved_addr = resolve_addr(&Some(addr))?;
            let mut config = parse_config()?;
            config.retain(|(n, _)| n != &name);

            if default {
                config.insert(0, (name, resolved_addr));
            } else {
                config.push((name, resolved_addr));
            }

            let config_str = config
                .iter()
                .map(|(name, addr)| format!("{name} {addr}\n"))
                .collect::<Vec<_>>()
                .join("");
            let path = PathBuf::from(home).join(".t2");
            std::fs::write(path, config_str)?;
        },
        Commands::Unregister { name } => {
            let home = match std::env::var("HOME") {
                Ok(home) => home,
                Err(_) => return Err(Error("HOME environment variable not set".into()))
            };
            let mut config = parse_config()?;
            config.retain(|(n, _)| n != &name);
            let config_str = config
                .iter()
                .map(|(name, addr)| format!("{name} {addr}\n"))
                .collect::<Vec<_>>()
                .join("");
            let path = PathBuf::from(home).join(".t2");
            std::fs::write(path, config_str)?;
        }
    }

    Ok(())
}

async fn build_client(addr: &Option<String>) -> Result<Client, Error>{
    let resolved_addr = resolve_addr(addr)?;

    match resolved_addr {
        ResolvedBusAddr::Tcp(addr) => {
            Ok(t2_bus::transport::tcp::connect(addr.parse::<SocketAddr>()?).await?)
        },
        ResolvedBusAddr::Unix(addr) => {
            Ok(t2_bus::transport::unix::connect(&addr).await?)
        }
    }
}

fn resolve_addr(addr: &Option<String>) -> Result<ResolvedBusAddr, Error>{
    match addr{
        Some(addr) => {
            let matches = regex::Regex::new(BUS_ADDR_RGX).unwrap().captures(addr).unwrap();
            let typ = matches.get(1).unwrap().as_str();
            let addr = matches.get(2).unwrap().as_str();
            match typ{
                "tcp" => Ok(ResolvedBusAddr::Tcp(addr.to_string())),
                "unix" => Ok(ResolvedBusAddr::Unix(PathBuf::from(addr))),
                "name" => {
                    let config = parse_config()?;
                    let addr = config.into_iter().find(|(name, _)| name == addr).ok_or(Error(format!("Name is not registered: {addr}")))?;
                    Ok(addr.1)
                },
                _ => Err(Error("Invalid address type".into()))
            }
        },
        None => {
            let mut config = parse_config()?;
            if config.is_empty() {
                return Err(Error("No default address found in config. Use the register command to add one or specify an address".into()));
            }
            let addr = config.remove(0).1;
            Ok(addr)
        }
    }
}

fn parse_config() -> Result<Vec<(String, ResolvedBusAddr)>, Error>{
    let home = match std::env::var("HOME") {
        Ok(home) => home,
        Err(_) => return Ok(Vec::new()),
    };

    let path = PathBuf::from(home).join(".t2");
    if !path.exists() {
        return Ok(Vec::new());
    }
    
    let config = std::fs::read_to_string(path).unwrap();
    
    let addrs= regex::Regex::new(BUS_ADDR_CONFIG_RGX)
        .unwrap()
        .captures_iter(&config)
        .map(|m| {
            let name = m.get(1).unwrap().as_str();
            let addr = m.get(3).unwrap().as_str();
            let addr_type = m.get(2).unwrap().as_str();
            let addr = match addr_type{
                "tcp" => ResolvedBusAddr::Tcp(addr.to_string()),
                "unix" => ResolvedBusAddr::Unix(PathBuf::from(addr)),
                _ => panic!()
            };
            (name.to_string(), addr)
        })
        .collect::<Vec<_>>();

    Ok(addrs)
}

#[derive(Serialize, Deserialize, Clone, Debug)]
struct F32Protocol(f32);

impl PublishProtocol for F32Protocol{
    fn prefix() -> &'static str {
        "f32"
    }
}

#[derive(Serialize, Deserialize, Clone, Debug)]
struct StringProtocol(String);

impl PublishProtocol for StringProtocol{
    fn prefix() -> &'static str {
        "string"
    }
}