use std::collections::HashMap;
use std::path::Path;
use config::ConfigError;
use serde::Deserialize;
use serde::Serialize;
use serde_repr::Deserialize_repr;
use serde_repr::Serialize_repr;
use self::inbound::*;
use self::outbound::*;
use core::fmt;
use std::net::SocketAddr;
use std::str::FromStr;
impl Config {
pub fn read(path: &Path) -> Result<Self, ConfigError> {
config::Config::builder()
.add_source(config::File::with_name(&path.display().to_string()))
.add_source(config::Environment::with_prefix("NEORAY"))
.build()?
.try_deserialize()
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct Config {
pub version: Version,
pub flows: HashMap<String, Flow>,
}
#[derive(Debug, Clone, Serialize_repr, Deserialize_repr, PartialEq)]
#[repr(u32)]
pub enum Version {
V1 = 1,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct Flow {
pub outbound_pools: Vec<OutboundPool>,
pub inbounds: Vec<Inbound>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum Authentication<T> {
Const(Vec<T>),
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct UserPassAuth {
pub username: Option<String>,
pub password: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SocketAddress {
Hostname { hostname: String, port: u16 },
Ip(SocketAddr),
}
pub mod outbound {
use super::Authentication;
use super::SocketAddress;
use super::UserPassAuth;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct OutboundPool {
pub outbounds: Vec<Outbound>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum Outbound {
Http {
address: SocketAddress,
auth: Option<Authentication<UserPassAuth>>,
},
Socks {
address: SocketAddress,
auth: Option<Authentication<UserPassAuth>>,
},
}
}
pub mod inbound {
use super::Authentication;
use super::SocketAddress;
use super::UserPassAuth;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum Inbound {
Http {
address: SocketAddress,
auth: Option<Authentication<UserPassAuth>>,
},
Socks {
address: SocketAddress,
auth: Option<Authentication<UserPassAuth>>,
},
}
}
#[cfg(test)]
mod tests {
use super::SocketAddress;
use std::{net::SocketAddr, str::FromStr};
#[test]
fn socket_addr_can_be_parsed_and_printed() {
let cases = [
(
"myserver.com:80",
Ok(SocketAddress::Hostname {
hostname: "myserver.com".into(),
port: 80,
}),
),
(
"localhost:80",
Ok(SocketAddress::Hostname {
hostname: "localhost".into(),
port: 80,
}),
),
(
"localhost",
Err("expected `hostname:port`, but port not found".into()),
),
("10:localhost", Err("invalid socket address syntax".into())),
("10.10.20.30", Err("invalid socket address syntax".into())),
(
"10.10.20.30:1020",
Ok(SocketAddress::Ip(
SocketAddr::from_str("10.10.20.30:1020").unwrap(),
)),
),
];
for (input, expected) in cases {
assert_eq!(expected, SocketAddress::from_str(input));
if let Ok(addr) = expected {
assert_eq!(addr.to_string(), input);
}
}
}
}
impl Serialize for SocketAddress {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.to_string().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for SocketAddress {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct SocketAddressVisitor;
impl<'de> serde::de::Visitor<'de> for SocketAddressVisitor {
type Value = SocketAddress;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("ip:port or hostname:port")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Self::Value::from_str(v).map_err(E::custom)
}
}
deserializer.deserialize_str(SocketAddressVisitor)
}
}
impl std::fmt::Display for SocketAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SocketAddress::Hostname { hostname, port } => {
f.write_fmt(format_args!("{hostname}:{port}"))
}
SocketAddress::Ip(socket_addr) => socket_addr.fmt(f),
}
}
}
impl FromStr for SocketAddress {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.is_empty() {
return Err("empty socket address".into());
}
if s.chars().next().unwrap().is_alphabetic() {
let Some((hostname, port)) = s.split_once(':') else {
return Err("expected `hostname:port`, but port not found".into());
};
let Ok(port) = port.parse::<u16>() else {
return Err("invalid port".into());
};
Ok(Self::Hostname {
hostname: hostname.to_string(),
port,
})
} else {
SocketAddr::from_str(&s)
.map_err(|e| e.to_string())
.map(Self::Ip)
}
}
}