neoray 0.0.0-alpha-1

A proxy service
Documentation
use std::collections::HashMap;
use std::path::Path;

use config::ConfigError;
use serde::Deserialize;
use serde::Serialize;
use serde_repr::Deserialize_repr;
use serde_repr::Serialize_repr;

use self::inbound::*;
use self::outbound::*;
use core::fmt;
use std::net::SocketAddr;
use std::str::FromStr;

impl Config {
    /// Reads `Config` from the provided `path` with any supported extension and also env vars
    /// sufficed with `NEORAY` and tries to deserialize them.
    pub fn read(path: &Path) -> Result<Self, ConfigError> {
        config::Config::builder()
            .add_source(config::File::with_name(&path.display().to_string()))
            .add_source(config::Environment::with_prefix("NEORAY"))
            .build()?
            .try_deserialize()
    }
}

#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct Config {
    pub version: Version,
    pub flows: HashMap<String, Flow>,
}

#[derive(Debug, Clone, Serialize_repr, Deserialize_repr, PartialEq)]
#[repr(u32)]
pub enum Version {
    V1 = 1,
}

#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct Flow {
    pub outbound_pools: Vec<OutboundPool>,
    pub inbounds: Vec<Inbound>,
}

#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum Authentication<T> {
    Const(Vec<T>),
}

#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct UserPassAuth {
    pub username: Option<String>,
    pub password: Option<String>,
}

#[derive(Debug, Clone, PartialEq)]
pub enum SocketAddress {
    Hostname { hostname: String, port: u16 },
    Ip(SocketAddr),
}

pub mod outbound {
    use super::Authentication;
    use super::SocketAddress;
    use super::UserPassAuth;
    use serde::{Deserialize, Serialize};

    #[derive(Debug, Serialize, Deserialize, PartialEq)]
    pub struct OutboundPool {
        pub outbounds: Vec<Outbound>,
    }

    #[derive(Debug, Serialize, Deserialize, PartialEq)]
    #[serde(rename_all = "snake_case")]
    pub enum Outbound {
        Http {
            address: SocketAddress,
            auth: Option<Authentication<UserPassAuth>>,
        },
        Socks {
            address: SocketAddress,
            auth: Option<Authentication<UserPassAuth>>,
        },
    }
}

pub mod inbound {
    use super::Authentication;
    use super::SocketAddress;
    use super::UserPassAuth;
    use serde::{Deserialize, Serialize};

    #[derive(Debug, Serialize, Deserialize, PartialEq)]
    #[serde(rename_all = "snake_case")]
    pub enum Inbound {
        Http {
            address: SocketAddress,
            auth: Option<Authentication<UserPassAuth>>,
        },
        Socks {
            address: SocketAddress,
            auth: Option<Authentication<UserPassAuth>>,
        },
    }
}

#[cfg(test)]
mod tests {
    use super::SocketAddress;
    use std::{net::SocketAddr, str::FromStr};

    #[test]
    fn socket_addr_can_be_parsed_and_printed() {
        let cases = [
            (
                "myserver.com:80",
                Ok(SocketAddress::Hostname {
                    hostname: "myserver.com".into(),
                    port: 80,
                }),
            ),
            (
                "localhost:80",
                Ok(SocketAddress::Hostname {
                    hostname: "localhost".into(),
                    port: 80,
                }),
            ),
            (
                "localhost",
                Err("expected `hostname:port`, but port not found".into()),
            ),
            ("10:localhost", Err("invalid socket address syntax".into())),
            ("10.10.20.30", Err("invalid socket address syntax".into())),
            (
                "10.10.20.30:1020",
                Ok(SocketAddress::Ip(
                    SocketAddr::from_str("10.10.20.30:1020").unwrap(),
                )),
            ),
        ];

        for (input, expected) in cases {
            assert_eq!(expected, SocketAddress::from_str(input));
            if let Ok(addr) = expected {
                assert_eq!(addr.to_string(), input);
            }
        }
    }
}

impl Serialize for SocketAddress {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        self.to_string().serialize(serializer)
    }
}

impl<'de> Deserialize<'de> for SocketAddress {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        struct SocketAddressVisitor;

        impl<'de> serde::de::Visitor<'de> for SocketAddressVisitor {
            type Value = SocketAddress;

            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                formatter.write_str("ip:port or hostname:port")
            }

            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
            where
                E: serde::de::Error,
            {
                Self::Value::from_str(v).map_err(E::custom)
            }
        }
        deserializer.deserialize_str(SocketAddressVisitor)
    }
}

impl std::fmt::Display for SocketAddress {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            SocketAddress::Hostname { hostname, port } => {
                f.write_fmt(format_args!("{hostname}:{port}"))
            }
            SocketAddress::Ip(socket_addr) => socket_addr.fmt(f),
        }
    }
}

impl FromStr for SocketAddress {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        if s.is_empty() {
            return Err("empty socket address".into());
        }

        if s.chars().next().unwrap().is_alphabetic() {
            let Some((hostname, port)) = s.split_once(':') else {
                return Err("expected `hostname:port`, but port not found".into());
            };
            let Ok(port) = port.parse::<u16>() else {
                return Err("invalid port".into());
            };
            Ok(Self::Hostname {
                hostname: hostname.to_string(),
                port,
            })
        } else {
            SocketAddr::from_str(&s)
                .map_err(|e| e.to_string())
                .map(Self::Ip)
        }
    }
}