Skip to main content

construct/channels/
discord_history.rs

1use super::traits::{Channel, ChannelMessage, SendMessage};
2use async_trait::async_trait;
3use futures_util::{SinkExt, StreamExt};
4use parking_lot::Mutex;
5use serde_json::json;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio_tungstenite::tungstenite::Message;
9use uuid::Uuid;
10
11use crate::memory::{Memory, MemoryCategory};
12
13/// Discord History channel — connects via Gateway WebSocket, stores ALL non-bot messages
14/// to a dedicated discord.db, and forwards @mention messages to the agent.
15pub struct DiscordHistoryChannel {
16    bot_token: String,
17    guild_id: Option<String>,
18    allowed_users: Vec<String>,
19    /// Channel IDs to watch. Empty = watch all channels.
20    channel_ids: Vec<String>,
21    /// Dedicated discord.db memory backend.
22    discord_memory: Arc<dyn Memory>,
23    typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
24    proxy_url: Option<String>,
25    /// When false, DM messages are not stored in discord.db.
26    store_dms: bool,
27    /// When false, @mentions in DMs are not forwarded to the agent.
28    respond_to_dms: bool,
29}
30
31impl DiscordHistoryChannel {
32    pub fn new(
33        bot_token: String,
34        guild_id: Option<String>,
35        allowed_users: Vec<String>,
36        channel_ids: Vec<String>,
37        discord_memory: Arc<dyn Memory>,
38        store_dms: bool,
39        respond_to_dms: bool,
40    ) -> Self {
41        Self {
42            bot_token,
43            guild_id,
44            allowed_users,
45            channel_ids,
46            discord_memory,
47            typing_handles: Mutex::new(HashMap::new()),
48            proxy_url: None,
49            store_dms,
50            respond_to_dms,
51        }
52    }
53
54    pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
55        self.proxy_url = proxy_url;
56        self
57    }
58
59    fn http_client(&self) -> reqwest::Client {
60        crate::config::build_channel_proxy_client(
61            "channel.discord_history",
62            self.proxy_url.as_deref(),
63        )
64    }
65
66    fn is_user_allowed(&self, user_id: &str) -> bool {
67        if self.allowed_users.is_empty() {
68            return true; // default open for logging channel
69        }
70        self.allowed_users.iter().any(|u| u == "*" || u == user_id)
71    }
72
73    fn is_channel_watched(&self, channel_id: &str) -> bool {
74        self.channel_ids.is_empty() || self.channel_ids.iter().any(|c| c == channel_id)
75    }
76
77    fn bot_user_id_from_token(token: &str) -> Option<String> {
78        let part = token.split('.').next()?;
79        base64_decode(part)
80    }
81
82    async fn resolve_channel_name(&self, channel_id: &str) -> String {
83        // 1. Check persistent database (via discord_memory)
84        let cache_key = format!("cache:channel_name:{}", channel_id);
85
86        if let Ok(Some(cached_mem)) = self.discord_memory.get(&cache_key).await {
87            // Check if it's still fresh (e.g., less than 24 hours old)
88            // Note: cached_mem.timestamp is an RFC3339 string
89            let is_fresh =
90                if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&cached_mem.timestamp) {
91                    chrono::Utc::now().signed_duration_since(ts.with_timezone(&chrono::Utc))
92                        < chrono::Duration::hours(24)
93                } else {
94                    false
95                };
96
97            if is_fresh {
98                return cached_mem.content.clone();
99            }
100        }
101
102        // 2. Fetch from API (either not in DB or stale)
103        let url = format!("https://discord.com/api/v10/channels/{channel_id}");
104        let resp = self
105            .http_client()
106            .get(&url)
107            .header("Authorization", format!("Bot {}", self.bot_token))
108            .send()
109            .await;
110
111        let name = if let Ok(r) = resp {
112            if let Ok(json) = r.json::<serde_json::Value>().await {
113                json.get("name")
114                    .and_then(|n| n.as_str())
115                    .map(|s| s.to_string())
116                    .or_else(|| {
117                        // For DMs, there might not be a 'name', use the recipient's username if available
118                        json.get("recipients")
119                            .and_then(|r| r.as_array())
120                            .and_then(|a| a.first())
121                            .and_then(|u| u.get("username"))
122                            .and_then(|un| un.as_str())
123                            .map(|s| format!("dm-{}", s))
124                    })
125            } else {
126                None
127            }
128        } else {
129            None
130        };
131
132        let resolved = name.unwrap_or_else(|| channel_id.to_string());
133
134        // 3. Store in persistent database
135        let _ = self
136            .discord_memory
137            .store(
138                &cache_key,
139                &resolved,
140                crate::memory::MemoryCategory::Custom("channel_cache".to_string()),
141                Some(channel_id),
142            )
143            .await;
144
145        resolved
146    }
147}
148
149const BASE64_ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
150
151#[allow(clippy::cast_possible_truncation)]
152fn base64_decode(input: &str) -> Option<String> {
153    let padded = match input.len() % 4 {
154        2 => format!("{input}=="),
155        3 => format!("{input}="),
156        _ => input.to_string(),
157    };
158    let mut bytes = Vec::new();
159    let chars: Vec<u8> = padded.bytes().collect();
160    for chunk in chars.chunks(4) {
161        if chunk.len() < 4 {
162            break;
163        }
164        let mut v = [0usize; 4];
165        for (i, &b) in chunk.iter().enumerate() {
166            if b == b'=' {
167                v[i] = 0;
168            } else {
169                v[i] = BASE64_ALPHABET.iter().position(|&a| a == b)?;
170            }
171        }
172        bytes.push(((v[0] << 2) | (v[1] >> 4)) as u8);
173        if chunk[2] != b'=' {
174            bytes.push((((v[1] & 0xF) << 4) | (v[2] >> 2)) as u8);
175        }
176        if chunk[3] != b'=' {
177            bytes.push((((v[2] & 0x3) << 6) | v[3]) as u8);
178        }
179    }
180    String::from_utf8(bytes).ok()
181}
182
183fn contains_bot_mention(content: &str, bot_user_id: &str) -> bool {
184    if bot_user_id.is_empty() {
185        return false;
186    }
187    content.contains(&format!("<@{bot_user_id}>"))
188        || content.contains(&format!("<@!{bot_user_id}>"))
189}
190
191fn strip_bot_mention(content: &str, bot_user_id: &str) -> String {
192    let mut result = content.to_string();
193    for tag in [format!("<@{bot_user_id}>"), format!("<@!{bot_user_id}>")] {
194        result = result.replace(&tag, " ");
195    }
196    result.trim().to_string()
197}
198
199#[async_trait]
200impl Channel for DiscordHistoryChannel {
201    fn name(&self) -> &str {
202        "discord_history"
203    }
204
205    /// Send a reply back to Discord (used when agent responds to @mention).
206    async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
207        let content = super::strip_tool_call_tags(&message.content);
208        let url = format!(
209            "https://discord.com/api/v10/channels/{}/messages",
210            message.recipient
211        );
212        self.http_client()
213            .post(&url)
214            .header("Authorization", format!("Bot {}", self.bot_token))
215            .json(&json!({"content": content}))
216            .send()
217            .await?;
218        Ok(())
219    }
220
221    #[allow(clippy::too_many_lines)]
222    async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
223        let bot_user_id = Self::bot_user_id_from_token(&self.bot_token).unwrap_or_default();
224
225        // Get Gateway URL
226        let gw_resp: serde_json::Value = self
227            .http_client()
228            .get("https://discord.com/api/v10/gateway/bot")
229            .header("Authorization", format!("Bot {}", self.bot_token))
230            .send()
231            .await?
232            .json()
233            .await?;
234
235        let gw_url = gw_resp
236            .get("url")
237            .and_then(|u| u.as_str())
238            .unwrap_or("wss://gateway.discord.gg");
239
240        let ws_url = format!("{gw_url}/?v=10&encoding=json");
241        tracing::info!("DiscordHistory: connecting to gateway...");
242
243        let (ws_stream, _) = crate::config::ws_connect_with_proxy(
244            &ws_url,
245            "channel.discord",
246            self.proxy_url.as_deref(),
247        )
248        .await?;
249        let (mut write, mut read) = ws_stream.split();
250
251        // Read Hello (opcode 10)
252        let hello = read.next().await.ok_or(anyhow::anyhow!("No hello"))??;
253        let hello_data: serde_json::Value = serde_json::from_str(&hello.to_string())?;
254        let heartbeat_interval = hello_data
255            .get("d")
256            .and_then(|d| d.get("heartbeat_interval"))
257            .and_then(serde_json::Value::as_u64)
258            .unwrap_or(41250);
259
260        // Identify with intents for guild + DM messages + message content
261        let identify = json!({
262            "op": 2,
263            "d": {
264                "token": self.bot_token,
265                "intents": 37377,
266                "properties": {
267                    "os": "linux",
268                    "browser": "construct",
269                    "device": "construct"
270                }
271            }
272        });
273        write
274            .send(Message::Text(identify.to_string().into()))
275            .await?;
276
277        tracing::info!("DiscordHistory: connected and identified");
278
279        let mut sequence: i64 = -1;
280
281        let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1);
282        tokio::spawn(async move {
283            let mut interval =
284                tokio::time::interval(std::time::Duration::from_millis(heartbeat_interval));
285            loop {
286                interval.tick().await;
287                if hb_tx.send(()).await.is_err() {
288                    break;
289                }
290            }
291        });
292
293        let guild_filter = self.guild_id.clone();
294        let discord_memory = Arc::clone(&self.discord_memory);
295        let store_dms = self.store_dms;
296        let respond_to_dms = self.respond_to_dms;
297
298        loop {
299            tokio::select! {
300                _ = hb_rx.recv() => {
301                    let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
302                    let hb = json!({"op": 1, "d": d});
303                    if write.send(Message::Text(hb.to_string().into())).await.is_err() {
304                        break;
305                    }
306                }
307                msg = read.next() => {
308                    let msg = match msg {
309                        Some(Ok(Message::Text(t))) => t,
310                        Some(Ok(Message::Ping(payload))) => {
311                            if write.send(Message::Pong(payload)).await.is_err() {
312                                break;
313                            }
314                            continue;
315                        }
316                        Some(Ok(Message::Close(_))) | None => break,
317                        Some(Err(e)) => {
318                            tracing::warn!("DiscordHistory: websocket error: {e}");
319                            break;
320                        }
321                        _ => continue,
322                    };
323
324                    let event: serde_json::Value = match serde_json::from_str(msg.as_ref()) {
325                        Ok(e) => e,
326                        Err(_) => continue,
327                    };
328
329                    if let Some(s) = event.get("s").and_then(serde_json::Value::as_i64) {
330                        sequence = s;
331                    }
332
333                    let op = event.get("op").and_then(serde_json::Value::as_u64).unwrap_or(0);
334                    match op {
335                        1 => {
336                            let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
337                            let hb = json!({"op": 1, "d": d});
338                            if write.send(Message::Text(hb.to_string().into())).await.is_err() {
339                                break;
340                            }
341                            continue;
342                        }
343                        7 => { tracing::warn!("DiscordHistory: Reconnect (op 7)"); break; }
344                        9 => { tracing::warn!("DiscordHistory: Invalid Session (op 9)"); break; }
345                        _ => {}
346                    }
347
348                    let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or("");
349                    if event_type != "MESSAGE_CREATE" {
350                        continue;
351                    }
352
353                    let Some(d) = event.get("d") else { continue };
354
355                    // Skip messages from the bot itself
356                    let author_id = d
357                        .get("author")
358                        .and_then(|a| a.get("id"))
359                        .and_then(|i| i.as_str())
360                        .unwrap_or("");
361                    let username = d
362                        .get("author")
363                        .and_then(|a| a.get("username"))
364                        .and_then(|i| i.as_str())
365                        .unwrap_or(author_id);
366
367                    if author_id == bot_user_id {
368                        continue;
369                    }
370
371                    // Skip other bots
372                    if d.get("author")
373                        .and_then(|a| a.get("bot"))
374                        .and_then(serde_json::Value::as_bool)
375                        .unwrap_or(false)
376                    {
377                        continue;
378                    }
379
380                    let channel_id = d
381                        .get("channel_id")
382                        .and_then(|c| c.as_str())
383                        .unwrap_or("")
384                        .to_string();
385
386                    // DM detection: DMs have no guild_id
387                    let is_dm_event = d.get("guild_id").and_then(serde_json::Value::as_str).is_none();
388
389                    // Resolve channel name (with cache)
390                    let channel_display = if is_dm_event {
391                        "dm".to_string()
392                    } else {
393                        self.resolve_channel_name(&channel_id).await
394                    };
395
396                    if is_dm_event && !store_dms && !respond_to_dms {
397                        continue;
398                    }
399
400                    // Guild filter
401                    if let Some(ref gid) = guild_filter {
402                        let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str);
403                        if let Some(g) = msg_guild {
404                            if g != gid {
405                                continue;
406                            }
407                        }
408                    }
409
410                    // Channel filter
411                    if !self.is_channel_watched(&channel_id) {
412                        continue;
413                    }
414
415                    if !self.is_user_allowed(author_id) {
416                        continue;
417                    }
418
419                    let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("");
420                    let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
421                    let is_mention = contains_bot_mention(content, &bot_user_id);
422
423                    // Collect attachment URLs
424                    let attachments: Vec<String> = d
425                        .get("attachments")
426                        .and_then(|a| a.as_array())
427                        .map(|arr| {
428                            arr.iter()
429                                .filter_map(|a| a.get("url").and_then(|u| u.as_str()))
430                                .map(|u| u.to_string())
431                                .collect()
432                        })
433                        .unwrap_or_default();
434
435                    // Store messages to discord.db (skip DMs if store_dms=false)
436                    if (!is_dm_event || store_dms) && (!content.is_empty() || !attachments.is_empty()) {
437                        let ts = chrono::Utc::now().to_rfc3339();
438                        let mut mem_content = format!(
439                            "@{username} in #{channel_display} at {ts}: {content}"
440                        );
441                        if !attachments.is_empty() {
442                            mem_content.push_str(" [attachments: ");
443                            mem_content.push_str(&attachments.join(", "));
444                            mem_content.push(']');
445                        }
446                        let mem_key = format!(
447                            "discord_{}",
448                            if message_id.is_empty() {
449                                Uuid::new_v4().to_string()
450                            } else {
451                                message_id.to_string()
452                            }
453                        );
454                        let channel_id_for_session = if channel_id.is_empty() {
455                            None
456                        } else {
457                            Some(channel_id.as_str())
458                        };
459                        if let Err(err) = discord_memory
460                            .store(
461                                &mem_key,
462                                &mem_content,
463                                MemoryCategory::Custom("discord".to_string()),
464                                channel_id_for_session,
465                            )
466                            .await
467                        {
468                            tracing::warn!("discord_history: failed to store message: {err}");
469                        } else {
470                            tracing::debug!(
471                                "discord_history: stored message from @{username} in #{channel_display}"
472                            );
473                        }
474                    }
475
476                    // Forward @mention to agent (skip DMs if respond_to_dms=false)
477                    if is_mention && (!is_dm_event || respond_to_dms) {
478                        let clean_content = strip_bot_mention(content, &bot_user_id);
479                        if clean_content.is_empty() {
480                            continue;
481                        }
482                        let channel_msg = ChannelMessage {
483                            id: if message_id.is_empty() {
484                                Uuid::new_v4().to_string()
485                            } else {
486                                format!("discord_{message_id}")
487                            },
488                            sender: author_id.to_string(),
489                            reply_target: if channel_id.is_empty() {
490                                author_id.to_string()
491                            } else {
492                                channel_id.clone()
493                            },
494                            content: clean_content,
495                            channel: "discord_history".to_string(),
496                            timestamp: std::time::SystemTime::now()
497                                .duration_since(std::time::UNIX_EPOCH)
498                                .unwrap_or_default()
499                                .as_secs(),
500                            thread_ts: None,
501                            interruption_scope_id: None,
502                            attachments: Vec::new(),
503                        };
504                        if tx.send(channel_msg).await.is_err() {
505                            break;
506                        }
507                    }
508                }
509            }
510        }
511
512        Ok(())
513    }
514
515    async fn health_check(&self) -> bool {
516        self.http_client()
517            .get("https://discord.com/api/v10/users/@me")
518            .header("Authorization", format!("Bot {}", self.bot_token))
519            .send()
520            .await
521            .map(|r| r.status().is_success())
522            .unwrap_or(false)
523    }
524
525    async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> {
526        let mut guard = self.typing_handles.lock();
527        if let Some(h) = guard.remove(recipient) {
528            h.abort();
529        }
530        let client = self.http_client();
531        let token = self.bot_token.clone();
532        let channel_id = recipient.to_string();
533        let handle = tokio::spawn(async move {
534            let url = format!("https://discord.com/api/v10/channels/{channel_id}/typing");
535            loop {
536                let _ = client
537                    .post(&url)
538                    .header("Authorization", format!("Bot {token}"))
539                    .send()
540                    .await;
541                tokio::time::sleep(std::time::Duration::from_secs(8)).await;
542            }
543        });
544        guard.insert(recipient.to_string(), handle);
545        Ok(())
546    }
547
548    async fn stop_typing(&self, recipient: &str) -> anyhow::Result<()> {
549        let mut guard = self.typing_handles.lock();
550        if let Some(handle) = guard.remove(recipient) {
551            handle.abort();
552        }
553        Ok(())
554    }
555}