use serde::Deserialize;
use std::{fmt::Debug, path::PathBuf, str::FromStr};
use tor_config_path::{
addr::{CfgAddr, CfgAddrError},
CfgPath, CfgPathError, CfgPathResolver,
};
use tor_general_addr::general;
use crate::HasClientErrorAction;
#[derive(Clone, Debug)]
pub struct ParsedConnectPoint(ConnectPointEnum<Unresolved>);
#[derive(Clone, Debug)]
pub struct ResolvedConnectPoint(pub(crate) ConnectPointEnum<Resolved>);
impl ParsedConnectPoint {
pub fn resolve(
&self,
resolver: &CfgPathResolver,
) -> Result<ResolvedConnectPoint, ResolveError> {
use ConnectPointEnum as CPE;
Ok(ResolvedConnectPoint(match &self.0 {
CPE::Connect(connect) => CPE::Connect(connect.resolve(resolver)?),
CPE::Builtin(builtin) => CPE::Builtin(builtin.clone()),
}))
}
}
impl FromStr for ParsedConnectPoint {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let de: ConnectPointDe = toml::from_str(s).map_err(ParseError::InvalidConnectPoint)?;
Ok(ParsedConnectPoint(de.try_into()?))
}
}
#[derive(Clone, Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ParseError {
#[error("Invalid connect point")]
InvalidConnectPoint(#[source] toml::de::Error),
#[error("Conflicting members in connect point")]
ConflictingMembers,
#[error("Unrecognized format on connect point")]
UnrecognizedFormat,
}
impl HasClientErrorAction for ParseError {
fn client_action(&self) -> crate::ClientErrorAction {
use crate::ClientErrorAction as A;
match self {
ParseError::InvalidConnectPoint(_) => A::Abort,
ParseError::ConflictingMembers => A::Abort,
ParseError::UnrecognizedFormat => A::Decline,
}
}
}
#[derive(Clone, Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ResolveError {
#[error("Unable to resolve variables in path")]
InvalidPath(#[from] CfgPathError),
#[error("Unable to resolve variables in address")]
InvalidAddr(#[from] CfgAddrError),
#[error("Cannot represent expanded path as string")]
PathNotString,
#[error("Tried to bind or connect to a non-loopback TCP address")]
AddressNotLoopback,
#[error("Authorization type not compatible with address family")]
AuthNotCompatible,
#[error("Authorization type not recognized as a supported type")]
AuthNotRecognized,
#[error("Address type not recognized")]
AddressTypeNotRecognized,
}
impl HasClientErrorAction for ResolveError {
fn client_action(&self) -> crate::ClientErrorAction {
use crate::ClientErrorAction as A;
match self {
ResolveError::InvalidPath(e) => e.client_action(),
ResolveError::InvalidAddr(e) => e.client_action(),
ResolveError::PathNotString => A::Decline,
ResolveError::AddressNotLoopback => A::Decline,
ResolveError::AuthNotCompatible => A::Abort,
ResolveError::AuthNotRecognized => A::Decline,
ResolveError::AddressTypeNotRecognized => A::Decline,
}
}
}
#[derive(Clone, Debug)]
pub(crate) enum ConnectPointEnum<R: Addresses> {
Connect(Connect<R>),
Builtin(Builtin),
}
pub(crate) trait Addresses {
type SocketAddr: Clone + std::fmt::Debug;
type Path: Clone + std::fmt::Debug;
}
#[derive(Deserialize, Clone, Debug)]
struct ConnectPointDe {
connect: Option<Connect<Unresolved>>,
builtin: Option<Builtin>,
}
impl TryFrom<ConnectPointDe> for ConnectPointEnum<Unresolved> {
type Error = ParseError;
fn try_from(value: ConnectPointDe) -> Result<Self, Self::Error> {
match value {
ConnectPointDe {
connect: Some(c),
builtin: None,
} => Ok(ConnectPointEnum::Connect(c)),
ConnectPointDe {
connect: None,
builtin: Some(b),
} => Ok(ConnectPointEnum::Builtin(b)),
ConnectPointDe {
connect: Some(_),
builtin: Some(_),
} => Err(ParseError::ConflictingMembers),
_ => Err(ParseError::UnrecognizedFormat),
}
}
}
#[derive(Deserialize, Clone, Debug)]
pub(crate) struct Builtin {
pub(crate) builtin: BuiltinVariant,
}
#[derive(Deserialize, Clone, Debug)]
#[serde(rename_all = "lowercase")]
pub(crate) enum BuiltinVariant {
Abort,
}
#[derive(Deserialize, Clone, Debug)]
#[serde(bound = "R::Path : Deserialize<'de>, AddrWithStr<R::SocketAddr> : Deserialize<'de>")]
pub(crate) struct Connect<R: Addresses> {
pub(crate) socket: AddrWithStr<R::SocketAddr>,
pub(crate) socket_canonical: Option<AddrWithStr<R::SocketAddr>>,
pub(crate) auth: Auth<R>,
}
impl Connect<Unresolved> {
fn resolve(&self, resolver: &CfgPathResolver) -> Result<Connect<Resolved>, ResolveError> {
let socket = self.socket.resolve(resolver)?;
let socket_canonical = self
.socket_canonical
.as_ref()
.map(|sc| sc.resolve(resolver))
.transpose()?;
let auth = self.auth.resolve(resolver)?;
Connect {
socket,
socket_canonical,
auth,
}
.validate()
}
}
impl Connect<Resolved> {
fn validate(self) -> Result<Self, ResolveError> {
use general::SocketAddr::{Inet, Unix};
match (self.socket.as_ref(), &self.auth) {
(Inet(addr), _) if !addr.ip().is_loopback() => Err(ResolveError::AddressNotLoopback),
(Inet(_), Auth::None) => Err(ResolveError::AuthNotCompatible),
(_, Auth::Unrecognized {}) => Err(ResolveError::AuthNotRecognized),
(Inet(_), Auth::Cookie { .. }) => Ok(self),
(Unix(_), _) => Ok(self),
(_, _) => Err(ResolveError::AddressTypeNotRecognized),
}
}
}
#[derive(Deserialize, Clone, Debug)]
#[serde(rename_all = "lowercase")]
pub(crate) enum Auth<R: Addresses> {
None,
Cookie {
path: R::Path,
},
#[serde(untagged)]
Unrecognized {},
}
impl Auth<Unresolved> {
fn resolve(&self, resolver: &CfgPathResolver) -> Result<Auth<Resolved>, ResolveError> {
match self {
Auth::None => Ok(Auth::None),
Auth::Cookie { path } => Ok(Auth::Cookie {
path: path.path(resolver)?,
}),
Auth::Unrecognized {} => Ok(Auth::Unrecognized {}),
}
}
}
#[derive(Clone, Debug)]
struct Unresolved;
impl Addresses for Unresolved {
type SocketAddr = CfgAddr;
type Path = CfgPath;
}
#[derive(Clone, Debug)]
pub(crate) struct Resolved;
impl Addresses for Resolved {
type SocketAddr = general::SocketAddr;
type Path = PathBuf;
}
#[derive(Clone, Debug, derive_more::AsRef, serde_with::DeserializeFromStr)]
pub(crate) struct AddrWithStr<A>
where
A: Clone + Debug,
{
string: String,
#[as_ref]
addr: A,
}
impl<A> AddrWithStr<A>
where
A: Clone + Debug,
{
pub(crate) fn as_str(&self) -> &str {
self.string.as_str()
}
}
impl AddrWithStr<CfgAddr> {
pub(crate) fn resolve(
&self,
resolver: &CfgPathResolver,
) -> Result<AddrWithStr<general::SocketAddr>, ResolveError> {
let AddrWithStr { string, addr } = self;
let substituted = addr.substitutions_will_apply();
let addr = addr.address(resolver)?;
let string = if substituted {
addr.try_to_string().ok_or(ResolveError::PathNotString)?
} else {
string.clone()
};
Ok(AddrWithStr { string, addr })
}
}
impl<A> FromStr for AddrWithStr<A>
where
A: Clone + Debug + FromStr,
{
type Err = <A as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let addr = s.parse()?;
let string = s.to_owned();
Ok(Self { string, addr })
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use assert_matches::assert_matches;
fn parse(s: &str) -> ParsedConnectPoint {
s.parse().unwrap()
}
#[test]
fn examples() {
let _e1 = parse(
r#"
[builtin]
builtin = "abort"
"#,
);
let _e2 = parse(
r#"
[connect]
socket = "unix:/var/run/arti/rpc_socket"
auth = "none"
"#,
);
let _e3 = parse(
r#"
[connect]
socket = "inet:[::1]:9191"
socket_canonical = "inet:[::1]:2020"
auth = { cookie = { path = "/home/user/.arti_rpc/cookie" } }
"#,
);
let _e4 = parse(
r#"
[connect]
socket = "inet:[::1]:9191"
socket_canonical = "inet:[::1]:2020"
[connect.auth.cookie]
path = "/home/user/.arti_rpc/cookie"
"#,
);
}
#[test]
fn parse_errors() {
let r: Result<ParsedConnectPoint, _> = "not a toml string".parse();
assert_matches!(r, Err(ParseError::InvalidConnectPoint(_)));
let r: Result<ParsedConnectPoint, _> = "[squidcakes]".parse();
assert_matches!(r, Err(ParseError::UnrecognizedFormat));
let r: Result<ParsedConnectPoint, _> = r#"
[builtin]
builtin = "abort"
[connect]
socket = "inet:[::1]:9191"
socket_canonical = "inet:[::1]:2020"
auth = { cookie = { path = "/home/user/.arti_rpc/cookie" } }
"#
.parse();
assert_matches!(r, Err(ParseError::ConflictingMembers));
}
#[test]
fn resolve_errors() {
let resolver = CfgPathResolver::default();
let r: ParsedConnectPoint = r#"
[connect]
socket = "inet:[::1]:9191"
socket_canonical = "inet:[::1]:2020"
[connect.auth.esp]
telekinetic_handshake = 3
"#
.parse()
.unwrap();
let err = r.resolve(&resolver).err();
assert_matches!(err, Some(ResolveError::AuthNotRecognized));
}
}