progscrape_scrapers/backends/
reddit.rs

1use std::{borrow::Cow, collections::HashMap};
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use super::{
7    GenericScrape, ScrapeConfigSource, ScrapeCore, ScrapeSource, ScrapeSourceDef, ScrapeStory,
8    Scraper, scrape_story, utils::html::unescape_entities,
9};
10use crate::{
11    datasci::titletrimmer::{AWKWARD_LENGTH, IDEAL_LENGTH, remove_tags, trim_title},
12    types::*,
13};
14
15pub struct Reddit {}
16
17impl ScrapeSourceDef for Reddit {
18    type Config = RedditConfig;
19    type Scrape = RedditStory;
20    type Scraper = RedditScraper;
21
22    fn comments_url(id: &str, subsource: Option<&str>) -> String {
23        if let Some(subsource) = subsource {
24            format!("https://www.reddit.com/r/{subsource}/comments/{id}/")
25        } else {
26            format!("https://www.reddit.com/comments/{id}/")
27        }
28    }
29
30    fn id_from_comments_url(url: &str) -> Option<(&str, Option<&str>)> {
31        let url = url.trim_end_matches('/');
32        if let Some(url) = url.strip_prefix("https://www.reddit.com/comments/") {
33            Some((url, None))
34        } else {
35            let url = url.strip_prefix("https://www.reddit.com/r/")?;
36            if let Some((subreddit, rest)) = url.split_once('/') {
37                if let Some((_, id)) = rest.split_once('/') {
38                    Some((id, Some(subreddit)))
39                } else {
40                    None
41                }
42            } else {
43                None
44            }
45        }
46    }
47
48    fn is_comments_host(host: &str) -> bool {
49        host.ends_with("reddit.com")
50    }
51}
52
53#[derive(Clone, Default, Serialize, Deserialize)]
54pub struct RedditConfig {
55    api: String,
56    subreddit_batch: usize,
57    limit: usize,
58    subreddits: HashMap<String, SubredditConfig>,
59}
60
61impl ScrapeConfigSource for RedditConfig {
62    fn subsources(&self) -> Vec<String> {
63        self.subreddits.iter().map(|s| s.0.clone()).collect()
64    }
65
66    fn provide_urls(&self, subsources: Vec<String>) -> Vec<String> {
67        let mut output = vec![];
68        for chunk in subsources.chunks(self.subreddit_batch) {
69            output.push(
70                self.api.replace("${subreddits}", &chunk.join("+"))
71                    + &format!("?limit={}", self.limit),
72            )
73        }
74        output
75    }
76}
77
78#[derive(Clone, Default, Serialize, Deserialize)]
79pub struct SubredditConfig {
80    #[serde(default)]
81    is_tag: bool,
82    #[serde(default)]
83    flair_is_tag: bool,
84}
85
86#[derive(Default)]
87pub struct RedditScraper {}
88
89scrape_story! {
90    RedditStory {
91        flair: String,
92        position: u32,
93        upvotes: u32,
94        downvotes: u32,
95        num_comments: u32,
96        score: u32,
97        upvote_ratio: f32,
98    }
99}
100
101impl ScrapeStory for RedditStory {
102    const TYPE: ScrapeSource = ScrapeSource::Reddit;
103
104    fn merge(&mut self, other: RedditStory) {
105        self.position = std::cmp::max(self.position, other.position);
106        self.upvotes = std::cmp::max(self.upvotes, other.upvotes);
107        self.downvotes = std::cmp::max(self.downvotes, other.downvotes);
108        self.num_comments = std::cmp::max(self.num_comments, other.num_comments);
109        self.score = std::cmp::max(self.score, other.score);
110        self.upvote_ratio = f32::max(self.upvote_ratio, other.upvote_ratio);
111    }
112}
113
114impl RedditScraper {
115    fn require_string(&self, data: &Value, key: &str) -> Result<String, String> {
116        Ok(data[key]
117            .as_str()
118            .ok_or(format!("Missing field {key:?}"))?
119            .to_owned())
120    }
121
122    fn optional_string(&self, data: &Value, key: &str) -> Result<String, String> {
123        Ok(data[key].as_str().unwrap_or_default().to_owned())
124    }
125
126    fn require_integer<T: TryFrom<i64> + TryFrom<u64>>(
127        &self,
128        data: &Value,
129        key: &str,
130    ) -> Result<T, String> {
131        if let Value::Number(n) = &data[key] {
132            if let Some(n) = n.as_u64() {
133                if let Ok(n) = n.try_into() {
134                    return Ok(n);
135                }
136            }
137            if let Some(n) = n.as_i64() {
138                if let Ok(n) = n.try_into() {
139                    return Ok(n);
140                }
141            }
142            if let Some(n) = n.as_f64() {
143                let n = n as i64;
144                if let Ok(n) = n.try_into() {
145                    return Ok(n);
146                }
147            }
148            Err(format!(
149                "Failed to parse {key} as integer (value was {n:?})"
150            ))
151        } else {
152            Err(format!(
153                "Missing or invalid field {:?} (value was {:?})",
154                key, data[key]
155            ))
156        }
157    }
158
159    fn require_float(&self, data: &Value, key: &str) -> Result<f64, String> {
160        if let Value::Number(n) = &data[key] {
161            if let Some(n) = n.as_u64() {
162                return Ok(n as f64);
163            }
164            if let Some(n) = n.as_i64() {
165                return Ok(n as f64);
166            }
167            if let Some(n) = n.as_f64() {
168                return Ok(n);
169            }
170            Err(format!("Failed to parse {key} as float (value was {n:?})"))
171        } else {
172            Err(format!(
173                "Missing or invalid field {:?} (value was {:?})",
174                key, data[key]
175            ))
176        }
177    }
178
179    fn map_story(
180        &self,
181        child: &Value,
182        positions: &mut HashMap<String, u32>,
183    ) -> Result<GenericScrape<<Self as Scraper>::Output>, String> {
184        let kind = child["kind"].as_str();
185        let data = if kind == Some("t3") {
186            &child["data"]
187        } else {
188            return Err(format!("Unknown story type: {kind:?}"));
189        };
190
191        let id = self.require_string(data, "id")?;
192        let subreddit = self.require_string(data, "subreddit")?.to_ascii_lowercase();
193        if let Some(true) = data["stickied"].as_bool() {
194            return Err(format!("Ignoring stickied story {subreddit}/{id}"));
195        }
196        let position = *positions
197            .entry(subreddit.clone())
198            .and_modify(|n| *n += 1)
199            .or_default()
200            + 1;
201        let seconds: i64 = self.require_integer(data, "created_utc")?;
202        let millis = seconds * 1000;
203        let date = StoryDate::from_millis(millis).ok_or_else(|| "Unmappable date".to_string())?;
204        let url = StoryUrl::parse(unescape_entities(&self.require_string(data, "url")?))
205            .ok_or_else(|| "Unmappable URL".to_string())?;
206        let raw_title = unescape_entities(&self.require_string(data, "title")?);
207        let num_comments = self.require_integer(data, "num_comments")?;
208        let score = self.require_integer(data, "score")?;
209        let downvotes = self.require_integer(data, "downs")?;
210        let upvotes = self.require_integer(data, "ups")?;
211        let upvote_ratio = self.require_float(data, "upvote_ratio")? as f32;
212        let flair = unescape_entities(&self.optional_string(data, "link_flair_text")?);
213        let story = RedditStory::new_subsource(
214            id,
215            subreddit,
216            date,
217            raw_title,
218            url,
219            flair,
220            position,
221            upvotes,
222            downvotes,
223            num_comments,
224            score,
225            upvote_ratio,
226        );
227        Ok(story)
228    }
229}
230
231impl Scraper for RedditScraper {
232    type Config = <Reddit as ScrapeSourceDef>::Config;
233    type Output = <Reddit as ScrapeSourceDef>::Scrape;
234
235    fn scrape(
236        &self,
237        _args: &RedditConfig,
238        input: &str,
239    ) -> Result<(Vec<GenericScrape<Self::Output>>, Vec<String>), ScrapeError> {
240        let root: Value = serde_json::from_str(input)?;
241        let mut value = &root;
242        for path in ["data", "children"] {
243            if let Some(object) = value.as_object() {
244                if let Some(nested_value) = object.get(path) {
245                    value = nested_value;
246                } else {
247                    return Err(ScrapeError::StructureError(
248                        "Failed to parse Reddit JSON data.children".to_owned(),
249                    ));
250                }
251            }
252        }
253
254        if let Some(children) = value.as_array() {
255            let mut vec = vec![];
256            let mut errors = vec![];
257            let mut positions = HashMap::new();
258            for child in children {
259                match self.map_story(child, &mut positions) {
260                    Ok(story) => vec.push(story),
261                    Err(e) => errors.push(e),
262                }
263            }
264            Ok((vec, errors))
265        } else {
266            Err(ScrapeError::StructureError(
267                "Missing children element".to_owned(),
268            ))
269        }
270    }
271
272    fn extract_core<'a>(
273        &self,
274        args: &Self::Config,
275        input: &'a GenericScrape<Self::Output>,
276    ) -> ScrapeCore<'a> {
277        let mut tags = vec![];
278        if let Some(ref subreddit) = input.shared.id.subsource {
279            if let Some(config) = args.subreddits.get(subreddit) {
280                if config.flair_is_tag && !input.data.flair.contains(' ') {
281                    tags.push(Cow::Owned(input.data.flair.to_lowercase()));
282                }
283                if config.is_tag {
284                    tags.push(Cow::Borrowed(subreddit.as_str()));
285                }
286            }
287        }
288
289        // Trim any [tag] prefixes or suffixes
290        let (title, _, _) = remove_tags(&input.raw_title);
291
292        let title = trim_title(title, IDEAL_LENGTH, AWKWARD_LENGTH);
293
294        ScrapeCore {
295            source: &input.shared.id,
296            title,
297            url: &input.shared.url,
298            date: input.shared.date,
299            rank: (input.data.position as usize).checked_sub(1),
300            tags,
301        }
302    }
303}