reddit_rs/
client.rs

1use std::{borrow::Cow, env, time::Duration};
2
3use chrono::Utc;
4use hyper::{
5    body::{aggregate, Buf},
6    header::{ACCEPT, AUTHORIZATION, USER_AGENT},
7    Body, Request, Uri,
8};
9use serde::de::DeserializeOwned;
10use thiserror::Error;
11
12use crate::models::{
13    auth::RedditAuth,
14    fullname::FullName,
15    link::{RedditListing, Sort},
16};
17
18pub struct RedditClient {
19    pub(crate) client: hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>>,
20
21    #[allow(dead_code)]
22    pub(crate) client_secret: String,
23
24    #[allow(dead_code)]
25    pub(crate) client_id: String,
26
27    #[allow(dead_code)]
28    pub(crate) base_path: Uri,
29
30    pub(crate) auth: RedditAuth,
31
32    pub(crate) user_agent: &'static str,
33}
34
35#[derive(Debug, thiserror::Error)]
36pub enum MissingEnvVariableError {
37    #[error("REDDIT_CLIENT_ID environment variable missing")]
38    RedditClientId,
39    #[error("REDDIT_CLIENT_SECRET environment variable missing")]
40    RedditClientSecret,
41}
42
43/// Configuration struct for [`RedditClient`].
44pub struct Config<'a> {
45    pub client_id: &'a str,
46    pub client_secret: &'a str,
47    /// The name of the client, to be used in the `User-Agent` header in all
48    /// rqeuests to reddit.
49    ///
50    /// https://github.com/reddit-archive/reddit/wiki/API#rules
51    pub client_name: &'a str,
52}
53
54impl RedditClient {
55    /// Creates a new client, reading from the `REDDIT_CLIENT_ID` and
56    /// `REDDIT_CLIENT_SECRET` for the `client_id` and `client_secret`,
57    /// respectively. This function uses dotenv, so the variables can be kept in
58    /// a `.env` file at the root of the workspace.
59    ///
60    /// # Errors
61    /// This function will error if the credentials provided in the
62    /// aformentioned environment variables are not able to authenticate on
63    /// Reddit.
64    ///
65    /// # Panics
66    /// This function will panic if either of the aformentioned environment
67    /// variables are not present.
68    pub async fn from_env(client_name: &str) -> Result<Self, RedditError> {
69        #[allow(clippy::unwrap_used)]
70        Self::try_from_env(client_name).await.unwrap()
71    }
72
73    /// Creates a new client, reading from the `REDDIT_CLIENT_ID` and
74    /// `REDDIT_CLIENT_SECRET` for the `client_id` and `client_secret`,
75    /// respectively. This function uses dotenv, so the variables can be kept in
76    /// a `.env` file at the root of the workspace.
77    ///
78    /// # Errors
79    /// This function will error if the aformentioned environment variables are
80    /// not present or if the credentials provided in them are not able to
81    /// authenticate on Reddit.
82    pub async fn try_from_env(
83        client_name: &str,
84    ) -> Result<Result<Self, RedditError>, MissingEnvVariableError> {
85        dotenv::dotenv().ok();
86
87        #[allow(clippy::disallowed_method)]
88        let client_id = match env::var("REDDIT_CLIENT_ID") {
89            Ok(client_id) => client_id,
90            Err(_) => return Err(MissingEnvVariableError::RedditClientId),
91        };
92
93        #[allow(clippy::disallowed_method)]
94        let client_secret = match env::var("REDDIT_CLIENT_SECRET") {
95            Ok(client_secret) => client_secret,
96            Err(_) => return Err(MissingEnvVariableError::RedditClientSecret),
97        };
98
99        Self::from_config(Config {
100            client_id: &client_id,
101            client_secret: &client_secret,
102            client_name,
103        })
104        .await
105    }
106
107    /// Creates a new client with the provided [`Config`]. See the documentation
108    /// of [`Config`] for more information.
109    ///
110    /// # Errors
111    /// This function will error if the aformentioned environment variables are
112    /// not present or if the credentials provided in them are not able to
113    /// authenticate on Reddit.
114    pub async fn from_config(
115        Config {
116            client_id,
117            client_secret,
118            client_name,
119        }: Config<'_>,
120    ) -> Result<Result<Self, RedditError>, MissingEnvVariableError> {
121        let base_path = "https://www.reddit.com".parse().expect("infallible");
122
123        let https = hyper_rustls::HttpsConnector::with_native_roots();
124        let client = hyper::Client::builder().build::<_, hyper::Body>(https);
125
126        let auth = match RedditClient::authorize(&client, client_id, client_secret).await {
127            Ok(auth) => auth,
128            Err(err) => return Ok(Err(err)),
129        };
130
131        // https://github.com/reddit-archive/reddit/wiki/API#rules
132        let user_agent: &'static str = {
133            let version = env!("CARGO_PKG_VERSION");
134            let user_agent = format!(
135                "ubuntu:{name}:{version} (by /u/benluelo)",
136                name = client_name,
137                version = version
138            );
139            Box::leak(Box::new(user_agent))
140        };
141
142        Ok(Ok(Self {
143            client,
144            client_secret: client_secret.into(),
145            client_id: client_id.into(),
146            base_path,
147            auth,
148            user_agent,
149        }))
150    }
151
152    async fn authorize(
153        client: &hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>>,
154        client_id: &str,
155        client_secret: &str,
156    ) -> Result<RedditAuth, RedditError> {
157        let auth_path = format!(
158            "{}/{}",
159            std::env::temp_dir().display(),
160            "reddit_client_authorization.json"
161        );
162
163        if let Ok(Ok(auth)) = std::fs::read_to_string(&auth_path)
164            .as_deref()
165            .map(serde_json::from_str::<RedditAuth>)
166        {
167            if auth.expires_at > Utc::now() {
168                println!("cache hit for auth");
169                return Ok(auth);
170            }
171            println!("fetching new auth");
172        }
173
174        let response = client
175            .request(
176                Request::post("https://www.reddit.com/api/v1/access_token")
177                    .header(
178                        AUTHORIZATION,
179                        dbg!(format!(
180                            "Basic {}",
181                            base64::encode_config(
182                                format!("{}:{}", client_id, client_secret),
183                                base64::URL_SAFE
184                            )
185                        )),
186                    )
187                    .body(hyper::Body::from("grant_type=client_credentials"))
188                    .expect("infallible"),
189            )
190            .await?;
191
192        let buf = aggregate(response).await?;
193
194        let auth = serde_json::from_reader::<_, RedditAuth>(buf.reader())?;
195
196        if let Err(why) = std::fs::write(
197            &auth_path,
198            serde_json::to_string_pretty(&auth).expect("infallible"),
199        ) {
200            log::warn!(target: "reddit_client", "Unable to cache auth file: {}", why);
201        }
202
203        Ok(auth)
204    }
205
206    pub(crate) fn auth_header_value(&self) -> String {
207        format!("Bearer {}", &self.auth.access_token)
208    }
209
210    /// Gets posts from a subreddit, using the specified `sorting` and `after`
211    /// options.
212    ///
213    /// # Errors
214    /// This function will error for many reasons. See the documentation for
215    /// [`RedditError`] for more information.
216    ///
217    /// # Panics
218    /// This function will panic if either of the aformentioned environment
219    /// variables are not present.
220    pub async fn get_subreddit_posts(
221        &mut self,
222        subreddit: &str,
223        sorting: Sort,
224        after: Option<&FullName>,
225    ) -> Result<RedditListing, RedditError> {
226        let after: Cow<_> = if let Some(fullname) = after {
227            format!("after={}", fullname).into()
228        } else {
229            "".into()
230        };
231
232        let uri = Uri::builder()
233            .scheme("https")
234            .authority("oauth.reddit.com")
235            .path_and_query(format!(
236                "/r/{subreddit}/{sorting}.json?{after}",
237                subreddit = subreddit,
238                sorting = sorting,
239                after = &after
240            ))
241            .build()
242            .expect("Uri builder shouldn't fail"); //oauth.reddit.com/r/PKMNTCGDeals/new.json");
243
244        self.request(uri).await
245    }
246
247    async fn request<T: DeserializeOwned>(&mut self, uri: Uri) -> Result<T, RedditError> {
248        self.check_auth().await?;
249
250        tokio::time::sleep(Duration::from_secs(1)).await;
251
252        let request = self
253            .base_request(uri)
254            .body(Body::empty())
255            .expect("infallible");
256
257        let response = self.client.request(request).await?;
258
259        let mut buf = aggregate(response).await?.reader();
260
261        let mut bytes = Vec::new();
262
263        std::io::copy(&mut buf, &mut bytes)?;
264
265        let json = String::from_utf8(bytes)?;
266
267        // println!("{}", &json);
268
269        let listings = serde_json::from_str(&json)?;
270
271        Ok(listings)
272    }
273
274    fn base_request(&self, uri: Uri) -> hyper::http::request::Builder {
275        Request::get(uri)
276            .header(AUTHORIZATION, &self.auth_header_value())
277            .header(ACCEPT, "*/*")
278            .header(USER_AGENT, self.user_agent)
279    }
280
281    pub(crate) async fn check_auth(&mut self) -> Result<(), RedditError> {
282        if self.auth.expires_at <= Utc::now() {
283            self.auth =
284                RedditClient::authorize(&self.client, &self.client_id, &self.client_secret).await?;
285            Ok(())
286        } else {
287            Ok(())
288        }
289    }
290
291    /// Gets the comments for the provided post.
292    ///
293    /// # Errors
294    /// This function will error for many reasons. See the documentation for
295    /// [`RedditError`] for more information.
296    pub async fn get_comments(&mut self) -> Result<Vec<RedditListing>, RedditError> {
297        let uri: Uri = "https://oauth.reddit.com/r/benluelo_testing/comments/qbq1jr/yeet/.json"
298            .parse()
299            .expect("");
300
301        self.request(uri).await
302    }
303}
304
305#[derive(Debug, Error)]
306pub enum RedditError {
307    #[error("error fetching resource")]
308    Request(#[from] hyper::Error),
309
310    #[error("error deserializing resource")]
311    Deserialize(#[from] serde_json::Error),
312
313    #[error("Payload was not valid UTF8")]
314    Utf8(#[from] std::string::FromUtf8Error),
315
316    #[error("IO error")]
317    Io(#[from] std::io::Error),
318}
319
320// TODO: Move to integration test
321// #[cfg(test)]
322// mod test_reddit_client {
323//     use super::*;
324
325//     #[tokio::test]
326//     async fn test_get_subreddit_posts() {
327//         let mut client = RedditClient::new("TCG Collector Discord Bot")
328//             .await
329//             .unwrap();
330
331//         let mut after = None;
332//         loop {
333//             let listing = client
334//                 .get_subreddit_posts("PKMNTCGDeals", Sort::New,
335// after.as_ref())                 .await
336//                 .unwrap();
337//             println!("got posts");
338//             after = match listing {
339//                 RedditListing::Listing {
340//                     after, children, ..
341//                 } => {
342//                     for child in children {
343//                         println!("link: {:#?}", &child);
344//                     }
345//                     after
346//                 }
347//                 RedditListing::Link(_) => unreachable!(),
348//             };
349//         }
350//     }
351// }