progscrape_scrapers/backends/
reddit.rs
1use std::{borrow::Cow, collections::HashMap};
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use super::{
7 scrape_story, utils::html::unescape_entities, GenericScrape, ScrapeConfigSource, ScrapeCore,
8 ScrapeShared, ScrapeSource, ScrapeSourceDef, ScrapeStory, Scraper,
9};
10use crate::{
11 datasci::titletrimmer::{trim_title, AWKWARD_LENGTH, IDEAL_LENGTH},
12 types::*,
13};
14
15pub struct Reddit {}
16
17impl ScrapeSourceDef for Reddit {
18 type Config = RedditConfig;
19 type Scrape = RedditStory;
20 type Scraper = RedditScraper;
21
22 fn comments_url(id: &str, subsource: Option<&str>) -> String {
23 if let Some(subsource) = subsource {
24 format!("https://www.reddit.com/r/{}/comments/{}/", subsource, id)
25 } else {
26 format!("https://www.reddit.com/comments/{}/", id)
27 }
28 }
29
30 fn id_from_comments_url(url: &str) -> Option<(&str, Option<&str>)> {
31 let url = url.trim_end_matches('/');
32 if let Some(url) = url.strip_prefix("https://www.reddit.com/comments/") {
33 Some((url, None))
34 } else {
35 let url = url.strip_prefix("https://www.reddit.com/r/")?;
36 if let Some((subreddit, rest)) = url.split_once('/') {
37 if let Some((_, id)) = rest.split_once('/') {
38 Some((id, Some(subreddit)))
39 } else {
40 None
41 }
42 } else {
43 None
44 }
45 }
46 }
47
48 fn is_comments_host(host: &str) -> bool {
49 host.ends_with("reddit.com")
50 }
51}
52
53#[derive(Clone, Default, Serialize, Deserialize)]
54pub struct RedditConfig {
55 api: String,
56 subreddit_batch: usize,
57 limit: usize,
58 subreddits: HashMap<String, SubredditConfig>,
59}
60
61impl ScrapeConfigSource for RedditConfig {
62 fn subsources(&self) -> Vec<String> {
63 self.subreddits.iter().map(|s| s.0.clone()).collect()
64 }
65
66 fn provide_urls(&self, subsources: Vec<String>) -> Vec<String> {
67 let mut output = vec![];
68 for chunk in subsources.chunks(self.subreddit_batch) {
69 output.push(
70 self.api.replace("${subreddits}", &chunk.join("+"))
71 + &format!("?limit={}", self.limit),
72 )
73 }
74 output
75 }
76}
77
78#[derive(Clone, Default, Serialize, Deserialize)]
79pub struct SubredditConfig {
80 #[serde(default)]
81 is_tag: bool,
82 #[serde(default)]
83 flair_is_tag: bool,
84}
85
86#[derive(Default)]
87pub struct RedditScraper {}
88
89scrape_story! {
90 RedditStory {
91 flair: String,
92 position: u32,
93 upvotes: u32,
94 downvotes: u32,
95 num_comments: u32,
96 score: u32,
97 upvote_ratio: f32,
98 }
99}
100
101impl ScrapeStory for RedditStory {
102 const TYPE: ScrapeSource = ScrapeSource::Reddit;
103
104 fn merge(&mut self, other: RedditStory) {
105 self.position = std::cmp::max(self.position, other.position);
106 self.upvotes = std::cmp::max(self.upvotes, other.upvotes);
107 self.downvotes = std::cmp::max(self.downvotes, other.downvotes);
108 self.num_comments = std::cmp::max(self.num_comments, other.num_comments);
109 self.score = std::cmp::max(self.score, other.score);
110 self.upvote_ratio = f32::max(self.upvote_ratio, other.upvote_ratio);
111 }
112}
113
114impl RedditScraper {
115 fn require_string(&self, data: &Value, key: &str) -> Result<String, String> {
116 Ok(data[key]
117 .as_str()
118 .ok_or(format!("Missing field {:?}", key))?
119 .to_owned())
120 }
121
122 fn optional_string(&self, data: &Value, key: &str) -> Result<String, String> {
123 Ok(data[key].as_str().unwrap_or_default().to_owned())
124 }
125
126 fn require_integer<T: TryFrom<i64> + TryFrom<u64>>(
127 &self,
128 data: &Value,
129 key: &str,
130 ) -> Result<T, String> {
131 if let Value::Number(n) = &data[key] {
132 if let Some(n) = n.as_u64() {
133 if let Ok(n) = n.try_into() {
134 return Ok(n);
135 }
136 }
137 if let Some(n) = n.as_i64() {
138 if let Ok(n) = n.try_into() {
139 return Ok(n);
140 }
141 }
142 if let Some(n) = n.as_f64() {
143 let n = n as i64;
144 if let Ok(n) = n.try_into() {
145 return Ok(n);
146 }
147 }
148 Err(format!(
149 "Failed to parse {} as integer (value was {:?})",
150 key, n
151 ))
152 } else {
153 Err(format!(
154 "Missing or invalid field {:?} (value was {:?})",
155 key, data[key]
156 ))
157 }
158 }
159
160 fn require_float(&self, data: &Value, key: &str) -> Result<f64, String> {
161 if let Value::Number(n) = &data[key] {
162 if let Some(n) = n.as_u64() {
163 return Ok(n as f64);
164 }
165 if let Some(n) = n.as_i64() {
166 return Ok(n as f64);
167 }
168 if let Some(n) = n.as_f64() {
169 return Ok(n);
170 }
171 Err(format!(
172 "Failed to parse {} as float (value was {:?})",
173 key, n
174 ))
175 } else {
176 Err(format!(
177 "Missing or invalid field {:?} (value was {:?})",
178 key, data[key]
179 ))
180 }
181 }
182
183 fn map_story(
184 &self,
185 child: &Value,
186 positions: &mut HashMap<String, u32>,
187 ) -> Result<GenericScrape<<Self as Scraper>::Output>, String> {
188 let kind = child["kind"].as_str();
189 let data = if kind == Some("t3") {
190 &child["data"]
191 } else {
192 return Err(format!("Unknown story type: {:?}", kind));
193 };
194
195 let id = self.require_string(data, "id")?;
196 let subreddit = self.require_string(data, "subreddit")?.to_ascii_lowercase();
197 if let Some(true) = data["stickied"].as_bool() {
198 return Err(format!("Ignoring stickied story {}/{}", subreddit, id));
199 }
200 let position = *positions
201 .entry(subreddit.clone())
202 .and_modify(|n| *n += 1)
203 .or_default()
204 + 1;
205 let seconds: i64 = self.require_integer(data, "created_utc")?;
206 let millis = seconds * 1000;
207 let date = StoryDate::from_millis(millis).ok_or_else(|| "Unmappable date".to_string())?;
208 let url = StoryUrl::parse(unescape_entities(&self.require_string(data, "url")?))
209 .ok_or_else(|| "Unmappable URL".to_string())?;
210 let raw_title = unescape_entities(&self.require_string(data, "title")?);
211 let num_comments = self.require_integer(data, "num_comments")?;
212 let score = self.require_integer(data, "score")?;
213 let downvotes = self.require_integer(data, "downs")?;
214 let upvotes = self.require_integer(data, "ups")?;
215 let upvote_ratio = self.require_float(data, "upvote_ratio")? as f32;
216 let flair = unescape_entities(&self.optional_string(data, "link_flair_text")?);
217 let story = RedditStory::new_subsource(
218 id,
219 subreddit,
220 date,
221 raw_title,
222 url,
223 flair,
224 position,
225 upvotes,
226 downvotes,
227 num_comments,
228 score,
229 upvote_ratio,
230 );
231 Ok(story)
232 }
233}
234
235impl Scraper for RedditScraper {
236 type Config = <Reddit as ScrapeSourceDef>::Config;
237 type Output = <Reddit as ScrapeSourceDef>::Scrape;
238
239 fn scrape(
240 &self,
241 _args: &RedditConfig,
242 input: &str,
243 ) -> Result<(Vec<GenericScrape<Self::Output>>, Vec<String>), ScrapeError> {
244 let root: Value = serde_json::from_str(input)?;
245 let mut value = &root;
246 for path in ["data", "children"] {
247 if let Some(object) = value.as_object() {
248 if let Some(nested_value) = object.get(path) {
249 value = nested_value;
250 } else {
251 return Err(ScrapeError::StructureError(
252 "Failed to parse Reddit JSON data.children".to_owned(),
253 ));
254 }
255 }
256 }
257
258 if let Some(children) = value.as_array() {
259 let mut vec = vec![];
260 let mut errors = vec![];
261 let mut positions = HashMap::new();
262 for child in children {
263 match self.map_story(child, &mut positions) {
264 Ok(story) => vec.push(story),
265 Err(e) => errors.push(e),
266 }
267 }
268 Ok((vec, errors))
269 } else {
270 Err(ScrapeError::StructureError(
271 "Missing children element".to_owned(),
272 ))
273 }
274 }
275
276 fn extract_core<'a>(
277 &self,
278 args: &Self::Config,
279 input: &'a GenericScrape<Self::Output>,
280 ) -> ScrapeCore<'a> {
281 let mut tags = vec![];
282 if let Some(ref subreddit) = input.shared.id.subsource {
283 if let Some(config) = args.subreddits.get(subreddit) {
284 if config.flair_is_tag && !input.data.flair.contains(' ') {
285 tags.push(Cow::Owned(input.data.flair.to_lowercase()));
286 }
287 if config.is_tag {
288 tags.push(Cow::Borrowed(subreddit.as_str()));
289 }
290 }
291 }
292
293 let title = trim_title(&input.shared.raw_title, IDEAL_LENGTH, AWKWARD_LENGTH);
294
295 ScrapeCore {
296 source: &input.shared.id,
297 title,
298 url: &input.shared.url,
299 date: input.shared.date,
300 rank: (input.data.position as usize).checked_sub(1),
301 tags,
302 }
303 }
304}