use std::fmt::{self, Debug, Display};
use std::net::SocketAddr;
use std::slice;
use std::str::FromStr;
use safelog::Redactable;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::HasAddrs;
#[derive(Debug, Clone, Default, Eq, PartialEq, Hash)]
pub struct TransportId(Inner);
#[derive(Debug, Clone, Eq, PartialEq, Hash, educe::Educe)]
#[educe(Default)]
enum Inner {
#[educe(Default)]
BuiltIn,
#[cfg(feature = "pt-client")]
Pluggable(PtTransportName),
}
#[derive(
Debug,
Clone,
Default,
Eq,
PartialEq,
Hash,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
)]
pub struct PtTransportName(String);
impl FromStr for PtTransportName {
type Err = TransportIdError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.to_string().try_into()
}
}
impl TryFrom<String> for PtTransportName {
type Error = TransportIdError;
fn try_from(s: String) -> Result<PtTransportName, Self::Error> {
if is_well_formed_id(&s) {
Ok(PtTransportName(s))
} else {
Err(TransportIdError::BadId(s))
}
}
}
impl Display for PtTransportName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Display::fmt(&self.0, f)
}
}
const BUILT_IN_IDS: &[&str] = &["-", "", "bridge", "<none>"];
impl FromStr for TransportId {
type Err = TransportIdError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if BUILT_IN_IDS.contains(&s) {
return Ok(TransportId(Inner::BuiltIn));
};
#[cfg(feature = "pt-client")]
{
let name: PtTransportName = s.parse()?;
Ok(TransportId(Inner::Pluggable(name)))
}
#[cfg(not(feature = "pt-client"))]
Err(TransportIdError::NoSupport)
}
}
impl Display for TransportId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0 {
Inner::BuiltIn => write!(f, "{}", BUILT_IN_IDS[0]),
#[cfg(feature = "pt-client")]
Inner::Pluggable(name) => write!(f, "{}", name),
}
}
}
#[cfg(feature = "pt-client")]
impl From<PtTransportName> for TransportId {
fn from(name: PtTransportName) -> Self {
TransportId(Inner::Pluggable(name))
}
}
fn is_well_formed_id(s: &str) -> bool {
let mut bytes = s.bytes();
if let Some(first) = bytes.next() {
(first.is_ascii_alphabetic() || first == b'_')
&& bytes.all(|b| b.is_ascii_alphanumeric() || b == b'_')
&& !s.eq_ignore_ascii_case("bridge")
} else {
false
}
}
#[derive(Clone, Debug, thiserror::Error)]
#[non_exhaustive]
pub enum TransportIdError {
#[error("Not compiled with pluggable transport support")]
NoSupport,
#[error("{0:?} is not a valid pluggable transport ID")]
BadId(String),
}
impl TransportId {
pub fn new_builtin() -> Self {
TransportId(Inner::BuiltIn)
}
#[cfg(feature = "pt-client")]
pub fn new_pluggable(pt: PtTransportName) -> Self {
pt.into()
}
pub fn is_builtin(&self) -> bool {
self.0 == Inner::BuiltIn
}
#[cfg(feature = "pt-client")]
pub fn as_pluggable(&self) -> Option<&PtTransportName> {
match &self.0 {
Inner::BuiltIn => None,
#[cfg(feature = "pt-client")]
Inner::Pluggable(pt) => Some(pt),
}
}
#[cfg(feature = "pt-client")]
pub fn into_pluggable(self) -> Option<PtTransportName> {
match self.0 {
Inner::BuiltIn => None,
#[cfg(feature = "pt-client")]
Inner::Pluggable(pt) => Some(pt),
}
}
}
const NONE_ADDR: &str = "-";
#[derive(
Clone, Debug, PartialEq, Eq, Hash, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
)]
#[non_exhaustive]
pub enum PtTargetAddr {
IpPort(SocketAddr),
HostPort(String, u16),
None,
}
#[derive(
Clone,
Debug,
PartialEq,
Eq,
Hash,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
derive_more::Display,
)]
pub struct BridgeAddr(BridgeAddrInner<SocketAddr, String>);
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum BridgeAddrInner<SA, HN> {
IpPort(SA),
HostPort(HN, u16),
}
impl BridgeAddr {
pub fn new_addr_from_sockaddr(sa: SocketAddr) -> Self {
BridgeAddr(BridgeAddrInner::IpPort(sa))
}
pub fn as_socketaddr(&self) -> Option<&SocketAddr> {
match &self.0 {
BridgeAddrInner::IpPort(sa) => Some(sa),
BridgeAddrInner::HostPort(..) => None,
}
}
pub fn new_named_host_port(hostname: impl Into<String>, port: u16) -> Self {
BridgeAddr(BridgeAddrInner::HostPort(hostname.into(), port))
}
pub fn as_host_port(&self) -> Option<(&str, u16)> {
match &self.0 {
BridgeAddrInner::IpPort(..) => None,
BridgeAddrInner::HostPort(hn, port) => Some((hn, *port)),
}
}
}
impl From<PtTargetAddr> for Option<BridgeAddr> {
fn from(pt: PtTargetAddr) -> Option<BridgeAddr> {
match pt {
PtTargetAddr::IpPort(sa) => Some(BridgeAddrInner::IpPort(sa)),
PtTargetAddr::HostPort(hn, p) => Some(BridgeAddrInner::HostPort(hn, p)),
PtTargetAddr::None => None,
}
.map(BridgeAddr)
}
}
impl From<Option<BridgeAddr>> for PtTargetAddr {
fn from(pt: Option<BridgeAddr>) -> PtTargetAddr {
match pt.map(|ba| ba.0) {
Some(BridgeAddrInner::IpPort(sa)) => PtTargetAddr::IpPort(sa),
Some(BridgeAddrInner::HostPort(hn, p)) => PtTargetAddr::HostPort(hn, p),
None => PtTargetAddr::None,
}
}
}
#[derive(Clone, Debug, thiserror::Error)]
#[non_exhaustive]
pub enum BridgeAddrError {
#[error("Not compiled with pluggable transport support.")]
NoSupport,
#[error("Cannot parse {0:?} as an address.")]
BadAddress(String),
}
impl FromStr for BridgeAddr {
type Err = BridgeAddrError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(BridgeAddr(if let Ok(addr) = s.parse() {
BridgeAddrInner::IpPort(addr)
} else if let Some((name, port)) = s.rsplit_once(':') {
let port = port
.parse()
.map_err(|_| BridgeAddrError::BadAddress(s.to_string()))?;
BridgeAddrInner::HostPort(name.to_string(), port)
} else {
return Err(BridgeAddrError::BadAddress(s.to_string()));
}))
}
}
impl FromStr for PtTargetAddr {
type Err = BridgeAddrError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(if s == NONE_ADDR {
PtTargetAddr::None
} else {
Some(BridgeAddr::from_str(s)?).into()
})
}
}
impl PtTargetAddr {
fn as_bridge_ref(&self) -> Option<BridgeAddrInner<&SocketAddr, &str>> {
match self {
PtTargetAddr::IpPort(addr) => Some(BridgeAddrInner::IpPort(addr)),
PtTargetAddr::HostPort(host, port) => Some(BridgeAddrInner::HostPort(host, *port)),
PtTargetAddr::None => None,
}
}
}
impl<SA: Display, HN: Display> Display for BridgeAddrInner<SA, HN> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BridgeAddrInner::IpPort(addr) => write!(f, "{}", addr),
BridgeAddrInner::HostPort(host, port) => write!(f, "{}:{}", host, port),
}
}
}
impl Display for PtTargetAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.as_bridge_ref() {
Some(b) => write!(f, "{}", b),
None => write!(f, "{}", NONE_ADDR),
}
}
}
impl<SA: Debug + Redactable, HN: Debug + Display + AsRef<str>> Redactable
for BridgeAddrInner<SA, HN>
{
fn display_redacted(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BridgeAddrInner::IpPort(a) => a.display_redacted(f),
BridgeAddrInner::HostPort(host, port) => write!(f, "{}…:{}", &host.as_ref()[..2], port),
}
}
}
impl Redactable for BridgeAddr {
fn display_redacted(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.display_redacted(f)
}
}
impl Redactable for PtTargetAddr {
fn display_redacted(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.as_bridge_ref() {
Some(b) => b.display_redacted(f),
None => write!(f, "{}", NONE_ADDR),
}
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(into = "Vec<(String, String)>", try_from = "Vec<(String, String)>")]
pub struct PtTargetSettings {
settings: Vec<(String, String)>,
}
impl PtTargetSettings {
fn check_setting(k: &str, v: &str) -> Result<(), PtTargetInvalidSetting> {
if k.find(|c: char| c == '=' || c.is_whitespace()).is_some() {
return Err(PtTargetInvalidSetting::Key(k.to_string()));
}
if v.find(|c: char| c.is_whitespace()).is_some() {
return Err(PtTargetInvalidSetting::Value(v.to_string()));
}
Ok(())
}
fn push_setting(
&mut self,
k: impl Into<String>,
v: impl Into<String>,
) -> Result<(), PtTargetInvalidSetting> {
let k = k.into();
let v = v.into();
Self::check_setting(&k, &v)?;
self.settings.push((k, v));
Ok(())
}
pub fn into_inner(self) -> Vec<(String, String)> {
self.settings
}
}
impl TryFrom<Vec<(String, String)>> for PtTargetSettings {
type Error = PtTargetInvalidSetting;
fn try_from(settings: Vec<(String, String)>) -> Result<Self, Self::Error> {
for (k, v) in settings.iter() {
Self::check_setting(k, v)?;
}
Ok(Self { settings })
}
}
impl From<PtTargetSettings> for Vec<(String, String)> {
fn from(settings: PtTargetSettings) -> Self {
settings.settings
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct PtTarget {
transport: PtTransportName,
addr: PtTargetAddr,
#[serde(default)]
settings: PtTargetSettings,
}
#[derive(Error, Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum PtTargetInvalidSetting {
#[error("key {0:?} has invalid or unsupported syntax")]
Key(String),
#[error("value {0:?} has invalid or unsupported syntax")]
Value(String),
}
impl PtTarget {
pub fn new(transport: PtTransportName, addr: PtTargetAddr) -> Self {
PtTarget {
transport,
addr,
settings: Default::default(),
}
}
pub fn push_setting(
&mut self,
k: impl Into<String>,
v: impl Into<String>,
) -> Result<(), PtTargetInvalidSetting> {
self.settings.push_setting(k, v)
}
pub fn transport(&self) -> &PtTransportName {
&self.transport
}
pub fn addr(&self) -> &PtTargetAddr {
&self.addr
}
pub fn settings(&self) -> impl Iterator<Item = (&str, &str)> {
self.settings.settings.iter().map(|(k, v)| (&**k, &**v))
}
pub fn socket_addrs(&self) -> Option<&[std::net::SocketAddr]> {
match self {
PtTarget {
addr: PtTargetAddr::IpPort(addr),
..
} => Some(std::slice::from_ref(addr)),
_ => None,
}
}
pub fn into_parts(self) -> (PtTransportName, PtTargetAddr, PtTargetSettings) {
(self.transport, self.addr, self.settings)
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
#[non_exhaustive]
pub enum ChannelMethod {
Direct(Vec<std::net::SocketAddr>),
#[cfg(feature = "pt-client")]
Pluggable(PtTarget),
}
impl ChannelMethod {
pub fn socket_addrs(&self) -> Option<&[std::net::SocketAddr]> {
match self {
ChannelMethod::Direct(addr) => Some(addr.as_ref()),
#[cfg(feature = "pt-client")]
ChannelMethod::Pluggable(t) => t.socket_addrs(),
}
}
pub fn target_addr(&self) -> Option<PtTargetAddr> {
match self {
ChannelMethod::Direct(addr) if !addr.is_empty() => Some(PtTargetAddr::IpPort(addr[0])),
#[cfg(feature = "pt-client")]
ChannelMethod::Pluggable(PtTarget { addr, .. }) => Some(addr.clone()),
_ => None,
}
}
pub fn is_direct(&self) -> bool {
matches!(self, ChannelMethod::Direct(_))
}
pub fn transport_id(&self) -> TransportId {
match self {
ChannelMethod::Direct(_) => TransportId::default(),
#[cfg(feature = "pt-client")]
ChannelMethod::Pluggable(target) => target.transport().clone().into(),
}
}
pub fn retain_addrs<P>(&mut self, pred: P) -> Result<(), RetainAddrsError>
where
P: Fn(&std::net::SocketAddr) -> bool,
{
#[cfg(feature = "pt-client")]
use PtTargetAddr as Pt;
match self {
ChannelMethod::Direct(d) if d.is_empty() => {}
ChannelMethod::Direct(d) => {
d.retain(pred);
if d.is_empty() {
return Err(RetainAddrsError::NoAddrsLeft);
}
}
#[cfg(feature = "pt-client")]
ChannelMethod::Pluggable(PtTarget { addr, .. }) => match addr {
Pt::IpPort(a) => {
if !pred(a) {
*addr = Pt::None;
return Err(RetainAddrsError::NoAddrsLeft);
}
}
Pt::HostPort(_, _) => {}
Pt::None => {}
},
}
Ok(())
}
pub fn contained_by(&self, other: &ChannelMethod) -> bool {
use ChannelMethod as CM;
match (self, other) {
(CM::Direct(our_addrs), CM::Direct(their_addrs)) => {
our_addrs.iter().all(|a| their_addrs.contains(a))
}
#[cfg(feature = "pt-client")]
(CM::Pluggable(our_target), CM::Pluggable(their_target)) => our_target == their_target,
#[cfg(feature = "pt-client")]
(_, _) => false,
}
}
}
#[derive(Clone, Debug, thiserror::Error)]
pub enum RetainAddrsError {
#[error("All addresses were removed.")]
NoAddrsLeft,
}
impl HasAddrs for PtTargetAddr {
fn addrs(&self) -> &[SocketAddr] {
match self {
PtTargetAddr::IpPort(sockaddr) => slice::from_ref(sockaddr),
PtTargetAddr::HostPort(..) | PtTargetAddr::None => &[],
}
}
}
impl HasAddrs for ChannelMethod {
fn addrs(&self) -> &[SocketAddr] {
match self {
ChannelMethod::Direct(addrs) => addrs,
#[cfg(feature = "pt-client")]
ChannelMethod::Pluggable(pt) => pt.addr.addrs(),
}
}
}
#[cfg(test)]
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn builtin() {
assert!(TransportId::default().is_builtin());
assert_eq!(
TransportId::default(),
"<none>".parse().expect("Couldn't parse default ID")
);
}
#[test]
#[cfg(not(feature = "pt-client"))]
fn nosupport() {
assert!(matches!(
TransportId::from_str("obfs4"),
Err(TransportIdError::NoSupport)
));
}
#[test]
#[cfg(feature = "pt-client")]
fn wellformed() {
for id in &["snowflake", "obfs4", "_ohai", "Z", "future_WORK2"] {
assert!(is_well_formed_id(id));
}
for id in &[" ", "Mölm", "12345", ""] {
assert!(!is_well_formed_id(id));
}
}
#[test]
#[cfg(feature = "pt-client")]
fn parsing() {
let obfs = TransportId::from_str("obfs4").unwrap();
let dflt = TransportId::default();
let dflt2 = TransportId::from_str("<none>").unwrap();
let dflt3 = TransportId::from_str("-").unwrap();
let dflt4 = TransportId::from_str("").unwrap();
let dflt5 = TransportId::from_str("bridge").unwrap();
let snow = TransportId::from_str("snowflake").unwrap();
let obfs_again = TransportId::from_str("obfs4").unwrap();
assert_eq!(obfs, obfs_again);
assert_eq!(dflt, dflt2);
assert_eq!(dflt, dflt3);
assert_eq!(dflt, dflt4);
assert_eq!(dflt, dflt5);
assert_ne!(snow, obfs);
assert_ne!(snow, dflt);
assert_eq!(dflt.to_string(), "-");
assert!(matches!(
TransportId::from_str("12345"),
Err(TransportIdError::BadId(_))
));
assert!(matches!(
TransportId::from_str("Bridge"),
Err(TransportIdError::BadId(_))
));
}
#[test]
fn addr() {
let chk_bridge_addr = |a: &PtTargetAddr, addr: &str| {
let ba: BridgeAddr = addr.parse().unwrap();
assert_eq!(&ba.to_string(), addr);
assert_eq!(&PtTargetAddr::from(Some(ba.clone())), a);
let reba: Option<BridgeAddr> = a.clone().into();
assert_eq!(reba.as_ref(), Some(&ba));
};
for addr in &["1.2.3.4:555", "[::1]:9999"] {
let a: PtTargetAddr = addr.parse().unwrap();
assert_eq!(&a.to_string(), addr);
let sa: SocketAddr = addr.parse().unwrap();
assert_eq!(a.addrs(), &[sa]);
chk_bridge_addr(&a, addr);
}
for addr in &["www.example.com:9100", "-"] {
let a: PtTargetAddr = addr.parse().unwrap();
assert_eq!(&a.to_string(), addr);
assert_eq!(a.addrs(), &[]);
if a == PtTargetAddr::None {
let e = BridgeAddr::from_str(addr).unwrap_err();
assert!(matches!(e, BridgeAddrError::BadAddress(_)));
} else {
chk_bridge_addr(&a, addr);
}
}
for addr in &["foobar", "<<<>>>"] {
let e = PtTargetAddr::from_str(addr).unwrap_err();
assert!(matches!(e, BridgeAddrError::BadAddress(_)));
let e = BridgeAddr::from_str(addr).unwrap_err();
assert!(matches!(e, BridgeAddrError::BadAddress(_)));
}
}
#[test]
fn transport_id() {
let id1: TransportId = "<none>".parse().unwrap();
assert!(id1.is_builtin());
assert_eq!(id1.to_string(), "-".to_string());
#[cfg(feature = "pt-client")]
{
let id2: TransportId = "obfs4".parse().unwrap();
assert_ne!(id2, id1);
assert!(!id2.is_builtin());
assert_eq!(id2.to_string(), "obfs4");
assert!(matches!(
TransportId::from_str("==="),
Err(TransportIdError::BadId(_))
));
}
#[cfg(not(feature = "pt-client"))]
{
assert!(matches!(
TransportId::from_str("obfs4"),
Err(TransportIdError::NoSupport)
))
}
}
#[test]
fn settings() {
let s = PtTargetSettings::try_from(vec![]).unwrap();
assert_eq!(Vec::<_>::from(s), vec![]);
let v = vec![("abc".into(), "def".into()), ("ghi".into(), "jkl".into())];
let s = PtTargetSettings::try_from(v.clone()).unwrap();
assert_eq!(Vec::<_>::from(s), v);
let v = vec![("a=b".into(), "def".into())];
let s = PtTargetSettings::try_from(v);
assert!(matches!(s, Err(PtTargetInvalidSetting::Key(_))));
let v = vec![("abc".into(), "d ef".into())];
let s = PtTargetSettings::try_from(v);
assert!(matches!(s, Err(PtTargetInvalidSetting::Value(_))));
}
#[test]
fn chanmethod_direct() {
let a1 = "127.0.0.1:8080".parse().unwrap();
let a2 = "127.0.0.2:8181".parse().unwrap();
let a3 = "127.0.0.3:8282".parse().unwrap();
let m = ChannelMethod::Direct(vec![a1, a2]);
assert_eq!(m.socket_addrs(), Some(&[a1, a2][..]));
assert_eq!((m.target_addr()), Some(PtTargetAddr::IpPort(a1)));
assert!(m.is_direct());
assert_eq!(m.transport_id(), TransportId::default());
let m2 = ChannelMethod::Direct(vec![a1, a2, a3]);
assert!(m.contained_by(&m));
assert!(m.contained_by(&m2));
assert!(!m2.contained_by(&m));
let mut m3 = m2.clone();
m3.retain_addrs(|a| a.port() != 8282).unwrap();
assert_eq!(m3, m);
assert_ne!(m3, m2);
}
#[test]
#[cfg(feature = "pt-client")]
fn chanmethod_pt() {
use itertools::Itertools;
let transport = "giraffe".parse().unwrap();
let addr1 = PtTargetAddr::HostPort("pt.example.com".into(), 1234);
let target1 = PtTarget::new("giraffe".parse().unwrap(), addr1.clone());
let m1 = ChannelMethod::Pluggable(target1);
let addr2 = PtTargetAddr::IpPort("127.0.0.1:567".parse().unwrap());
let target2 = PtTarget::new("giraffe".parse().unwrap(), addr2.clone());
let m2 = ChannelMethod::Pluggable(target2);
let addr3 = PtTargetAddr::None;
let target3 = PtTarget::new("giraffe".parse().unwrap(), addr3.clone());
let m3 = ChannelMethod::Pluggable(target3);
assert_eq!(m1.socket_addrs(), None);
assert_eq!(
m2.socket_addrs(),
Some(&["127.0.0.1:567".parse().unwrap()][..])
);
assert_eq!(m3.socket_addrs(), None);
assert_eq!(m1.target_addr(), Some(addr1));
assert_eq!(m2.target_addr(), Some(addr2));
assert_eq!(m3.target_addr(), Some(addr3));
assert!(!m1.is_direct());
assert!(!m2.is_direct());
assert!(!m3.is_direct());
assert_eq!(m1.transport_id(), transport);
assert_eq!(m2.transport_id(), transport);
assert_eq!(m3.transport_id(), transport);
for v in [&m1, &m2, &m3].iter().combinations(2) {
let first = v[0];
let second = v[1];
assert_eq!(first.contained_by(second), first == second);
}
let mut m1new = m1.clone();
let mut m2new = m2.clone();
let mut m3new = m3.clone();
m1new.retain_addrs(|a| a.port() == 567).unwrap();
m2new.retain_addrs(|a| a.port() == 567).unwrap();
m3new.retain_addrs(|a| a.port() == 567).unwrap();
assert_eq!(m1new, m1);
assert_eq!(m2new, m2);
assert_eq!(m3new, m3);
assert!(matches!(
m2new.retain_addrs(|a| a.port() == 999),
Err(RetainAddrsError::NoAddrsLeft)
));
}
}