use std::fmt::Display;
use std::io::{self, BufRead, BufReader, Error, ErrorKind, Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::num::ParseIntError;
use std::result::Result;
use std::str::FromStr;
use std::sync::Arc;
use rustls::{ClientConnection, Stream};
macro_rules! regex {
($re:literal $(,)?) => {{
static RE: once_cell::sync::OnceCell<regex::Regex> = once_cell::sync::OnceCell::new();
RE.get_or_init(|| regex::Regex::new($re).unwrap())
}};
}
pub fn insecure<A: ToSocketAddrs>(addr: A) -> Result<TcpStream, Error> {
let mut stream = TcpStream::connect(addr)?;
stream.read_greeting()?;
Ok(stream)
}
pub fn secure<A: ToSocketAddrs + Display>(
addr: A,
hostname: &str,
) -> Result<SecureConnection, Error> {
let mut root_store = rustls::RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let rc_config = Arc::new(config);
let server_name = hostname.try_into().unwrap();
let mut client = ClientConnection::new(rc_config, server_name).unwrap();
let mut socket = TcpStream::connect(addr)?;
let mut stream = Stream::new(&mut client, &mut socket);
stream.read_greeting()?;
Ok(SecureConnection { client, socket })
}
pub struct SecureConnection {
client: ClientConnection,
socket: TcpStream,
}
impl SecureConnection {
pub fn stream(&mut self) -> Stream<'_, ClientConnection, TcpStream> {
Stream::new(&mut self.client, &mut self.socket)
}
}
pub trait Connection {
fn read_greeting(&mut self) -> io::Result<()>;
fn authenticate(&mut self, username: &str, password: &str) -> io::Result<Status>;
fn quit(&mut self) -> io::Result<()>;
fn execute(&mut self, command: &Command) -> io::Result<(Status, Vec<u8>)>;
fn send_command(&mut self, command: &Command) -> io::Result<()>;
fn read_response(&mut self) -> io::Result<(Status, Vec<u8>)>;
}
const TEXT_FOLLOWS: &[StatusCode] = &[
StatusCode(215),
StatusCode(220),
StatusCode(221),
StatusCode(222),
StatusCode(230),
StatusCode(231),
];
impl<T> Connection for T
where
T: Read + Write,
{
fn read_greeting(&mut self) -> io::Result<()> {
self
.read_response()
.map(|(_, _)| ())
.map_err(|_| Error::new(ErrorKind::Other, "Failed to read greeting response"))
}
fn authenticate(&mut self, username: &str, password: &str) -> io::Result<Status> {
let (status, _) = self
.execute(&Command(vec![
"AUTHINFO".to_string(),
"USER".to_string(),
username.to_string(),
]))
.and_then(|result| {
if result.0.status_code == StatusCode(381) {
self.execute(&Command(vec![
"AUTHINFO".to_string(),
"PASS".to_string(),
password.to_string(),
]))
} else {
Ok(result)
}
})?;
if status.status_code == StatusCode(281) {
Ok(status)
} else {
Err(Error::new(
ErrorKind::Other,
format!("Authentication failed: {status:?}"),
))
}
}
fn quit(&mut self) -> io::Result<()> {
self.execute(&Command(vec!["QUIT".to_string()])).map(|_| ())
}
fn execute(&mut self, command: &Command) -> io::Result<(Status, Vec<u8>)> {
self
.send_command(command)
.and_then(|_| self.read_response())
}
fn send_command(&mut self, command: &Command) -> io::Result<()> {
let mut encoded = command.0.join(" ");
encoded.push_str("\r\n");
self
.write_all(encoded.as_bytes())
.and_then(|_| self.flush())
}
fn read_response(&mut self) -> io::Result<(Status, Vec<u8>)> {
let mut reader = BufReader::new(self);
let mut status = String::new();
let count = reader.read_line(&mut status)?;
if count > 0 {
status
.trim()
.parse::<Status>()
.map_err(|_| Error::new(ErrorKind::InvalidData, "Failed to parse status."))
.and_then(|status| {
if TEXT_FOLLOWS.contains(&status.status_code) {
let mut data: Vec<u8> = Vec::new();
while !data.ends_with(&[0x0d, 0x0a, 0x2e, 0x0d, 0x0a]) {
let mut buf = reader.fill_buf()?.to_vec();
reader.consume(buf.len());
data.append(&mut buf);
}
data.truncate(data.len() - 5);
Ok((status, data))
} else {
Ok((status, Vec::new()))
}
})
} else {
Err(Error::from(ErrorKind::UnexpectedEof))
}
}
}
#[derive(Clone, Debug)]
pub struct Command(Vec<String>);
impl Command {
pub fn new(params: &[String]) -> Command {
Command(params.to_vec())
}
}
#[derive(Clone, Debug)]
pub struct Status {
pub status_code: StatusCode,
pub message: String,
}
impl FromStr for Status {
type Err = StatusParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
regex!(r"^(?P<status_code>\d{3}) (?P<message>.+)")
.captures(s)
.and_then(|caps| {
let status_code = caps
.name("status_code")?
.as_str()
.parse::<StatusCode>()
.ok()?;
let message = caps.name("message")?.as_str().to_string();
Some(Status {
status_code,
message,
})
})
.ok_or(StatusParseError)
}
}
pub struct StatusParseError;
#[derive(PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Hash, Debug)]
pub struct StatusCode(usize);
impl From<usize> for StatusCode {
fn from(val: usize) -> Self {
Self(val)
}
}
impl FromStr for StatusCode {
type Err = ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.parse().map(Self)
}
}