Skip to main content

discord_echo/
lib.rs

1pub mod client;
2pub mod config;
3pub mod dedup;
4pub mod gateway;
5pub mod tool;
6pub mod types;
7
8use std::any::Any;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use pulse_system_types::plugin::{Plugin, PluginContext, PluginResult, PluginRole};
14use pulse_system_types::{HealthStatus, PluginMeta, SetupPrompt};
15use tokio::sync::{mpsc, Mutex, Notify};
16
17use crate::client::DiscordClient;
18use crate::config::Config;
19use crate::dedup::DedupState;
20use crate::types::IncomingMessage;
21
22/// Discord text integration for echo-system entities.
23///
24/// Manages the gateway WebSocket connection (reading messages)
25/// and provides a REST client (posting messages).
26pub struct DiscordEcho {
27    config: Arc<Config>,
28    client: Arc<DiscordClient>,
29    shutdown: Arc<Notify>,
30    dedup: Arc<Mutex<DedupState>>,
31    gateway_handle: Option<tokio::task::JoinHandle<()>>,
32    forwarder_handle: Option<tokio::task::JoinHandle<()>>,
33}
34
35impl DiscordEcho {
36    /// Create a new DiscordEcho instance from config.
37    pub fn new(config: Config) -> Self {
38        let client = DiscordClient::new(config.bot_token.clone(), config.channels.clone());
39        let config = Arc::new(config);
40        Self {
41            config,
42            client,
43            shutdown: Arc::new(Notify::new()),
44            dedup: Arc::new(Mutex::new(DedupState::new())),
45            gateway_handle: None,
46            forwarder_handle: None,
47        }
48    }
49
50    /// Get a reference to the Discord client for tool use.
51    pub fn client(&self) -> Arc<DiscordClient> {
52        Arc::clone(&self.client)
53    }
54
55    /// Health check.
56    fn health_check(&self) -> HealthStatus {
57        if self.gateway_handle.is_some() {
58            HealthStatus::Healthy
59        } else {
60            HealthStatus::Down("Not started".to_string())
61        }
62    }
63
64    /// Setup prompts for first-time configuration.
65    fn get_setup_prompts() -> Vec<SetupPrompt> {
66        vec![
67            SetupPrompt {
68                key: "bot_token".to_string(),
69                question: "Discord bot token:".to_string(),
70                default: None,
71                required: true,
72                secret: true,
73            },
74            SetupPrompt {
75                key: "guild_id".to_string(),
76                question: "Discord server (guild) ID:".to_string(),
77                default: None,
78                required: true,
79                secret: false,
80            },
81        ]
82    }
83}
84
85/// Factory function — creates a fully initialized discord-echo plugin.
86pub async fn create(
87    config: &serde_json::Value,
88    _ctx: &PluginContext,
89) -> Result<Box<dyn Plugin>, Box<dyn std::error::Error + Send + Sync>> {
90    let cfg: Config = serde_json::from_value(config.clone())?;
91    Ok(Box::new(DiscordEcho::new(cfg)))
92}
93
94impl Plugin for DiscordEcho {
95    fn meta(&self) -> PluginMeta {
96        PluginMeta {
97            name: "discord-echo".into(),
98            version: env!("CARGO_PKG_VERSION").into(),
99            description: "Discord text integration".into(),
100        }
101    }
102
103    fn role(&self) -> PluginRole {
104        PluginRole::Interface
105    }
106
107    fn start(&mut self) -> PluginResult<'_> {
108        Box::pin(async move {
109            if self.gateway_handle.is_some() {
110                return Err("Already running".into());
111            }
112
113            let (message_tx, message_rx) = mpsc::channel::<IncomingMessage>(64);
114
115            let gw_config = Arc::clone(&self.config);
116            let gw_shutdown = Arc::clone(&self.shutdown);
117            self.gateway_handle = Some(tokio::spawn(async move {
118                gateway::run_gateway(gw_config, message_tx, gw_shutdown).await;
119            }));
120
121            let fwd_client = Arc::clone(&self.client);
122            let fwd_config = Arc::clone(&self.config);
123            let fwd_shutdown = Arc::clone(&self.shutdown);
124            let fwd_dedup = Arc::clone(&self.dedup);
125            self.forwarder_handle = Some(tokio::spawn(async move {
126                message_forwarder(message_rx, fwd_client, fwd_config, fwd_shutdown, fwd_dedup)
127                    .await;
128            }));
129
130            tracing::info!("Discord text integration started");
131            Ok(())
132        })
133    }
134
135    fn stop(&mut self) -> PluginResult<'_> {
136        Box::pin(async move {
137            self.shutdown.notify_waiters();
138
139            if let Some(h) = self.gateway_handle.take() {
140                let _ = tokio::time::timeout(std::time::Duration::from_secs(5), h).await;
141            }
142            if let Some(h) = self.forwarder_handle.take() {
143                let _ = tokio::time::timeout(std::time::Duration::from_secs(5), h).await;
144            }
145
146            self.shutdown = Arc::new(Notify::new());
147
148            tracing::info!("Discord text integration stopped");
149            Ok(())
150        })
151    }
152
153    fn health(&self) -> Pin<Box<dyn Future<Output = HealthStatus> + Send + '_>> {
154        Box::pin(async move { self.health_check() })
155    }
156
157    fn setup_prompts(&self) -> Vec<SetupPrompt> {
158        Self::get_setup_prompts()
159    }
160
161    fn as_any(&self) -> &dyn Any {
162        self
163    }
164}
165
166/// Silent response markers. If the entity's response starts with any of these,
167/// the forwarder will not post it to Discord. This lets the entity decide
168/// on a per-message basis whether to respond or stay quiet.
169const SILENT_MARKERS: &[&str] = &["[SILENT]", "[NO_RESPONSE]", "No response requested"];
170
171/// Check if a response indicates the entity chose not to respond.
172fn is_silent(response: &str) -> bool {
173    let trimmed = response.trim();
174    SILENT_MARKERS
175        .iter()
176        .any(|marker| trimmed.starts_with(marker))
177}
178
179/// Receives messages from the gateway, forwards to the entity's chat endpoint,
180/// and posts responses back to Discord.
181///
182/// The `dedup` state is shared with the `DiscordEcho` struct and persists
183/// across gateway reconnects, preventing duplicate processing.
184async fn message_forwarder(
185    mut rx: mpsc::Receiver<IncomingMessage>,
186    client: Arc<DiscordClient>,
187    config: Arc<Config>,
188    shutdown: Arc<Notify>,
189    dedup: Arc<Mutex<DedupState>>,
190) {
191    let http = reqwest::Client::new();
192
193    loop {
194        tokio::select! {
195            msg = rx.recv() => {
196                let msg = match msg {
197                    Some(m) => m,
198                    None => return, // gateway dropped
199                };
200
201                // Defense-in-depth: reject messages older than the staleness threshold.
202                // This prevents replay of old messages after a gateway reconnect.
203                {
204                    let state = dedup.lock().await;
205                    if state.is_stale(&msg.timestamp) {
206                        tracing::debug!(
207                            "Ignoring stale message {} (age exceeds threshold)",
208                            msg.message_id
209                        );
210                        continue;
211                    }
212                }
213
214                // Dedup: skip if we've already processed this message ID.
215                // The dedup state survives gateway reconnects.
216                {
217                    let mut state = dedup.lock().await;
218                    if state.check_and_record_seen(&msg.message_id) {
219                        tracing::debug!("Skipping duplicate message {}", msg.message_id);
220                        continue;
221                    }
222                }
223
224                let channel_label = msg.channel_name.as_deref().unwrap_or("discord");
225                tracing::info!(
226                    "Message from {} in #{}: {}",
227                    msg.author_name,
228                    channel_label,
229                    if msg.content.len() > 80 { &msg.content[..80] } else { &msg.content }
230                );
231
232                // Forward to chat endpoint
233                let mut req = http
234                    .post(&config.chat_endpoint)
235                    .json(&serde_json::json!({
236                        "message": msg.content,
237                        "channel": config.chat_channel_name,
238                        "sender": msg.author_name,
239                    }));
240
241                if let Some(ref secret) = config.chat_secret {
242                    req = req.header("X-Echo-Secret", secret);
243                }
244
245                match req.send().await {
246                    Ok(resp) if resp.status().is_success() => {
247                        if let Ok(data) = resp.json::<serde_json::Value>().await {
248                            let response_text = data["response"]
249                                .as_str()
250                                .or_else(|| data["text"].as_str())
251                                .unwrap_or("");
252
253                            if !response_text.is_empty() && !is_silent(response_text) {
254                                // Check responded_ids before posting to Discord.
255                                // This is a second line of defense: even if the message
256                                // passed the seen_ids check (e.g., buffer eviction),
257                                // we won't post a duplicate response.
258                                {
259                                    let state = dedup.lock().await;
260                                    if state.has_responded(&msg.message_id) {
261                                        tracing::debug!(
262                                            "Already responded to message {}, skipping",
263                                            msg.message_id
264                                        );
265                                        continue;
266                                    }
267                                }
268
269                                if let Err(e) = client.send_message_by_id(&msg.channel_id, response_text).await {
270                                    tracing::error!("Failed to reply in Discord: {e}");
271                                } else {
272                                    // Record successful response
273                                    let mut state = dedup.lock().await;
274                                    state.record_responded(&msg.message_id);
275                                }
276                            } else if is_silent(response_text) {
277                                tracing::debug!(
278                                    "Silent response for message from {} in #{}",
279                                    msg.author_name,
280                                    channel_label,
281                                );
282                            }
283                        }
284                    }
285                    Ok(resp) => {
286                        tracing::warn!(
287                            "Chat endpoint returned {}",
288                            resp.status()
289                        );
290                    }
291                    Err(e) => {
292                        tracing::error!("Failed to forward to chat endpoint: {e}");
293                    }
294                }
295            }
296            _ = shutdown.notified() => return,
297        }
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use std::collections::HashMap;
305
306    #[tokio::test]
307    async fn test_health_down_before_start() {
308        let config = Config {
309            bot_token: "test".to_string(),
310            guild_id: "123".to_string(),
311            listen_channels: vec![],
312            allowed_user_ids: vec![],
313            chat_endpoint: "http://localhost:3100/chat".to_string(),
314            chat_secret: None,
315            chat_channel_name: "discord".to_string(),
316            channels: HashMap::new(),
317        };
318        let echo = DiscordEcho::new(config);
319        let health = Plugin::health(&echo).await;
320        assert!(matches!(health, HealthStatus::Down(_)));
321    }
322
323    #[test]
324    fn test_setup_prompts_not_empty() {
325        let config = Config {
326            bot_token: "test".to_string(),
327            guild_id: "123".to_string(),
328            listen_channels: vec![],
329            allowed_user_ids: vec![],
330            chat_endpoint: "http://localhost:3100/chat".to_string(),
331            chat_secret: None,
332            chat_channel_name: "discord".to_string(),
333            channels: HashMap::new(),
334        };
335        let echo = DiscordEcho::new(config);
336        let prompts = Plugin::setup_prompts(&echo);
337        assert!(!prompts.is_empty());
338        assert!(prompts.iter().any(|p| p.key == "bot_token"));
339        assert!(prompts.iter().any(|p| p.key == "guild_id"));
340    }
341
342    #[test]
343    fn test_is_silent() {
344        assert!(is_silent("[SILENT]"));
345        assert!(is_silent("[SILENT] I have nothing to add"));
346        assert!(is_silent("[NO_RESPONSE]"));
347        assert!(is_silent("No response requested"));
348        assert!(is_silent("No response requested."));
349        assert!(is_silent("  [SILENT]  ")); // trimmed
350        assert!(!is_silent("Hello, how are you?"));
351        assert!(!is_silent(""));
352        assert!(!is_silent("I think [SILENT] is interesting")); // not at start
353    }
354
355    #[test]
356    fn test_client_reference() {
357        let config = Config {
358            bot_token: "test".to_string(),
359            guild_id: "123".to_string(),
360            listen_channels: vec![],
361            allowed_user_ids: vec![],
362            chat_endpoint: "http://localhost:3100/chat".to_string(),
363            chat_secret: None,
364            chat_channel_name: "discord".to_string(),
365            channels: HashMap::from([("test".to_string(), "456".to_string())]),
366        };
367        let echo = DiscordEcho::new(config);
368        let client = echo.client();
369        assert_eq!(client.resolve_channel("test"), Some("456"));
370    }
371
372    /// Verify that the dedup state is shared (Arc) and survives cloning
373    /// the reference — simulating what happens across gateway reconnects.
374    #[tokio::test]
375    async fn test_dedup_state_persists_across_clones() {
376        let config = Config {
377            bot_token: "test".to_string(),
378            guild_id: "123".to_string(),
379            listen_channels: vec![],
380            allowed_user_ids: vec![],
381            chat_endpoint: "http://localhost:3100/chat".to_string(),
382            chat_secret: None,
383            chat_channel_name: "discord".to_string(),
384            channels: HashMap::new(),
385        };
386        let echo = DiscordEcho::new(config);
387
388        // Clone the Arc (simulates what the forwarder task does)
389        let dedup_ref = Arc::clone(&echo.dedup);
390
391        // Record a message as seen through the cloned reference
392        {
393            let mut state = dedup_ref.lock().await;
394            assert!(!state.check_and_record_seen("msg-abc"));
395        }
396
397        // The original Arc should see it too
398        {
399            let mut state = echo.dedup.lock().await;
400            assert!(state.check_and_record_seen("msg-abc"));
401        }
402    }
403}