rig_core/integrations/
discord_bot.rs1use crate::agent::Agent;
4use crate::completion::{CompletionModel, request::Chat};
5use crate::message::Message as RigMessage;
6use serenity::all::{
7 Command, CommandInteraction, Context, CreateCommand, CreateThread, EventHandler,
8 GatewayIntents, Interaction, Message, Ready, async_trait,
9};
10use std::collections::HashMap;
11use std::env;
12use std::sync::Arc;
13use thiserror::Error;
14use tokio::sync::RwLock;
15
16#[derive(Debug, Error)]
17pub enum DiscordBotError {
18 #[error("Discord bot token missing from environment: {0}")]
19 MissingToken(#[from] env::VarError),
20 #[error("Failed to build Discord client: {0}")]
21 ClientBuild(#[from] serenity::Error),
22}
23
24struct BotState<M: CompletionModel> {
26 agent: Agent<M>,
27 conversations: Arc<RwLock<HashMap<u64, Vec<RigMessage>>>>,
28}
29
30impl<M: CompletionModel> BotState<M> {
31 fn new(agent: Agent<M>) -> Self {
32 Self {
33 agent,
34 conversations: Arc::new(RwLock::new(HashMap::new())),
35 }
36 }
37}
38
39struct Handler<M: CompletionModel> {
41 state: Arc<BotState<M>>,
42}
43
44#[async_trait]
45impl<M> EventHandler for Handler<M>
46where
47 M: CompletionModel + Send + Sync + 'static,
48{
49 async fn ready(&self, ctx: Context, ready: Ready) {
50 println!("{} is connected!", ready.user.name);
51
52 let register_cmd =
53 CreateCommand::new("new").description("Start a new chat session with the bot");
54
55 let command = Command::create_global_command(&ctx.http, register_cmd).await;
57
58 match command {
59 Ok(cmd) => println!("Registered global command: {}", cmd.name),
60 Err(e) => eprintln!("Failed to register command: {}", e),
61 }
62 }
63
64 async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
65 if let Interaction::Command(command) = interaction {
66 self.handle_command(&ctx, &command).await;
67 }
68 }
69
70 async fn message(&self, ctx: Context, msg: Message) {
71 if msg.author.bot {
73 return;
74 }
75
76 let conversations = self.state.conversations.read().await;
78 if conversations.contains_key(&msg.channel_id.get()) {
79 drop(conversations);
80 self.handle_thread_message(&ctx, &msg).await;
81 }
82 }
83}
84
85impl<M> Handler<M>
86where
87 M: CompletionModel + Send + Sync + 'static,
88{
89 async fn handle_command(&self, ctx: &Context, command: &CommandInteraction) {
90 if command.data.name.as_str() == "new" {
91 if let Err(e) = command.defer(&ctx.http).await {
93 eprintln!("Failed to defer command: {}", e);
94 return;
95 }
96
97 let thread_name = format!("AI Conversation - {}", command.user.name);
99
100 let thread = match command
101 .channel_id
102 .create_thread(
103 &ctx.http,
104 CreateThread::new(thread_name)
105 .kind(serenity::all::ChannelType::PublicThread)
106 .auto_archive_duration(serenity::all::AutoArchiveDuration::OneDay),
107 )
108 .await
109 {
110 Ok(t) => t,
111 Err(e) => {
112 eprintln!("Failed to create thread: {}", e);
113 let _ = command
114 .edit_response(
115 &ctx.http,
116 serenity::all::EditInteractionResponse::new()
117 .content("Failed to create thread. Please try again."),
118 )
119 .await;
120 return;
121 }
122 };
123
124 let mut conversations = self.state.conversations.write().await;
126 conversations.insert(thread.id.get(), Vec::new());
127 drop(conversations);
128
129 if let Err(e) = command
131 .edit_response(
132 &ctx.http,
133 serenity::all::EditInteractionResponse::new()
134 .content(format!(
135 "Started a new conversation in <#{}>! Send messages there to chat with the AI.",
136 thread.id
137 ))
138 )
139 .await
140 {
141 eprintln!("Failed to edit response: {}", e);
142 }
143
144 if let Err(e) = thread
146 .send_message(
147 &ctx.http,
148 serenity::all::CreateMessage::new()
149 .content("Hello! I'm ready to help. What would you like to talk about?"),
150 )
151 .await
152 {
153 eprintln!("Failed to send welcome message: {}", e);
154 }
155 }
156 }
157
158 async fn handle_thread_message(&self, ctx: &Context, msg: &Message) {
159 let thread_id = msg.channel_id.get();
160
161 let _ = msg.channel_id.broadcast_typing(&ctx.http).await;
163
164 let conversations = self.state.conversations.read().await;
166 let mut history = if let Some(history) = conversations.get(&thread_id) {
167 history.clone()
168 } else {
169 vec![]
170 };
171 drop(conversations);
172
173 let response = match self.state.agent.chat(&msg.content, &mut history).await {
176 Ok(resp) => resp,
177 Err(e) => {
178 eprintln!("Agent error: {}", e);
179 let _ = msg
180 .channel_id
181 .say(
182 &ctx.http,
183 "Sorry, I encountered an error processing your message.",
184 )
185 .await;
186 return;
187 }
188 };
189
190 {
192 let mut conversations = self.state.conversations.write().await;
193 conversations.insert(thread_id, history);
194 }
195
196 let chunks: Vec<String> = response
198 .chars()
199 .collect::<Vec<_>>()
200 .chunks(1900)
201 .map(|c| c.iter().collect())
202 .collect();
203
204 for chunk in chunks {
205 if let Err(e) = msg.channel_id.say(&ctx.http, &chunk).await {
206 eprintln!("Failed to send message: {}", e);
207 }
208 }
209 }
210}
211
212pub trait DiscordExt: Sized + Send + Sync
215where
216 Self: 'static,
217{
218 fn into_discord_bot(
219 self,
220 token: &str,
221 ) -> impl std::future::Future<Output = Result<serenity::Client, DiscordBotError>> + Send;
222
223 fn into_discord_bot_from_env(
224 self,
225 ) -> impl std::future::Future<Output = Result<serenity::Client, DiscordBotError>> + Send {
226 async move {
227 let token = std::env::var("DISCORD_BOT_TOKEN")?;
228 DiscordExt::into_discord_bot(self, &token).await
229 }
230 }
231}
232
233impl<M> DiscordExt for Agent<M>
234where
235 M: CompletionModel + Send + Sync + 'static,
236{
237 async fn into_discord_bot(self, token: &str) -> Result<serenity::Client, DiscordBotError> {
238 let intents = GatewayIntents::GUILDS
239 | GatewayIntents::GUILD_MESSAGES
240 | GatewayIntents::MESSAGE_CONTENT;
241
242 let state = Arc::new(BotState::new(self));
243 let handler = Handler {
244 state: state.clone(),
245 };
246
247 serenity::Client::builder(token, intents)
248 .event_handler(handler)
249 .await
250 .map_err(DiscordBotError::from)
251 }
252}