use std::time::Duration;
use irc_proto::chan::ChannelExt;
use tokio::io::{AsyncWriteExt, BufWriter};
use tokio::net::TcpStream;
pub const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
pub const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
pub const DEFAULT_FLOOD_BURST: usize = 4;
pub const DEFAULT_FLOOD_RATE: Duration = Duration::from_millis(500);
pub struct State {
pub nick: String,
pub channels: Vec<String>,
pub server: String,
pub(crate) keepalive_interval: Duration,
pub(crate) keepalive_timeout: Duration,
pub(crate) flood_burst: usize,
pub(crate) flood_rate: Duration,
pub(crate) reader: tokio::io::BufReader<tokio::net::tcp::OwnedReadHalf>,
pub(crate) write_half: tokio::net::tcp::OwnedWriteHalf,
#[cfg(unix)]
pub raw_fd: std::os::unix::io::RawFd,
}
impl State {
fn normalise_channel(ch: String) -> String {
if ch.is_channel_name() {
ch
} else {
format!("#{ch}")
}
}
pub async fn connect(
nick: String,
server: &str,
channels: Vec<String>,
) -> Result<State, Box<dyn std::error::Error + Send + Sync>> {
let channels: Vec<String> = channels.into_iter().map(Self::normalise_channel).collect();
let stream = TcpStream::connect(server).await?;
#[cfg(unix)]
let raw_fd = {
use std::os::unix::io::AsRawFd;
stream.as_raw_fd()
};
let (read_half, write_half) = stream.into_split();
let reader = tokio::io::BufReader::new(read_half);
let mut writer = BufWriter::new(write_half);
writer
.write_all(format!("NICK {nick}\r\n").as_bytes())
.await?;
writer
.write_all(format!("USER {nick} 0 * :{nick}\r\n").as_bytes())
.await?;
writer.flush().await?;
let write_half = writer.into_inner();
Ok(State {
nick,
channels,
server: server.to_string(),
keepalive_interval: DEFAULT_KEEPALIVE_INTERVAL,
keepalive_timeout: DEFAULT_KEEPALIVE_TIMEOUT,
flood_burst: DEFAULT_FLOOD_BURST,
flood_rate: DEFAULT_FLOOD_RATE,
reader,
write_half,
#[cfg(unix)]
raw_fd,
})
}
#[cfg(unix)]
pub fn try_inherit_from_env() -> Result<Option<State>, Box<dyn std::error::Error + Send + Sync>>
{
use std::os::unix::io::{FromRawFd, RawFd};
use crate::hot_reload::{
ENV_CHANNELS, ENV_FD, ENV_KA_INTERVAL, ENV_KA_TIMEOUT, ENV_NICK, ENV_SERVER,
};
let fd_str = match std::env::var(ENV_FD) {
Ok(v) => v,
Err(_) => return Ok(None), };
let raw_fd: RawFd = fd_str.parse()?;
let nick = std::env::var(ENV_NICK)?;
let server = std::env::var(ENV_SERVER)?;
let channels_raw = std::env::var(ENV_CHANNELS)?;
let ka_interval_ms: u64 = std::env::var(ENV_KA_INTERVAL)?.parse()?;
let ka_timeout_ms: u64 = std::env::var(ENV_KA_TIMEOUT)?.parse()?;
for var in &[
ENV_FD,
ENV_NICK,
ENV_SERVER,
ENV_CHANNELS,
ENV_KA_INTERVAL,
ENV_KA_TIMEOUT,
] {
std::env::remove_var(var);
}
let channels: Vec<String> = if channels_raw.is_empty() {
vec![]
} else {
channels_raw.split(',').map(str::to_string).collect()
};
let std_stream = unsafe { std::net::TcpStream::from_raw_fd(raw_fd) };
std_stream.set_nonblocking(true)?;
let stream = TcpStream::from_std(std_stream)?;
let (read_half, write_half) = stream.into_split();
let reader = tokio::io::BufReader::new(read_half);
Ok(Some(State {
nick,
channels,
server,
keepalive_interval: Duration::from_millis(ka_interval_ms),
keepalive_timeout: Duration::from_millis(ka_timeout_ms),
flood_burst: DEFAULT_FLOOD_BURST,
flood_rate: DEFAULT_FLOOD_RATE,
reader,
write_half,
raw_fd,
}))
}
pub fn with_keepalive(mut self, interval: Duration, timeout: Duration) -> Self {
self.keepalive_interval = interval;
self.keepalive_timeout = timeout;
self
}
pub fn with_flood_control(mut self, burst: usize, rate: Duration) -> Self {
self.flood_burst = burst;
self.flood_rate = rate;
self
}
pub fn keepalive_interval(&self) -> Duration {
self.keepalive_interval
}
pub fn keepalive_timeout(&self) -> Duration {
self.keepalive_timeout
}
}