use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock, RwLock};
use std::time::Duration;
use regex::Regex;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufWriter};
use tokio::sync::mpsc;
use crate::{
connection::State,
context::{Context, User},
handler::{HandlerEntry, Trigger},
irc::{ChannelExt, CtcpMessage, Message},
BoxError,
};
use irc_proto::{prefix::Prefix, Command, Response};
const CMD_PREFIX: char = '!';
const KEEPALIVE_TOKEN: &str = "ircbot-keepalive";
pub type HandlerSet<T> = Arc<RwLock<Arc<Vec<HandlerEntry<T>>>>>;
pub async fn run_bot_internal<T: Send + Sync + 'static>(
bot: Arc<T>,
state: State,
handlers: HandlerSet<T>,
) -> Result<(), BoxError> {
let State {
nick,
channels,
server: _,
keepalive_interval,
keepalive_timeout,
flood_burst,
flood_rate,
reader,
write_half,
#[cfg(unix)]
raw_fd: _,
} = state;
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<String>();
let write_task = tokio::spawn(async move {
let mut writer = BufWriter::new(write_half);
let max_tokens = flood_burst as f64;
let mut tokens = max_tokens;
let token_rate = 1.0 / flood_rate.as_secs_f64(); let mut last_refill = tokio::time::Instant::now();
while let Some(msg) = write_rx.recv().await {
let now = tokio::time::Instant::now();
let elapsed = (now - last_refill).as_secs_f64();
tokens = (tokens + elapsed * token_rate).min(max_tokens);
last_refill = now;
if tokens < 1.0 {
let wait = Duration::from_secs_f64((1.0 - tokens) / token_rate);
tokio::time::sleep(wait).await;
tokens = 0.0;
last_refill = tokio::time::Instant::now();
} else {
tokens -= 1.0;
}
if writer.write_all(msg.as_bytes()).await.is_err() {
break;
}
if writer.flush().await.is_err() {
break;
}
}
});
let bot_nick = nick.clone();
let cron_snapshot: Arc<Vec<HandlerEntry<T>>> = {
let guard = handlers.read().unwrap_or_else(|e| e.into_inner());
Arc::clone(&*guard)
};
let mut cron_tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new();
for idx in 0..cron_snapshot.len() {
let (schedule_str, tz_str, cron_target) = match &cron_snapshot[idx].trigger {
Trigger::Cron {
schedule,
tz,
target,
} => (
schedule.clone(),
tz.clone(),
target.clone().unwrap_or_default(),
),
_ => continue,
};
let cron_is_channel = cron_target.is_channel_name();
let bot_cron = Arc::clone(&bot);
let write_tx_cron = write_tx.clone();
let bot_nick_cron = bot_nick.clone();
let snapshot_cron = Arc::clone(&cron_snapshot);
let task = tokio::spawn(async move {
let schedule: cron::Schedule = match schedule_str.parse() {
Ok(s) => s,
Err(e) => {
eprintln!("[ircbot] invalid cron expression {schedule_str:?}: {e}");
return;
}
};
let tz: chrono_tz::Tz = match tz_str.parse() {
Ok(tz) => tz,
Err(e) => {
eprintln!("[ircbot] invalid timezone {tz_str:?}: {e}");
return;
}
};
loop {
let now = chrono::Utc::now().with_timezone(&tz);
let Some(next) = schedule.upcoming(tz).next() else {
eprintln!(
"[ircbot] cron schedule {schedule_str:?} has no upcoming occurrences"
);
return;
};
let delay = (next - now).to_std().unwrap_or(Duration::ZERO);
tokio::time::sleep(delay).await;
let raw = format!(":{nick}!cron@cron PING :cron", nick = bot_nick_cron)
.parse::<Message>()
.unwrap_or_else(|_| {
format!(
":{nick}!cron@cron PRIVMSG #cron :cron",
nick = bot_nick_cron
)
.parse()
.unwrap()
});
let ctx = Context {
tx: write_tx_cron.clone(),
target: cron_target.clone(),
is_channel: cron_is_channel,
sender: None,
raw,
bot_nick: bot_nick_cron.clone(),
captures: vec![],
};
let fut = (snapshot_cron[idx].handler)(Arc::clone(&bot_cron), ctx);
if let Err(e) = fut.await {
eprintln!("[ircbot] cron handler error: {e}");
}
}
});
cron_tasks.push(task);
}
let pong_received = Arc::new(AtomicBool::new(true));
let pong_received_keepalive = Arc::clone(&pong_received);
let keepalive_write_tx = write_tx.clone();
let (keepalive_fail_tx, keepalive_fail_rx) = tokio::sync::oneshot::channel::<()>();
let keepalive_task = tokio::spawn(async move {
let mut fail_tx = Some(keepalive_fail_tx);
loop {
tokio::time::sleep(keepalive_interval).await;
pong_received_keepalive.store(false, Ordering::Relaxed);
if keepalive_write_tx
.send(format!("PING {KEEPALIVE_TOKEN}\r\n"))
.is_err()
{
break;
}
tokio::time::sleep(keepalive_timeout).await;
if !pong_received_keepalive.load(Ordering::Relaxed) {
eprintln!("[ircbot] keepalive timeout — reconnecting");
if let Some(tx) = fail_tx.take() {
let _ = tx.send(());
}
break;
}
}
});
let mut joined = false;
let mut lines = reader.lines();
let mut keepalive_fail_rx = keepalive_fail_rx;
let loop_result: Result<(), BoxError> = async {
loop {
tokio::select! {
result = lines.next_line() => {
let Some(line) = result? else { break; };
let line = line.trim_end_matches('\r').to_string();
if line.is_empty() {
continue;
}
if let Ok(msg) = line.parse::<Message>() {
match &msg.command {
Command::PING(srv, _) => {
if let Err(e) = write_tx.send(format!("PONG :{srv}\r\n")) {
eprintln!("[ircbot] failed to send PONG: {e}");
}
}
Command::PONG(a, b) => {
let token = b.as_deref().unwrap_or(a.as_str());
if token == KEEPALIVE_TOKEN {
pong_received.store(true, Ordering::Relaxed);
}
}
Command::Response(Response::RPL_WELCOME, _) => {
if !joined {
joined = true;
for ch in &channels {
if let Err(e) = write_tx.send(format!("JOIN {ch}\r\n")) {
eprintln!("[ircbot] failed to send JOIN {ch}: {e}");
}
}
}
dispatch(&bot, &handlers, &msg, &bot_nick, write_tx.clone()).await;
}
Command::PRIVMSG(_, _) => {
handle_privmsg(
&bot,
&handlers,
&msg,
&bot_nick,
write_tx.clone(),
)
.await;
}
_ => {
dispatch(&bot, &handlers, &msg, &bot_nick, write_tx.clone()).await;
}
}
}
}
_ = &mut keepalive_fail_rx => {
break;
}
}
}
Ok(())
}
.await;
keepalive_task.abort();
for task in &cron_tasks {
task.abort();
}
drop(write_tx);
let _ = write_task.await;
loop_result
}
#[must_use]
pub fn check_trigger(trigger: &Trigger, msg: &Message, bot_nick: &str) -> Option<Vec<String>> {
match trigger {
Trigger::Command { name, target } => {
let Command::PRIVMSG(msg_target, text) = &msg.command else {
return None;
};
if let Some(t) = target {
if msg_target.as_str() != t.as_str() {
return None;
}
}
let text = text.strip_prefix(CMD_PREFIX)?;
let (cmd, rest) = text
.split_once(' ')
.map_or((text, ""), |(c, r)| (c, r.trim()));
if !cmd.eq_ignore_ascii_case(name) {
return None;
}
Some(if rest.is_empty() {
vec![]
} else {
vec![rest.to_string()]
})
}
Trigger::Message { pattern, target } => {
let Command::PRIVMSG(msg_target, text) = &msg.command else {
return None;
};
if let Some(t) = target {
if msg_target.as_str() != t.as_str() {
return None;
}
}
glob_match(pattern, text)
}
Trigger::Event {
event,
target,
regex,
} => {
if !command_name(msg).eq_ignore_ascii_case(event) {
return None;
}
if let Some(t) = target {
if target_param(msg) != Some(t.as_str()) {
return None;
}
}
if let Some(re_str) = regex {
let text = trailing_param(msg).unwrap_or("");
let re = cached_regex(re_str)?;
let caps = re.captures(text)?;
let groups: Vec<String> = caps
.iter()
.skip(1)
.filter_map(|m| m.map(|m| m.as_str().to_string()))
.collect();
Some(groups)
} else {
Some(vec![])
}
}
Trigger::Cron { .. } => None,
Trigger::Mention { target } => {
let Command::PRIVMSG(msg_target, text) = &msg.command else {
return None;
};
if let Some(t) = target {
if msg_target.as_str() != t.as_str() {
return None;
}
}
let lower = text.to_ascii_lowercase();
let nick_lower = bot_nick.to_ascii_lowercase();
let rest = [": ", ", "].iter().find_map(|sep| {
let prefix = format!("{}{}", nick_lower, sep);
if lower.starts_with(prefix.as_str()) {
Some(text[prefix.len()..].trim().to_string())
} else {
None
}
})?;
Some(if rest.is_empty() { vec![] } else { vec![rest] })
}
}
}
fn command_name(msg: &Message) -> std::borrow::Cow<'_, str> {
use std::borrow::Cow;
match &msg.command {
Command::Raw(name, _) => Cow::Borrowed(name.as_str()),
cmd => {
let s = String::from(cmd);
let end = s.find(' ').unwrap_or(s.len());
Cow::Owned(s[..end].to_ascii_uppercase())
}
}
}
fn trailing_param(msg: &Message) -> Option<&str> {
match &msg.command {
Command::PRIVMSG(_, text) | Command::NOTICE(_, text) => Some(text),
Command::PING(server, _) => Some(server),
Command::PONG(_, Some(token)) => Some(token),
Command::PONG(server, None) => Some(server),
Command::JOIN(channel, _, _) => Some(channel),
Command::PART(_, Some(reason)) => Some(reason),
Command::PART(channel, None) => Some(channel),
Command::QUIT(Some(message)) => Some(message),
Command::KICK(_, _, Some(reason)) => Some(reason),
Command::TOPIC(_, Some(topic)) => Some(topic),
Command::TOPIC(channel, None) => Some(channel),
Command::Response(_, args) => args.last().map(String::as_str),
Command::Raw(_, args) => args.last().map(String::as_str),
_ => None,
}
}
fn target_param(msg: &Message) -> Option<&str> {
match &msg.command {
Command::PRIVMSG(target, _) | Command::NOTICE(target, _) => Some(target),
Command::JOIN(channel, _, _) => Some(channel),
Command::PART(channel, _) => Some(channel),
Command::KICK(channel, _, _) => Some(channel),
Command::TOPIC(channel, _) => Some(channel),
Command::INVITE(_, channel) => Some(channel),
Command::ChannelMODE(channel, _) => Some(channel),
Command::UserMODE(nick, _) => Some(nick),
Command::Response(_, args) => args.first().map(String::as_str),
Command::Raw(_, args) => args.first().map(String::as_str),
_ => None,
}
}
fn cached_regex(pattern: &str) -> Option<Arc<Regex>> {
static CACHE: OnceLock<RwLock<HashMap<String, Arc<Regex>>>> = OnceLock::new();
let cache = CACHE.get_or_init(|| RwLock::new(HashMap::new()));
if let Ok(guard) = cache.read() {
if let Some(re) = guard.get(pattern) {
return Some(Arc::clone(re));
}
}
let re = Arc::new(Regex::new(pattern).ok()?);
if let Ok(mut guard) = cache.write() {
guard
.entry(pattern.to_string())
.or_insert_with(|| Arc::clone(&re));
}
Some(re)
}
#[must_use]
pub fn glob_match(pattern: &str, text: &str) -> Option<Vec<String>> {
let re_str = glob_to_regex(pattern);
let re = cached_regex(&re_str)?;
let caps = re.captures(text)?;
let groups: Vec<String> = caps
.iter()
.skip(1) .filter_map(|m| m.map(|m| m.as_str().to_string()))
.collect();
Some(groups)
}
fn glob_to_regex(pattern: &str) -> String {
let mut out = String::from("^(?i)");
for c in pattern.chars() {
match c {
'*' => out.push_str("(.*)"),
'?' => out.push('.'),
c if ".$+^{}[]|\\()".contains(c) => {
out.push('\\');
out.push(c);
}
c => out.push(c),
}
}
out.push('$');
out
}
async fn handle_privmsg<T: Send + Sync + 'static>(
bot: &Arc<T>,
handlers: &HandlerSet<T>,
msg: &Message,
bot_nick: &str,
tx: tokio::sync::mpsc::UnboundedSender<String>,
) {
let Command::PRIVMSG(_, text) = &msg.command else {
dispatch(bot, handlers, msg, bot_nick, tx).await;
return;
};
if let Some(ctcp) = CtcpMessage::parse(text) {
match ctcp.command.as_str() {
"PING" => {
if let Some(sender) = msg.source_nickname() {
let reply = format!(
"NOTICE {sender} :\x01PING{}{}\x01\r\n",
if ctcp.arg.is_empty() { "" } else { " " },
ctcp.arg,
);
if let Err(e) = tx.send(reply) {
eprintln!("[ircbot] failed to send CTCP PING reply: {e}");
}
}
return;
}
"VERSION" => {
if let Some(sender) = msg.source_nickname() {
let reply = format!(
"NOTICE {sender} :\x01VERSION ircbot {}\x01\r\n",
env!("CARGO_PKG_VERSION"),
);
if let Err(e) = tx.send(reply) {
eprintln!("[ircbot] failed to send CTCP VERSION reply: {e}");
}
}
return;
}
_ => {}
}
}
dispatch(bot, handlers, msg, bot_nick, tx).await;
}
async fn dispatch<T: Send + Sync + 'static>(
bot: &Arc<T>,
handlers: &HandlerSet<T>,
msg: &Message,
bot_nick: &str,
tx: tokio::sync::mpsc::UnboundedSender<String>,
) {
let current: Arc<Vec<HandlerEntry<T>>> = {
let guard = handlers.read().unwrap_or_else(|e| e.into_inner());
Arc::clone(&*guard)
};
let sender = match msg.prefix.as_ref() {
Some(Prefix::Nickname(nick, user, host)) if !user.is_empty() => Some(User {
nick: nick.clone(),
user: user.clone(),
host: host.clone(),
}),
_ => None,
};
let target = target_param(msg).unwrap_or("").to_string();
let is_channel = target.is_channel_name();
for entry in current.iter() {
if let Some(captures) = check_trigger(&entry.trigger, msg, bot_nick) {
let ctx = Context {
tx: tx.clone(),
target: target.clone(),
is_channel,
sender: sender.clone(),
raw: msg.clone(),
bot_nick: bot_nick.to_string(),
captures,
};
let bot_clone = Arc::clone(bot);
let fut = (entry.handler)(bot_clone, ctx);
if let Err(e) = fut.await {
eprintln!("[ircbot] handler error: {e}");
}
}
}
}