use super::SetOpt;
use curl::easy::{Easy2, List};
use http::Uri;
use std::{convert::TryFrom, fmt, net::SocketAddr, str::FromStr};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct DialerParseError(());
impl fmt::Display for DialerParseError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.write_str("invalid dial address syntax")
}
}
impl std::error::Error for DialerParseError {}
#[derive(Clone, Debug)]
pub struct Dialer(Inner);
#[derive(Clone, Debug, Eq, PartialEq)]
enum Inner {
Default,
IpSocket(String),
#[cfg(unix)]
UnixSocket(std::path::PathBuf),
}
impl Dialer {
pub fn ip_socket(addr: impl Into<SocketAddr>) -> Self {
Self(Inner::IpSocket(format!("::{}", addr.into())))
}
#[cfg(unix)]
pub fn unix_socket(path: impl Into<std::path::PathBuf>) -> Self {
Self(Inner::UnixSocket(path.into()))
}
}
impl Default for Dialer {
fn default() -> Self {
Self(Inner::Default)
}
}
impl From<SocketAddr> for Dialer {
fn from(socket_addr: SocketAddr) -> Self {
Self::ip_socket(socket_addr)
}
}
impl FromStr for Dialer {
type Err = DialerParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.starts_with("tcp:") {
let addr_str = s[4..].trim_start_matches('/');
return addr_str
.parse::<SocketAddr>()
.map(Self::ip_socket)
.map_err(|_| DialerParseError(()));
}
#[cfg(unix)]
{
if s.starts_with("unix:") {
let mut path = std::path::PathBuf::from("/");
path.push(&s[5..].trim_start_matches('/'));
return Ok(Self(Inner::UnixSocket(path)));
}
}
Err(DialerParseError(()))
}
}
impl TryFrom<&'_ str> for Dialer {
type Error = DialerParseError;
fn try_from(str: &str) -> Result<Self, Self::Error> {
str.parse()
}
}
impl TryFrom<String> for Dialer {
type Error = DialerParseError;
fn try_from(string: String) -> Result<Self, Self::Error> {
string.parse()
}
}
impl TryFrom<Uri> for Dialer {
type Error = DialerParseError;
fn try_from(uri: Uri) -> Result<Self, Self::Error> {
uri.to_string().parse()
}
}
impl SetOpt for Dialer {
fn set_opt<H>(&self, easy: &mut Easy2<H>) -> Result<(), curl::Error> {
let mut connect_to = List::new();
if let Inner::IpSocket(addr) = &self.0 {
connect_to.append(addr)?;
}
easy.connect_to(connect_to)?;
#[cfg(unix)]
easy.unix_socket_path(match &self.0 {
Inner::UnixSocket(path) => Some(path),
_ => None,
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_tcp_socket_and_port_uri() {
let dialer = "tcp:127.0.0.1:1200".parse::<Dialer>().unwrap();
assert_eq!(dialer.0, Inner::IpSocket("::127.0.0.1:1200".into()));
}
#[test]
fn parse_invalid_tcp_uri() {
let result = "tcp:127.0.0.1-1200".parse::<Dialer>();
assert!(result.is_err());
}
#[test]
#[cfg(unix)]
fn parse_unix_socket_uri() {
let dialer = "unix:/path/to/my.sock".parse::<Dialer>().unwrap();
assert_eq!(dialer.0, Inner::UnixSocket("/path/to/my.sock".into()));
}
#[test]
#[cfg(unix)]
fn from_unix_socket_uri() {
let uri = "unix://path/to/my.sock".parse::<http::Uri>().unwrap();
let dialer = Dialer::try_from(uri).unwrap();
assert_eq!(dialer.0, Inner::UnixSocket("/path/to/my.sock".into()));
}
}