rig/integrations/
discord_bot.rs1use crate::OneOrMany;
4use crate::agent::Agent;
5use crate::completion::{AssistantContent, CompletionModel, request::Chat};
6use crate::message::{Message as RigMessage, UserContent};
7use serenity::all::{
8 Command, CommandInteraction, Context, CreateCommand, CreateThread, EventHandler,
9 GatewayIntents, Interaction, Message, Ready, async_trait,
10};
11use std::collections::HashMap;
12use std::env;
13use std::sync::Arc;
14use thiserror::Error;
15use tokio::sync::RwLock;
16
17#[derive(Debug, Error)]
18pub enum DiscordBotError {
19 #[error("Discord bot token missing from environment: {0}")]
20 MissingToken(#[from] env::VarError),
21 #[error("Failed to build Discord client: {0}")]
22 ClientBuild(#[from] serenity::Error),
23}
24
25struct BotState<M: CompletionModel> {
27 agent: Agent<M>,
28 conversations: Arc<RwLock<HashMap<u64, Vec<RigMessage>>>>,
29}
30
31impl<M: CompletionModel> BotState<M> {
32 fn new(agent: Agent<M>) -> Self {
33 Self {
34 agent,
35 conversations: Arc::new(RwLock::new(HashMap::new())),
36 }
37 }
38}
39
40struct Handler<M: CompletionModel> {
42 state: Arc<BotState<M>>,
43}
44
45#[async_trait]
46impl<M> EventHandler for Handler<M>
47where
48 M: CompletionModel + Send + Sync + 'static,
49{
50 async fn ready(&self, ctx: Context, ready: Ready) {
51 println!("{} is connected!", ready.user.name);
52
53 let register_cmd =
54 CreateCommand::new("new").description("Start a new chat session with the bot");
55
56 let command = Command::create_global_command(&ctx.http, register_cmd).await;
58
59 match command {
60 Ok(cmd) => println!("Registered global command: {}", cmd.name),
61 Err(e) => eprintln!("Failed to register command: {}", e),
62 }
63 }
64
65 async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
66 if let Interaction::Command(command) = interaction {
67 self.handle_command(&ctx, &command).await;
68 }
69 }
70
71 async fn message(&self, ctx: Context, msg: Message) {
72 if msg.author.bot {
74 return;
75 }
76
77 let conversations = self.state.conversations.read().await;
79 if conversations.contains_key(&msg.channel_id.get()) {
80 drop(conversations);
81 self.handle_thread_message(&ctx, &msg).await;
82 }
83 }
84}
85
86impl<M> Handler<M>
87where
88 M: CompletionModel + Send + Sync + 'static,
89{
90 async fn handle_command(&self, ctx: &Context, command: &CommandInteraction) {
91 if command.data.name.as_str() == "new" {
92 if let Err(e) = command.defer(&ctx.http).await {
94 eprintln!("Failed to defer command: {}", e);
95 return;
96 }
97
98 let thread_name = format!("AI Conversation - {}", command.user.name);
100
101 let thread = match command
102 .channel_id
103 .create_thread(
104 &ctx.http,
105 CreateThread::new(thread_name)
106 .kind(serenity::all::ChannelType::PublicThread)
107 .auto_archive_duration(serenity::all::AutoArchiveDuration::OneDay),
108 )
109 .await
110 {
111 Ok(t) => t,
112 Err(e) => {
113 eprintln!("Failed to create thread: {}", e);
114 let _ = command
115 .edit_response(
116 &ctx.http,
117 serenity::all::EditInteractionResponse::new()
118 .content("Failed to create thread. Please try again."),
119 )
120 .await;
121 return;
122 }
123 };
124
125 let mut conversations = self.state.conversations.write().await;
127 conversations.insert(thread.id.get(), Vec::new());
128 drop(conversations);
129
130 if let Err(e) = command
132 .edit_response(
133 &ctx.http,
134 serenity::all::EditInteractionResponse::new()
135 .content(format!(
136 "Started a new conversation in <#{}>! Send messages there to chat with the AI.",
137 thread.id
138 ))
139 )
140 .await
141 {
142 eprintln!("Failed to edit response: {}", e);
143 }
144
145 if let Err(e) = thread
147 .send_message(
148 &ctx.http,
149 serenity::all::CreateMessage::new()
150 .content("Hello! I'm ready to help. What would you like to talk about?"),
151 )
152 .await
153 {
154 eprintln!("Failed to send welcome message: {}", e);
155 }
156 }
157 }
158
159 async fn handle_thread_message(&self, ctx: &Context, msg: &Message) {
160 let thread_id = msg.channel_id.get();
161
162 {
164 let mut conversations = self.state.conversations.write().await;
165 if let Some(history) = conversations.get_mut(&thread_id) {
166 history.push(RigMessage::User {
167 content: OneOrMany::one(UserContent::text(msg.content.clone())),
168 });
169 }
170 }
171
172 let _ = msg.channel_id.broadcast_typing(&ctx.http).await;
174
175 let conversations = self.state.conversations.read().await;
177 let history = if let Some(history) = conversations.get(&thread_id) {
178 history.clone()
179 } else {
180 vec![]
181 };
182 drop(conversations);
183
184 let response = match self.state.agent.chat(&msg.content, history).await {
186 Ok(resp) => resp,
187 Err(e) => {
188 eprintln!("Agent error: {}", e);
189 let _ = msg
190 .channel_id
191 .say(
192 &ctx.http,
193 "Sorry, I encountered an error processing your message.",
194 )
195 .await;
196 return;
197 }
198 };
199
200 {
202 let mut conversations = self.state.conversations.write().await;
203 if let Some(history) = conversations.get_mut(&thread_id) {
204 history.push(RigMessage::Assistant {
205 content: OneOrMany::one(AssistantContent::text(msg.content.clone())),
206 id: None,
207 });
208 }
209 }
210
211 let chunks: Vec<String> = response
213 .chars()
214 .collect::<Vec<_>>()
215 .chunks(1900)
216 .map(|c| c.iter().collect())
217 .collect();
218
219 for chunk in chunks {
220 if let Err(e) = msg.channel_id.say(&ctx.http, &chunk).await {
221 eprintln!("Failed to send message: {}", e);
222 }
223 }
224 }
225}
226
227pub trait DiscordExt: Sized + Send + Sync
230where
231 Self: 'static,
232{
233 fn into_discord_bot(
234 self,
235 token: &str,
236 ) -> impl std::future::Future<Output = Result<serenity::Client, DiscordBotError>> + Send;
237
238 fn into_discord_bot_from_env(
239 self,
240 ) -> impl std::future::Future<Output = Result<serenity::Client, DiscordBotError>> + Send {
241 async move {
242 let token = std::env::var("DISCORD_BOT_TOKEN")?;
243 DiscordExt::into_discord_bot(self, &token).await
244 }
245 }
246}
247
248impl<M> DiscordExt for Agent<M>
249where
250 M: CompletionModel + Send + Sync + 'static,
251{
252 async fn into_discord_bot(self, token: &str) -> Result<serenity::Client, DiscordBotError> {
253 let intents = GatewayIntents::GUILDS
254 | GatewayIntents::GUILD_MESSAGES
255 | GatewayIntents::MESSAGE_CONTENT;
256
257 let state = Arc::new(BotState::new(self));
258 let handler = Handler {
259 state: state.clone(),
260 };
261
262 serenity::Client::builder(token, intents)
263 .event_handler(handler)
264 .await
265 .map_err(DiscordBotError::from)
266 }
267}