use super::{timed_io, Bidir, TimeLimitedTokio};
use crate::ircmsg::ClientCodec;
use std::{pin::Pin, time::Duration};
use tokio::{
io::{AsyncBufRead, AsyncWrite, BufReader},
net::TcpStream,
};
impl<'a> super::ServerAddr<'a> {
pub async fn connect_tokio_no_tls(&self) -> std::io::Result<BufReader<StreamTokio>> {
let string = self.utf8_address()?;
let sock = tokio::net::TcpStream::connect((string, self.port_num())).await?;
Ok(BufReader::with_capacity(super::BUFSIZE, StreamTokio { stream: StreamInner::Tcp(sock) }))
}
#[cfg(feature = "tls-tokio")]
pub async fn connect_tokio(
&self,
tls_fn: impl FnOnce() -> std::io::Result<crate::client::tls::TlsConfig>,
) -> std::io::Result<BufReader<StreamTokio>> {
use std::io::{Error, ErrorKind};
let string = self.utf8_address()?;
let stream = if self.tls {
let name = rustls::pki_types::ServerName::try_from(string)
.map_err(|e| Error::new(ErrorKind::InvalidInput, e))?;
let config = tls_fn()?;
let conn: tokio_rustls::TlsConnector = config.into();
let sock = tokio::net::TcpStream::connect((string, self.port_num())).await?;
let tls = conn.connect(name.to_owned(), sock).await?;
StreamInner::Tls(tls)
} else {
let sock = tokio::net::TcpStream::connect((string, self.port_num())).await?;
StreamInner::Tcp(sock)
};
Ok(BufReader::with_capacity(super::BUFSIZE, StreamTokio { stream }))
}
}
#[derive(Debug, Default)]
pub struct StreamTokio {
stream: StreamInner,
}
#[derive(Debug, Default)]
enum StreamInner {
#[default]
Closed,
Tcp(TcpStream),
#[cfg(feature = "tls-tokio")]
Tls(tokio_rustls::client::TlsStream<TcpStream>),
}
impl tokio::io::AsyncRead for StreamTokio {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match &mut (self.get_mut()).stream {
StreamInner::Closed => std::task::Poll::Ready(Ok(())),
StreamInner::Tcp(tcp) => Pin::new(tcp).poll_read(cx, buf),
#[cfg(feature = "tls-tokio")]
StreamInner::Tls(tls) => Pin::new(tls).poll_read(cx, buf),
}
}
}
impl tokio::io::AsyncWrite for StreamTokio {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut (self.get_mut()).stream {
StreamInner::Closed => std::task::Poll::Ready(Ok(0)),
StreamInner::Tcp(tcp) => Pin::new(tcp).poll_write(cx, buf),
#[cfg(feature = "tls-tokio")]
StreamInner::Tls(tls) => Pin::new(tls).poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut (self.get_mut()).stream {
StreamInner::Closed => std::task::Poll::Ready(Ok(())),
StreamInner::Tcp(tcp) => Pin::new(tcp).poll_flush(cx),
#[cfg(feature = "tls-tokio")]
StreamInner::Tls(tls) => Pin::new(tls).poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut (self.get_mut()).stream {
StreamInner::Closed => std::task::Poll::Ready(Ok(())),
StreamInner::Tcp(tcp) => Pin::new(tcp).poll_shutdown(cx),
#[cfg(feature = "tls-tokio")]
StreamInner::Tls(tls) => Pin::new(tls).poll_shutdown(cx),
}
}
}
pub trait ConnectionTokio: Unpin {
type AsyncBufRead: tokio::io::AsyncBufRead;
type AsyncWrite: tokio::io::AsyncWrite + Unpin;
fn as_bufread(&mut self) -> Pin<&mut Self::AsyncBufRead>;
fn as_write(&mut self) -> &mut Self::AsyncWrite;
}
impl<R: AsyncBufRead + Unpin, W: AsyncWrite + Unpin> ConnectionTokio for Bidir<R, W> {
type AsyncBufRead = R;
type AsyncWrite = W;
fn as_bufread(&mut self) -> Pin<&mut Self::AsyncBufRead> {
Pin::new(&mut self.0)
}
fn as_write(&mut self) -> &mut Self::AsyncWrite {
&mut self.1
}
}
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> ConnectionTokio for BufReader<T> {
type AsyncBufRead = Self;
type AsyncWrite = T;
fn as_bufread(&mut self) -> Pin<&mut Self::AsyncBufRead> {
Pin::new(self)
}
fn as_write(&mut self) -> &mut Self::AsyncWrite {
self.get_mut()
}
}
impl<C: ConnectionTokio, S> crate::client::Client<C, S> {
pub async fn run_tokio(&mut self) -> std::io::Result<Option<(&[usize], &[usize])>> {
let finished_at = loop {
let wait_for = self.flush_partial_tokio().await?;
if self.handlers.is_empty() {
if let Some(wait_for) = wait_for {
tokio::time::sleep(wait_for).await;
continue;
}
return Ok(Some((Default::default(), Default::default())));
}
let mut conn = TimeLimitedTokio::new(&mut self.conn, &self.timeout);
let finished_at = if self.handlers.wants_owning() {
let fut = ClientCodec::read_owning_from_tokio(&mut conn, &mut self.buf_i);
let msg = match timed_io(fut, wait_for, self.timeout.read_timeout()).await? {
Ok(m) => m,
Err(true) => continue,
Err(false) => return Ok(None),
};
#[cfg(feature = "tracing")]
tracing::debug!(target: "vinezombie::recv", "{}", msg);
self.queue.adjust(&msg);
self.handlers.handle(&msg, &mut self.state, &mut self.queue)
} else {
let fut = ClientCodec::read_borrowing_from_tokio(&mut conn, &mut self.buf_i);
let msg = match timed_io(fut, wait_for, self.timeout.read_timeout()).await? {
Ok(m) => m,
Err(true) => continue,
Err(false) => return Ok(None),
};
#[cfg(feature = "tracing")]
tracing::debug!(target: "vinezombie::recv", "{}", msg);
self.queue.adjust(&msg);
let fa = self.handlers.handle(&msg, &mut self.state, &mut self.queue);
self.buf_i.clear();
fa
};
if self.handlers.has_results(finished_at) {
self.flush_partial_tokio().await?;
break finished_at;
}
};
Ok(Some(self.handlers.last_run_results(finished_at)))
}
pub async fn flush_partial_tokio(&mut self) -> std::io::Result<Option<Duration>> {
use tokio::io::AsyncWriteExt;
if self.queue.is_empty() {
return Ok(None);
}
let mut timeout = None;
while let Some(popped) = self.queue.pop(|new_timeout| timeout = new_timeout) {
#[cfg(feature = "tracing")]
tracing::debug!(target: "vinezombie::send", "{}", popped);
let _ = ClientCodec::write_to(&popped, &mut self.buf_o);
self.buf_o.extend_from_slice(b"\r\n");
}
let mut conn = TimeLimitedTokio::new(&mut self.conn, &self.timeout);
let result = conn.write_all(&self.buf_o).await;
self.buf_o.clear();
result?;
conn.flush().await?;
Ok(timeout)
}
}