use super::{filter_time_error, ReadTimeout, TimeLimitedSync, WriteTimeout};
use crate::ircmsg::ClientCodec;
use std::{
io::{BufRead, BufReader, Read, Write},
net::TcpStream,
time::Duration,
};
impl<'a> super::ServerAddr<'a> {
pub fn connect_no_tls(&self) -> std::io::Result<BufReader<Stream>> {
let string = self.utf8_address()?;
let sock = std::net::TcpStream::connect((string, self.port_num()))?;
Ok(BufReader::with_capacity(super::BUFSIZE, Stream(StreamInner::Tcp(sock))))
}
#[cfg(feature = "tls")]
pub fn connect(
&self,
tls_fn: impl FnOnce() -> std::io::Result<crate::client::tls::TlsConfig>,
) -> std::io::Result<BufReader<Stream>> {
use std::io::{Error, ErrorKind};
let string = self.utf8_address()?;
let stream = if self.tls {
let name = rustls::pki_types::ServerName::try_from(string)
.map_err(|e| Error::new(ErrorKind::InvalidInput, e))?;
let config = tls_fn()?;
let conn = rustls::ClientConnection::new(config, name.to_owned())
.map_err(|e| Error::new(ErrorKind::Other, e))?;
let sock = std::net::TcpStream::connect((string, self.port_num()))?;
let mut tls = rustls::StreamOwned { conn, sock };
tls.flush()?;
StreamInner::Tls(Box::new(tls))
} else {
let sock = std::net::TcpStream::connect((string, self.port_num()))?;
StreamInner::Tcp(sock)
};
Ok(BufReader::with_capacity(super::BUFSIZE, Stream(stream)))
}
}
#[derive(Debug)]
pub struct Stream(StreamInner);
#[derive(Debug, Default)]
enum StreamInner {
#[default]
Closed,
Tcp(TcpStream),
#[cfg(feature = "tls")]
Tls(Box<rustls::StreamOwned<rustls::ClientConnection, TcpStream>>),
}
impl Stream {
pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
match &self.0 {
StreamInner::Closed => Ok(()),
StreamInner::Tcp(s) => s.shutdown(how),
#[cfg(feature = "tls")]
StreamInner::Tls(s) => s.sock.shutdown(how),
}
}
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> std::io::Result<()> {
match &self.0 {
StreamInner::Closed => Ok(()),
StreamInner::Tcp(s) => s.set_read_timeout(timeout),
#[cfg(feature = "tls")]
StreamInner::Tls(s) => s.sock.set_read_timeout(timeout),
}
}
pub fn set_write_timeout(&self, timeout: Option<Duration>) -> std::io::Result<()> {
match &self.0 {
StreamInner::Closed => Ok(()),
StreamInner::Tcp(s) => s.set_write_timeout(timeout),
#[cfg(feature = "tls")]
StreamInner::Tls(s) => s.sock.set_write_timeout(timeout),
}
}
pub fn read_timeout(&self) -> std::io::Result<Option<Duration>> {
match &self.0 {
StreamInner::Closed => Ok(None),
StreamInner::Tcp(s) => s.read_timeout(),
#[cfg(feature = "tls")]
StreamInner::Tls(s) => s.sock.read_timeout(),
}
}
pub fn write_timeout(&self) -> std::io::Result<Option<Duration>> {
match &self.0 {
StreamInner::Closed => Ok(None),
StreamInner::Tcp(s) => s.write_timeout(),
#[cfg(feature = "tls")]
StreamInner::Tls(s) => s.sock.write_timeout(),
}
}
}
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match &mut self.0 {
StreamInner::Closed => Ok(0),
StreamInner::Tcp(s) => s.read(buf),
#[cfg(feature = "tls")]
StreamInner::Tls(s) => s.read(buf),
}
}
fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result<usize> {
match &mut self.0 {
StreamInner::Closed => Ok(0),
StreamInner::Tcp(s) => s.read_vectored(bufs),
#[cfg(feature = "tls")]
StreamInner::Tls(s) => s.read_vectored(bufs),
}
}
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match &mut self.0 {
StreamInner::Closed => Ok(0),
StreamInner::Tcp(s) => s.write(buf),
#[cfg(feature = "tls")]
StreamInner::Tls(s) => s.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match &mut self.0 {
StreamInner::Closed => Ok(()),
StreamInner::Tcp(s) => s.flush(),
#[cfg(feature = "tls")]
StreamInner::Tls(s) => s.flush(),
}
}
}
impl ReadTimeout for TcpStream {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
Self::set_read_timeout(self, timeout)
}
}
impl WriteTimeout for TcpStream {
fn set_write_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
Self::set_write_timeout(self, timeout)
}
}
impl ReadTimeout for Stream {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
Self::set_read_timeout(self, timeout)
}
}
impl WriteTimeout for Stream {
fn set_write_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
Self::set_write_timeout(self, timeout)
}
}
#[cfg(feature = "tls")]
impl<
'a,
S: rustls::SideData,
C: 'a + std::ops::DerefMut + std::ops::Deref<Target = rustls::ConnectionCommon<S>>,
T: ReadTimeout + Read + Write,
> ReadTimeout for rustls::Stream<'a, C, T>
{
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
self.sock.set_read_timeout(timeout)
}
}
#[cfg(feature = "tls")]
impl<
'a,
S: rustls::SideData,
C: 'a + std::ops::DerefMut + std::ops::Deref<Target = rustls::ConnectionCommon<S>>,
T: WriteTimeout + Read + Write,
> WriteTimeout for rustls::Stream<'a, C, T>
{
fn set_write_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
self.sock.set_write_timeout(timeout)
}
}
#[cfg(feature = "tls")]
impl<
S: rustls::SideData,
C: std::ops::DerefMut + std::ops::Deref<Target = rustls::ConnectionCommon<S>>,
T: ReadTimeout + Read + Write,
> ReadTimeout for rustls::StreamOwned<C, T>
{
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
self.sock.set_read_timeout(timeout)
}
}
#[cfg(feature = "tls")]
impl<
S: rustls::SideData,
C: std::ops::DerefMut + std::ops::Deref<Target = rustls::ConnectionCommon<S>>,
T: WriteTimeout + Read + Write,
> WriteTimeout for rustls::StreamOwned<C, T>
{
fn set_write_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
self.sock.set_write_timeout(timeout)
}
}
pub trait Connection: ReadTimeout + WriteTimeout {
type BufRead: std::io::BufRead;
type Write: Write;
fn as_bufread(&mut self) -> &mut Self::BufRead;
fn as_write(&mut self) -> &mut Self::Write;
}
impl<R: BufRead, W: Write> Connection for super::Bidir<R, W>
where
super::Bidir<R, W>: ReadTimeout + WriteTimeout,
{
type BufRead = R;
type Write = W;
fn as_bufread(&mut self) -> &mut Self::BufRead {
&mut self.0
}
fn as_write(&mut self) -> &mut Self::Write {
&mut self.1
}
}
impl<T: ReadTimeout> ReadTimeout for BufReader<T> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
self.get_mut().set_read_timeout(timeout)
}
}
impl<T: WriteTimeout> WriteTimeout for BufReader<T> {
fn set_write_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
self.get_mut().set_write_timeout(timeout)
}
}
impl<T: ReadTimeout + WriteTimeout + Read + Write> Connection for BufReader<T> {
type BufRead = Self;
type Write = T;
fn as_bufread(&mut self) -> &mut Self::BufRead {
self
}
fn as_write(&mut self) -> &mut Self::Write {
self.get_mut()
}
}
impl<C: Connection, S> crate::client::Client<C, S> {
pub fn run(&mut self) -> std::io::Result<Option<(&[usize], &[usize])>> {
let finished_at = loop {
let wait_for = self.flush_partial()?;
if self.handlers.is_empty() {
if let Some(wait_for) = wait_for {
std::thread::sleep(wait_for);
continue;
}
return Ok(Some((Default::default(), Default::default())));
}
let (mut conn, should_continue) =
TimeLimitedSync::new(&mut self.conn, &mut self.timeout, wait_for)?;
let finished_at = if self.handlers.wants_owning() {
let msg = ClientCodec::read_owning_from(&mut conn, &mut self.buf_i);
let Some(msg) = filter_time_error(msg)? else {
if should_continue {
continue;
}
return Ok(None);
};
#[cfg(feature = "tracing")]
tracing::debug!(target: "vinezombie::recv", "{}", msg);
self.queue.adjust(&msg);
self.handlers.handle(&msg, &mut self.state, &mut self.queue)
} else {
let msg = ClientCodec::read_borrowing_from(&mut conn, &mut self.buf_i);
let Some(msg) = filter_time_error(msg)? else {
if should_continue {
continue;
}
return Ok(None);
};
#[cfg(feature = "tracing")]
tracing::debug!(target: "vinezombie::recv", "{}", msg);
self.queue.adjust(&msg);
let fa = self.handlers.handle(&msg, &mut self.state, &mut self.queue);
self.buf_i.clear();
fa
};
if self.handlers.has_results(finished_at) {
self.flush_partial()?;
break finished_at;
}
};
Ok(Some(self.handlers.last_run_results(finished_at)))
}
pub fn flush_partial(&mut self) -> std::io::Result<Option<Duration>> {
if self.queue.is_empty() {
return Ok(None);
}
let mut timeout = None;
while let Some(popped) = self.queue.pop(|new_timeout| timeout = new_timeout) {
#[cfg(feature = "tracing")]
tracing::debug!(target: "vinezombie::send", "{}", popped);
let _ = ClientCodec::write_to(&popped, &mut self.buf_o);
self.buf_o.extend_from_slice(b"\r\n");
}
let result = self.conn.as_write().write_all(&self.buf_o);
self.buf_o.clear();
result?;
self.conn.as_write().flush()?;
Ok(timeout)
}
}