mcrx-core 0.2.3

Runtime-agnostic and portable multicast receiver library for IPv4 and IPv6 ASM/SSM.
Documentation
use std::ffi::CString;
use std::net::IpAddr;

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ReceiveCliArgs {
    pub(crate) group: IpAddr,
    pub(crate) dst_port: u16,
    pub(crate) source: Option<IpAddr>,
    pub(crate) interface: Option<IpAddr>,
    pub(crate) interface_index: Option<u32>,
}

type ParsedInterface = (Option<IpAddr>, Option<u32>);
type ParsedSourceAndInterface = (Option<IpAddr>, Option<IpAddr>, Option<u32>);

pub(crate) fn parse_receive_cli_args(args: &[String]) -> Result<ReceiveCliArgs, String> {
    if args.len() < 3 {
        return Err("invalid arguments".to_string());
    }

    let group = parse_ip("group", &args[1])?;
    let dst_port = parse_port(&args[2])?;
    let remainder = &args[3..];

    let (source, interface, interface_index) = parse_mixed_args(group, remainder)?;

    if !group.is_multicast() {
        return Err(format!("group address {group} is not multicast"));
    }

    Ok(ReceiveCliArgs {
        group,
        dst_port,
        source,
        interface,
        interface_index,
    })
}

fn parse_flag_args(
    group: IpAddr,
    remainder: &[String],
) -> Result<ParsedSourceAndInterface, String> {
    let mut source = None;
    let mut interface = None;
    let mut interface_index = None;
    let mut index = 0usize;

    while index < remainder.len() {
        match remainder[index].as_str() {
            "--source" => {
                let value = remainder
                    .get(index + 1)
                    .ok_or_else(|| "missing value after --source".to_string())?;
                source = Some(parse_ip("source", value)?);
                index += 2;
            }
            "--interface" => {
                let value = remainder
                    .get(index + 1)
                    .ok_or_else(|| "missing value after --interface".to_string())?;
                let parsed = parse_interface_value(group, value)?;
                interface = parsed.0;
                interface_index = parsed.1;
                index += 2;
            }
            other => {
                return Err(format!("unexpected argument '{other}'"));
            }
        }
    }

    Ok((source, interface, interface_index))
}

fn parse_mixed_args(
    group: IpAddr,
    remainder: &[String],
) -> Result<ParsedSourceAndInterface, String> {
    let mut positional = Vec::new();
    let mut flagged = Vec::new();
    let mut index = 0usize;

    while index < remainder.len() {
        if remainder[index].starts_with("--") {
            flagged.push(remainder[index].clone());
            let value = remainder
                .get(index + 1)
                .ok_or_else(|| format!("missing value after {}", remainder[index]))?;
            flagged.push(value.clone());
            index += 2;
        } else {
            positional.push(remainder[index].clone());
            index += 1;
        }
    }

    let (mut source, mut interface, mut interface_index) = parse_flag_args(group, &flagged)?;

    let mut positional = positional.into_iter();

    if source.is_none()
        && let Some(value) = positional.next()
    {
        source = Some(parse_ip("source", &value)?);
    }

    if interface.is_none()
        && let Some(value) = positional.next()
    {
        let parsed = parse_interface_value(group, &value)?;
        interface = parsed.0;
        interface_index = parsed.1;
    }

    if positional.next().is_some() {
        return Err("invalid arguments".to_string());
    }

    Ok((source, interface, interface_index))
}

fn parse_ip(name: &str, value: &str) -> Result<IpAddr, String> {
    value
        .parse::<IpAddr>()
        .map_err(|err| format!("invalid {name} '{value}': {err}"))
}

fn parse_interface_value(group: IpAddr, value: &str) -> Result<ParsedInterface, String> {
    if group.is_ipv6() {
        if let Some((addr, scope)) = value.rsplit_once('%') {
            let addr = addr
                .parse::<std::net::Ipv6Addr>()
                .map_err(|err| format!("invalid interface '{value}': {err}"))?;
            let scope = parse_interface_scope(scope)?;
            return Ok((Some(IpAddr::V6(addr)), Some(scope)));
        }

        if value.chars().all(|ch| ch.is_ascii_digit()) {
            let scope = parse_interface_scope(value)?;
            return Ok((None, Some(scope)));
        }
    }

    Ok((Some(parse_ip("interface", value)?), None))
}

fn parse_interface_scope(value: &str) -> Result<u32, String> {
    if value.chars().all(|ch| ch.is_ascii_digit()) {
        let scope = value
            .parse::<u32>()
            .map_err(|err| format!("invalid interface index '{value}': {err}"))?;
        if scope == 0 {
            return Err("interface index must not be 0".to_string());
        }
        return Ok(scope);
    }

    interface_name_to_index(value)
        .map_err(|err| format!("invalid interface scope '{value}': {err}"))
}

fn interface_name_to_index(name: &str) -> Result<u32, String> {
    let name =
        CString::new(name).map_err(|_| "interface name must not contain NUL bytes".to_string())?;

    #[cfg(windows)]
    unsafe {
        use windows_sys::Win32::NetworkManagement::IpHelper::if_nametoindex;

        let index = if_nametoindex(name.as_ptr().cast());
        if index == 0 {
            return Err("unknown interface name".to_string());
        }
        Ok(index)
    }

    #[cfg(not(windows))]
    unsafe {
        let index = libc::if_nametoindex(name.as_ptr());
        if index == 0 {
            return Err("unknown interface name".to_string());
        }
        Ok(index)
    }
}

