use crate::utils::{CongestionControl, UdpRelayMode};
use humantime::Duration as HumanDuration;
use lexopt::{Arg, Error as ArgumentError, Parser};
use log::LevelFilter;
use serde::{de::Error as DeError, Deserialize, Deserializer};
use serde_json::Error as SerdeError;
use std::{
env::ArgsOs,
fmt::Display,
fs::File,
io::Error as IoError,
net::{IpAddr, SocketAddr},
path::PathBuf,
str::FromStr,
sync::Arc,
time::Duration,
};
use thiserror::Error;
use uuid::Uuid;
const HELP_MSG: &str = r#"
Usage tuic-client [arguments]
Arguments:
-c, --config <path> Path to the config file (required)
-v, --version Print the version
-h, --help Print this help message
"#;
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Config {
pub relay: Relay,
pub local: Local,
#[serde(default = "default::log_level")]
pub log_level: LevelFilter,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Relay {
#[serde(deserialize_with = "deserialize_server")]
pub server: (String, u16),
pub uuid: Uuid,
#[serde(deserialize_with = "deserialize_password")]
pub password: Arc<[u8]>,
pub ip: Option<IpAddr>,
#[serde(default = "default::relay::certificates")]
pub certificates: Vec<PathBuf>,
#[serde(
default = "default::relay::udp_relay_mode",
deserialize_with = "deserialize_from_str"
)]
pub udp_relay_mode: UdpRelayMode,
#[serde(
default = "default::relay::congestion_control",
deserialize_with = "deserialize_from_str"
)]
pub congestion_control: CongestionControl,
#[serde(
default = "default::relay::alpn",
deserialize_with = "deserialize_alpn"
)]
pub alpn: Vec<Vec<u8>>,
#[serde(default = "default::relay::zero_rtt_handshake")]
pub zero_rtt_handshake: bool,
#[serde(default = "default::relay::disable_sni")]
pub disable_sni: bool,
#[serde(
default = "default::relay::timeout",
deserialize_with = "deserialize_duration"
)]
pub timeout: Duration,
#[serde(
default = "default::relay::heartbeat",
deserialize_with = "deserialize_duration"
)]
pub heartbeat: Duration,
#[serde(default = "default::relay::disable_native_certs")]
pub disable_native_certs: bool,
#[serde(default = "default::relay::send_window")]
pub send_window: u64,
#[serde(default = "default::relay::receive_window")]
pub receive_window: u32,
#[serde(
default = "default::relay::gc_interval",
deserialize_with = "deserialize_duration"
)]
pub gc_interval: Duration,
#[serde(
default = "default::relay::gc_lifetime",
deserialize_with = "deserialize_duration"
)]
pub gc_lifetime: Duration,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Local {
pub server: SocketAddr,
#[serde(deserialize_with = "deserialize_optional_bytes", default)]
pub username: Option<Vec<u8>>,
#[serde(deserialize_with = "deserialize_optional_bytes", default)]
pub password: Option<Vec<u8>>,
pub dual_stack: Option<bool>,
#[serde(default = "default::local::max_packet_size")]
pub max_packet_size: usize,
}
impl Config {
pub fn parse(args: ArgsOs) -> Result<Self, ConfigError> {
let mut parser = Parser::from_iter(args);
let mut path = None;
while let Some(arg) = parser.next()? {
match arg {
Arg::Short('c') | Arg::Long("config") => {
if path.is_none() {
path = Some(parser.value()?);
} else {
return Err(ConfigError::Argument(arg.unexpected()));
}
}
Arg::Short('v') | Arg::Long("version") => {
return Err(ConfigError::Version(env!("CARGO_PKG_VERSION")))
}
Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(HELP_MSG)),
_ => return Err(ConfigError::Argument(arg.unexpected())),
}
}
if path.is_none() {
return Err(ConfigError::NoConfig);
}
let file = File::open(path.unwrap())?;
Ok(serde_json::from_reader(file)?)
}
}
mod default {
use log::LevelFilter;
pub mod relay {
use crate::utils::{CongestionControl, UdpRelayMode};
use std::{path::PathBuf, time::Duration};
pub fn certificates() -> Vec<PathBuf> {
Vec::new()
}
pub fn udp_relay_mode() -> UdpRelayMode {
UdpRelayMode::Native
}
pub fn congestion_control() -> CongestionControl {
CongestionControl::Cubic
}
pub fn alpn() -> Vec<Vec<u8>> {
Vec::new()
}
pub fn zero_rtt_handshake() -> bool {
false
}
pub fn disable_sni() -> bool {
false
}
pub fn timeout() -> Duration {
Duration::from_secs(8)
}
pub fn heartbeat() -> Duration {
Duration::from_secs(3)
}
pub fn disable_native_certs() -> bool {
false
}
pub fn send_window() -> u64 {
8 * 1024 * 1024 * 2
}
pub fn receive_window() -> u32 {
8 * 1024 * 1024
}
pub fn gc_interval() -> Duration {
Duration::from_secs(3)
}
pub fn gc_lifetime() -> Duration {
Duration::from_secs(15)
}
}
pub mod local {
pub fn max_packet_size() -> usize {
1500
}
}
pub fn log_level() -> LevelFilter {
LevelFilter::Warn
}
}
pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: FromStr,
<T as FromStr>::Err: Display,
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
T::from_str(&s).map_err(DeError::custom)
}
pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error>
where
D: Deserializer<'de>,
{
let mut s = String::deserialize(deserializer)?;
let (domain, port) = s
.rsplit_once(':')
.ok_or(DeError::custom("invalid server address"))?;
let port = port.parse().map_err(DeError::custom)?;
s.truncate(domain.len());
Ok((s, port))
}
pub fn deserialize_password<'de, D>(deserializer: D) -> Result<Arc<[u8]>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(Arc::from(s.into_bytes().into_boxed_slice()))
}
pub fn deserialize_alpn<'de, D>(deserializer: D) -> Result<Vec<Vec<u8>>, D::Error>
where
D: Deserializer<'de>,
{
let s = Vec::<String>::deserialize(deserializer)?;
Ok(s.into_iter().map(|alpn| alpn.into_bytes()).collect())
}
pub fn deserialize_optional_bytes<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(Some(s.into_bytes()))
}
pub fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse::<HumanDuration>()
.map(|d| *d)
.map_err(DeError::custom)
}
#[derive(Debug, Error)]
pub enum ConfigError {
#[error(transparent)]
Argument(#[from] ArgumentError),
#[error("no config file specified")]
NoConfig,
#[error("{0}")]
Version(&'static str),
#[error("{0}")]
Help(&'static str),
#[error(transparent)]
Io(#[from] IoError),
#[error(transparent)]
Serde(#[from] SerdeError),
}