use crate::channel::{Channel, IncomingMessage, OutgoingMessage};
use crate::config::DiscordConfig;
use anyhow::{Context, Result};
use async_trait::async_trait;
use serenity::Client;
use serenity::all::{
ChannelId, Context as SerenityCtx, EventHandler, GatewayIntents, Message, Ready,
};
use serenity::http::Http;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::{OnceCell, mpsc};
use tracing::{debug, info, warn};
struct DiscordHandler {
tx: mpsc::Sender<IncomingMessage>,
allowed_channel_ids: HashSet<u64>,
allowed_user_ids: HashSet<u64>,
}
#[serenity::async_trait]
impl EventHandler for DiscordHandler {
async fn ready(&self, _ctx: SerenityCtx, ready: Ready) {
info!("Discord bot connected as {}", ready.user.name);
}
async fn message(&self, _ctx: SerenityCtx, msg: Message) {
if msg.author.bot {
return;
}
let channel_id = msg.channel_id.get();
if !self.allowed_channel_ids.is_empty() && !self.allowed_channel_ids.contains(&channel_id) {
debug!("Ignoring message from channel {channel_id} (not in allowed list)");
return;
}
let user_id = msg.author.id.get();
if !self.allowed_user_ids.is_empty() && !self.allowed_user_ids.contains(&user_id) {
debug!("Ignoring message from user {user_id} (not in allowed list)");
return;
}
let content = msg.content.trim().to_string();
if content.is_empty() {
return;
}
let incoming = IncomingMessage {
id: msg.id.to_string(),
sender: msg.author.id.to_string(),
content,
room_id: msg.channel_id.to_string(),
timestamp: msg.timestamp.unix_timestamp() as u64 * 1000,
thread_id: None,
};
if let Err(e) = self.tx.send(incoming).await {
warn!("Failed to forward Discord message: {e}");
}
}
}
pub struct DiscordChannel {
token: String,
channel_ids: HashSet<u64>,
allowed_user_ids: HashSet<u64>,
http: Arc<OnceCell<Arc<Http>>>,
}
impl std::fmt::Debug for DiscordChannel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DiscordChannel")
.field("channel_ids", &self.channel_ids)
.finish_non_exhaustive()
}
}
impl DiscordChannel {
pub fn new(cfg: &DiscordConfig) -> Result<Self> {
let channel_ids = cfg
.channel_ids
.iter()
.map(|s| s.parse::<u64>().context("Invalid Discord channel_id"))
.collect::<Result<HashSet<_>>>()?;
let allowed_user_ids = cfg
.allowed_users
.iter()
.map(|s| {
s.parse::<u64>()
.context("Invalid Discord user ID in allowed_users")
})
.collect::<Result<HashSet<_>>>()?;
Ok(Self {
token: cfg.bot_token.clone(),
channel_ids,
allowed_user_ids,
http: Arc::new(OnceCell::new()),
})
}
fn get_http_or_new(&self) -> Arc<Http> {
if let Some(http) = self.http.get() {
Arc::clone(http)
} else {
Arc::new(Http::new(&self.token))
}
}
}
#[async_trait]
impl Channel for DiscordChannel {
fn name(&self) -> &str {
"discord"
}
async fn send(&self, message: &OutgoingMessage) -> Result<()> {
let http = self.get_http_or_new();
let channel_id: u64 = message
.room_id
.parse()
.context("Discord room_id is not a valid channel ID")?;
let channel_id = ChannelId::new(channel_id);
for chunk in split_for_discord(&message.content) {
channel_id
.say(http.as_ref(), chunk)
.await
.context("Failed to send Discord message")?;
}
Ok(())
}
async fn listen(&self, tx: mpsc::Sender<IncomingMessage>) -> Result<()> {
let intents = GatewayIntents::GUILD_MESSAGES
| GatewayIntents::DIRECT_MESSAGES
| GatewayIntents::MESSAGE_CONTENT;
let min_backoff = std::time::Duration::from_secs(1);
let max_backoff = std::time::Duration::from_secs(300);
let stable_threshold = std::time::Duration::from_secs(60);
let mut backoff = min_backoff;
loop {
let handler = DiscordHandler {
tx: tx.clone(),
allowed_channel_ids: self.channel_ids.clone(),
allowed_user_ids: self.allowed_user_ids.clone(),
};
let mut client = Client::builder(&self.token, intents)
.event_handler(handler)
.await
.context("Failed to build Discord client")?;
let _ = self.http.set(Arc::clone(&client.http));
info!("Starting Discord gateway...");
let started = std::time::Instant::now();
match client.start().await {
Ok(()) => {
warn!("Discord gateway exited without error; reconnecting in {backoff:?}");
}
Err(e) => {
warn!("Discord gateway exited with error: {e}; reconnecting in {backoff:?}");
}
}
tokio::time::sleep(backoff).await;
if started.elapsed() >= stable_threshold {
backoff = min_backoff;
} else {
backoff = (backoff * 2).min(max_backoff);
}
info!("Reconnecting Discord gateway...");
}
}
async fn start_typing(&self, room_id: &str) -> Result<()> {
let http = self.get_http_or_new();
let channel_id: u64 = room_id.parse().context("Invalid Discord channel ID")?;
ChannelId::new(channel_id)
.broadcast_typing(http.as_ref())
.await
.context("Failed to send typing indicator")?;
Ok(())
}
}
fn split_for_discord(content: &str) -> Vec<String> {
const LIMIT: usize = 1990; if content.len() <= LIMIT {
return vec![content.to_owned()];
}
let mut chunks = Vec::new();
let mut remaining = content;
while remaining.len() > LIMIT {
let mut split = LIMIT;
while !remaining.is_char_boundary(split) {
split -= 1;
}
if let Some(nl) = remaining[..split].rfind('\n') {
split = nl + 1;
}
chunks.push(remaining[..split].to_owned());
remaining = remaining[split..].trim_start_matches('\n');
}
if !remaining.is_empty() {
chunks.push(remaining.to_owned());
}
chunks
}