use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock, RwLock};
use std::time::Duration;
use leaky_bucket::RateLimiter;
use regex::Regex;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufWriter};
use tokio::sync::mpsc;
use crate::{
connection::State,
context::{Context, User},
handler::{HandlerEntry, Trigger},
irc::{CtcpMessage, Message},
types::{Nick, Target},
BoxError,
};
use irc_proto::{prefix::Prefix, Command, Response};
const CMD_PREFIX: char = '!';
const KEEPALIVE_TOKEN: &str = "ircbot-keepalive";
const MAX_NICK_ATTEMPTS: u32 = 8;
const CRON_RESCAN_INTERVAL: Duration = Duration::from_secs(60);
pub type HandlerSet<T> = Arc<RwLock<Arc<Vec<HandlerEntry<T>>>>>;
pub async fn run_bot_internal<T: Send + Sync + 'static>(
bot: Arc<T>,
state: State,
handlers: HandlerSet<T>,
) -> Result<(), BoxError> {
let State {
nick,
channels,
server: _,
keepalive_interval,
keepalive_timeout,
flood_burst,
flood_rate,
ctcp_version,
reader,
write_half,
#[cfg(unix)]
raw_fd: _,
} = state;
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<String>();
let write_task = tokio::spawn(async move {
let mut writer = BufWriter::new(write_half);
let limiter = RateLimiter::builder()
.max(flood_burst.max(1))
.initial(flood_burst)
.refill(1)
.interval(flood_rate)
.build();
while let Some(msg) = write_rx.recv().await {
limiter.acquire_one().await;
if writer.write_all(msg.as_bytes()).await.is_err() {
break;
}
if writer.flush().await.is_err() {
break;
}
}
});
let mut bot_nick = nick.clone();
let cron_task = tokio::spawn(run_cron_supervisor(
Arc::clone(&bot),
Arc::clone(&handlers),
write_tx.clone(),
bot_nick.clone(),
));
let pong_received = Arc::new(AtomicBool::new(true));
let pong_received_keepalive = Arc::clone(&pong_received);
let keepalive_write_tx = write_tx.clone();
let (keepalive_fail_tx, keepalive_fail_rx) = tokio::sync::oneshot::channel::<()>();
let keepalive_task = tokio::spawn(async move {
let mut fail_tx = Some(keepalive_fail_tx);
loop {
tokio::time::sleep(keepalive_interval).await;
pong_received_keepalive.store(false, Ordering::Relaxed);
if keepalive_write_tx
.send(format!("PING {KEEPALIVE_TOKEN}\r\n"))
.is_err()
{
break;
}
tokio::time::sleep(keepalive_timeout).await;
if !pong_received_keepalive.load(Ordering::Relaxed) {
eprintln!("[ircbot] keepalive timeout — reconnecting");
if let Some(tx) = fail_tx.take() {
let _ = tx.send(());
}
break;
}
}
});
let mut joined = false;
let mut nick_attempt = 0u32;
let mut lines = reader.lines();
let mut keepalive_fail_rx = keepalive_fail_rx;
let loop_result: Result<(), BoxError> = async {
loop {
tokio::select! {
result = lines.next_line() => {
let Some(line) = result? else { break; };
let line = line.trim_end_matches('\r').to_string();
if line.is_empty() {
continue;
}
if let Ok(msg) = line.parse::<Message>() {
match &msg.command {
Command::PING(srv, _) => {
if let Err(e) = write_tx.send(format!("PONG :{srv}\r\n")) {
eprintln!("[ircbot] failed to send PONG: {e}");
}
}
Command::PONG(a, b) => {
let token = b.as_deref().unwrap_or(a.as_str());
if token == KEEPALIVE_TOKEN {
pong_received.store(true, Ordering::Relaxed);
}
}
Command::Response(Response::RPL_WELCOME, _) => {
if !joined {
joined = true;
for ch in &channels {
if let Err(e) = write_tx.send(format!("JOIN {ch}\r\n")) {
eprintln!("[ircbot] failed to send JOIN {ch}: {e}");
}
}
}
dispatch(&bot, &handlers, &msg, &bot_nick, write_tx.clone()).await;
}
Command::Response(
Response::ERR_NICKNAMEINUSE | Response::ERR_UNAVAILRESOURCE,
_,
) => {
if !joined {
nick_attempt += 1;
if nick_attempt <= MAX_NICK_ATTEMPTS {
let candidate = fallback_nick(nick.as_str(), nick_attempt);
eprintln!(
"[ircbot] nick {bot_nick:?} unavailable — retrying as {candidate:?}"
);
if let Err(e) =
write_tx.send(format!("NICK {candidate}\r\n"))
{
eprintln!(
"[ircbot] failed to send NICK {candidate}: {e}"
);
}
bot_nick = Nick::from(candidate);
} else {
eprintln!(
"[ircbot] giving up on registration after \
{MAX_NICK_ATTEMPTS} nick attempts"
);
}
}
dispatch(&bot, &handlers, &msg, &bot_nick, write_tx.clone()).await;
}
Command::PRIVMSG(_, _) => {
handle_privmsg(
&bot,
&handlers,
&msg,
&bot_nick,
ctcp_version.as_deref(),
write_tx.clone(),
)
.await;
}
_ => {
dispatch(&bot, &handlers, &msg, &bot_nick, write_tx.clone()).await;
}
}
}
}
_ = &mut keepalive_fail_rx => {
break;
}
}
}
Ok(())
}
.await;
keepalive_task.abort();
cron_task.abort();
drop(write_tx);
let _ = write_task.await;
loop_result
}
async fn run_cron_supervisor<T: Send + Sync + 'static>(
bot: Arc<T>,
handlers: HandlerSet<T>,
tx: mpsc::UnboundedSender<String>,
bot_nick: Nick,
) {
loop {
let now = chrono::Utc::now();
let fire_at = {
let live = snapshot(&handlers);
live.iter()
.filter_map(|e| next_cron_fire(&e.trigger, &now))
.min()
};
let wait = fire_at.map_or(CRON_RESCAN_INTERVAL, |at| {
(at - now)
.to_std()
.unwrap_or(Duration::ZERO)
.min(CRON_RESCAN_INTERVAL)
});
tokio::time::sleep(wait).await;
let Some(fire_at) = fire_at else { continue };
if chrono::Utc::now() < fire_at {
continue;
}
let live = snapshot(&handlers);
let now2 = chrono::Utc::now();
for entry in live.iter() {
let Trigger::Cron { target, .. } = &entry.trigger else {
continue;
};
let Some(next) = next_cron_fire(&entry.trigger, &now) else {
continue;
};
if next > now2 {
continue; }
let cron_target = target.clone().unwrap_or_default();
let ctx = Context {
tx: tx.clone(),
target: Target::from_raw(&cron_target),
sender: None,
raw: synthesize_cron_message(bot_nick.as_str()),
bot_nick: bot_nick.clone(),
captures: vec![],
};
if let Err(e) = (entry.handler)(Arc::clone(&bot), ctx).await {
eprintln!("[ircbot] cron handler error: {e}");
}
}
}
}
fn snapshot<T>(handlers: &HandlerSet<T>) -> Arc<Vec<HandlerEntry<T>>> {
let guard = handlers.read().unwrap_or_else(|e| e.into_inner());
Arc::clone(&*guard)
}
fn next_cron_fire(
trigger: &Trigger,
after: &chrono::DateTime<chrono::Utc>,
) -> Option<chrono::DateTime<chrono::Utc>> {
let Trigger::Cron { schedule, tz, .. } = trigger else {
return None;
};
let schedule: cron::Schedule = schedule.parse().ok()?;
let tz: chrono_tz::Tz = tz.parse().ok()?;
schedule
.after(&after.with_timezone(&tz))
.next()
.map(|dt| dt.with_timezone(&chrono::Utc))
}
fn synthesize_cron_message(bot_nick: &str) -> Message {
format!(":{bot_nick}!cron@cron PING :cron")
.parse::<Message>()
.unwrap_or_else(|_| {
format!(":{bot_nick}!cron@cron PRIVMSG #cron :cron")
.parse()
.unwrap()
})
}
#[must_use]
pub fn check_trigger(trigger: &Trigger, msg: &Message, bot_nick: &str) -> Option<Vec<String>> {
match trigger {
Trigger::Command { name, target } => {
let Command::PRIVMSG(msg_target, text) = &msg.command else {
return None;
};
if let Some(t) = target {
if msg_target.as_str() != t.as_str() {
return None;
}
}
let text = text.strip_prefix(CMD_PREFIX)?;
let (cmd, rest) = text
.split_once(' ')
.map_or((text, ""), |(c, r)| (c, r.trim()));
if !cmd.eq_ignore_ascii_case(name) {
return None;
}
Some(if rest.is_empty() {
vec![]
} else {
vec![rest.to_string()]
})
}
Trigger::Message { pattern, target } => {
let Command::PRIVMSG(msg_target, text) = &msg.command else {
return None;
};
if let Some(t) = target {
if msg_target.as_str() != t.as_str() {
return None;
}
}
glob_match(pattern, text)
}
Trigger::Event {
event,
target,
regex,
} => {
if !command_name(msg).eq_ignore_ascii_case(event) {
return None;
}
if let Some(t) = target {
if target_param(msg) != Some(t.as_str()) {
return None;
}
}
if let Some(re_str) = regex {
let text = trailing_param(msg).unwrap_or("");
let re = cached_regex(re_str)?;
let caps = re.captures(text)?;
let groups: Vec<String> = caps
.iter()
.skip(1)
.filter_map(|m| m.map(|m| m.as_str().to_string()))
.collect();
Some(groups)
} else {
Some(vec![])
}
}
Trigger::Cron { .. } => None,
Trigger::Mention { target } => {
let Command::PRIVMSG(msg_target, text) = &msg.command else {
return None;
};
if let Some(t) = target {
if msg_target.as_str() != t.as_str() {
return None;
}
}
let lower = text.to_ascii_lowercase();
let nick_lower = bot_nick.to_ascii_lowercase();
let rest = [": ", ", "].iter().find_map(|sep| {
let prefix = format!("{}{}", nick_lower, sep);
if lower.starts_with(prefix.as_str()) {
Some(text[prefix.len()..].trim().to_string())
} else {
None
}
})?;
Some(if rest.is_empty() { vec![] } else { vec![rest] })
}
}
}
#[must_use]
fn fallback_nick(base: &str, attempt: u32) -> String {
if attempt <= 3 {
format!("{base}{}", "_".repeat(attempt as usize))
} else {
format!("{base}{attempt}")
}
}
fn command_name(msg: &Message) -> std::borrow::Cow<'_, str> {
use std::borrow::Cow;
match &msg.command {
Command::Raw(name, _) => Cow::Borrowed(name.as_str()),
cmd => {
let s = String::from(cmd);
let end = s.find(' ').unwrap_or(s.len());
Cow::Owned(s[..end].to_ascii_uppercase())
}
}
}
fn trailing_param(msg: &Message) -> Option<&str> {
match &msg.command {
Command::PRIVMSG(_, text) | Command::NOTICE(_, text) => Some(text),
Command::PING(server, _) => Some(server),
Command::PONG(_, Some(token)) => Some(token),
Command::PONG(server, None) => Some(server),
Command::JOIN(channel, _, _) => Some(channel),
Command::PART(_, Some(reason)) => Some(reason),
Command::PART(channel, None) => Some(channel),
Command::QUIT(Some(message)) => Some(message),
Command::KICK(_, _, Some(reason)) => Some(reason),
Command::TOPIC(_, Some(topic)) => Some(topic),
Command::TOPIC(channel, None) => Some(channel),
Command::Response(_, args) => args.last().map(String::as_str),
Command::Raw(_, args) => args.last().map(String::as_str),
_ => None,
}
}
fn target_param(msg: &Message) -> Option<&str> {
match &msg.command {
Command::PRIVMSG(target, _) | Command::NOTICE(target, _) => Some(target),
Command::JOIN(channel, _, _) => Some(channel),
Command::PART(channel, _) => Some(channel),
Command::KICK(channel, _, _) => Some(channel),
Command::TOPIC(channel, _) => Some(channel),
Command::INVITE(_, channel) => Some(channel),
Command::ChannelMODE(channel, _) => Some(channel),
Command::UserMODE(nick, _) => Some(nick),
Command::Response(_, args) => args.first().map(String::as_str),
Command::Raw(_, args) => args.first().map(String::as_str),
_ => None,
}
}
fn cached_regex(pattern: &str) -> Option<Arc<Regex>> {
static CACHE: OnceLock<RwLock<HashMap<String, Arc<Regex>>>> = OnceLock::new();
let cache = CACHE.get_or_init(|| RwLock::new(HashMap::new()));
if let Ok(guard) = cache.read() {
if let Some(re) = guard.get(pattern) {
return Some(Arc::clone(re));
}
}
let re = Arc::new(Regex::new(pattern).ok()?);
if let Ok(mut guard) = cache.write() {
guard
.entry(pattern.to_string())
.or_insert_with(|| Arc::clone(&re));
}
Some(re)
}
#[must_use]
pub fn glob_match(pattern: &str, text: &str) -> Option<Vec<String>> {
let re_str = glob_to_regex(pattern);
let re = cached_regex(&re_str)?;
let caps = re.captures(text)?;
let groups: Vec<String> = caps
.iter()
.skip(1) .filter_map(|m| m.map(|m| m.as_str().to_string()))
.collect();
Some(groups)
}
fn glob_to_regex(pattern: &str) -> String {
let mut out = String::from("^(?i)");
for c in pattern.chars() {
match c {
'*' => out.push_str("(.*)"),
'?' => out.push('.'),
c if ".$+^{}[]|\\()".contains(c) => {
out.push('\\');
out.push(c);
}
c => out.push(c),
}
}
out.push('$');
out
}
async fn handle_privmsg<T: Send + Sync + 'static>(
bot: &Arc<T>,
handlers: &HandlerSet<T>,
msg: &Message,
bot_nick: &Nick,
ctcp_version: Option<&str>,
tx: tokio::sync::mpsc::UnboundedSender<String>,
) {
let Command::PRIVMSG(_, text) = &msg.command else {
dispatch(bot, handlers, msg, bot_nick, tx).await;
return;
};
if let Some(ctcp) = CtcpMessage::parse(text) {
match ctcp.command.as_str() {
"PING" => {
if let Some(sender) = msg.source_nickname() {
let reply = format!(
"NOTICE {sender} :\x01PING{}{}\x01\r\n",
if ctcp.arg.is_empty() { "" } else { " " },
ctcp.arg,
);
if let Err(e) = tx.send(reply) {
eprintln!("[ircbot] failed to send CTCP PING reply: {e}");
}
}
return;
}
"VERSION" => {
if let Some(sender) = msg.source_nickname() {
let version = ctcp_version.map_or_else(
|| format!("ircbot {}", env!("CARGO_PKG_VERSION")),
ToString::to_string,
);
let reply = format!("NOTICE {sender} :\x01VERSION {version}\x01\r\n");
if let Err(e) = tx.send(reply) {
eprintln!("[ircbot] failed to send CTCP VERSION reply: {e}");
}
}
return;
}
_ => {}
}
}
dispatch(bot, handlers, msg, bot_nick, tx).await;
}
async fn dispatch<T: Send + Sync + 'static>(
bot: &Arc<T>,
handlers: &HandlerSet<T>,
msg: &Message,
bot_nick: &Nick,
tx: tokio::sync::mpsc::UnboundedSender<String>,
) {
let current: Arc<Vec<HandlerEntry<T>>> = {
let guard = handlers.read().unwrap_or_else(|e| e.into_inner());
Arc::clone(&*guard)
};
let sender = match msg.prefix.as_ref() {
Some(Prefix::Nickname(nick, user, host)) if !user.is_empty() => Some(User {
nick: Nick::from(nick.clone()),
user: user.clone(),
host: host.clone(),
}),
_ => None,
};
let target = Target::from_raw(target_param(msg).unwrap_or(""));
for entry in current.iter() {
if let Some(captures) = check_trigger(&entry.trigger, msg, bot_nick.as_str()) {
let ctx = Context {
tx: tx.clone(),
target: target.clone(),
sender: sender.clone(),
raw: msg.clone(),
bot_nick: bot_nick.clone(),
captures,
};
let bot_clone = Arc::clone(bot);
let fut = (entry.handler)(bot_clone, ctx);
if let Err(e) = fut.await {
eprintln!("[ircbot] handler error: {e}");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fallback_nick_appends_underscores_for_early_attempts() {
assert_eq!(fallback_nick("bot", 1), "bot_");
assert_eq!(fallback_nick("bot", 2), "bot__");
assert_eq!(fallback_nick("bot", 3), "bot___");
}
#[test]
fn fallback_nick_appends_number_for_later_attempts() {
assert_eq!(fallback_nick("bot", 4), "bot4");
assert_eq!(fallback_nick("bot", 8), "bot8");
}
#[test]
fn fallback_nick_candidates_are_distinct_across_all_attempts() {
let mut seen = std::collections::HashSet::new();
for attempt in 1..=MAX_NICK_ATTEMPTS {
assert!(
seen.insert(fallback_nick("bot", attempt)),
"duplicate fallback nick at attempt {attempt}"
);
}
}
async fn ctcp_version_reply(custom: Option<&str>) -> String {
let bot = std::sync::Arc::new(());
let handlers = crate::internal::make_handler_set::<()>(vec![]);
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let msg = ":alice!u@h PRIVMSG mybot :\x01VERSION\x01"
.parse::<Message>()
.unwrap();
handle_privmsg(&bot, &handlers, &msg, &Nick::from("mybot"), custom, tx).await;
rx.try_recv().expect("a CTCP VERSION reply was sent")
}
#[tokio::test]
async fn ctcp_version_uses_custom_string_when_set() {
assert_eq!(
ctcp_version_reply(Some("rustbutler 1.2.3")).await,
"NOTICE alice :\x01VERSION rustbutler 1.2.3\x01\r\n",
);
}
#[tokio::test]
async fn ctcp_version_defaults_to_ircbot_crate_version() {
let reply = ctcp_version_reply(None).await;
assert_eq!(
reply,
format!(
"NOTICE alice :\x01VERSION ircbot {}\x01\r\n",
env!("CARGO_PKG_VERSION")
),
);
}
}