progscrape_scrapers/backends/
reddit.rs1use std::{borrow::Cow, collections::HashMap};
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use super::{
7 GenericScrape, ScrapeConfigSource, ScrapeCore, ScrapeSource, ScrapeSourceDef, ScrapeStory,
8 Scraper, scrape_story, utils::html::unescape_entities,
9};
10use crate::{
11 datasci::titletrimmer::{AWKWARD_LENGTH, IDEAL_LENGTH, remove_tags, trim_title},
12 types::*,
13};
14
15pub struct Reddit {}
16
17impl ScrapeSourceDef for Reddit {
18 type Config = RedditConfig;
19 type Scrape = RedditStory;
20 type Scraper = RedditScraper;
21
22 fn comments_url(id: &str, subsource: Option<&str>) -> String {
23 if let Some(subsource) = subsource {
24 format!("https://www.reddit.com/r/{subsource}/comments/{id}/")
25 } else {
26 format!("https://www.reddit.com/comments/{id}/")
27 }
28 }
29
30 fn id_from_comments_url(url: &str) -> Option<(&str, Option<&str>)> {
31 let url = url.trim_end_matches('/');
32 if let Some(url) = url.strip_prefix("https://www.reddit.com/comments/") {
33 Some((url, None))
34 } else {
35 let url = url.strip_prefix("https://www.reddit.com/r/")?;
36 if let Some((subreddit, rest)) = url.split_once('/') {
37 if let Some((_, id)) = rest.split_once('/') {
38 Some((id, Some(subreddit)))
39 } else {
40 None
41 }
42 } else {
43 None
44 }
45 }
46 }
47
48 fn is_comments_host(host: &str) -> bool {
49 host.ends_with("reddit.com")
50 }
51}
52
53#[derive(Clone, Default, Serialize, Deserialize)]
54pub struct RedditConfig {
55 api: String,
56 subreddit_batch: usize,
57 limit: usize,
58 subreddits: HashMap<String, SubredditConfig>,
59}
60
61impl ScrapeConfigSource for RedditConfig {
62 fn subsources(&self) -> Vec<String> {
63 self.subreddits.iter().map(|s| s.0.clone()).collect()
64 }
65
66 fn provide_urls(&self, subsources: Vec<String>) -> Vec<String> {
67 let mut output = vec![];
68 for chunk in subsources.chunks(self.subreddit_batch) {
69 output.push(
70 self.api.replace("${subreddits}", &chunk.join("+"))
71 + &format!("?limit={}", self.limit),
72 )
73 }
74 output
75 }
76}
77
78#[derive(Clone, Default, Serialize, Deserialize)]
79pub struct SubredditConfig {
80 #[serde(default)]
81 is_tag: bool,
82 #[serde(default)]
83 flair_is_tag: bool,
84}
85
86#[derive(Default)]
87pub struct RedditScraper {}
88
89scrape_story! {
90 RedditStory {
91 flair: String,
92 position: u32,
93 upvotes: u32,
94 downvotes: u32,
95 num_comments: u32,
96 score: u32,
97 upvote_ratio: f32,
98 }
99}
100
101impl ScrapeStory for RedditStory {
102 const TYPE: ScrapeSource = ScrapeSource::Reddit;
103
104 fn merge(&mut self, other: RedditStory) {
105 self.position = std::cmp::max(self.position, other.position);
106 self.upvotes = std::cmp::max(self.upvotes, other.upvotes);
107 self.downvotes = std::cmp::max(self.downvotes, other.downvotes);
108 self.num_comments = std::cmp::max(self.num_comments, other.num_comments);
109 self.score = std::cmp::max(self.score, other.score);
110 self.upvote_ratio = f32::max(self.upvote_ratio, other.upvote_ratio);
111 }
112}
113
114impl RedditScraper {
115 fn require_string(&self, data: &Value, key: &str) -> Result<String, String> {
116 Ok(data[key]
117 .as_str()
118 .ok_or(format!("Missing field {key:?}"))?
119 .to_owned())
120 }
121
122 fn optional_string(&self, data: &Value, key: &str) -> Result<String, String> {
123 Ok(data[key].as_str().unwrap_or_default().to_owned())
124 }
125
126 fn require_integer<T: TryFrom<i64> + TryFrom<u64>>(
127 &self,
128 data: &Value,
129 key: &str,
130 ) -> Result<T, String> {
131 if let Value::Number(n) = &data[key] {
132 if let Some(n) = n.as_u64() {
133 if let Ok(n) = n.try_into() {
134 return Ok(n);
135 }
136 }
137 if let Some(n) = n.as_i64() {
138 if let Ok(n) = n.try_into() {
139 return Ok(n);
140 }
141 }
142 if let Some(n) = n.as_f64() {
143 let n = n as i64;
144 if let Ok(n) = n.try_into() {
145 return Ok(n);
146 }
147 }
148 Err(format!(
149 "Failed to parse {key} as integer (value was {n:?})"
150 ))
151 } else {
152 Err(format!(
153 "Missing or invalid field {:?} (value was {:?})",
154 key, data[key]
155 ))
156 }
157 }
158
159 fn require_float(&self, data: &Value, key: &str) -> Result<f64, String> {
160 if let Value::Number(n) = &data[key] {
161 if let Some(n) = n.as_u64() {
162 return Ok(n as f64);
163 }
164 if let Some(n) = n.as_i64() {
165 return Ok(n as f64);
166 }
167 if let Some(n) = n.as_f64() {
168 return Ok(n);
169 }
170 Err(format!("Failed to parse {key} as float (value was {n:?})"))
171 } else {
172 Err(format!(
173 "Missing or invalid field {:?} (value was {:?})",
174 key, data[key]
175 ))
176 }
177 }
178
179 fn map_story(
180 &self,
181 child: &Value,
182 positions: &mut HashMap<String, u32>,
183 ) -> Result<GenericScrape<<Self as Scraper>::Output>, String> {
184 let kind = child["kind"].as_str();
185 let data = if kind == Some("t3") {
186 &child["data"]
187 } else {
188 return Err(format!("Unknown story type: {kind:?}"));
189 };
190
191 let id = self.require_string(data, "id")?;
192 let subreddit = self.require_string(data, "subreddit")?.to_ascii_lowercase();
193 if let Some(true) = data["stickied"].as_bool() {
194 return Err(format!("Ignoring stickied story {subreddit}/{id}"));
195 }
196 let position = *positions
197 .entry(subreddit.clone())
198 .and_modify(|n| *n += 1)
199 .or_default()
200 + 1;
201 let seconds: i64 = self.require_integer(data, "created_utc")?;
202 let millis = seconds * 1000;
203 let date = StoryDate::from_millis(millis).ok_or_else(|| "Unmappable date".to_string())?;
204 let url = StoryUrl::parse(unescape_entities(&self.require_string(data, "url")?))
205 .ok_or_else(|| "Unmappable URL".to_string())?;
206 let raw_title = unescape_entities(&self.require_string(data, "title")?);
207 let num_comments = self.require_integer(data, "num_comments")?;
208 let score = self.require_integer(data, "score")?;
209 let downvotes = self.require_integer(data, "downs")?;
210 let upvotes = self.require_integer(data, "ups")?;
211 let upvote_ratio = self.require_float(data, "upvote_ratio")? as f32;
212 let flair = unescape_entities(&self.optional_string(data, "link_flair_text")?);
213 let story = RedditStory::new_subsource(
214 id,
215 subreddit,
216 date,
217 raw_title,
218 url,
219 flair,
220 position,
221 upvotes,
222 downvotes,
223 num_comments,
224 score,
225 upvote_ratio,
226 );
227 Ok(story)
228 }
229}
230
231impl Scraper for RedditScraper {
232 type Config = <Reddit as ScrapeSourceDef>::Config;
233 type Output = <Reddit as ScrapeSourceDef>::Scrape;
234
235 fn scrape(
236 &self,
237 _args: &RedditConfig,
238 input: &str,
239 ) -> Result<(Vec<GenericScrape<Self::Output>>, Vec<String>), ScrapeError> {
240 let root: Value = serde_json::from_str(input)?;
241 let mut value = &root;
242 for path in ["data", "children"] {
243 if let Some(object) = value.as_object() {
244 if let Some(nested_value) = object.get(path) {
245 value = nested_value;
246 } else {
247 return Err(ScrapeError::StructureError(
248 "Failed to parse Reddit JSON data.children".to_owned(),
249 ));
250 }
251 }
252 }
253
254 if let Some(children) = value.as_array() {
255 let mut vec = vec![];
256 let mut errors = vec![];
257 let mut positions = HashMap::new();
258 for child in children {
259 match self.map_story(child, &mut positions) {
260 Ok(story) => vec.push(story),
261 Err(e) => errors.push(e),
262 }
263 }
264 Ok((vec, errors))
265 } else {
266 Err(ScrapeError::StructureError(
267 "Missing children element".to_owned(),
268 ))
269 }
270 }
271
272 fn extract_core<'a>(
273 &self,
274 args: &Self::Config,
275 input: &'a GenericScrape<Self::Output>,
276 ) -> ScrapeCore<'a> {
277 let mut tags = vec![];
278 if let Some(ref subreddit) = input.shared.id.subsource {
279 if let Some(config) = args.subreddits.get(subreddit) {
280 if config.flair_is_tag && !input.data.flair.contains(' ') {
281 tags.push(Cow::Owned(input.data.flair.to_lowercase()));
282 }
283 if config.is_tag {
284 tags.push(Cow::Borrowed(subreddit.as_str()));
285 }
286 }
287 }
288
289 let (title, _, _) = remove_tags(&input.raw_title);
291
292 let title = trim_title(title, IDEAL_LENGTH, AWKWARD_LENGTH);
293
294 ScrapeCore {
295 source: &input.shared.id,
296 title,
297 url: &input.shared.url,
298 date: input.shared.date,
299 rank: (input.data.position as usize).checked_sub(1),
300 tags,
301 }
302 }
303}