neptunium/
reddit.rs

1use std::error::Error;
2use serde_derive::Deserialize;
3use serde_derive::Serialize;
4use serde_json::Value;
5use isahc::{ReadResponseExt, Request};
6use rand::seq::SliceRandom;
7use rand::thread_rng;
8
9
10#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
11#[serde(rename_all = "camelCase")]
12pub struct RedditRoot {
13    pub kind: String,
14    pub data: RedditPostCollection,
15}
16
17#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct RedditPostCollection {
20    pub children: Vec<Children>,
21}
22
23#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
24#[serde(rename_all = "camelCase")]
25pub struct Children {
26    pub kind: String,
27    pub data: Data2,
28}
29
30#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct Data2 {
33    pub subreddit: String,
34    pub selftext: String,
35    #[serde(rename = "author_fullname")]
36    pub author_fullname: String,
37    pub title: String,
38    #[serde(rename = "subreddit_name_prefixed")]
39    pub subreddit_name_prefixed: String,
40    pub downs: i64,
41    pub name: String,
42    #[serde(rename = "upvote_ratio")]
43    pub upvote_ratio: f64,
44    pub ups: i64,
45    #[serde(rename = "total_awards_received")]
46    pub total_awards_received: i64,
47    #[serde(rename = "thumbnail_width")]
48    pub thumbnail_width: Option<i64>,
49    #[serde(rename = "is_original_content")]
50    pub is_original_content: bool,
51    pub category: Value,
52    pub score: i64,
53    pub thumbnail: String,
54    pub edited: Value,
55    pub created: f64,
56    pub domain: String,
57    #[serde(rename = "over_18")]
58    pub nsfw: bool,
59    #[serde(rename = "media_only")]
60    pub media_only: bool,
61    pub id: String,
62    pub author: String,
63    #[serde(rename = "num_comments")]
64    pub num_comments: i64,
65    pub permalink: String,
66    pub url: String,
67    #[serde(rename = "created_utc")]
68    pub created_utc: f64,
69    pub media: Value,
70}
71
72const ALLOWED_EXTENSIONS: &[&str] = &["jpg", "jpeg", "gif", "png", "webp"];
73
74/// Returns a full list of posts in a subreddit.
75///
76/// Max 50.
77pub fn get_subreddit(subreddit: String) -> Result<RedditPostCollection, Box<dyn Error>> {
78    let search_url = format!("https://www.reddit.com/r/{}.json", subreddit);
79
80    #[cfg(debug_assertions)]
81    println!(
82        "Reddit search: {:?}\nConstructed URL: {}",
83        subreddit, search_url
84    );
85
86    let mut reddit_response = Request::get(&search_url)
87        .header(
88            "User-Agent",
89            "EcchiBot - contact <privateger@privateger.me>",
90        )
91        .body(())
92        .map_err(Into::into)
93        .and_then(isahc::send)?;
94
95    #[cfg(debug_assertions)]
96    println!(
97        "Reddit response: {}\nURL: {}",
98        reddit_response.status(),
99        &search_url
100    );
101
102    let res: RedditRoot = reddit_response.json()?;
103
104    Ok(res.data)
105}
106
107/// Returns a Vector of direct links to images posted in a subreddit.
108///
109/// This Vector is pre-shuffled.
110pub fn images(subreddit: String, limit: u32) -> Option<Vec<String>> {
111    if limit == 0 {
112        return None
113    }
114
115    let mut subreddit_data : RedditPostCollection = get_subreddit(subreddit).ok()?;
116    let mut result_collector: Vec<String> = vec![];
117
118    // shuffle the post collection to randomize picked images
119    subreddit_data.children.shuffle(&mut thread_rng());
120
121    for post in subreddit_data.children.iter() {
122        for allowed_extensions in ALLOWED_EXTENSIONS.iter() {
123            if post.data.url.ends_with(allowed_extensions) {
124                result_collector.push(post.data.url.clone());
125            }
126            if result_collector.len()+1 > limit as usize {
127                return Some(result_collector)
128            }
129        }
130    }
131
132    Some(result_collector)
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    pub fn test_subreddit() {
141        let response = get_subreddit("pics".to_string()).expect("failed");
142
143        assert_eq!(response.children[0].data.subreddit, "pics".to_string());
144        assert_ne!(response.children[0].data.url, "".to_string())
145    }
146
147    #[test]
148    pub fn test_image_links() {
149        let links = images("pics".to_string(), 5).expect("got error back");
150
151        if links.len() > 5 {
152            panic!("excessive length {:?}", links)
153        }
154
155        if links.is_empty() {
156            panic!("no results despite expecting them")
157        }
158    }
159
160    #[test]
161    #[should_panic]
162    pub fn test_error_on_empty_limit() {
163        images("pics".to_string(), 0).expect("got proper error");
164    }
165
166    #[test]
167    pub fn verify_randomization() {
168        let sample_1 : Vec<String> = images("pics".to_string(), 20).expect("got error");
169        let sample_2 : Vec<String> = images("pics".to_string(), 20).expect("got error");
170
171        let matching = sample_1.iter().zip(&sample_2).filter(|&(a, b)| a == b).count();
172        if matching == sample_2.len() {
173            panic!("shuffle not working or you won the lottery")
174        }
175    }
176}