use std::fmt;
use std::io::{Read, Write};
use std::net::{Shutdown, TcpStream, ToSocketAddrs};
use std::str::FromStr;
use std::time::Duration;
use crate::{
error::Error,
pipeline::{Pipeline, ResponseItem},
query::Query,
};
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct IrrClient<A> {
addr: A,
client_id: Option<String>,
server_timeout: Option<Duration>,
}
impl<A> IrrClient<A>
where
A: ToSocketAddrs + fmt::Display,
{
pub const DEFAULT_CLIENT_ID: &'static str =
concat!(env!("CARGO_PKG_NAME"), "-", env!("CARGO_PKG_VERSION"));
pub const fn new(addr: A) -> Self {
Self {
addr,
client_id: None,
server_timeout: None,
}
}
pub fn client_id<S: AsRef<str>>(&mut self, id: Option<S>) {
self.client_id = id.map(|id| id.as_ref().to_string());
}
pub fn server_timeout(&mut self, duration: Option<Duration>) {
self.server_timeout = duration;
}
#[tracing::instrument(skip(self), fields(addr = %self.addr), level = "debug")]
pub fn connect(&self) -> Result<Connection, Error> {
Connection::connect(self)
}
fn effective_client_id(&self) -> &str {
self.client_id
.as_ref()
.map_or(Self::DEFAULT_CLIENT_ID, String::as_ref)
}
}
#[derive(Debug)]
pub struct Connection {
conn: TcpStream,
}
impl Connection {
pub const DEFAULT_CAPACITY: usize = 1 << 20;
#[allow(clippy::cognitive_complexity)]
fn connect<A>(builder: &IrrClient<A>) -> Result<Self, Error>
where
A: ToSocketAddrs + fmt::Display,
{
tracing::info!("trying to connect to {}", builder.addr);
let mut conn = TcpStream::connect(&builder.addr)?;
tracing::debug!("disabling Nagle's algorithm");
conn.set_nodelay(true)?;
tracing::debug!("requesting multiple command mode");
conn.write_all(b"!!\n")?;
conn.flush()?;
tracing::info!("connected to {}", builder.addr);
let mut this = Self { conn };
{
let mut init_pipeline = this.pipeline_with_capacity(8);
_ = init_pipeline.push(Query::SetClientId(builder.effective_client_id().to_owned()))?;
if let Some(server_timeout) = builder.server_timeout {
_ = init_pipeline.push(Query::SetTimeout(server_timeout))?;
}
}
Ok(this)
}
pub fn pipeline(&mut self) -> Pipeline<'_> {
self.pipeline_with_capacity(Self::DEFAULT_CAPACITY)
}
pub fn pipeline_from_initial<T, F, I>(
&mut self,
initial: Query,
f: F,
) -> Result<Pipeline<'_>, Error>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
F: FnMut(Result<ResponseItem<T>, Error>) -> Option<I>,
I: IntoIterator<Item = Query>,
{
Pipeline::from_initial(self, initial, f)
}
pub fn pipeline_from_iter<I>(&mut self, iter: I) -> Pipeline<'_>
where
I: IntoIterator<Item = Query>,
{
let mut pipeline = self.pipeline();
pipeline.extend(iter);
pipeline
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn pipeline_with_capacity(&mut self, capacity: usize) -> Pipeline<'_> {
tracing::debug!("constructing new query pipeline");
Pipeline::new(self, capacity)
}
pub fn version(&mut self) -> Result<String, Error> {
Ok(self
.pipeline()
.push(Query::Version)?
.pop::<String>()
.unwrap_or_else(|| Err(Error::Dequeue))?
.next()
.unwrap_or_else(|| Err(Error::EmptyResponse(Query::Version)))?
.content()
.clone())
}
#[tracing::instrument(skip(self), level = "debug")]
pub(crate) fn send(&mut self, query: &str) -> Result<(), Error> {
tracing::debug!("sending query");
self.conn.write_all(query.as_bytes())?;
self.conn.flush().map_err(Error::from)
}
pub(crate) fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
self.conn.read(buf).map_err(Error::from)
}
}
impl Drop for Connection {
fn drop(&mut self) {
tracing::info!("closing connection");
if let Err(err) = self.conn.write(b"!q\n") {
tracing::error!("failed to send quit command: {err}");
}
if let Err(err) = self.conn.shutdown(Shutdown::Both) {
tracing::error!("failed to close connection: {err}");
}
}
}