#[macro_use]
mod macros;
pub mod conn;
pub mod read;
pub mod util;
pub mod write;
use self::conn::{OpenStreamError, TlsConfig, TlsConfigError};
use self::read::{ReadStream, RecvError};
use self::write::WriteStream;
use crate::irc::Command;
use crate::IrcMessage;
use futures_util::StreamExt;
use rand::{thread_rng, Rng};
use std::fmt::{Display, Write};
use std::future::Future;
use std::io;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio_rustls::rustls::client::InvalidDnsNameError;
use tokio_rustls::rustls::ServerName;
use tokio_stream::wrappers::LinesStream;
use util::Timeout;
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Clone)]
pub struct Credentials {
pub login: String,
pub token: Option<String>,
}
impl Credentials {
const ANON_RANGE: std::ops::Range<u32> = 10000..99999;
pub fn new(login: impl ToString, token: impl ToString) -> Self {
Self {
login: login.to_string(),
token: Some(token.to_string()),
}
}
pub fn anon() -> Self {
Self {
token: None,
login: format!("justinfan{}", thread_rng().gen_range(Self::ANON_RANGE)),
}
}
pub fn is_anon(&self) -> bool {
self.token.is_none()
}
pub fn login(&self) -> &str {
self.login.as_str()
}
pub fn token(&self) -> Option<&str> {
self.token.as_deref()
}
}
impl Default for Credentials {
fn default() -> Self {
Self::anon()
}
}
impl std::fmt::Debug for Credentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Credentials")
.field("nick", &self.login)
.finish_non_exhaustive()
}
}
#[derive(Clone, Copy, Debug)]
pub struct Backoff {
pub max_tries: Option<u64>,
pub initial_delay: Duration,
pub delay_multiplier: u32,
pub max_delay: Duration,
}
impl Default for Backoff {
fn default() -> Self {
Self {
max_tries: Some(8),
initial_delay: Duration::from_secs(1),
delay_multiplier: 3,
max_delay: Duration::from_secs(12),
}
}
}
#[derive(Clone, Debug)]
pub struct Config {
pub credentials: Credentials,
pub timeout: Duration,
pub backoff: Backoff,
}
impl Default for Config {
fn default() -> Self {
Self {
credentials: Default::default(),
timeout: DEFAULT_TIMEOUT,
backoff: Default::default(),
}
}
}
pub struct ClientBuilder {
config: Config,
}
impl ClientBuilder {
pub fn credentials(mut self, credentials: Credentials) -> Self {
self.config.credentials = credentials;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.config.timeout = timeout;
self
}
pub fn backoff(mut self, backoff: Backoff) -> Self {
self.config.backoff = backoff;
self
}
pub fn connect(self) -> impl Future<Output = Result<Client, ConnectError>> {
Client::connect(self.config)
}
}
pub struct Client {
reader: ReadStream,
writer: WriteStream,
scratch: String,
tls: TlsConfig,
config: Config,
}
impl Client {
pub fn builder() -> ClientBuilder {
ClientBuilder {
config: Default::default(),
}
}
pub fn anonymous() -> impl Future<Output = Result<Client, ConnectError>> {
Self::connect(Config::default())
}
async fn connect(config: Config) -> Result<Client, ConnectError> {
trace!("connecting");
let tls = TlsConfig::load(ServerName::try_from(conn::HOST)?)?;
trace!("opening connection to twitch");
let timeout = config.timeout;
let stream = conn::open(tls.clone()).timeout(timeout).await??;
let (reader, writer) = split(stream);
let mut client = Client {
reader,
writer,
scratch: String::with_capacity(1024),
tls,
config,
};
client.handshake().timeout(timeout).await??;
Ok(client)
}
pub async fn reconnect(&mut self) -> Result<(), ReconnectError> {
trace!("reconnecting");
let backoff = self.config.backoff;
let timeout = self.config.timeout;
let mut tries = backoff.max_tries;
let mut delay = backoff.initial_delay;
let mut cause = ConnectError::Timeout;
while matches!(tries, None | Some(1..)) {
tokio::time::sleep(delay).await;
if let Some(tries) = &mut tries {
*tries -= 1;
}
delay = std::cmp::min(backoff.max_delay, delay * backoff.delay_multiplier);
trace!("opening connection to twitch");
let stream = match conn::open(self.tls.clone()).timeout(timeout).await? {
Ok(stream) => stream,
Err(e @ OpenStreamError::Io(_)) => {
cause = e.into();
continue;
}
};
(self.reader, self.writer) = split(stream);
if let Err(e) = self.handshake().timeout(timeout).await? {
if e.should_retry() {
cause = e;
continue;
}
return Err(e.into());
};
return Ok(());
}
Err(ReconnectError { cause })
}
async fn handshake(&mut self) -> Result<(), ConnectError> {
trace!("performing handshake");
let credentials = &self.config.credentials;
const CAP: &str = "twitch.tv/commands twitch.tv/tags twitch.tv/membership";
trace!("CAP REQ {CAP:?}; NICK {:?}; PASS ***", credentials.login);
write!(&mut self.scratch, "CAP REQ :{CAP}\r\n").unwrap();
let login = credentials.login.as_str();
let token = match credentials.token.as_ref() {
Some(token) => token.as_str(),
None => "just_a_lil_guy",
};
let oauth = if token.starts_with("oauth:") {
""
} else {
"oauth:"
};
write!(&mut self.scratch, "PASS {oauth}{token}\r\n").unwrap();
write!(&mut self.scratch, "NICK {login}\r\n").unwrap();
self.writer.write_all(self.scratch.as_bytes()).await?;
self.writer.flush().await?;
self.scratch.clear();
trace!("waiting for CAP * ACK");
let message = self.recv().timeout(Duration::from_secs(5)).await??;
trace!(?message, "received message");
match message.command() {
Command::Capability => {
if message.params().is_some_and(|v| v.starts_with("* ACK")) {
trace!("received CAP * ACK")
} else {
return Err(ConnectError::Auth);
}
}
_ => {
trace!("unexpected message");
return Err(ConnectError::Welcome(message));
}
}
trace!("waiting for NOTICE 001");
let message = self.recv().timeout(Duration::from_secs(5)).await??;
trace!(?message, "received message");
match message.command() {
Command::RplWelcome => {
trace!("connected");
}
Command::Notice => {
if message
.params()
.map(|v| v.contains("authentication failed"))
.unwrap_or(false)
{
trace!("invalid credentials");
return Err(ConnectError::Auth);
}
trace!("unrecognized error");
return Err(ConnectError::Notice(message));
}
_ => {
trace!("first message not recognized");
return Err(ConnectError::Welcome(message));
}
}
Ok(())
}
}
impl Client {
#[inline]
pub fn config(&self) -> &Config {
&self.config
}
#[inline]
pub fn credentials(&self) -> &Credentials {
&self.config.credentials
}
}
fn split(stream: conn::Stream) -> (ReadStream, WriteStream) {
let (reader, writer) = tokio::io::split(stream);
(
LinesStream::new(BufReader::new(reader).lines()).fuse(),
writer,
)
}
#[derive(Debug)]
pub struct ReconnectError {
pub cause: ConnectError,
}
impl Display for ReconnectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"all reconnect attempts failed. last error was: {}",
self.cause
)
}
}
impl std::error::Error for ReconnectError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.cause)
}
fn cause(&self) -> Option<&dyn std::error::Error> {
self.source()
}
}
impl<T: Into<ConnectError>> From<T> for ReconnectError {
fn from(cause: T) -> Self {
Self {
cause: cause.into(),
}
}
}
#[derive(Debug)]
pub enum ConnectError {
Read(RecvError),
Io(io::Error),
Dns(InvalidDnsNameError),
Tls(TlsConfigError),
Open(OpenStreamError),
Timeout,
Welcome(IrcMessage),
Auth,
Notice(IrcMessage),
}
impl ConnectError {
fn should_retry(&self) -> bool {
matches!(self, Self::Open(OpenStreamError::Io(_)) | Self::Io(_))
}
}
impl From<RecvError> for ConnectError {
fn from(value: RecvError) -> Self {
Self::Read(value)
}
}
impl From<io::Error> for ConnectError {
fn from(value: io::Error) -> Self {
Self::Io(value)
}
}
impl From<InvalidDnsNameError> for ConnectError {
fn from(value: InvalidDnsNameError) -> Self {
Self::Dns(value)
}
}
impl From<TlsConfigError> for ConnectError {
fn from(value: TlsConfigError) -> Self {
Self::Tls(value)
}
}
impl From<OpenStreamError> for ConnectError {
fn from(value: OpenStreamError) -> Self {
Self::Open(value)
}
}
impl From<tokio::time::error::Elapsed> for ConnectError {
fn from(_: tokio::time::error::Elapsed) -> Self {
Self::Timeout
}
}
impl Display for ConnectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnectError::Read(e) => write!(f, "failed to connect: {e}"),
ConnectError::Io(e) => write!(f, "failed to connect: {e}"),
ConnectError::Dns(e) => write!(f, "failed to connect: {e}"),
ConnectError::Tls(e) => write!(f, "failed to connect: {e}"),
ConnectError::Open(e) => write!(f, "failed to connect: {e}"),
ConnectError::Timeout => write!(f, "failed to connect: connection timed out"),
ConnectError::Welcome(msg) => write!(
f,
"failed to connect: expected `NOTICE` or `001` as first message, instead received: {msg:?}"
),
ConnectError::Auth => write!(f, "failed to connect: invalid credentials"),
ConnectError::Notice(msg) => write!(
f,
"failed to connect: received unrecognized notice: {msg:?}"
),
}
}
}
impl std::error::Error for ConnectError {}
static_assert_send!(Client);
static_assert_sync!(Client);