dispatch-proxy 0.2.4

A SOCKS proxy that balances traffic between network interfaces.
use std::{
    collections::HashMap,
    fmt::{Display, Formatter},
    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
    num::NonZeroUsize,
    str::FromStr,
    sync::Arc,
};

use color_eyre::Help;
use eyre::{Context, Result};
use network_interface::NetworkInterfaceConfig;
use tokio::sync::Mutex;
use tracing::instrument;

use crate::net::get_valid_addresses;

use super::Dispatch;

#[derive(Clone, Debug)]
pub struct RawWeightedAddress {
    interface: RawInterface,
    weight: NonZeroUsize,
}

impl FromStr for RawWeightedAddress {
    type Err = eyre::Report;

    fn from_str(src: &str) -> Result<Self> {
        let mut items = src.split('/');

        let interface: RawInterface = items.next().unwrap().parse()?;

        let weight = match items.next() {
            Some(priority) => priority.parse()?,
            None => NonZeroUsize::new(1).unwrap(),
        };

        Ok(RawWeightedAddress { interface, weight })
    }
}

#[derive(Clone, Debug)]
pub struct RawInterface(String);

impl RawInterface {
    pub fn as_str(&self) -> &str {
        &self.0
    }
}

impl FromStr for RawInterface {
    type Err = eyre::Report;

    fn from_str(src: &str) -> Result<Self> {
        Ok(RawInterface(src.to_string()))
    }
}

impl Display for RawInterface {
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
        self.0.fmt(f)
    }
}

#[derive(Clone, Debug)]
pub enum Interface {
    Named {
        name: String,
        ipv4: Option<Ipv4Addr>,
        ipv6: Option<Ipv6Addr>,
    },
    Ip(IpAddr),
}

#[derive(Clone, Debug)]
pub struct WeightedAddress {
    interface: Interface,
    weight: NonZeroUsize,
}

impl WeightedAddress {
    pub fn resolve(addresses: Vec<RawWeightedAddress>) -> Result<Vec<WeightedAddress>> {
        let interfaces = network_interface::NetworkInterface::show()?;
        let interfaces_by_name = interfaces
            .iter()
            .map(|interface| (interface.name.as_str(), interface))
            .collect::<HashMap<_, _>>();

        let mut resolved = Vec::with_capacity(addresses.len());

        'interfaces: for RawWeightedAddress { interface, weight } in addresses {
            if let Some(net_interface) = interfaces_by_name.get(interface.as_str()) {
                let mut ipv4_addrs = vec![];
                let mut ipv6_addrs = vec![];

                let addresses = get_valid_addresses(&net_interface.addr);

                for addr in addresses {
                    match addr {
                        IpAddr::V4(v4) => {
                            ipv4_addrs.push(v4);
                        }
                        IpAddr::V6(v6) => {
                            ipv6_addrs.push(v6);
                        }
                    }
                }

                let ipv4 = if let Some(ipv4) = ipv4_addrs.into_iter().next() {
                    if ipv4.is_loopback() {
                        return Err(eyre::eyre!(
                            "Local address `{}` is a loopback address",
                            ipv4
                        ));
                    }
                    Some(ipv4)
                } else {
                    None
                };

                let ipv6 = if let Some(ipv6) = ipv6_addrs.into_iter().next() {
                    if ipv6.is_loopback() {
                        return Err(eyre::eyre!(
                            "Local address `{}` is a loopback address",
                            ipv6
                        ));
                    }
                    Some(ipv6)
                } else {
                    None
                };

                if ipv4.is_none() && ipv6.is_none() {
                    return Err(eyre::eyre!(
                        "No IP addresses found for network interface `{}`",
                        net_interface.name
                    ));
                }

                resolved.push(WeightedAddress {
                    interface: Interface::Named {
                        name: net_interface.name.clone(),
                        ipv4,
                        ipv6,
                    },
                    weight,
                });
                continue 'interfaces;
            }

            let ip: IpAddr = interface.as_str().parse().with_context(|| {
                format!(
                    "Failed to parse `{}` as an IP address or network interface name",
                    interface.as_str()
                )
            })?;

            resolved.push(WeightedAddress {
                interface: Interface::Ip(ip),
                weight,
            });
        }

        Ok(resolved)
    }
}

