cowirc 0.2.0

Asychronous IRCv3 library for Rust
Documentation
/*
 * This file is part of CowIRC.
 *
 * CowIRC is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * CowIRC is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with CowIRC. If not, see <http://www.gnu.org/licenses/>.
 */

use std::collections::HashMap;
use std::{error, fmt};
use thiserror::Error;
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_native_tls::native_tls::{HandshakeError, TlsConnector};

use crate::parser;
use crate::Message;

/// A wrapper for the handshake error.
#[derive(Debug)]
pub struct HandshakeErrorWrapper(HandshakeError<TcpStream>);

impl fmt::Display for HandshakeErrorWrapper {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{:?}", self.0)
    }
}

impl error::Error for HandshakeErrorWrapper {
    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
        Some(&self.0)
    }
}

impl From<parser::Error> for IrcConnectionError {
    fn from(error: parser::Error) -> Self {
        Self::ParserError(error)
    }
}

/// A wrapper for errors that can occur during the establishment of the IRC connection.
#[derive(Debug, Error)]
pub enum IrcConnectionError {
    /// Indicates an error connecting to the server.
    #[error("Failed to connect to server: {0}")]
    ConnectionError(#[from] std::io::Error),
    /// Indicates an error establishing the TLS connection.
    #[error("Failed to establish TLS connection: {0}")]
    TlsError(HandshakeErrorWrapper),
    /// Indicates an error parsing a response from the server.
    #[error("Failed to send message: {0}")]
    ParserError(parser::Error),
}

/// Represents the parts of a TCP stream.
type TcpStreamParts = (Mutex<ReadHalf<TcpStream>>, Mutex<WriteHalf<TcpStream>>);

/// Represents the parts of a TLS stream.
type TlsStreamParts = (
    Mutex<ReadHalf<tokio_native_tls::TlsStream<tokio::net::TcpStream>>>,
    Mutex<WriteHalf<tokio_native_tls::TlsStream<tokio::net::TcpStream>>>,
);

/// Represents a TCP or TLS stream.
#[derive(Debug)]
pub enum MaybeTlsStream {
    Tcp(TcpStreamParts),
    Tls(TlsStreamParts),
}

impl PartialEq for MaybeTlsStream {
    fn eq(&self, other: &Self) -> bool {
        matches!(
            (self, other),
            (Self::Tcp(_), Self::Tcp(_)) | (Self::Tls(_), Self::Tls(_))
        )
    }
}

/// Represents an IRC connection.
#[derive(Debug)]
pub struct IrcConnection {
    pub stream: MaybeTlsStream,
    pub server: String,
    pub port: u16,
    pub tls: bool,
    pub accept_invalid_tls_cert: bool,
    buf: Mutex<Vec<u8>>,
}

impl PartialEq for IrcConnection {
    fn eq(&self, other: &Self) -> bool {
        self.stream == other.stream
            && self.server == other.server
            && self.port == other.port
            && self.tls == other.tls
            && self.accept_invalid_tls_cert == other.accept_invalid_tls_cert
    }
}

impl IrcConnection {
    /// Initializes a new IRC connection.
    ///
    /// # Arguments
    ///
    /// * `server` - The server to connect to.
    /// * `port` - The port to connect to.
    /// * `tls` - Whether to use TLS for the connection.
    /// * `accept_invalid_tls_cert` - Whether to accept invalid TLS certificates.
    ///
    /// # Returns
    ///
    /// * `Result<Self, IrcConnectionError>` - A result containing either an IRC connection,
    /// or on failure an error of type `IrcConnectionError` with more details.
    ///
    /// # Errors
    ///
    /// Will return `IrcConnectionError` if an error occurs in the socket construction, or the
    /// SSL handshake.
    ///
    /// # Examples
    ///
    /// ```rust
    /// use cowirc::networking::IrcConnection;
    /// use std::error::Error;
    ///
    /// #[tokio::main]
    /// async fn main() -> Result<(), Box<dyn Error>> {
    ///     match IrcConnection::new("irc.server.com", 6667, false, false).await {
    ///         Ok(connection) => {
    ///             println!("Successfully connected to irc.server.com: {:?}", connection);
    ///         }
    ///         Err(e) => {
    ///             eprintln!("error: failed to connect: {e}");
    ///         }
    ///     }
    ///     Ok(())
    /// }
    /// ```
    pub async fn new(
        server: &str,
        port: u16,
        tls: bool,
        accept_invalid_tls_cert: bool,
    ) -> Result<Self, IrcConnectionError> {
        let tcpstream = TcpStream::connect(format!("{server}:{port}"))
            .await
            .map_err(IrcConnectionError::ConnectionError)?;
        let stream = if tls {
            let connector = TlsConnector::builder()
                .danger_accept_invalid_certs(accept_invalid_tls_cert)
                .build()
                .map_err(|e| {
                    IrcConnectionError::TlsError(HandshakeErrorWrapper(HandshakeError::Failure(e)))
                })?;
            let connector = tokio_native_tls::TlsConnector::from(connector);
            let tls_stream = connector.connect(server, tcpstream).await.map_err(|e| {
                IrcConnectionError::TlsError(HandshakeErrorWrapper(
                    tokio_native_tls::native_tls::HandshakeError::Failure(e),
                ))
            })?;
            let (reader, writer) = tokio::io::split(tls_stream);
            MaybeTlsStream::Tls((Mutex::new(reader), Mutex::new(writer)))
        } else {
            let (reader, writer) = tokio::io::split(tcpstream);
            MaybeTlsStream::Tcp((Mutex::new(reader), Mutex::new(writer)))
        };
        Ok(Self {
            stream,
            server: server.to_string(),
            port,
            tls,
            accept_invalid_tls_cert,
            buf: Mutex::new(Vec::new()),
        })
    }

