Skip to main content

construct/channels/
reddit.rs

1use super::traits::{Channel, ChannelMessage, SendMessage};
2use anyhow::{Result, bail};
3use async_trait::async_trait;
4use parking_lot::Mutex;
5use serde::Deserialize;
6use std::time::{Duration, Instant};
7
8/// Reddit channel — polls for mentions, DMs, and comment replies via Reddit OAuth2 API.
9pub struct RedditChannel {
10    client_id: String,
11    client_secret: String,
12    refresh_token: String,
13    username: String,
14    subreddit: Option<String>,
15    auth: Mutex<RedditAuth>,
16}
17
18struct RedditAuth {
19    access_token: String,
20    expires_at: Instant,
21}
22
23#[derive(Deserialize)]
24struct RedditTokenResponse {
25    access_token: String,
26    expires_in: u64,
27}
28
29#[derive(Deserialize)]
30struct RedditListing {
31    data: RedditListingData,
32}
33
34#[derive(Deserialize)]
35struct RedditListingData {
36    children: Vec<RedditChild>,
37}
38
39#[derive(Deserialize)]
40struct RedditChild {
41    data: RedditItemData,
42}
43
44#[allow(dead_code)]
45#[derive(Deserialize)]
46struct RedditItemData {
47    name: Option<String>,
48    author: Option<String>,
49    body: Option<String>,
50    subject: Option<String>,
51    parent_id: Option<String>,
52    link_id: Option<String>,
53    subreddit: Option<String>,
54    created_utc: Option<f64>,
55    new: Option<bool>,
56    #[serde(rename = "type")]
57    message_type: Option<String>,
58    context: Option<String>,
59}
60
61const REDDIT_API_BASE: &str = "https://oauth.reddit.com";
62const REDDIT_TOKEN_URL: &str = "https://www.reddit.com/api/v1/access_token";
63const USER_AGENT: &str = "construct:channel:v0.1.0 (by /u/construct-bot)";
64/// Reddit enforces 60 requests per minute.
65const POLL_INTERVAL: Duration = Duration::from_secs(5);
66
67impl RedditChannel {
68    pub fn new(
69        client_id: String,
70        client_secret: String,
71        refresh_token: String,
72        username: String,
73        subreddit: Option<String>,
74    ) -> Self {
75        Self {
76            client_id,
77            client_secret,
78            refresh_token,
79            username,
80            subreddit,
81            auth: Mutex::new(RedditAuth {
82                access_token: String::new(),
83                expires_at: Instant::now(),
84            }),
85        }
86    }
87
88    fn http_client(&self) -> reqwest::Client {
89        crate::config::build_runtime_proxy_client("channel.reddit")
90    }
91
92    /// Refresh the OAuth2 access token using the refresh token.
93    async fn refresh_access_token(&self) -> Result<()> {
94        let client = self.http_client();
95        let resp = client
96            .post(REDDIT_TOKEN_URL)
97            .basic_auth(&self.client_id, Some(&self.client_secret))
98            .header("User-Agent", USER_AGENT)
99            .form(&[
100                ("grant_type", "refresh_token"),
101                ("refresh_token", &self.refresh_token),
102            ])
103            .send()
104            .await?;
105
106        let status = resp.status();
107        if !status.is_success() {
108            let body = resp
109                .text()
110                .await
111                .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
112            bail!("Reddit token refresh failed ({status}): {body}");
113        }
114
115        let token_resp: RedditTokenResponse = resp.json().await?;
116        let mut auth = self.auth.lock();
117        auth.access_token = token_resp.access_token;
118        auth.expires_at =
119            Instant::now() + Duration::from_secs(token_resp.expires_in.saturating_sub(60));
120        Ok(())
121    }
122
123    /// Get a valid access token, refreshing if expired.
124    async fn get_access_token(&self) -> Result<String> {
125        {
126            let auth = self.auth.lock();
127            if !auth.access_token.is_empty() && Instant::now() < auth.expires_at {
128                return Ok(auth.access_token.clone());
129            }
130        }
131        self.refresh_access_token().await?;
132        let auth = self.auth.lock();
133        Ok(auth.access_token.clone())
134    }
135
136    /// Fetch unread inbox items (mentions, DMs, comment replies).
137    async fn fetch_inbox(&self) -> Result<Vec<RedditChild>> {
138        let token = self.get_access_token().await?;
139        let client = self.http_client();
140
141        let resp = client
142            .get(format!("{REDDIT_API_BASE}/message/unread"))
143            .bearer_auth(&token)
144            .header("User-Agent", USER_AGENT)
145            .query(&[("limit", "25")])
146            .send()
147            .await?;
148
149        let status = resp.status();
150        if !status.is_success() {
151            let body = resp
152                .text()
153                .await
154                .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
155            tracing::warn!("Reddit inbox fetch failed ({status}): {body}");
156            return Ok(Vec::new());
157        }
158
159        let listing: RedditListing = resp.json().await?;
160        Ok(listing.data.children)
161    }
162
163    /// Mark inbox items as read.
164    async fn mark_read(&self, fullnames: &[String]) -> Result<()> {
165        if fullnames.is_empty() {
166            return Ok(());
167        }
168        let token = self.get_access_token().await?;
169        let client = self.http_client();
170
171        let ids = fullnames.join(",");
172        let resp = client
173            .post(format!("{REDDIT_API_BASE}/api/read_message"))
174            .bearer_auth(&token)
175            .header("User-Agent", USER_AGENT)
176            .form(&[("id", ids.as_str())])
177            .send()
178            .await?;
179
180        if !resp.status().is_success() {
181            tracing::warn!("Reddit mark_read failed: {}", resp.status());
182        }
183        Ok(())
184    }
185
186    /// Parse a Reddit inbox item into a ChannelMessage.
187    fn parse_item(&self, item: &RedditItemData) -> Option<ChannelMessage> {
188        let author = item.author.as_deref().unwrap_or("");
189        let body = item.body.as_deref().unwrap_or("");
190        let name = item.name.as_deref().unwrap_or("");
191
192        // Skip messages from ourselves
193        if author.eq_ignore_ascii_case(&self.username) || author.is_empty() || body.is_empty() {
194            return None;
195        }
196
197        // If a subreddit filter is set, skip items from other subreddits
198        if let Some(ref sub) = self.subreddit {
199            if let Some(ref item_sub) = item.subreddit {
200                if !item_sub.eq_ignore_ascii_case(sub) {
201                    return None;
202                }
203            }
204        }
205
206        // Determine reply target: for comment replies use the parent thing name,
207        // for DMs reply to the author.
208        let reply_target =
209            if item.message_type.as_deref() == Some("comment_reply") || item.parent_id.is_some() {
210                // For comment replies, the recipient is the parent fullname
211                item.parent_id.clone().unwrap_or_else(|| name.to_string())
212            } else {
213                // For DMs, reply to the author
214                author.to_string()
215            };
216
217        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
218        let timestamp = item.created_utc.unwrap_or(0.0) as u64;
219
220        Some(ChannelMessage {
221            id: format!("reddit_{name}"),
222            sender: author.to_string(),
223            reply_target,
224            content: body.to_string(),
225            channel: "reddit".to_string(),
226            timestamp,
227            thread_ts: item.parent_id.clone(),
228            interruption_scope_id: None,
229            attachments: vec![],
230        })
231    }
232}
233
234#[async_trait]
235impl Channel for RedditChannel {
236    fn name(&self) -> &str {
237        "reddit"
238    }
239
240    async fn send(&self, message: &SendMessage) -> Result<()> {
241        let token = self.get_access_token().await?;
242        let client = self.http_client();
243
244        // If recipient looks like a Reddit fullname (t1_, t3_, t4_), it's a comment reply.
245        // Otherwise treat it as a DM to a username.
246        if message.recipient.starts_with("t1_")
247            || message.recipient.starts_with("t3_")
248            || message.recipient.starts_with("t4_")
249        {
250            // Comment reply
251            let resp = client
252                .post(format!("{REDDIT_API_BASE}/api/comment"))
253                .bearer_auth(&token)
254                .header("User-Agent", USER_AGENT)
255                .form(&[
256                    ("thing_id", message.recipient.as_str()),
257                    ("text", &message.content),
258                ])
259                .send()
260                .await?;
261
262            let status = resp.status();
263            if !status.is_success() {
264                let body = resp
265                    .text()
266                    .await
267                    .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
268                bail!("Reddit comment reply failed ({status}): {body}");
269            }
270        } else {
271            // Direct message
272            let subject = message
273                .subject
274                .as_deref()
275                .unwrap_or("Message from Construct");
276            let resp = client
277                .post(format!("{REDDIT_API_BASE}/api/compose"))
278                .bearer_auth(&token)
279                .header("User-Agent", USER_AGENT)
280                .form(&[
281                    ("to", message.recipient.as_str()),
282                    ("subject", subject),
283                    ("text", &message.content),
284                ])
285                .send()
286                .await?;
287
288            let status = resp.status();
289            if !status.is_success() {
290                let body = resp
291                    .text()
292                    .await
293                    .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
294                bail!("Reddit DM failed ({status}): {body}");
295            }
296        }
297
298        Ok(())
299    }
300
301    async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
302        // Initial auth
303        self.refresh_access_token().await?;
304
305        tracing::info!(
306            "Reddit channel listening as u/{} {}...",
307            self.username,
308            self.subreddit
309                .as_ref()
310                .map(|s| format!("in r/{s}"))
311                .unwrap_or_default()
312        );
313
314        loop {
315            tokio::time::sleep(POLL_INTERVAL).await;
316
317            let items = match self.fetch_inbox().await {
318                Ok(items) => items,
319                Err(e) => {
320                    tracing::warn!("Reddit poll error: {e}");
321                    continue;
322                }
323            };
324
325            let mut read_ids = Vec::new();
326            for child in &items {
327                if let Some(ref name) = child.data.name {
328                    read_ids.push(name.clone());
329                }
330                if let Some(msg) = self.parse_item(&child.data) {
331                    if tx.send(msg).await.is_err() {
332                        return Ok(());
333                    }
334                }
335            }
336
337            if let Err(e) = self.mark_read(&read_ids).await {
338                tracing::warn!("Reddit mark_read error: {e}");
339            }
340        }
341    }
342
343    async fn health_check(&self) -> bool {
344        self.get_access_token().await.is_ok()
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    fn make_channel() -> RedditChannel {
353        RedditChannel::new(
354            "client_id".into(),
355            "client_secret".into(),
356            "refresh_token".into(),
357            "testbot".into(),
358            None,
359        )
360    }
361
362    fn make_channel_with_sub(sub: &str) -> RedditChannel {
363        RedditChannel::new(
364            "client_id".into(),
365            "client_secret".into(),
366            "refresh_token".into(),
367            "testbot".into(),
368            Some(sub.into()),
369        )
370    }
371
372    #[test]
373    fn parse_comment_reply() {
374        let ch = make_channel();
375        let item = RedditItemData {
376            name: Some("t1_abc123".into()),
377            author: Some("user1".into()),
378            body: Some("hello bot".into()),
379            subject: None,
380            parent_id: Some("t1_parent1".into()),
381            link_id: Some("t3_post1".into()),
382            subreddit: Some("rust".into()),
383            created_utc: Some(1_700_000_000.0),
384            new: Some(true),
385            message_type: Some("comment_reply".into()),
386            context: None,
387        };
388
389        let msg = ch.parse_item(&item).unwrap();
390        assert_eq!(msg.sender, "user1");
391        assert_eq!(msg.content, "hello bot");
392        assert_eq!(msg.reply_target, "t1_parent1");
393        assert_eq!(msg.channel, "reddit");
394        assert_eq!(msg.id, "reddit_t1_abc123");
395    }
396
397    #[test]
398    fn parse_dm() {
399        let ch = make_channel();
400        let item = RedditItemData {
401            name: Some("t4_dm456".into()),
402            author: Some("user2".into()),
403            body: Some("private message".into()),
404            subject: Some("Hello".into()),
405            parent_id: None,
406            link_id: None,
407            subreddit: None,
408            created_utc: Some(1_700_000_100.0),
409            new: Some(true),
410            message_type: None,
411            context: None,
412        };
413
414        let msg = ch.parse_item(&item).unwrap();
415        assert_eq!(msg.sender, "user2");
416        assert_eq!(msg.content, "private message");
417        assert_eq!(msg.reply_target, "user2"); // DM reply goes to author
418    }
419
420    #[test]
421    fn skip_self_messages() {
422        let ch = make_channel();
423        let item = RedditItemData {
424            name: Some("t1_self".into()),
425            author: Some("testbot".into()),
426            body: Some("my own message".into()),
427            subject: None,
428            parent_id: None,
429            link_id: None,
430            subreddit: None,
431            created_utc: Some(1_700_000_000.0),
432            new: Some(true),
433            message_type: None,
434            context: None,
435        };
436
437        assert!(ch.parse_item(&item).is_none());
438    }
439
440    #[test]
441    fn skip_empty_body() {
442        let ch = make_channel();
443        let item = RedditItemData {
444            name: Some("t1_empty".into()),
445            author: Some("user1".into()),
446            body: Some(String::new()),
447            subject: None,
448            parent_id: None,
449            link_id: None,
450            subreddit: None,
451            created_utc: Some(1_700_000_000.0),
452            new: Some(true),
453            message_type: None,
454            context: None,
455        };
456
457        assert!(ch.parse_item(&item).is_none());
458    }
459
460    #[test]
461    fn subreddit_filter() {
462        let ch = make_channel_with_sub("rust");
463        let item = RedditItemData {
464            name: Some("t1_other".into()),
465            author: Some("user1".into()),
466            body: Some("hello".into()),
467            subject: None,
468            parent_id: None,
469            link_id: None,
470            subreddit: Some("python".into()),
471            created_utc: Some(1_700_000_000.0),
472            new: Some(true),
473            message_type: None,
474            context: None,
475        };
476
477        assert!(ch.parse_item(&item).is_none());
478
479        let matching_item = RedditItemData {
480            name: Some("t1_match".into()),
481            author: Some("user1".into()),
482            body: Some("hello".into()),
483            subject: None,
484            parent_id: None,
485            link_id: None,
486            subreddit: Some("rust".into()),
487            created_utc: Some(1_700_000_000.0),
488            new: Some(true),
489            message_type: None,
490            context: None,
491        };
492
493        assert!(ch.parse_item(&matching_item).is_some());
494    }
495
496    #[test]
497    fn send_message_formatting() {
498        // Verify SendMessage can be constructed for both DM and comment reply
499        let dm = SendMessage::new("hello", "user1");
500        assert_eq!(dm.recipient, "user1");
501        assert_eq!(dm.content, "hello");
502
503        let reply = SendMessage::new("response", "t1_abc123");
504        assert!(reply.recipient.starts_with("t1_"));
505    }
506}