1use super::traits::{Channel, ChannelMessage, SendMessage};
2use anyhow::{Result, bail};
3use async_trait::async_trait;
4use parking_lot::Mutex;
5use serde::Deserialize;
6use std::time::{Duration, Instant};
7
8pub struct RedditChannel {
10 client_id: String,
11 client_secret: String,
12 refresh_token: String,
13 username: String,
14 subreddit: Option<String>,
15 auth: Mutex<RedditAuth>,
16}
17
18struct RedditAuth {
19 access_token: String,
20 expires_at: Instant,
21}
22
23#[derive(Deserialize)]
24struct RedditTokenResponse {
25 access_token: String,
26 expires_in: u64,
27}
28
29#[derive(Deserialize)]
30struct RedditListing {
31 data: RedditListingData,
32}
33
34#[derive(Deserialize)]
35struct RedditListingData {
36 children: Vec<RedditChild>,
37}
38
39#[derive(Deserialize)]
40struct RedditChild {
41 data: RedditItemData,
42}
43
44#[allow(dead_code)]
45#[derive(Deserialize)]
46struct RedditItemData {
47 name: Option<String>,
48 author: Option<String>,
49 body: Option<String>,
50 subject: Option<String>,
51 parent_id: Option<String>,
52 link_id: Option<String>,
53 subreddit: Option<String>,
54 created_utc: Option<f64>,
55 new: Option<bool>,
56 #[serde(rename = "type")]
57 message_type: Option<String>,
58 context: Option<String>,
59}
60
61const REDDIT_API_BASE: &str = "https://oauth.reddit.com";
62const REDDIT_TOKEN_URL: &str = "https://www.reddit.com/api/v1/access_token";
63const USER_AGENT: &str = "construct:channel:v0.1.0 (by /u/construct-bot)";
64const POLL_INTERVAL: Duration = Duration::from_secs(5);
66
67impl RedditChannel {
68 pub fn new(
69 client_id: String,
70 client_secret: String,
71 refresh_token: String,
72 username: String,
73 subreddit: Option<String>,
74 ) -> Self {
75 Self {
76 client_id,
77 client_secret,
78 refresh_token,
79 username,
80 subreddit,
81 auth: Mutex::new(RedditAuth {
82 access_token: String::new(),
83 expires_at: Instant::now(),
84 }),
85 }
86 }
87
88 fn http_client(&self) -> reqwest::Client {
89 crate::config::build_runtime_proxy_client("channel.reddit")
90 }
91
92 async fn refresh_access_token(&self) -> Result<()> {
94 let client = self.http_client();
95 let resp = client
96 .post(REDDIT_TOKEN_URL)
97 .basic_auth(&self.client_id, Some(&self.client_secret))
98 .header("User-Agent", USER_AGENT)
99 .form(&[
100 ("grant_type", "refresh_token"),
101 ("refresh_token", &self.refresh_token),
102 ])
103 .send()
104 .await?;
105
106 let status = resp.status();
107 if !status.is_success() {
108 let body = resp
109 .text()
110 .await
111 .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
112 bail!("Reddit token refresh failed ({status}): {body}");
113 }
114
115 let token_resp: RedditTokenResponse = resp.json().await?;
116 let mut auth = self.auth.lock();
117 auth.access_token = token_resp.access_token;
118 auth.expires_at =
119 Instant::now() + Duration::from_secs(token_resp.expires_in.saturating_sub(60));
120 Ok(())
121 }
122
123 async fn get_access_token(&self) -> Result<String> {
125 {
126 let auth = self.auth.lock();
127 if !auth.access_token.is_empty() && Instant::now() < auth.expires_at {
128 return Ok(auth.access_token.clone());
129 }
130 }
131 self.refresh_access_token().await?;
132 let auth = self.auth.lock();
133 Ok(auth.access_token.clone())
134 }
135
136 async fn fetch_inbox(&self) -> Result<Vec<RedditChild>> {
138 let token = self.get_access_token().await?;
139 let client = self.http_client();
140
141 let resp = client
142 .get(format!("{REDDIT_API_BASE}/message/unread"))
143 .bearer_auth(&token)
144 .header("User-Agent", USER_AGENT)
145 .query(&[("limit", "25")])
146 .send()
147 .await?;
148
149 let status = resp.status();
150 if !status.is_success() {
151 let body = resp
152 .text()
153 .await
154 .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
155 tracing::warn!("Reddit inbox fetch failed ({status}): {body}");
156 return Ok(Vec::new());
157 }
158
159 let listing: RedditListing = resp.json().await?;
160 Ok(listing.data.children)
161 }
162
163 async fn mark_read(&self, fullnames: &[String]) -> Result<()> {
165 if fullnames.is_empty() {
166 return Ok(());
167 }
168 let token = self.get_access_token().await?;
169 let client = self.http_client();
170
171 let ids = fullnames.join(",");
172 let resp = client
173 .post(format!("{REDDIT_API_BASE}/api/read_message"))
174 .bearer_auth(&token)
175 .header("User-Agent", USER_AGENT)
176 .form(&[("id", ids.as_str())])
177 .send()
178 .await?;
179
180 if !resp.status().is_success() {
181 tracing::warn!("Reddit mark_read failed: {}", resp.status());
182 }
183 Ok(())
184 }
185
186 fn parse_item(&self, item: &RedditItemData) -> Option<ChannelMessage> {
188 let author = item.author.as_deref().unwrap_or("");
189 let body = item.body.as_deref().unwrap_or("");
190 let name = item.name.as_deref().unwrap_or("");
191
192 if author.eq_ignore_ascii_case(&self.username) || author.is_empty() || body.is_empty() {
194 return None;
195 }
196
197 if let Some(ref sub) = self.subreddit {
199 if let Some(ref item_sub) = item.subreddit {
200 if !item_sub.eq_ignore_ascii_case(sub) {
201 return None;
202 }
203 }
204 }
205
206 let reply_target =
209 if item.message_type.as_deref() == Some("comment_reply") || item.parent_id.is_some() {
210 item.parent_id.clone().unwrap_or_else(|| name.to_string())
212 } else {
213 author.to_string()
215 };
216
217 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
218 let timestamp = item.created_utc.unwrap_or(0.0) as u64;
219
220 Some(ChannelMessage {
221 id: format!("reddit_{name}"),
222 sender: author.to_string(),
223 reply_target,
224 content: body.to_string(),
225 channel: "reddit".to_string(),
226 timestamp,
227 thread_ts: item.parent_id.clone(),
228 interruption_scope_id: None,
229 attachments: vec![],
230 })
231 }
232}
233
234#[async_trait]
235impl Channel for RedditChannel {
236 fn name(&self) -> &str {
237 "reddit"
238 }
239
240 async fn send(&self, message: &SendMessage) -> Result<()> {
241 let token = self.get_access_token().await?;
242 let client = self.http_client();
243
244 if message.recipient.starts_with("t1_")
247 || message.recipient.starts_with("t3_")
248 || message.recipient.starts_with("t4_")
249 {
250 let resp = client
252 .post(format!("{REDDIT_API_BASE}/api/comment"))
253 .bearer_auth(&token)
254 .header("User-Agent", USER_AGENT)
255 .form(&[
256 ("thing_id", message.recipient.as_str()),
257 ("text", &message.content),
258 ])
259 .send()
260 .await?;
261
262 let status = resp.status();
263 if !status.is_success() {
264 let body = resp
265 .text()
266 .await
267 .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
268 bail!("Reddit comment reply failed ({status}): {body}");
269 }
270 } else {
271 let subject = message
273 .subject
274 .as_deref()
275 .unwrap_or("Message from Construct");
276 let resp = client
277 .post(format!("{REDDIT_API_BASE}/api/compose"))
278 .bearer_auth(&token)
279 .header("User-Agent", USER_AGENT)
280 .form(&[
281 ("to", message.recipient.as_str()),
282 ("subject", subject),
283 ("text", &message.content),
284 ])
285 .send()
286 .await?;
287
288 let status = resp.status();
289 if !status.is_success() {
290 let body = resp
291 .text()
292 .await
293 .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
294 bail!("Reddit DM failed ({status}): {body}");
295 }
296 }
297
298 Ok(())
299 }
300
301 async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
302 self.refresh_access_token().await?;
304
305 tracing::info!(
306 "Reddit channel listening as u/{} {}...",
307 self.username,
308 self.subreddit
309 .as_ref()
310 .map(|s| format!("in r/{s}"))
311 .unwrap_or_default()
312 );
313
314 loop {
315 tokio::time::sleep(POLL_INTERVAL).await;
316
317 let items = match self.fetch_inbox().await {
318 Ok(items) => items,
319 Err(e) => {
320 tracing::warn!("Reddit poll error: {e}");
321 continue;
322 }
323 };
324
325 let mut read_ids = Vec::new();
326 for child in &items {
327 if let Some(ref name) = child.data.name {
328 read_ids.push(name.clone());
329 }
330 if let Some(msg) = self.parse_item(&child.data) {
331 if tx.send(msg).await.is_err() {
332 return Ok(());
333 }
334 }
335 }
336
337 if let Err(e) = self.mark_read(&read_ids).await {
338 tracing::warn!("Reddit mark_read error: {e}");
339 }
340 }
341 }
342
343 async fn health_check(&self) -> bool {
344 self.get_access_token().await.is_ok()
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 fn make_channel() -> RedditChannel {
353 RedditChannel::new(
354 "client_id".into(),
355 "client_secret".into(),
356 "refresh_token".into(),
357 "testbot".into(),
358 None,
359 )
360 }
361
362 fn make_channel_with_sub(sub: &str) -> RedditChannel {
363 RedditChannel::new(
364 "client_id".into(),
365 "client_secret".into(),
366 "refresh_token".into(),
367 "testbot".into(),
368 Some(sub.into()),
369 )
370 }
371
372 #[test]
373 fn parse_comment_reply() {
374 let ch = make_channel();
375 let item = RedditItemData {
376 name: Some("t1_abc123".into()),
377 author: Some("user1".into()),
378 body: Some("hello bot".into()),
379 subject: None,
380 parent_id: Some("t1_parent1".into()),
381 link_id: Some("t3_post1".into()),
382 subreddit: Some("rust".into()),
383 created_utc: Some(1_700_000_000.0),
384 new: Some(true),
385 message_type: Some("comment_reply".into()),
386 context: None,
387 };
388
389 let msg = ch.parse_item(&item).unwrap();
390 assert_eq!(msg.sender, "user1");
391 assert_eq!(msg.content, "hello bot");
392 assert_eq!(msg.reply_target, "t1_parent1");
393 assert_eq!(msg.channel, "reddit");
394 assert_eq!(msg.id, "reddit_t1_abc123");
395 }
396
397 #[test]
398 fn parse_dm() {
399 let ch = make_channel();
400 let item = RedditItemData {
401 name: Some("t4_dm456".into()),
402 author: Some("user2".into()),
403 body: Some("private message".into()),
404 subject: Some("Hello".into()),
405 parent_id: None,
406 link_id: None,
407 subreddit: None,
408 created_utc: Some(1_700_000_100.0),
409 new: Some(true),
410 message_type: None,
411 context: None,
412 };
413
414 let msg = ch.parse_item(&item).unwrap();
415 assert_eq!(msg.sender, "user2");
416 assert_eq!(msg.content, "private message");
417 assert_eq!(msg.reply_target, "user2"); }
419
420 #[test]
421 fn skip_self_messages() {
422 let ch = make_channel();
423 let item = RedditItemData {
424 name: Some("t1_self".into()),
425 author: Some("testbot".into()),
426 body: Some("my own message".into()),
427 subject: None,
428 parent_id: None,
429 link_id: None,
430 subreddit: None,
431 created_utc: Some(1_700_000_000.0),
432 new: Some(true),
433 message_type: None,
434 context: None,
435 };
436
437 assert!(ch.parse_item(&item).is_none());
438 }
439
440 #[test]
441 fn skip_empty_body() {
442 let ch = make_channel();
443 let item = RedditItemData {
444 name: Some("t1_empty".into()),
445 author: Some("user1".into()),
446 body: Some(String::new()),
447 subject: None,
448 parent_id: None,
449 link_id: None,
450 subreddit: None,
451 created_utc: Some(1_700_000_000.0),
452 new: Some(true),
453 message_type: None,
454 context: None,
455 };
456
457 assert!(ch.parse_item(&item).is_none());
458 }
459
460 #[test]
461 fn subreddit_filter() {
462 let ch = make_channel_with_sub("rust");
463 let item = RedditItemData {
464 name: Some("t1_other".into()),
465 author: Some("user1".into()),
466 body: Some("hello".into()),
467 subject: None,
468 parent_id: None,
469 link_id: None,
470 subreddit: Some("python".into()),
471 created_utc: Some(1_700_000_000.0),
472 new: Some(true),
473 message_type: None,
474 context: None,
475 };
476
477 assert!(ch.parse_item(&item).is_none());
478
479 let matching_item = RedditItemData {
480 name: Some("t1_match".into()),
481 author: Some("user1".into()),
482 body: Some("hello".into()),
483 subject: None,
484 parent_id: None,
485 link_id: None,
486 subreddit: Some("rust".into()),
487 created_utc: Some(1_700_000_000.0),
488 new: Some(true),
489 message_type: None,
490 context: None,
491 };
492
493 assert!(ch.parse_item(&matching_item).is_some());
494 }
495
496 #[test]
497 fn send_message_formatting() {
498 let dm = SendMessage::new("hello", "user1");
500 assert_eq!(dm.recipient, "user1");
501 assert_eq!(dm.content, "hello");
502
503 let reply = SendMessage::new("response", "t1_abc123");
504 assert!(reply.recipient.starts_with("t1_"));
505 }
506}