    /// Writes a message to the IRC connection.
    ///
    /// # Arguments
    ///
    /// * `message` - The message to write.
    ///
    /// # Errors
    ///
    /// Will return `IrcConnectionError` if an error occurs writing to the `WriteHalf`.
    ///
    /// # Returns
    ///
    /// * `Result<(), IrcConnectionError>` - A result containing either an empty tuple indicating
    /// success, or on failure an error of type `IrcConnectionError` with more details.
    ///
    /// # Examples
    ///
    /// ```rust,no_run
    /// use cowirc::networking::IrcConnection;
    /// use std::error::Error;
    ///
    /// #[tokio::main]
    /// async fn main() -> Result<(), Box<dyn Error>> {
    ///     let connection = IrcConnection::new("irc.server.com", 6667, false, false).await.unwrap();
    ///     match connection.write_line("PRIVMSG #channel :Hello, world!").await {
    ///         Ok(()) => {
    ///             println!("Sent message `PRIVMSG #channel :Hello, world!` successfully!");
    ///         }
    ///         Err(e) => {
    ///             eprintln!("error: could not send message: {e}");
    ///         }
    ///     }
    ///     Ok(())
    /// }
    /// ```
    pub async fn write_line(&self, message: &str) -> Result<(), IrcConnectionError> {
        let message = format!("{message}\r\n");
        match &self.stream {
            MaybeTlsStream::Tcp(tcp_stream) => {
                tcp_stream
                    .1
                    .lock()
                    .await
                    .write_all(message.as_bytes())
                    .await?;
                tcp_stream.1.lock().await.flush().await?;
            }
            MaybeTlsStream::Tls(tls_stream) => {
                tls_stream
                    .1
                    .lock()
                    .await
                    .write_all(message.as_bytes())
                    .await?;
                tls_stream.1.lock().await.flush().await?;
            }
        }
        Ok(())
    }

    /// Reads a line from the IRC connection.
    ///
    /// # Returns
    ///
    /// * `Result<String, std::io::Error>` - A result containing either the read line or, on failure
    /// an error of type `std::io::Error` with more details.
    ///
    /// # Errors
    ///
    /// Will return `std::io::Error` if an error occurs reading from the `ReadHalf`.
    ///
    /// # Examples
    ///
    /// ```rust,no_run
    /// use cowirc::networking::IrcConnection;
    /// use std::error::Error;
    ///
    /// #[tokio::main]
    /// async fn main() -> Result<(), Box<dyn Error>> {
    ///     let connection = IrcConnection::new("irc.server.com", 6667, false, false).await.unwrap();
    ///     let line = connection.read_line().await?;
    ///     println!("Received line: {line}");
    ///     Ok(())
    /// }
    /// ```
    pub async fn read_line(&self) -> Result<String, std::io::Error> {
        // Track where the newest data starts
        let mut cursor = 0;

        loop {
            let buf = self.buf.lock().await;

            if let Some(offset) = buf[cursor..].windows(2).position(|w| w == b"\r\n") {
                let index = cursor + offset;
                let response = String::from_utf8_lossy(&buf[..index]).to_string();
                drop(buf);
                self.buf.lock().await.drain(0..index + 2);
                return Ok(response);
            }

            drop(buf);

            // Extend the buffer with up to 1024 bytes of new data
            cursor = self.buf.lock().await.len();
            self.buf.lock().await.resize(cursor + 1024, 0);
            let got = match &self.stream {
                MaybeTlsStream::Tcp(tcp_stream) => {
                    tcp_stream
                        .0
                        .lock()
                        .await
                        .read(&mut self.buf.lock().await[cursor..])
                        .await?
                }
                MaybeTlsStream::Tls(tls_stream) => {
                    tls_stream
                        .0
                        .lock()
                        .await
                        .read(&mut self.buf.lock().await[cursor..])
                        .await?
                }
            };
            self.buf.lock().await.resize(cursor + got, 0);
        }
    }

    /// Identifies with the IRC server using the specified nickname, username, and real name.
    ///
    /// # Arguments
    ///
    /// * `nickname` - The nickname to identify with.
    /// * `username` - The username to identify with.
    /// * `realname` - The real name to identify with.
    ///
    /// # Errors
    ///
    /// Will return `IrcConnectionError` if an error occurs constructing a Message,
    /// or sending a message to the server
    ///
    /// # Returns
    ///
    /// * `Result<(), IrcConnectionError>` - A result containing either an empty tuple indicating
    /// success, or on failure an error of type `IrcConnectionError` with more details.
    ///
    /// # Examples
    ///
    /// ```rust,no_run
    /// use cowirc::networking::IrcConnection;
    /// use std::error::Error;
    ///
    /// #[tokio::main]
    /// async fn main() -> Result<(), Box<dyn Error>> {
    ///     let connection = IrcConnection::new("irc.server.com", 6667, false, false).await.unwrap();
    ///     connection.identify("nick", "ident", "real name").await?;
    ///     Ok(())
    /// }
    /// ```
    pub async fn identify(
        &self,
        nickname: &str,
        username: &str,
        realname: &str,
    ) -> Result<(), IrcConnectionError> {
        let token = Message::build_token(Message::new(
            HashMap::new(),
            None,
            String::from("NICK"),
            vec![nickname.to_string()],
        )?)?;
        self.write_line(&token).await?;
        let token = Message::build_token(Message::new(
            HashMap::new(),
            None,
            String::from("USER"),
            vec![
                username.to_string(),
                "0".to_string(),
                "*".to_string(),
                realname.to_string(),
            ],
        )?)?;
        self.write_line(&token).await?;
        Ok(())
    }
}