mod agent;
pub(crate) mod follow_up_question;
pub(crate) mod handler;
pub use agent::DiscordAgent;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, oneshot};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
type PendingDiscordQuestion = (oneshot::Sender<String>, Vec<String>);
pub struct DiscordState {
http: Mutex<Option<Arc<serenity::http::Http>>>,
owner_channel_id: Mutex<Option<u64>>,
bot_user_id: Mutex<Option<u64>>,
guild_id: Mutex<Option<u64>>,
session_channels: Mutex<HashMap<Uuid, u64>>,
pending_approvals: Mutex<HashMap<String, oneshot::Sender<(bool, bool)>>>,
pending_questions: Mutex<HashMap<String, PendingDiscordQuestion>>,
cancel_tokens: Mutex<HashMap<Uuid, CancellationToken>>,
}
impl Default for DiscordState {
fn default() -> Self {
Self::new()
}
}
impl DiscordState {
pub fn new() -> Self {
Self {
http: Mutex::new(None),
owner_channel_id: Mutex::new(None),
bot_user_id: Mutex::new(None),
guild_id: Mutex::new(None),
session_channels: Mutex::new(HashMap::new()),
pending_approvals: Mutex::new(HashMap::new()),
pending_questions: Mutex::new(HashMap::new()),
cancel_tokens: Mutex::new(HashMap::new()),
}
}
pub async fn set_connected(&self, http: Arc<serenity::http::Http>, channel_id: Option<u64>) {
*self.http.lock().await = Some(http);
if let Some(id) = channel_id {
*self.owner_channel_id.lock().await = Some(id);
}
}
pub async fn set_owner_channel(&self, channel_id: u64) {
*self.owner_channel_id.lock().await = Some(channel_id);
}
pub async fn http(&self) -> Option<Arc<serenity::http::Http>> {
self.http.lock().await.clone()
}
pub async fn owner_channel_id(&self) -> Option<u64> {
*self.owner_channel_id.lock().await
}
pub async fn set_bot_user_id(&self, id: u64) {
*self.bot_user_id.lock().await = Some(id);
}
pub async fn bot_user_id(&self) -> Option<u64> {
*self.bot_user_id.lock().await
}
pub async fn set_guild_id(&self, id: u64) {
*self.guild_id.lock().await = Some(id);
}
pub async fn guild_id(&self) -> Option<u64> {
*self.guild_id.lock().await
}
pub async fn is_connected(&self) -> bool {
self.http.lock().await.is_some()
}
pub async fn register_session_channel(&self, session_id: Uuid, channel_id: u64) {
self.session_channels
.lock()
.await
.insert(session_id, channel_id);
}
pub async fn session_channel(&self, session_id: Uuid) -> Option<u64> {
self.session_channels.lock().await.get(&session_id).copied()
}
pub async fn register_pending_approval(&self, id: String, tx: oneshot::Sender<(bool, bool)>) {
self.pending_approvals.lock().await.insert(id, tx);
}
pub async fn resolve_pending_approval(&self, id: &str, approved: bool, always: bool) -> bool {
if let Some(tx) = self.pending_approvals.lock().await.remove(id) {
let _ = tx.send((approved, always));
true
} else {
false
}
}
pub async fn register_pending_question(
&self,
id: String,
tx: oneshot::Sender<String>,
options: Vec<String>,
) {
self.pending_questions
.lock()
.await
.insert(id, (tx, options));
}
pub async fn resolve_pending_question(&self, id: &str, idx: usize) -> Option<String> {
let (tx, options) = self.pending_questions.lock().await.remove(id)?;
let answer = options.get(idx)?.clone();
let _ = tx.send(answer.clone());
Some(answer)
}
pub async fn store_cancel_token(&self, session_id: Uuid, token: CancellationToken) {
let mut tokens = self.cancel_tokens.lock().await;
if let Some(old) = tokens.remove(&session_id) {
tracing::warn!(
"Discord: cancelling previous in-flight agent call for session {}",
session_id
);
old.cancel();
}
tokens.insert(session_id, token);
}
pub async fn cancel_session(&self, session_id: Uuid) -> bool {
if let Some(token) = self.cancel_tokens.lock().await.remove(&session_id) {
token.cancel();
true
} else {
false
}
}
pub async fn remove_cancel_token(&self, session_id: Uuid) {
let mut tokens = self.cancel_tokens.lock().await;
if let Some(token) = tokens.get(&session_id)
&& token.is_cancelled()
{
tokens.remove(&session_id);
}
}
}