use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use tokio::sync::RwLock;
use tracing::{info, warn};
use punch_types::{PunchError, PunchResult};
use crate::{ChannelAdapter, ChannelPlatform, ChannelStatus, IncomingMessage};
const REDDIT_API_BASE: &str = "https://oauth.reddit.com";
const REDDIT_TOKEN_URL: &str = "https://www.reddit.com/api/v1/access_token";
const REDDIT_USER_AGENT: &str = "punch-agent-os:v0.1.0 (by /u/punch-bot)";
pub struct RedditAdapter {
client_id: String,
client_secret: String,
refresh_token: String,
#[allow(dead_code)]
subreddit: String,
access_token: RwLock<String>,
client: reqwest::Client,
running: AtomicBool,
started_at: RwLock<Option<DateTime<Utc>>>,
messages_received: AtomicU64,
messages_sent: AtomicU64,
}
impl RedditAdapter {
pub fn new(
client_id: String,
client_secret: String,
refresh_token: String,
subreddit: String,
) -> Self {
Self {
client_id,
client_secret,
refresh_token,
subreddit,
access_token: RwLock::new(String::new()),
client: reqwest::Client::new(),
running: AtomicBool::new(false),
started_at: RwLock::new(None),
messages_received: AtomicU64::new(0),
messages_sent: AtomicU64::new(0),
}
}
pub async fn refresh_access_token(&self) -> PunchResult<()> {
let resp = self
.client
.post(REDDIT_TOKEN_URL)
.basic_auth(&self.client_id, Some(&self.client_secret))
.header("User-Agent", REDDIT_USER_AGENT)
.form(&[
("grant_type", "refresh_token"),
("refresh_token", &self.refresh_token),
])
.send()
.await
.map_err(|e| PunchError::Channel {
channel: "reddit".to_string(),
message: format!("failed to refresh token: {e}"),
})?;
let status = resp.status();
if !status.is_success() {
let body_text = resp.text().await.unwrap_or_default();
return Err(PunchError::Channel {
channel: "reddit".to_string(),
message: format!("token refresh failed ({status}): {body_text}"),
});
}
let body: serde_json::Value = resp.json().await.map_err(|e| PunchError::Channel {
channel: "reddit".to_string(),
message: format!("failed to parse token response: {e}"),
})?;
let token = body
.get("access_token")
.and_then(|v| v.as_str())
.ok_or_else(|| PunchError::Channel {
channel: "reddit".to_string(),
message: "no access_token in response".to_string(),
})?;
*self.access_token.write().await = token.to_string();
info!("Reddit OAuth2 access token refreshed");
Ok(())
}
pub fn parse_webhook_payload(&self, payload: &serde_json::Value) -> Option<IncomingMessage> {
let kind = payload.get("kind")?.as_str()?;
let data = payload.get("data")?;
let author = data.get("author")?.as_str()?;
let body = data.get("body")?.as_str()?;
if body.is_empty() {
return None;
}
let thing_id = data.get("id")?.as_str()?;
let created_utc = data
.get("created_utc")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let timestamp = DateTime::from_timestamp(created_utc as i64, 0).unwrap_or_else(Utc::now);
let (channel_id, is_group) = match kind {
"t1" => {
let link_id = data
.get("link_id")
.and_then(|v| v.as_str())
.unwrap_or(thing_id);
(link_id.to_string(), true)
}
"t4" => (author.to_string(), false),
_ => return None,
};
let mut metadata = HashMap::new();
if let Some(subreddit) = data.get("subreddit").and_then(|v| v.as_str()) {
metadata.insert(
"subreddit".to_string(),
serde_json::Value::String(subreddit.to_string()),
);
}
if let Some(parent_id) = data.get("parent_id").and_then(|v| v.as_str()) {
metadata.insert(
"parent_id".to_string(),
serde_json::Value::String(parent_id.to_string()),
);
}
let full_id = format!("{}_{}", kind, thing_id);
self.messages_received.fetch_add(1, Ordering::Relaxed);
Some(IncomingMessage {
channel_id,
user_id: author.to_string(),
display_name: author.to_string(),
text: body.to_string(),
timestamp,
platform: ChannelPlatform::Reddit,
platform_message_id: full_id,
is_group,
metadata,
})
}
async fn api_post_comment(&self, parent_fullname: &str, text: &str) -> PunchResult<()> {
let url = format!("{}/api/comment", REDDIT_API_BASE);
let token = self.access_token.read().await.clone();
let resp = self
.client
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.header("User-Agent", REDDIT_USER_AGENT)
.form(&[("thing_id", parent_fullname), ("text", text)])
.send()
.await
.map_err(|e| PunchError::Channel {
channel: "reddit".to_string(),
message: format!("failed to post comment: {e}"),
})?;
let status = resp.status();
if let Some(remaining) = resp
.headers()
.get("x-ratelimit-remaining")
.and_then(|v| v.to_str().ok())
&& let Ok(remaining_f) = remaining.parse::<f64>()
&& remaining_f < 5.0
{
warn!(remaining = %remaining, "Reddit rate limit nearly exhausted");
}
if !status.is_success() {
let body_text = resp.text().await.unwrap_or_default();
warn!("Reddit post comment failed ({status}): {body_text}");
}
self.messages_sent.fetch_add(1, Ordering::Relaxed);
Ok(())
}
}
#[async_trait]
impl ChannelAdapter for RedditAdapter {
fn name(&self) -> &str {
"reddit"
}
fn platform(&self) -> ChannelPlatform {
ChannelPlatform::Reddit
}
async fn start(&self) -> PunchResult<()> {
self.refresh_access_token().await?;
self.running.store(true, Ordering::Relaxed);
*self.started_at.write().await = Some(Utc::now());
info!("Reddit adapter started");
Ok(())
}
async fn stop(&self) -> PunchResult<()> {
self.running.store(false, Ordering::Relaxed);
info!("Reddit adapter stopped");
Ok(())
}
async fn send_response(&self, channel_id: &str, message: &str) -> PunchResult<()> {
self.api_post_comment(channel_id, message).await
}
fn status(&self) -> ChannelStatus {
ChannelStatus {
connected: self.running.load(Ordering::Relaxed),
started_at: self.started_at.try_read().ok().and_then(|g| *g),
messages_received: self.messages_received.load(Ordering::Relaxed),
messages_sent: self.messages_sent.load(Ordering::Relaxed),
last_error: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_adapter() -> RedditAdapter {
RedditAdapter::new(
"test-client-id".to_string(),
"test-client-secret".to_string(),
"test-refresh-token".to_string(),
"testsubreddit".to_string(),
)
}
#[test]
fn test_reddit_adapter_creation() {
let adapter = make_adapter();
assert_eq!(adapter.name(), "reddit");
assert_eq!(adapter.platform(), ChannelPlatform::Reddit);
}
#[test]
fn test_parse_reddit_comment_mention() {
let adapter = make_adapter();
let payload = serde_json::json!({
"kind": "t1",
"data": {
"id": "abc123",
"author": "testuser",
"body": "u/punch-bot help me with this",
"subreddit": "testsubreddit",
"link_id": "t3_post789",
"parent_id": "t1_parent456",
"created_utc": 1700000000.0
}
});
let msg = adapter.parse_webhook_payload(&payload).unwrap();
assert_eq!(msg.platform, ChannelPlatform::Reddit);
assert_eq!(msg.user_id, "testuser");
assert_eq!(msg.text, "u/punch-bot help me with this");
assert_eq!(msg.channel_id, "t3_post789");
assert!(msg.is_group);
assert_eq!(
msg.metadata.get("subreddit").unwrap(),
&serde_json::Value::String("testsubreddit".to_string())
);
}
#[test]
fn test_parse_reddit_private_message() {
let adapter = make_adapter();
let payload = serde_json::json!({
"kind": "t4",
"data": {
"id": "msg789",
"author": "alice",
"body": "Hey bot, what's up?",
"created_utc": 1700000000.0
}
});
let msg = adapter.parse_webhook_payload(&payload).unwrap();
assert!(!msg.is_group);
assert_eq!(msg.channel_id, "alice");
assert_eq!(msg.platform_message_id, "t4_msg789");
}
#[test]
fn test_parse_reddit_empty_body() {
let adapter = make_adapter();
let payload = serde_json::json!({
"kind": "t1",
"data": {
"id": "abc",
"author": "user",
"body": "",
"subreddit": "test",
"link_id": "t3_x",
"created_utc": 1700000000.0
}
});
assert!(adapter.parse_webhook_payload(&payload).is_none());
}
#[test]
fn test_parse_reddit_unknown_kind() {
let adapter = make_adapter();
let payload = serde_json::json!({
"kind": "t3",
"data": {
"id": "post123",
"author": "user",
"body": "A submission",
"created_utc": 1700000000.0
}
});
assert!(adapter.parse_webhook_payload(&payload).is_none());
}
}