progscrape_scrapers/backends/
reddit.rs

1use std::{borrow::Cow, collections::HashMap};
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use super::{
7    scrape_story, utils::html::unescape_entities, GenericScrape, ScrapeConfigSource, ScrapeCore,
8    ScrapeShared, ScrapeSource, ScrapeSourceDef, ScrapeStory, Scraper,
9};
10use crate::{
11    datasci::titletrimmer::{trim_title, AWKWARD_LENGTH, IDEAL_LENGTH},
12    types::*,
13};
14
15pub struct Reddit {}
16
17impl ScrapeSourceDef for Reddit {
18    type Config = RedditConfig;
19    type Scrape = RedditStory;
20    type Scraper = RedditScraper;
21
22    fn comments_url(id: &str, subsource: Option<&str>) -> String {
23        if let Some(subsource) = subsource {
24            format!("https://www.reddit.com/r/{}/comments/{}/", subsource, id)
25        } else {
26            format!("https://www.reddit.com/comments/{}/", id)
27        }
28    }
29
30    fn id_from_comments_url(url: &str) -> Option<(&str, Option<&str>)> {
31        let url = url.trim_end_matches('/');
32        if let Some(url) = url.strip_prefix("https://www.reddit.com/comments/") {
33            Some((url, None))
34        } else {
35            let url = url.strip_prefix("https://www.reddit.com/r/")?;
36            if let Some((subreddit, rest)) = url.split_once('/') {
37                if let Some((_, id)) = rest.split_once('/') {
38                    Some((id, Some(subreddit)))
39                } else {
40                    None
41                }
42            } else {
43                None
44            }
45        }
46    }
47
48    fn is_comments_host(host: &str) -> bool {
49        host.ends_with("reddit.com")
50    }
51}
52
53#[derive(Clone, Default, Serialize, Deserialize)]
54pub struct RedditConfig {
55    api: String,
56    subreddit_batch: usize,
57    limit: usize,
58    subreddits: HashMap<String, SubredditConfig>,
59}
60
61impl ScrapeConfigSource for RedditConfig {
62    fn subsources(&self) -> Vec<String> {
63        self.subreddits.iter().map(|s| s.0.clone()).collect()
64    }
65
66    fn provide_urls(&self, subsources: Vec<String>) -> Vec<String> {
67        let mut output = vec![];
68        for chunk in subsources.chunks(self.subreddit_batch) {
69            output.push(
70                self.api.replace("${subreddits}", &chunk.join("+"))
71                    + &format!("?limit={}", self.limit),
72            )
73        }
74        output
75    }
76}
77
78#[derive(Clone, Default, Serialize, Deserialize)]
79pub struct SubredditConfig {
80    #[serde(default)]
81    is_tag: bool,
82    #[serde(default)]
83    flair_is_tag: bool,
84}
85
86#[derive(Default)]
87pub struct RedditScraper {}
88
89scrape_story! {
90    RedditStory {
91        flair: String,
92        position: u32,
93        upvotes: u32,
94        downvotes: u32,
95        num_comments: u32,
96        score: u32,
97        upvote_ratio: f32,
98    }
99}
100
101impl ScrapeStory for RedditStory {
102    const TYPE: ScrapeSource = ScrapeSource::Reddit;
103
104    fn merge(&mut self, other: RedditStory) {
105        self.position = std::cmp::max(self.position, other.position);
106        self.upvotes = std::cmp::max(self.upvotes, other.upvotes);
107        self.downvotes = std::cmp::max(self.downvotes, other.downvotes);
108        self.num_comments = std::cmp::max(self.num_comments, other.num_comments);
109        self.score = std::cmp::max(self.score, other.score);
110        self.upvote_ratio = f32::max(self.upvote_ratio, other.upvote_ratio);
111    }
112}
113
114impl RedditScraper {
115    fn require_string(&self, data: &Value, key: &str) -> Result<String, String> {
116        Ok(data[key]
117            .as_str()
118            .ok_or(format!("Missing field {:?}", key))?
119            .to_owned())
120    }
121
122    fn optional_string(&self, data: &Value, key: &str) -> Result<String, String> {
123        Ok(data[key].as_str().unwrap_or_default().to_owned())
124    }
125
126    fn require_integer<T: TryFrom<i64> + TryFrom<u64>>(
127        &self,
128        data: &Value,
129        key: &str,
130    ) -> Result<T, String> {
131        if let Value::Number(n) = &data[key] {
132            if let Some(n) = n.as_u64() {
133                if let Ok(n) = n.try_into() {
134                    return Ok(n);
135                }
136            }
137            if let Some(n) = n.as_i64() {
138                if let Ok(n) = n.try_into() {
139                    return Ok(n);
140                }
141            }
142            if let Some(n) = n.as_f64() {
143                let n = n as i64;
144                if let Ok(n) = n.try_into() {
145                    return Ok(n);
146                }
147            }
148            Err(format!(
149                "Failed to parse {} as integer (value was {:?})",
150                key, n
151            ))
152        } else {
153            Err(format!(
154                "Missing or invalid field {:?} (value was {:?})",
155                key, data[key]
156            ))
157        }
158    }
159
160    fn require_float(&self, data: &Value, key: &str) -> Result<f64, String> {
161        if let Value::Number(n) = &data[key] {
162            if let Some(n) = n.as_u64() {
163                return Ok(n as f64);
164            }
165            if let Some(n) = n.as_i64() {
166                return Ok(n as f64);
167            }
168            if let Some(n) = n.as_f64() {
169                return Ok(n);
170            }
171            Err(format!(
172                "Failed to parse {} as float (value was {:?})",
173                key, n
174            ))
175        } else {
176            Err(format!(
177                "Missing or invalid field {:?} (value was {:?})",
178                key, data[key]
179            ))
180        }
181    }
182
183    fn map_story(
184        &self,
185        child: &Value,
186        positions: &mut HashMap<String, u32>,
187    ) -> Result<GenericScrape<<Self as Scraper>::Output>, String> {
188        let kind = child["kind"].as_str();
189        let data = if kind == Some("t3") {
190            &child["data"]
191        } else {
192            return Err(format!("Unknown story type: {:?}", kind));
193        };
194
195        let id = self.require_string(data, "id")?;
196        let subreddit = self.require_string(data, "subreddit")?.to_ascii_lowercase();
197        if let Some(true) = data["stickied"].as_bool() {
198            return Err(format!("Ignoring stickied story {}/{}", subreddit, id));
199        }
200        let position = *positions
201            .entry(subreddit.clone())
202            .and_modify(|n| *n += 1)
203            .or_default()
204            + 1;
205        let seconds: i64 = self.require_integer(data, "created_utc")?;
206        let millis = seconds * 1000;
207        let date = StoryDate::from_millis(millis).ok_or_else(|| "Unmappable date".to_string())?;
208        let url = StoryUrl::parse(unescape_entities(&self.require_string(data, "url")?))
209            .ok_or_else(|| "Unmappable URL".to_string())?;
210        let raw_title = unescape_entities(&self.require_string(data, "title")?);
211        let num_comments = self.require_integer(data, "num_comments")?;
212        let score = self.require_integer(data, "score")?;
213        let downvotes = self.require_integer(data, "downs")?;
214        let upvotes = self.require_integer(data, "ups")?;
215        let upvote_ratio = self.require_float(data, "upvote_ratio")? as f32;
216        let flair = unescape_entities(&self.optional_string(data, "link_flair_text")?);
217        let story = RedditStory::new_subsource(
218            id,
219            subreddit,
220            date,
221            raw_title,
222            url,
223            flair,
224            position,
225            upvotes,
226            downvotes,
227            num_comments,
228            score,
229            upvote_ratio,
230        );
231        Ok(story)
232    }
233}
234
235impl Scraper for RedditScraper {
236    type Config = <Reddit as ScrapeSourceDef>::Config;
237    type Output = <Reddit as ScrapeSourceDef>::Scrape;
238
239    fn scrape(
240        &self,
241        _args: &RedditConfig,
242        input: &str,
243    ) -> Result<(Vec<GenericScrape<Self::Output>>, Vec<String>), ScrapeError> {
244        let root: Value = serde_json::from_str(input)?;
245        let mut value = &root;
246        for path in ["data", "children"] {
247            if let Some(object) = value.as_object() {
248                if let Some(nested_value) = object.get(path) {
249                    value = nested_value;
250                } else {
251                    return Err(ScrapeError::StructureError(
252                        "Failed to parse Reddit JSON data.children".to_owned(),
253                    ));
254                }
255            }
256        }
257
258        if let Some(children) = value.as_array() {
259            let mut vec = vec![];
260            let mut errors = vec![];
261            let mut positions = HashMap::new();
262            for child in children {
263                match self.map_story(child, &mut positions) {
264                    Ok(story) => vec.push(story),
265                    Err(e) => errors.push(e),
266                }
267            }
268            Ok((vec, errors))
269        } else {
270            Err(ScrapeError::StructureError(
271                "Missing children element".to_owned(),
272            ))
273        }
274    }
275
276    fn extract_core<'a>(
277        &self,
278        args: &Self::Config,
279        input: &'a GenericScrape<Self::Output>,
280    ) -> ScrapeCore<'a> {
281        let mut tags = vec![];
282        if let Some(ref subreddit) = input.shared.id.subsource {
283            if let Some(config) = args.subreddits.get(subreddit) {
284                if config.flair_is_tag && !input.data.flair.contains(' ') {
285                    tags.push(Cow::Owned(input.data.flair.to_lowercase()));
286                }
287                if config.is_tag {
288                    tags.push(Cow::Borrowed(subreddit.as_str()));
289                }
290            }
291        }
292
293        let title = trim_title(&input.shared.raw_title, IDEAL_LENGTH, AWKWARD_LENGTH);
294
295        ScrapeCore {
296            source: &input.shared.id,
297            title,
298            url: &input.shared.url,
299            date: input.shared.date,
300            rank: (input.data.position as usize).checked_sub(1),
301            tags,
302        }
303    }
304}