fn parse_port(value: &str) -> Result<u16, String> {
    let port = value
        .parse::<u16>()
        .map_err(|err| format!("invalid dst_port '{value}': {err}"))?;

    if port == 0 {
        return Err("dst_port must not be 0".to_string());
    }

    Ok(port)
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::net::{Ipv4Addr, Ipv6Addr};

    fn argv(parts: &[&str]) -> Vec<String> {
        parts.iter().map(|part| (*part).to_string()).collect()
    }

    #[test]
    fn parses_legacy_positional_asm_args() {
        let args = argv(&["mcrx-recv", "239.1.2.3", "5000"]);

        let parsed = parse_receive_cli_args(&args).unwrap();

        assert_eq!(parsed.group, IpAddr::V4(Ipv4Addr::new(239, 1, 2, 3)));
        assert_eq!(parsed.dst_port, 5000);
        assert_eq!(parsed.source, None);
        assert_eq!(parsed.interface, None);
        assert_eq!(parsed.interface_index, None);
    }

    #[test]
    fn parses_flagged_interface_for_ipv6_asm() {
        let args = argv(&["mcrx-recv-meta", "ff01::1234", "5000", "--interface", "::1"]);

        let parsed = parse_receive_cli_args(&args).unwrap();

        assert_eq!(
            parsed.group,
            IpAddr::V6("ff01::1234".parse::<Ipv6Addr>().unwrap())
        );
        assert_eq!(parsed.dst_port, 5000);
        assert_eq!(parsed.source, None);
        assert_eq!(parsed.interface, Some(IpAddr::V6(Ipv6Addr::LOCALHOST)));
        assert_eq!(parsed.interface_index, None);
    }

    #[test]
    fn parses_flagged_source_and_interface() {
        let args = argv(&[
            "mcrx-recv",
            "232.1.2.3",
            "5000",
            "--source",
            "192.168.1.10",
            "--interface",
            "192.168.1.20",
        ]);

        let parsed = parse_receive_cli_args(&args).unwrap();

        assert_eq!(
            parsed.source,
            Some(IpAddr::V4("192.168.1.10".parse::<Ipv4Addr>().unwrap()))
        );
        assert_eq!(
            parsed.interface,
            Some(IpAddr::V4("192.168.1.20".parse::<Ipv4Addr>().unwrap()))
        );
        assert_eq!(parsed.interface_index, None);
    }

    #[test]
    fn parses_positional_source_with_flagged_interface() {
        let args = argv(&[
            "mcrx-recv-meta",
            "ff12::1234",
            "5000",
            "fd06:ba51:f296:0:1caf:6b66:e6f7:4b10",
            "--interface",
            "fd06:ba51:f296:0:1caf:6b66:e6f7:4b10",
        ]);

        let parsed = parse_receive_cli_args(&args).unwrap();

        assert_eq!(
            parsed.source,
            Some(IpAddr::V6(
                "fd06:ba51:f296:0:1caf:6b66:e6f7:4b10"
                    .parse::<Ipv6Addr>()
                    .unwrap()
            ))
        );
        assert_eq!(
            parsed.interface,
            Some(IpAddr::V6(
                "fd06:ba51:f296:0:1caf:6b66:e6f7:4b10"
                    .parse::<Ipv6Addr>()
                    .unwrap()
            ))
        );
        assert_eq!(parsed.interface_index, None);
    }

    #[test]
    fn parses_scoped_ipv6_interface() {
        let args = argv(&[
            "mcrx-recv-meta",
            "ff32::8000:1234",
            "5000",
            "--interface",
            "fe80::1%7",
        ]);

        let parsed = parse_receive_cli_args(&args).unwrap();

        assert_eq!(
            parsed.interface,
            Some(IpAddr::V6("fe80::1".parse().unwrap()))
        );
        assert_eq!(parsed.interface_index, Some(7));
    }

    #[cfg(any(
        target_os = "macos",
        target_os = "linux",
        target_os = "android",
        target_os = "freebsd",
        target_os = "openbsd",
        target_os = "netbsd",
        target_os = "dragonfly"
    ))]
    #[test]
    fn parses_scoped_ipv6_interface_with_name() {
        #[cfg(any(
            target_os = "macos",
            target_os = "freebsd",
            target_os = "openbsd",
            target_os = "netbsd",
            target_os = "dragonfly"
        ))]
        const LOOPBACK_INTERFACE: &str = "lo0";
        #[cfg(any(target_os = "linux", target_os = "android"))]
        const LOOPBACK_INTERFACE: &str = "lo";

        let scoped_interface = format!("fe80::1%{LOOPBACK_INTERFACE}");
        let args = argv(&[
            "mcrx-recv-meta",
            "ff32::8000:1234",
            "5000",
            "--interface",
            &scoped_interface,
        ]);

        let parsed = parse_receive_cli_args(&args).unwrap();

        assert_eq!(
            parsed.interface,
            Some(IpAddr::V6("fe80::1".parse().unwrap()))
        );
        assert!(parsed.interface_index.unwrap() > 0);
    }

    #[test]
    fn parses_numeric_ipv6_interface_index() {
        let args = argv(&[
            "mcrx-recv-meta",
            "ff3e::8000:1234",
            "5000",
            "--interface",
            "9",
        ]);

        let parsed = parse_receive_cli_args(&args).unwrap();

        assert_eq!(parsed.interface, None);
        assert_eq!(parsed.interface_index, Some(9));
    }
}