use std::fmt::Display;
use ractor::ActorRef;
use ractor::MessagingErr;
use tokio::net::TcpStream;
use tokio::net::ToSocketAddrs;
use tokio_rustls::rustls::pki_types::ServerName;
#[derive(Debug)]
pub enum ClientConnectErr {
Socket(tokio::io::Error),
Messaging(MessagingErr<super::NodeServerMessage>),
Encryption(tokio::io::Error),
}
impl std::error::Error for ClientConnectErr {
fn cause(&self) -> Option<&dyn std::error::Error> {
match self {
Self::Socket(cause) => Some(cause),
Self::Messaging(cause) => Some(cause),
Self::Encryption(cause) => Some(cause),
}
}
}
impl Display for ClientConnectErr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl From<tokio::io::Error> for ClientConnectErr {
fn from(value: tokio::io::Error) -> Self {
Self::Socket(value)
}
}
impl From<MessagingErr<super::NodeServerMessage>> for ClientConnectErr {
fn from(value: MessagingErr<super::NodeServerMessage>) -> Self {
Self::Messaging(value)
}
}
pub async fn connect<T>(
node_server: &ActorRef<super::NodeServerMessage>,
address: T,
) -> Result<(), ClientConnectErr>
where
T: ToSocketAddrs,
{
let stream = TcpStream::connect(address).await?;
stream.set_nodelay(true)?;
let addr = stream.peer_addr()?;
let local = stream.local_addr()?;
node_server.cast(super::NodeServerMessage::ConnectionOpened {
stream: Box::new(crate::net::NetworkStream::Raw {
stream,
peer_addr: addr,
local_addr: local,
}),
is_server: false,
})?;
tracing::info!("TCP Session opened for {addr}");
Ok(())
}
pub async fn connect_enc<T>(
node_server: &ActorRef<super::NodeServerMessage>,
address: T,
encryption_settings: tokio_rustls::TlsConnector,
domain: ServerName<'static>,
) -> Result<(), ClientConnectErr>
where
T: ToSocketAddrs,
{
let stream = TcpStream::connect(address).await?;
stream.set_nodelay(true)?;
let addr = stream.peer_addr()?;
let local = stream.local_addr()?;
let enc_stream = encryption_settings
.connect(domain, stream)
.await
.map_err(ClientConnectErr::Encryption)?;
node_server.cast(super::NodeServerMessage::ConnectionOpened {
stream: Box::new(crate::net::NetworkStream::TlsClient {
stream: enc_stream,
peer_addr: addr,
local_addr: local,
}),
is_server: false,
})?;
tracing::info!("TCP Session opened for {addr}");
Ok(())
}
pub async fn connect_external(
node_server: &ActorRef<super::NodeServerMessage>,
stream: Box<dyn crate::net::ClusterBidiStream>,
) -> Result<(), ClientConnectErr> {
node_server.cast(super::NodeServerMessage::ConnectionOpenedExternal {
stream,
is_server: false,
})?;
tracing::info!("External session opened (client)");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_connect_err_from_tokio_io_error() {
let io_err = tokio::io::Error::new(
tokio::io::ErrorKind::ConnectionRefused,
"Connection refused",
);
let client_err: ClientConnectErr = io_err.into();
match client_err {
ClientConnectErr::Socket(_) => {
}
_ => {
panic!("Expected ClientConnectErr::Socket");
}
}
}
#[test]
fn test_client_connect_err_display() {
let io_err = tokio::io::Error::new(tokio::io::ErrorKind::Other, "Connection error");
let client_err: ClientConnectErr = io_err.into();
let _display_str = format!("{}", client_err);
}
#[test]
fn test_client_connect_err_is_error() {
use std::error::Error;
let io_err = tokio::io::Error::new(tokio::io::ErrorKind::Other, "Some error");
let client_err: ClientConnectErr = io_err.into();
let _err: &dyn Error = &client_err;
}
}