use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use refluxer::model::id::{ChannelId, UserId};
use refluxer::model::message::Message;
use refluxer::{Client, Context, EventHandler};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
struct Reminder {
user_id: UserId,
text: String,
fires_at: Instant,
handle: JoinHandle<()>,
}
#[derive(Clone)]
struct Store {
next_id: Arc<AtomicU64>,
inner: Arc<Mutex<HashMap<u64, Reminder>>>,
}
impl Store {
fn new() -> Self {
Self {
next_id: Arc::new(AtomicU64::new(1)),
inner: Arc::new(Mutex::new(HashMap::new())),
}
}
}
struct Handler {
store: Store,
}
#[async_trait::async_trait]
impl EventHandler for Handler {
async fn message_create(&self, ctx: Context, msg: Message) {
if msg.author.bot.unwrap_or(false) {
return;
}
let content = msg.content.trim();
enum Cmd {
Remind(String),
List,
Cancel(String),
}
let cmd = if let Some(rest) = content.strip_prefix("!remind ") {
Cmd::Remind(rest.to_string())
} else if content == "!reminders" {
Cmd::List
} else if let Some(rest) = content.strip_prefix("!cancel ") {
Cmd::Cancel(rest.trim().to_string())
} else {
return;
};
match cmd {
Cmd::Remind(rest) => self.handle_remind(ctx, msg, &rest).await,
Cmd::List => self.handle_list(ctx, msg).await,
Cmd::Cancel(arg) => self.handle_cancel(ctx, msg, &arg).await,
}
}
}
impl Handler {
async fn handle_remind(&self, ctx: Context, msg: Message, rest: &str) {
let rest = rest.trim();
let Some((dur_str, text)) = rest.split_once(' ') else {
reply(
&ctx,
msg.channel_id,
msg.id,
"Usage: `!remind <duration> <text>`",
)
.await;
return;
};
let Some(dur) = parse::parse_duration(dur_str) else {
reply(
&ctx,
msg.channel_id,
msg.id,
&format!("Invalid duration `{dur_str}` (use e.g. `30s`, `5m`, `2h`, `1d`)."),
)
.await;
return;
};
let text = text.trim();
if text.is_empty() {
reply(
&ctx,
msg.channel_id,
msg.id,
"Reminder text cannot be empty.",
)
.await;
return;
}
let id = self.store.next_id.fetch_add(1, Ordering::Relaxed);
let channel_id = msg.channel_id;
let user_id = msg.author.id;
let text_owned = text.to_string();
let store = self.store.clone();
let ctx_task = ctx.clone();
let text_task = text_owned.clone();
let handle = tokio::spawn(async move {
tokio::time::sleep(dur).await;
let sent = ctx_task
.send(channel_id)
.content(format!("<@{user_id}> ⏰ reminder: {text_task}"))
.await;
if let Err(e) = sent {
tracing::warn!(error = ?e, "failed to deliver reminder");
}
store.inner.lock().await.remove(&id);
});
self.store.inner.lock().await.insert(
id,
Reminder {
user_id,
text: text_owned,
fires_at: Instant::now() + dur,
handle,
},
);
reply(
&ctx,
msg.channel_id,
msg.id,
&format!("Reminder #{id} set for {}.", parse::format_duration(dur)),
)
.await;
}
async fn handle_list(&self, ctx: Context, msg: Message) {
let now = Instant::now();
let map = self.store.inner.lock().await;
let mut mine: Vec<_> = map
.iter()
.filter(|(_, r)| r.user_id == msg.author.id)
.collect();
mine.sort_by_key(|(id, _)| **id);
if mine.is_empty() {
drop(map);
reply(
&ctx,
msg.channel_id,
msg.id,
"You have no pending reminders.",
)
.await;
return;
}
let mut out = String::from("Your reminders:\n");
for (id, r) in mine {
let left = r.fires_at.saturating_duration_since(now);
out.push_str(&format!(
"• #{id} in {} — {}\n",
parse::format_duration(left),
r.text
));
}
drop(map);
reply(&ctx, msg.channel_id, msg.id, &out).await;
}
async fn handle_cancel(&self, ctx: Context, msg: Message, arg: &str) {
let Ok(id) = arg.parse::<u64>() else {
reply(&ctx, msg.channel_id, msg.id, "Usage: `!cancel <id>`").await;
return;
};
let mut map = self.store.inner.lock().await;
let response = match map.get(&id) {
Some(r) if r.user_id == msg.author.id => {
let r = map.remove(&id).expect("checked present");
r.handle.abort();
format!("Reminder #{id} cancelled.")
}
Some(_) => "That reminder belongs to someone else.".to_string(),
None => format!("No reminder with id {id}."),
};
drop(map);
reply(&ctx, msg.channel_id, msg.id, &response).await;
}
}
async fn reply(
ctx: &Context,
channel_id: ChannelId,
reply_to: refluxer::model::id::MessageId,
text: &str,
) {
if let Err(e) = ctx.send(channel_id).content(text).reply(reply_to).await {
tracing::warn!(error = ?e, "failed to send reply");
}
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let token = std::env::var("FLUXER_TOKEN").expect("FLUXER_TOKEN env var required");
let client = Client::builder()
.token(&token)
.event_handler(Handler {
store: Store::new(),
})
.build()
.expect("failed to build client");
if let Err(e) = client.start().await {
eprintln!("Client error: {e}");
}
}
mod parse {
use std::time::Duration;
pub fn parse_duration(s: &str) -> Option<Duration> {
let s = s.trim();
if s.len() < 2 {
return None;
}
let (num, suffix) = s.split_at(s.len() - 1);
let n: u64 = num.parse().ok()?;
if n == 0 {
return None;
}
let secs = match suffix {
"s" => n,
"m" => n.checked_mul(60)?,
"h" => n.checked_mul(3600)?,
"d" => n.checked_mul(86_400)?,
_ => return None,
};
Some(Duration::from_secs(secs))
}
pub fn format_duration(d: Duration) -> String {
let mut secs = d.as_secs();
if secs == 0 {
return "0s".to_string();
}
let days = secs / 86_400;
secs %= 86_400;
let hours = secs / 3600;
secs %= 3600;
let mins = secs / 60;
secs %= 60;
let mut out = String::new();
if days > 0 {
out.push_str(&format!("{days}d"));
}
if hours > 0 {
out.push_str(&format!("{hours}h"));
}
if mins > 0 {
out.push_str(&format!("{mins}m"));
}
if secs > 0 {
out.push_str(&format!("{secs}s"));
}
out
}
}
#[cfg(test)]
mod tests {
use super::parse::*;
use std::time::Duration;
#[test]
fn parses_all_suffixes() {
assert_eq!(parse_duration("30s"), Some(Duration::from_secs(30)));
assert_eq!(parse_duration("5m"), Some(Duration::from_secs(300)));
assert_eq!(parse_duration("2h"), Some(Duration::from_secs(7200)));
assert_eq!(parse_duration("1d"), Some(Duration::from_secs(86_400)));
}
#[test]
fn rejects_invalid_input() {
assert_eq!(parse_duration(""), None);
assert_eq!(parse_duration("m"), None);
assert_eq!(parse_duration("0s"), None);
assert_eq!(parse_duration("5x"), None);
assert_eq!(parse_duration("abc"), None);
}
#[test]
fn formats_compound_duration() {
assert_eq!(format_duration(Duration::from_secs(0)), "0s");
assert_eq!(format_duration(Duration::from_secs(90)), "1m30s");
assert_eq!(format_duration(Duration::from_secs(3661)), "1h1m1s");
assert_eq!(format_duration(Duration::from_secs(90_061)), "1d1h1m1s");
}
}