#[cfg(unix)]
use std::path::PathBuf;
use std::{
collections::HashMap,
fmt::{self, Debug, Display},
net::SocketAddr,
str::{self, FromStr},
sync::Arc,
time::Duration,
};
use base64::Engine as _;
use byte_string::ByteStr;
use bytes::Bytes;
use cfg_if::cfg_if;
use log::{error, warn};
use thiserror::Error;
use url::{self, Url};
#[cfg(any(feature = "stream-cipher", feature = "aead-cipher"))]
use crate::crypto::v1::openssl_bytes_to_key;
use crate::{crypto::CipherKind, plugin::PluginConfig, relay::socks5::Address};
const USER_KEY_BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
&base64::alphabet::STANDARD,
base64::engine::GeneralPurposeConfig::new()
.with_encode_padding(true)
.with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent),
);
#[cfg(feature = "aead-cipher-2022")]
const AEAD2022_PASSWORD_BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
&base64::alphabet::STANDARD,
base64::engine::GeneralPurposeConfig::new()
.with_encode_padding(true)
.with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent),
);
const URL_PASSWORD_BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
&base64::alphabet::URL_SAFE,
base64::engine::GeneralPurposeConfig::new()
.with_encode_padding(false)
.with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent),
);
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum ServerType {
Local,
Server,
}
impl ServerType {
pub fn is_local(self) -> bool {
self == Self::Local
}
pub fn is_server(self) -> bool {
self == Self::Server
}
}
#[derive(Clone, Copy, Debug)]
pub enum Mode {
TcpOnly = 0x01,
TcpAndUdp = 0x03,
UdpOnly = 0x02,
}
impl Mode {
pub fn enable_udp(self) -> bool {
matches!(self, Self::UdpOnly | Self::TcpAndUdp)
}
pub fn enable_tcp(self) -> bool {
matches!(self, Self::TcpOnly | Self::TcpAndUdp)
}
pub fn merge(&self, mode: Self) -> Self {
let me = *self as u8;
let fm = mode as u8;
match me | fm {
0x01 => Self::TcpOnly,
0x02 => Self::UdpOnly,
0x03 => Self::TcpAndUdp,
_ => unreachable!(),
}
}
pub fn as_str(&self) -> &'static str {
match *self {
Self::TcpOnly => "tcp_only",
Self::TcpAndUdp => "tcp_and_udp",
Self::UdpOnly => "udp_only",
}
}
}
impl fmt::Display for Mode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for Mode {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"tcp_only" => Ok(Self::TcpOnly),
"tcp_and_udp" => Ok(Self::TcpAndUdp),
"udp_only" => Ok(Self::UdpOnly),
_ => Err(()),
}
}
}
struct ModeVisitor;
impl serde::de::Visitor<'_> for ModeVisitor {
type Value = Mode;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("Mode")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match v.parse::<Mode>() {
Ok(m) => Ok(m),
Err(_) => Err(serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &self)),
}
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
self.visit_str::<E>(v.as_str())
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match str::from_utf8(v) {
Ok(v) => self.visit_str(v),
Err(_) => Err(serde::de::Error::invalid_value(serde::de::Unexpected::Bytes(v), &self)),
}
}
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match String::from_utf8(v) {
Ok(v) => self.visit_string(v),
Err(e) => Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Bytes(&e.into_bytes()),
&self,
)),
}
}
}
impl<'de> serde::Deserialize<'de> for Mode {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_string(ModeVisitor)
}
}
impl serde::Serialize for Mode {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.as_str())
}
}
#[derive(Debug, Clone)]
pub struct ServerWeight {
tcp_weight: f32,
udp_weight: f32,
}
impl Default for ServerWeight {
fn default() -> Self {
Self::new()
}
}
impl ServerWeight {
pub fn new() -> Self {
Self {
tcp_weight: 1.0,
udp_weight: 1.0,
}
}
pub fn tcp_weight(&self) -> f32 {
self.tcp_weight
}
pub fn set_tcp_weight(&mut self, weight: f32) {
assert!((0.0..=1.0).contains(&weight));
self.tcp_weight = weight;
}
pub fn udp_weight(&self) -> f32 {
self.udp_weight
}
pub fn set_udp_weight(&mut self, weight: f32) {
assert!((0.0..=1.0).contains(&weight));
self.udp_weight = weight;
}
}
#[derive(Clone)]
pub struct ServerUser {
name: String,
key: Bytes,
identity_hash: Bytes,
}
impl Debug for ServerUser {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ServerUser")
.field("name", &self.name)
.field("key", &USER_KEY_BASE64_ENGINE.encode(&self.key))
.field("identity_hash", &ByteStr::new(&self.identity_hash))
.finish()
}
}
impl ServerUser {
pub fn new<N, K>(name: N, key: K) -> Self
where
N: Into<String>,
K: Into<Bytes>,
{
let name = name.into();
let key = key.into();
let hash = blake3::hash(&key);
let identity_hash = Bytes::from(hash.as_bytes()[0..16].to_owned());
Self {
name,
key,
identity_hash,
}
}
pub fn with_encoded_key<N>(name: N, key: &str) -> Result<Self, ServerUserError>
where
N: Into<String>,
{
let key = USER_KEY_BASE64_ENGINE.decode(key)?;
Ok(Self::new(name, key))
}
pub fn name(&self) -> &str {
self.name.as_str()
}
pub fn key(&self) -> &[u8] {
self.key.as_ref()
}
pub fn encoded_key(&self) -> String {
USER_KEY_BASE64_ENGINE.encode(&self.key)
}
pub fn identity_hash(&self) -> &[u8] {
self.identity_hash.as_ref()
}
pub fn clone_identity_hash(&self) -> Bytes {
self.identity_hash.clone()
}
}
#[derive(Debug, Clone, Error)]
pub enum ServerUserError {
#[error("{0}")]
InvalidKeyEncoding(#[from] base64::DecodeError),
}
#[derive(Clone, Debug)]
pub struct ServerUserManager {
users: HashMap<Bytes, Arc<ServerUser>>,
}
impl ServerUserManager {
pub fn new() -> Self {
Self { users: HashMap::new() }
}
pub fn add_user(&mut self, user: ServerUser) {
self.users.insert(user.clone_identity_hash(), Arc::new(user));
}
pub fn get_user_by_hash(&self, user_hash: &[u8]) -> Option<&ServerUser> {
self.users.get(user_hash).map(AsRef::as_ref)
}
pub fn clone_user_by_hash(&self, user_hash: &[u8]) -> Option<Arc<ServerUser>> {
self.users.get(user_hash).cloned()
}
pub fn user_count(&self) -> usize {
self.users.len()
}
pub fn users_iter(&self) -> impl Iterator<Item = &ServerUser> {
self.users.values().map(|v| v.as_ref())
}
}
impl Default for ServerUserManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum ServerSource {
Default, Configuration, CommandLine, OnlineConfig, }
#[derive(Debug, Clone, Error)]
pub enum ServerConfigError {
#[error("invalid key encoding for {0}, {1}")]
InvalidKeyEncoding(CipherKind, base64::DecodeError),
#[error("invalid iPSK encoding for {0}, {1}")]
InvalidUserKeyEncoding(CipherKind, base64::DecodeError),
#[error("invalid key length for {0}, expecting {1} bytes, but found {2} bytes")]
InvalidKeyLength(CipherKind, usize, usize),
#[error("invalid user key length for {0}, expecting {1} bytes, but found {2} bytes")]
InvalidUserKeyLength(CipherKind, usize, usize),
}
#[derive(Clone, Debug)]
pub struct ServerConfig {
addr: ServerAddr,
password: String,
method: CipherKind,
enc_key: Box<[u8]>,
timeout: Option<Duration>,
identity_keys: Arc<Vec<Bytes>>,
user_manager: Option<Arc<ServerUserManager>>,
plugin: Option<PluginConfig>,
plugin_addr: Option<ServerAddr>,
remarks: Option<String>,
id: Option<String>,
mode: Mode,
weight: ServerWeight,
source: ServerSource,
}
#[inline]
fn make_derived_key(method: CipherKind, password: &str, enc_key: &mut [u8]) -> Result<(), ServerConfigError> {
#[cfg(feature = "aead-cipher-2022")]
if method.is_aead_2022() {
match AEAD2022_PASSWORD_BASE64_ENGINE.decode(password) {
Ok(v) => {
if v.len() != enc_key.len() {
return Err(ServerConfigError::InvalidKeyLength(method, enc_key.len(), v.len()));
}
enc_key.copy_from_slice(&v);
}
Err(err) => {
return Err(ServerConfigError::InvalidKeyEncoding(method, err));
}
}
return Ok(());
}
cfg_if! {
if #[cfg(any(feature = "stream-cipher", feature = "aead-cipher"))] {
let _ = method;
openssl_bytes_to_key(password.as_bytes(), enc_key);
Ok(())
} else {
let _ = password;
let _ = enc_key;
unreachable!("{method} don't know how to make a derived key");
}
}
}
#[cfg(feature = "aead-cipher-2022")]
#[inline]
pub fn method_support_eih(method: CipherKind) -> bool {
matches!(
method,
CipherKind::AEAD2022_BLAKE3_AES_128_GCM | CipherKind::AEAD2022_BLAKE3_AES_256_GCM
)
}
#[allow(clippy::type_complexity)]
fn password_to_keys<P>(method: CipherKind, password: P) -> Result<(String, Box<[u8]>, Vec<Bytes>), ServerConfigError>
where
P: Into<String>,
{
let password = password.into();
match method {
CipherKind::NONE => {
debug_assert_eq!(method.key_len(), 0);
if !password.is_empty() {
warn!(
"method \"none\" doesn't need a password, which should be set as an empty String, but password.len() = {}",
password.len()
);
}
return Ok((password, Vec::new().into_boxed_slice(), Vec::new()));
}
#[cfg(feature = "stream-cipher")]
CipherKind::SS_TABLE => {
let enc_key = password.clone().into_bytes().into_boxed_slice();
return Ok((password, enc_key, Vec::new()));
}
#[allow(unreachable_patterns)]
_ => {}
}
#[cfg(feature = "aead-cipher-2022")]
if method_support_eih(method) {
let mut identity_keys = Vec::new();
let mut split_iter = password.rsplit(':');
let upsk = split_iter.next().expect("uPSK");
let mut enc_key = vec![0u8; method.key_len()].into_boxed_slice();
make_derived_key(method, upsk, &mut enc_key)?;
for ipsk in split_iter {
match USER_KEY_BASE64_ENGINE.decode(ipsk) {
Ok(v) => {
match method {
CipherKind::AEAD2022_BLAKE3_AES_128_GCM => {
if v.len() != 16 {
return Err(ServerConfigError::InvalidUserKeyLength(method, 16, v.len()));
}
}
CipherKind::AEAD2022_BLAKE3_AES_256_GCM => {
if v.len() != 32 {
return Err(ServerConfigError::InvalidUserKeyLength(method, 32, v.len()));
}
}
_ => unreachable!("{} doesn't support EIH", method),
}
identity_keys.push(Bytes::from(v));
}
Err(err) => {
return Err(ServerConfigError::InvalidUserKeyEncoding(method, err));
}
}
}
identity_keys.reverse();
return Ok((upsk.to_owned(), enc_key, identity_keys));
}
let mut enc_key = vec![0u8; method.key_len()].into_boxed_slice();
make_derived_key(method, &password, &mut enc_key)?;
Ok((password, enc_key, Vec::new()))
}
impl ServerConfig {
pub fn new<A, P>(addr: A, password: P, method: CipherKind) -> Result<Self, ServerConfigError>
where
A: Into<ServerAddr>,
P: Into<String>,
{
let (password, enc_key, identity_keys) = password_to_keys(method, password)?;
Ok(Self {
addr: addr.into(),
password,
method,
enc_key,
identity_keys: Arc::new(identity_keys),
user_manager: None,
timeout: None,
plugin: None,
plugin_addr: None,
remarks: None,
id: None,
mode: Mode::TcpAndUdp, weight: ServerWeight::new(),
source: ServerSource::Default,
})
}
pub fn set_method<P>(&mut self, method: CipherKind, password: P) -> Result<(), ServerConfigError>
where
P: Into<String>,
{
self.method = method;
let (password, enc_key, identity_keys) = password_to_keys(method, password)?;
self.password = password;
self.enc_key = enc_key;
self.identity_keys = Arc::new(identity_keys);
Ok(())
}
pub fn set_plugin(&mut self, p: PluginConfig) {
self.plugin = Some(p);
}
pub fn set_addr<A>(&mut self, a: A)
where
A: Into<ServerAddr>,
{
self.addr = a.into();
}
pub fn addr(&self) -> &ServerAddr {
&self.addr
}
pub fn key(&self) -> &[u8] {
self.enc_key.as_ref()
}
pub fn password(&self) -> &str {
self.password.as_str()
}
pub fn identity_keys(&self) -> &[Bytes] {
&self.identity_keys
}
pub fn clone_identity_keys(&self) -> Arc<Vec<Bytes>> {
self.identity_keys.clone()
}
pub fn set_user_manager(&mut self, user_manager: ServerUserManager) {
self.user_manager = Some(Arc::new(user_manager));
}
pub fn user_manager(&self) -> Option<&ServerUserManager> {
self.user_manager.as_deref()
}
pub fn clone_user_manager(&self) -> Option<Arc<ServerUserManager>> {
self.user_manager.clone()
}
pub fn method(&self) -> CipherKind {
self.method
}
pub fn plugin(&self) -> Option<&PluginConfig> {
self.plugin.as_ref()
}
pub fn set_plugin_addr(&mut self, a: ServerAddr) {
self.plugin_addr = Some(a);
}
pub fn plugin_addr(&self) -> Option<&ServerAddr> {
self.plugin_addr.as_ref()
}
pub fn tcp_external_addr(&self) -> &ServerAddr {
if let Some(plugin) = self.plugin()
&& plugin.plugin_mode.enable_tcp() {
return self.plugin_addr.as_ref().unwrap_or(&self.addr);
}
&self.addr
}
pub fn udp_external_addr(&self) -> &ServerAddr {
if let Some(plugin) = self.plugin()
&& plugin.plugin_mode.enable_udp() {
return self.plugin_addr.as_ref().unwrap_or(&self.addr);
}
&self.addr
}
pub fn set_timeout(&mut self, timeout: Duration) {
self.timeout = Some(timeout);
}
pub fn timeout(&self) -> Option<Duration> {
self.timeout
}
pub fn remarks(&self) -> Option<&str> {
self.remarks.as_ref().map(AsRef::as_ref)
}
pub fn set_remarks<S>(&mut self, remarks: S)
where
S: Into<String>,
{
self.remarks = Some(remarks.into());
}
pub fn id(&self) -> Option<&str> {
self.id.as_ref().map(AsRef::as_ref)
}
pub fn set_id<S>(&mut self, id: S)
where
S: Into<String>,
{
self.id = Some(id.into())
}
pub fn mode(&self) -> Mode {
self.mode
}
pub fn set_mode(&mut self, mode: Mode) {
self.mode = mode;
}
pub fn weight(&self) -> &ServerWeight {
&self.weight
}
pub fn set_weight(&mut self, weight: ServerWeight) {
self.weight = weight;
}
pub fn source(&self) -> ServerSource {
self.source
}
pub fn set_source(&mut self, source: ServerSource) {
self.source = source;
}
pub fn to_qrcode_url(&self) -> String {
let param = format!("{}:{}@{}", self.method(), self.password(), self.addr());
format!("ss://{}", URL_PASSWORD_BASE64_ENGINE.encode(param))
}
pub fn to_url(&self) -> String {
cfg_if! {
if #[cfg(feature = "aead-cipher-2022")] {
let user_info = if !self.method().is_aead_2022() {
let user_info = format!("{}:{}", self.method(), self.password());
URL_PASSWORD_BASE64_ENGINE.encode(user_info)
} else {
format!("{}:{}", self.method(), percent_encoding::utf8_percent_encode(self.password(), percent_encoding::NON_ALPHANUMERIC))
};
} else {
let mut user_info = format!("{}:{}", self.method(), self.password());
user_info = URL_PASSWORD_BASE64_ENGINE.encode(&user_info)
}
}
let mut url = format!("ss://{}@{}", user_info, self.addr());
if let Some(c) = self.plugin() {
let mut plugin = c.plugin.clone();
if let Some(ref opt) = c.plugin_opts {
plugin += ";";
plugin += opt;
}
url += "/?plugin=";
for c in percent_encoding::utf8_percent_encode(&plugin, percent_encoding::NON_ALPHANUMERIC) {
url.push_str(c);
}
}
if let Some(remark) = self.remarks() {
url += "#";
for c in percent_encoding::utf8_percent_encode(remark, percent_encoding::NON_ALPHANUMERIC) {
url.push_str(c);
}
}
url
}
pub fn from_url(encoded: &str) -> Result<Self, UrlParseError> {
let parsed = Url::parse(encoded).map_err(UrlParseError::from)?;
if parsed.scheme() != "ss" {
return Err(UrlParseError::InvalidScheme);
}
let user_info = parsed.username();
if user_info.is_empty() {
let encoded = match parsed.host_str() {
Some(e) => e,
None => return Err(UrlParseError::MissingHost),
};
let mut decoded_body = match URL_PASSWORD_BASE64_ENGINE.decode(encoded) {
Ok(b) => match String::from_utf8(b) {
Ok(b) => b,
Err(..) => return Err(UrlParseError::InvalidServerAddr),
},
Err(err) => {
error!("failed to parse legacy ss://ENCODED with Base64, err: {}", err);
return Err(UrlParseError::InvalidServerAddr);
}
};
decoded_body.insert_str(0, "ss://");
return Self::from_url(&decoded_body);
}
let (method, pwd) = match parsed.password() {
Some(password) => {
let m = match percent_encoding::percent_decode_str(user_info).decode_utf8() {
Ok(m) => m,
Err(err) => {
error!("failed to parse percent-encoded method in userinfo, err: {}", err);
return Err(UrlParseError::InvalidAuthInfo);
}
};
let p = match percent_encoding::percent_decode_str(password).decode_utf8() {
Ok(m) => m,
Err(err) => {
error!("failed to parse percent-encoded password in userinfo, err: {}", err);
return Err(UrlParseError::InvalidAuthInfo);
}
};
(m, p)
}
None => {
let decoded_user_info = match percent_encoding::percent_decode_str(user_info).decode_utf8() {
Ok(m) => m,
Err(err) => {
error!("failed to parse percent-encoded userinfo, err: {}", err);
return Err(UrlParseError::InvalidAuthInfo);
}
};
let decoded_user_info: &str = &decoded_user_info;
let account = match URL_PASSWORD_BASE64_ENGINE.decode(decoded_user_info) {
Ok(account) => match String::from_utf8(account) {
Ok(ac) => ac,
Err(..) => return Err(UrlParseError::InvalidAuthInfo),
},
Err(err) => {
error!("failed to parse UserInfo with Base64, err: {}", err);
return Err(UrlParseError::InvalidUserInfo);
}
};
let mut sp2 = account.splitn(2, ':');
let (m, p) = match (sp2.next(), sp2.next()) {
(Some(m), Some(p)) => (m, p),
_ => return Err(UrlParseError::InvalidUserInfo),
};
(m.to_owned().into(), p.to_owned().into())
}
};
let host = match parsed.host_str() {
Some(host) => host,
None => return Err(UrlParseError::MissingHost),
};
let port = parsed.port().unwrap_or(8388);
let addr = format!("{host}:{port}");
let addr = match addr.parse::<ServerAddr>() {
Ok(a) => a,
Err(err) => {
error!("failed to parse \"{}\" to ServerAddr, err: {:?}", addr, err);
return Err(UrlParseError::InvalidServerAddr);
}
};
let method = match method.parse::<CipherKind>() {
Ok(m) => m,
Err(err) => {
error!("failed to parse \"{}\" to CipherKind, err: {:?}", method, err);
return Err(UrlParseError::InvalidMethod);
}
};
let mut svrconfig = Self::new(addr, pwd, method)?;
if let Some(q) = parsed.query() {
let query = match serde_urlencoded::from_bytes::<Vec<(String, String)>>(q.as_bytes()) {
Ok(q) => q,
Err(err) => {
error!("failed to parse QueryString, err: {}", err);
return Err(UrlParseError::InvalidQueryString);
}
};
for (key, value) in query {
if key != "plugin" {
continue;
}
let mut vsp = value.splitn(2, ';');
match vsp.next() {
None => {}
Some(p) => {
let plugin = PluginConfig {
plugin: p.to_owned(),
plugin_opts: vsp.next().map(ToOwned::to_owned),
plugin_args: Vec::new(), plugin_mode: Mode::TcpOnly, };
svrconfig.set_plugin(plugin);
}
}
}
}
if let Some(frag) = parsed.fragment() {
match percent_encoding::percent_decode_str(frag).decode_utf8() {
Ok(m) => svrconfig.set_remarks(m),
Err(..) => svrconfig.set_remarks(frag),
}
}
Ok(svrconfig)
}
pub fn is_basic(&self) -> bool {
self.remarks.is_none() && self.id.is_none()
}
}
#[derive(Debug, Clone, Error)]
pub enum UrlParseError {
#[error("{0}")]
ParseError(#[from] url::ParseError),
#[error("URL must have \"ss://\" scheme")]
InvalidScheme,
#[error("unknown encryption method")]
InvalidMethod,
#[error("invalid user info")]
InvalidUserInfo,
#[error("missing host")]
MissingHost,
#[error("invalid authentication info")]
InvalidAuthInfo,
#[error("invalid server address")]
InvalidServerAddr,
#[error("invalid query string")]
InvalidQueryString,
#[error("{0}")]
ServerConfigError(#[from] ServerConfigError),
}
impl FromStr for ServerConfig {
type Err = UrlParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_url(s)
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ServerAddr {
SocketAddr(SocketAddr),
DomainName(String, u16),
}
impl ServerAddr {
pub fn host(&self) -> String {
match *self {
Self::SocketAddr(ref s) => s.ip().to_string(),
Self::DomainName(ref dm, _) => dm.clone(),
}
}
pub fn port(&self) -> u16 {
match *self {
Self::SocketAddr(ref s) => s.port(),
Self::DomainName(_, p) => p,
}
}
}
#[derive(Debug)]
pub struct ServerAddrError;
impl Display for ServerAddrError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("invalid ServerAddr")
}
}
impl FromStr for ServerAddr {
type Err = ServerAddrError;
fn from_str(s: &str) -> Result<Self, ServerAddrError> {
match s.parse::<SocketAddr>() {
Ok(addr) => Ok(Self::SocketAddr(addr)),
Err(..) => {
let mut sp = s.split(':');
match (sp.next(), sp.next()) {
(Some(dn), Some(port)) => {
if dn.is_empty() {
return Err(ServerAddrError);
}
match port.parse::<u16>() {
Ok(port) => Ok(Self::DomainName(dn.to_owned(), port)),
Err(..) => Err(ServerAddrError),
}
}
_ => Err(ServerAddrError),
}
}
}
}
}
impl Display for ServerAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::SocketAddr(ref a) => write!(f, "{a}"),
Self::DomainName(ref d, port) => write!(f, "{d}:{port}"),
}
}
}
struct ServerAddrVisitor;
impl serde::de::Visitor<'_> for ServerAddrVisitor {
type Value = ServerAddr;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("ServerAddr")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match v.parse::<ServerAddr>() {
Ok(m) => Ok(m),
Err(_) => Err(serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &self)),
}
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
self.visit_str::<E>(v.as_str())
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match str::from_utf8(v) {
Ok(v) => self.visit_str(v),
Err(_) => Err(serde::de::Error::invalid_value(serde::de::Unexpected::Bytes(v), &self)),
}
}
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match String::from_utf8(v) {
Ok(v) => self.visit_string(v),
Err(e) => Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Bytes(&e.into_bytes()),
&self,
)),
}
}
}
impl<'de> serde::Deserialize<'de> for ServerAddr {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_string(ServerAddrVisitor)
}
}
impl serde::Serialize for ServerAddr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.to_string().as_str())
}
}
impl From<SocketAddr> for ServerAddr {
fn from(addr: SocketAddr) -> Self {
Self::SocketAddr(addr)
}
}
impl<I: Into<String>> From<(I, u16)> for ServerAddr {
fn from((dname, port): (I, u16)) -> Self {
Self::DomainName(dname.into(), port)
}
}
impl From<Address> for ServerAddr {
fn from(addr: Address) -> Self {
match addr {
Address::SocketAddress(sa) => Self::SocketAddr(sa),
Address::DomainNameAddress(dn, port) => Self::DomainName(dn, port),
}
}
}
impl From<&Address> for ServerAddr {
fn from(addr: &Address) -> Self {
match *addr {
Address::SocketAddress(sa) => Self::SocketAddr(sa),
Address::DomainNameAddress(ref dn, port) => Self::DomainName(dn.clone(), port),
}
}
}
impl From<ServerAddr> for Address {
fn from(addr: ServerAddr) -> Self {
match addr {
ServerAddr::SocketAddr(sa) => Self::SocketAddress(sa),
ServerAddr::DomainName(dn, port) => Self::DomainNameAddress(dn, port),
}
}
}
impl From<&ServerAddr> for Address {
fn from(addr: &ServerAddr) -> Self {
match *addr {
ServerAddr::SocketAddr(sa) => Self::SocketAddress(sa),
ServerAddr::DomainName(ref dn, port) => Self::DomainNameAddress(dn.clone(), port),
}
}
}
#[derive(Debug, Clone)]
pub enum ManagerAddr {
SocketAddr(SocketAddr),
DomainName(String, u16),
#[cfg(unix)]
UnixSocketAddr(PathBuf),
}
#[derive(Debug)]
pub struct ManagerAddrError;
impl Display for ManagerAddrError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("invalid ManagerAddr")
}
}
impl FromStr for ManagerAddr {
type Err = ManagerAddrError;
fn from_str(s: &str) -> Result<Self, ManagerAddrError> {
match s.find(':') {
Some(pos) => {
match s.parse::<SocketAddr>() {
Ok(saddr) => Ok(Self::SocketAddr(saddr)),
Err(..) => {
let (sdomain, sport) = s.split_at(pos);
let (sdomain, sport) = (sdomain.trim(), sport[1..].trim());
match sport.parse::<u16>() {
Ok(port) => Ok(Self::DomainName(sdomain.to_owned(), port)),
Err(..) => Err(ManagerAddrError),
}
}
}
}
#[cfg(unix)]
None => {
Ok(Self::UnixSocketAddr(PathBuf::from(s)))
}
#[cfg(not(unix))]
None => Err(ManagerAddrError),
}
}
}
impl Display for ManagerAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::SocketAddr(ref saddr) => fmt::Display::fmt(saddr, f),
Self::DomainName(ref dname, port) => write!(f, "{dname}:{port}"),
#[cfg(unix)]
Self::UnixSocketAddr(ref path) => fmt::Display::fmt(&path.display(), f),
}
}
}
struct ManagerAddrVisitor;
impl serde::de::Visitor<'_> for ManagerAddrVisitor {
type Value = ManagerAddr;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("ManagerAddr")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match v.parse::<ManagerAddr>() {
Ok(m) => Ok(m),
Err(_) => Err(serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &self)),
}
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
self.visit_str::<E>(v.as_str())
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match str::from_utf8(v) {
Ok(v) => self.visit_str(v),
Err(_) => Err(serde::de::Error::invalid_value(serde::de::Unexpected::Bytes(v), &self)),
}
}
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match String::from_utf8(v) {
Ok(v) => self.visit_string(v),
Err(e) => Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Bytes(&e.into_bytes()),
&self,
)),
}
}
}
impl<'de> serde::Deserialize<'de> for ManagerAddr {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_string(ManagerAddrVisitor)
}
}
impl serde::Serialize for ManagerAddr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.to_string().as_str())
}
}
impl From<SocketAddr> for ManagerAddr {
fn from(addr: SocketAddr) -> Self {
Self::SocketAddr(addr)
}
}
impl<'a> From<(&'a str, u16)> for ManagerAddr {
fn from((dname, port): (&'a str, u16)) -> Self {
Self::DomainName(dname.to_owned(), port)
}
}
impl From<(String, u16)> for ManagerAddr {
fn from((dname, port): (String, u16)) -> Self {
Self::DomainName(dname, port)
}
}
#[cfg(unix)]
impl From<PathBuf> for ManagerAddr {
fn from(p: PathBuf) -> Self {
Self::UnixSocketAddr(p)
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
pub enum ReplayAttackPolicy {
#[default]
Default,
Ignore,
Detect,
Reject,
}
impl Display for ReplayAttackPolicy {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::Default => f.write_str("default"),
Self::Ignore => f.write_str("ignore"),
Self::Detect => f.write_str("detect"),
Self::Reject => f.write_str("reject"),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ReplayAttackPolicyError;
impl Display for ReplayAttackPolicyError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("invalid ReplayAttackPolicy")
}
}
impl FromStr for ReplayAttackPolicy {
type Err = ReplayAttackPolicyError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"default" => Ok(Self::Default),
"ignore" => Ok(Self::Ignore),
"detect" => Ok(Self::Detect),
"reject" => Ok(Self::Reject),
_ => Err(ReplayAttackPolicyError),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_server_config_from_url() {
let server_config = ServerConfig::from_url("ss://foo:bar@127.0.0.1:9999");
assert!(matches!(server_config, Err(UrlParseError::InvalidMethod)));
}
}