impl Display for WeightedAddress {
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
        match &self.interface {
            Interface::Named { name, ipv4, ipv6 } => {
                f.write_fmt(format_args!("{name}/{}", self.weight))?;
                if let Some(ipv4) = ipv4 {
                    f.write_fmt(format_args!(" ({ipv4})"))?;
                }
                if let Some(ipv6) = ipv6 {
                    f.write_fmt(format_args!(" ({ipv6})"))?;
                }
            }
            Interface::Ip(ip) => {
                f.write_fmt(format_args!("{ip}/{}", self.weight))?;
            }
        }
        Ok(())
    }
}

#[derive(Clone, Debug)]
pub struct WeightedIp {
    ip: IpAddr,
    weight: NonZeroUsize,
}

#[derive(Debug)]
struct WeightedRoundRobinDispatcherInner {
    ipv4: State,
    ipv6: State,
}

#[derive(Debug)]
struct State {
    ips: Vec<WeightedIp>,
    ip_idx: usize,
    count: usize,
}

impl WeightedRoundRobinDispatcherInner {
    fn new(addresses: Vec<WeightedAddress>) -> WeightedRoundRobinDispatcherInner {
        debug_assert!(
            !addresses.is_empty(),
            "dispatcher should have at least one address"
        );

        let mut ipv4s = vec![];
        let mut ipv6s = vec![];

        for address in addresses {
            match address.interface {
                Interface::Named { ipv4, ipv6, .. } => {
                    if let Some(ipv4) = ipv4 {
                        ipv4s.push(WeightedIp {
                            ip: IpAddr::V4(ipv4),
                            weight: address.weight,
                        });
                    }
                    if let Some(ipv6) = ipv6 {
                        ipv6s.push(WeightedIp {
                            ip: IpAddr::V6(ipv6),
                            weight: address.weight,
                        });
                    }
                }
                Interface::Ip(ip) => match ip {
                    IpAddr::V4(v4) => ipv4s.push(WeightedIp {
                        ip: IpAddr::V4(v4),
                        weight: address.weight,
                    }),
                    IpAddr::V6(v6) => ipv6s.push(WeightedIp {
                        ip: IpAddr::V6(v6),
                        weight: address.weight,
                    }),
                },
            }
        }

        WeightedRoundRobinDispatcherInner {
            ipv4: State {
                ips: ipv4s,
                ip_idx: 0,
                count: 0,
            },
            ipv6: State {
                ips: ipv6s,
                ip_idx: 0,
                count: 0,
            },
        }
    }

    fn dispatch(&mut self, remote_addr: &SocketAddr) -> Result<IpAddr> {
        let state = self.select_state(remote_addr)?;

        let ip = &state.ips[state.ip_idx];

        state.count += 1;
        if state.count == usize::from(ip.weight) {
            state.count = 0;
            state.ip_idx = (state.ip_idx + 1) % state.ips.len();
        }

        Ok(ip.ip)
    }

    fn select_state(&mut self, remote_addr: &SocketAddr) -> Result<&mut State> {
        let state = match remote_addr.ip() {
            IpAddr::V4(_) => &mut self.ipv4,
            IpAddr::V6(_) => &mut self.ipv6,
        };

        if state.ips.is_empty() {
            return Err(eyre::eyre!(
                "Address type mismatch: no configured local address or interface can connect to \
                remote address `{}` ({}) because the address types are incompatible",
                remote_addr,
                addr_type(remote_addr.ip())
            )
            .suggestion(format!(
                "Please ensure that the local addresses or network interfaces you have \
                configured support {}",
                addr_type(remote_addr.ip())
            ))
            .suggestion(
                "As a last resort, you can try to disable IPv6 support in the settings of your main \
                network interface to force your OS to use IPv4 everywhere",
            ));
        }

        Ok(state)
    }
}

#[derive(Debug, Clone)]
pub struct WeightedRoundRobinDispatcher(Arc<Mutex<WeightedRoundRobinDispatcherInner>>);

impl WeightedRoundRobinDispatcher {
    pub fn new(addresses: Vec<WeightedAddress>) -> WeightedRoundRobinDispatcher {
        WeightedRoundRobinDispatcher(Arc::new(Mutex::new(
            WeightedRoundRobinDispatcherInner::new(addresses),
        )))
    }
}

#[async_trait::async_trait]
impl Dispatch for WeightedRoundRobinDispatcher {
    #[instrument]
    async fn dispatch(&self, remote_addr: &SocketAddr) -> Result<IpAddr> {
        let mut dispatcher = self.0.lock().await;
        dispatcher.dispatch(remote_addr)
    }
}

fn addr_type(addr: IpAddr) -> &'static str {
    match addr {
        IpAddr::V4(_) => "IPv4",
        IpAddr::V6(_) => "IPv6",
    }
}