use std::{
io::{self, BufReader, BufWriter, Write},
net::{SocketAddr, TcpStream},
};
use serde::{Deserialize, Serialize};
use crate::{
inject_io_failure,
parser::ControlOp,
parser::{expect_info, parse_control_op},
split_tls, AuthStyle, FinalizedOptions, Reader, SecureString, ServerInfo, Writer,
};
fn default_echo() -> bool {
true
}
#[derive(Clone, Serialize, Deserialize, Debug)]
#[doc(hidden)]
#[allow(clippy::module_name_repetitions)]
pub struct ConnectInfo {
pub verbose: bool,
pub pedantic: bool,
#[serde(rename = "jwt", skip_serializing_if = "is_empty_or_none")]
pub user_jwt: Option<SecureString>,
#[serde(rename = "sig", skip_serializing_if = "is_empty_or_none")]
pub signature: Option<SecureString>,
#[serde(skip_serializing_if = "is_empty_or_none")]
pub name: Option<SecureString>,
#[serde(skip_serializing_if = "is_true", default = "default_echo")]
pub echo: bool,
pub lang: String,
pub version: String,
#[serde(default)]
pub tls_required: bool,
#[serde(skip_serializing_if = "is_empty_or_none")]
pub user: Option<SecureString>,
#[serde(skip_serializing_if = "is_empty_or_none")]
pub pass: Option<SecureString>,
#[serde(skip_serializing_if = "is_empty_or_none")]
pub auth_token: Option<SecureString>,
}
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_true(field: &bool) -> bool {
*field
}
#[allow(clippy::trivially_copy_pass_by_ref)]
#[inline]
fn is_empty_or_none(field: &Option<SecureString>) -> bool {
match field {
Some(inner) => inner.is_empty(),
None => true,
}
}
pub(crate) fn connect_to_socket_addr(
addr: SocketAddr,
host: &str,
tls_required: bool,
options: &FinalizedOptions,
) -> io::Result<(Reader, Writer, ServerInfo)> {
inject_io_failure()?;
let mut stream = TcpStream::connect(&addr)?;
let server_info = expect_info(&mut stream)?;
let (mut reader, writer) = authenticate(stream, &server_info, options, tls_required, host)?;
let parsed_op = parse_control_op(&mut reader)?;
match parsed_op {
ControlOp::Pong => Ok((reader, writer, server_info)),
ControlOp::Err(e) => Err(io::Error::new(io::ErrorKind::ConnectionRefused, e)),
ControlOp::Ping | ControlOp::Msg(_) | ControlOp::Info(_) | ControlOp::Unknown(_) => {
log::error!(
"encountered unexpected control op during connection: {:?}",
parsed_op
);
Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
"Protocol Error",
))
}
}
}
fn authenticate(
stream: TcpStream,
server_info: &ServerInfo,
options: &FinalizedOptions,
tls_required: bool,
host: &str,
) -> io::Result<(Reader, Writer)> {
let mut connect_info = ConnectInfo {
tls_required,
name: options.name.clone().map(SecureString::from),
pedantic: false,
verbose: false,
lang: crate::LANG.to_string(),
version: crate::VERSION.to_string(),
user: None,
pass: None,
auth_token: None,
user_jwt: None,
signature: None,
echo: !options.no_echo,
};
match &options.auth {
AuthStyle::UserPass(user, pass) => {
connect_info.user = Some(SecureString::from(user.to_string()));
connect_info.pass = Some(SecureString::from(pass.to_string()));
}
AuthStyle::Token(token) => {
connect_info.auth_token = Some(token.to_string().into());
}
AuthStyle::Credentials { jwt_cb, sig_cb } => {
let jwt = jwt_cb()?;
let sig = sig_cb(server_info.nonce.as_bytes())?;
connect_info.user_jwt = Some(jwt);
connect_info.signature = Some(sig);
}
AuthStyle::None => {}
}
let op = format!(
"CONNECT {}\r\nPING\r\n",
serde_json::to_string(&connect_info)?
);
let (reader, mut writer) = if options.tls_required || server_info.tls_required || tls_required {
let attempt = if let Some(ref tls_connector) = options.tls_connector {
tls_connector.connect(host, stream)
} else {
match native_tls::TlsConnector::new() {
Ok(connector) => connector.connect(host, stream),
Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
}
};
match attempt {
Ok(tls) => {
let (tls_reader, tls_writer) = split_tls(tls);
let reader = Reader::Tls(BufReader::with_capacity(64 * 1024, tls_reader));
let writer = Writer::Tls(BufWriter::with_capacity(64 * 1024, tls_writer));
(reader, writer)
}
Err(e) => {
log::error!("failed to upgrade TLS: {:?}", e);
return Err(io::Error::new(io::ErrorKind::PermissionDenied, e));
}
}
} else {
let reader = Reader::Tcp(BufReader::with_capacity(64 * 1024, stream.try_clone()?));
let writer = Writer::Tcp(BufWriter::with_capacity(64 * 1024, stream));
(reader, writer)
};
writer.write_all(op.as_bytes())?;
writer.flush()?;
Ok((reader, writer))
}