use std::collections::HashMap;
use std::{error, fmt};
use thiserror::Error;
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_native_tls::native_tls::{HandshakeError, TlsConnector};
use crate::parser;
use crate::Message;
#[derive(Debug)]
pub struct HandshakeErrorWrapper(HandshakeError<TcpStream>);
impl fmt::Display for HandshakeErrorWrapper {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl error::Error for HandshakeErrorWrapper {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
Some(&self.0)
}
}
impl From<parser::Error> for IrcConnectionError {
fn from(error: parser::Error) -> Self {
Self::ParserError(error)
}
}
#[derive(Debug, Error)]
pub enum IrcConnectionError {
#[error("Failed to connect to server: {0}")]
ConnectionError(#[from] std::io::Error),
#[error("Failed to establish TLS connection: {0}")]
TlsError(HandshakeErrorWrapper),
#[error("Failed to send message: {0}")]
ParserError(parser::Error),
}
type TcpStreamParts = (Mutex<ReadHalf<TcpStream>>, Mutex<WriteHalf<TcpStream>>);
type TlsStreamParts = (
Mutex<ReadHalf<tokio_native_tls::TlsStream<tokio::net::TcpStream>>>,
Mutex<WriteHalf<tokio_native_tls::TlsStream<tokio::net::TcpStream>>>,
);
#[derive(Debug)]
pub enum MaybeTlsStream {
Tcp(TcpStreamParts),
Tls(TlsStreamParts),
}
impl PartialEq for MaybeTlsStream {
fn eq(&self, other: &Self) -> bool {
matches!(
(self, other),
(Self::Tcp(_), Self::Tcp(_)) | (Self::Tls(_), Self::Tls(_))
)
}
}
#[derive(Debug)]
pub struct IrcConnection {
pub stream: MaybeTlsStream,
pub server: String,
pub port: u16,
pub tls: bool,
pub accept_invalid_tls_cert: bool,
buf: Mutex<Vec<u8>>,
}
impl PartialEq for IrcConnection {
fn eq(&self, other: &Self) -> bool {
self.stream == other.stream
&& self.server == other.server
&& self.port == other.port
&& self.tls == other.tls
&& self.accept_invalid_tls_cert == other.accept_invalid_tls_cert
}
}
impl IrcConnection {
pub async fn new(
server: &str,
port: u16,
tls: bool,
accept_invalid_tls_cert: bool,
) -> Result<Self, IrcConnectionError> {
let tcpstream = TcpStream::connect(format!("{server}:{port}"))
.await
.map_err(IrcConnectionError::ConnectionError)?;
let stream = if tls {
let connector = TlsConnector::builder()
.danger_accept_invalid_certs(accept_invalid_tls_cert)
.build()
.map_err(|e| {
IrcConnectionError::TlsError(HandshakeErrorWrapper(HandshakeError::Failure(e)))
})?;
let connector = tokio_native_tls::TlsConnector::from(connector);
let tls_stream = connector.connect(server, tcpstream).await.map_err(|e| {
IrcConnectionError::TlsError(HandshakeErrorWrapper(
tokio_native_tls::native_tls::HandshakeError::Failure(e),
))
})?;
let (reader, writer) = tokio::io::split(tls_stream);
MaybeTlsStream::Tls((Mutex::new(reader), Mutex::new(writer)))
} else {
let (reader, writer) = tokio::io::split(tcpstream);
MaybeTlsStream::Tcp((Mutex::new(reader), Mutex::new(writer)))
};
Ok(Self {
stream,
server: server.to_string(),
port,
tls,
accept_invalid_tls_cert,
buf: Mutex::new(Vec::new()),
})
}
pub async fn write_line(&self, message: &str) -> Result<(), IrcConnectionError> {
let message = format!("{message}\r\n");
match &self.stream {
MaybeTlsStream::Tcp(tcp_stream) => {
tcp_stream
.1
.lock()
.await
.write_all(message.as_bytes())
.await?;
tcp_stream.1.lock().await.flush().await?;
}
MaybeTlsStream::Tls(tls_stream) => {
tls_stream
.1
.lock()
.await
.write_all(message.as_bytes())
.await?;
tls_stream.1.lock().await.flush().await?;
}
}
Ok(())
}
pub async fn read_line(&self) -> Result<String, std::io::Error> {
let mut cursor = 0;
loop {
let buf = self.buf.lock().await;
if let Some(offset) = buf[cursor..].windows(2).position(|w| w == b"\r\n") {
let index = cursor + offset;
let response = String::from_utf8_lossy(&buf[..index]).to_string();
drop(buf);
self.buf.lock().await.drain(0..index + 2);
return Ok(response);
}
drop(buf);
cursor = self.buf.lock().await.len();
self.buf.lock().await.resize(cursor + 1024, 0);
let got = match &self.stream {
MaybeTlsStream::Tcp(tcp_stream) => {
tcp_stream
.0
.lock()
.await
.read(&mut self.buf.lock().await[cursor..])
.await?
}
MaybeTlsStream::Tls(tls_stream) => {
tls_stream
.0
.lock()
.await
.read(&mut self.buf.lock().await[cursor..])
.await?
}
};
self.buf.lock().await.resize(cursor + got, 0);
}
}
pub async fn identify(
&self,
nickname: &str,
username: &str,
realname: &str,
) -> Result<(), IrcConnectionError> {
let token = Message::build_token(Message::new(
HashMap::new(),
None,
String::from("NICK"),
vec![nickname.to_string()],
)?)?;
self.write_line(&token).await?;
let token = Message::build_token(Message::new(
HashMap::new(),
None,
String::from("USER"),
vec![
username.to_string(),
"0".to_string(),
"*".to_string(),
realname.to_string(),
],
)?)?;
self.write_line(&token).await?;
Ok(())
}
}