1use std::{borrow::Cow, env, time::Duration};
2
3use chrono::Utc;
4use hyper::{
5 body::{aggregate, Buf},
6 header::{ACCEPT, AUTHORIZATION, USER_AGENT},
7 Body, Request, Uri,
8};
9use serde::de::DeserializeOwned;
10use thiserror::Error;
11
12use crate::models::{
13 auth::RedditAuth,
14 fullname::FullName,
15 link::{RedditListing, Sort},
16};
17
18pub struct RedditClient {
19 pub(crate) client: hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>>,
20
21 #[allow(dead_code)]
22 pub(crate) client_secret: String,
23
24 #[allow(dead_code)]
25 pub(crate) client_id: String,
26
27 #[allow(dead_code)]
28 pub(crate) base_path: Uri,
29
30 pub(crate) auth: RedditAuth,
31
32 pub(crate) user_agent: &'static str,
33}
34
35#[derive(Debug, thiserror::Error)]
36pub enum MissingEnvVariableError {
37 #[error("REDDIT_CLIENT_ID environment variable missing")]
38 RedditClientId,
39 #[error("REDDIT_CLIENT_SECRET environment variable missing")]
40 RedditClientSecret,
41}
42
43pub struct Config<'a> {
45 pub client_id: &'a str,
46 pub client_secret: &'a str,
47 pub client_name: &'a str,
52}
53
54impl RedditClient {
55 pub async fn from_env(client_name: &str) -> Result<Self, RedditError> {
69 #[allow(clippy::unwrap_used)]
70 Self::try_from_env(client_name).await.unwrap()
71 }
72
73 pub async fn try_from_env(
83 client_name: &str,
84 ) -> Result<Result<Self, RedditError>, MissingEnvVariableError> {
85 dotenv::dotenv().ok();
86
87 #[allow(clippy::disallowed_method)]
88 let client_id = match env::var("REDDIT_CLIENT_ID") {
89 Ok(client_id) => client_id,
90 Err(_) => return Err(MissingEnvVariableError::RedditClientId),
91 };
92
93 #[allow(clippy::disallowed_method)]
94 let client_secret = match env::var("REDDIT_CLIENT_SECRET") {
95 Ok(client_secret) => client_secret,
96 Err(_) => return Err(MissingEnvVariableError::RedditClientSecret),
97 };
98
99 Self::from_config(Config {
100 client_id: &client_id,
101 client_secret: &client_secret,
102 client_name,
103 })
104 .await
105 }
106
107 pub async fn from_config(
115 Config {
116 client_id,
117 client_secret,
118 client_name,
119 }: Config<'_>,
120 ) -> Result<Result<Self, RedditError>, MissingEnvVariableError> {
121 let base_path = "https://www.reddit.com".parse().expect("infallible");
122
123 let https = hyper_rustls::HttpsConnector::with_native_roots();
124 let client = hyper::Client::builder().build::<_, hyper::Body>(https);
125
126 let auth = match RedditClient::authorize(&client, client_id, client_secret).await {
127 Ok(auth) => auth,
128 Err(err) => return Ok(Err(err)),
129 };
130
131 let user_agent: &'static str = {
133 let version = env!("CARGO_PKG_VERSION");
134 let user_agent = format!(
135 "ubuntu:{name}:{version} (by /u/benluelo)",
136 name = client_name,
137 version = version
138 );
139 Box::leak(Box::new(user_agent))
140 };
141
142 Ok(Ok(Self {
143 client,
144 client_secret: client_secret.into(),
145 client_id: client_id.into(),
146 base_path,
147 auth,
148 user_agent,
149 }))
150 }
151
152 async fn authorize(
153 client: &hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>>,
154 client_id: &str,
155 client_secret: &str,
156 ) -> Result<RedditAuth, RedditError> {
157 let auth_path = format!(
158 "{}/{}",
159 std::env::temp_dir().display(),
160 "reddit_client_authorization.json"
161 );
162
163 if let Ok(Ok(auth)) = std::fs::read_to_string(&auth_path)
164 .as_deref()
165 .map(serde_json::from_str::<RedditAuth>)
166 {
167 if auth.expires_at > Utc::now() {
168 println!("cache hit for auth");
169 return Ok(auth);
170 }
171 println!("fetching new auth");
172 }
173
174 let response = client
175 .request(
176 Request::post("https://www.reddit.com/api/v1/access_token")
177 .header(
178 AUTHORIZATION,
179 dbg!(format!(
180 "Basic {}",
181 base64::encode_config(
182 format!("{}:{}", client_id, client_secret),
183 base64::URL_SAFE
184 )
185 )),
186 )
187 .body(hyper::Body::from("grant_type=client_credentials"))
188 .expect("infallible"),
189 )
190 .await?;
191
192 let buf = aggregate(response).await?;
193
194 let auth = serde_json::from_reader::<_, RedditAuth>(buf.reader())?;
195
196 if let Err(why) = std::fs::write(
197 &auth_path,
198 serde_json::to_string_pretty(&auth).expect("infallible"),
199 ) {
200 log::warn!(target: "reddit_client", "Unable to cache auth file: {}", why);
201 }
202
203 Ok(auth)
204 }
205
206 pub(crate) fn auth_header_value(&self) -> String {
207 format!("Bearer {}", &self.auth.access_token)
208 }
209
210 pub async fn get_subreddit_posts(
221 &mut self,
222 subreddit: &str,
223 sorting: Sort,
224 after: Option<&FullName>,
225 ) -> Result<RedditListing, RedditError> {
226 let after: Cow<_> = if let Some(fullname) = after {
227 format!("after={}", fullname).into()
228 } else {
229 "".into()
230 };
231
232 let uri = Uri::builder()
233 .scheme("https")
234 .authority("oauth.reddit.com")
235 .path_and_query(format!(
236 "/r/{subreddit}/{sorting}.json?{after}",
237 subreddit = subreddit,
238 sorting = sorting,
239 after = &after
240 ))
241 .build()
242 .expect("Uri builder shouldn't fail"); self.request(uri).await
245 }
246
247 async fn request<T: DeserializeOwned>(&mut self, uri: Uri) -> Result<T, RedditError> {
248 self.check_auth().await?;
249
250 tokio::time::sleep(Duration::from_secs(1)).await;
251
252 let request = self
253 .base_request(uri)
254 .body(Body::empty())
255 .expect("infallible");
256
257 let response = self.client.request(request).await?;
258
259 let mut buf = aggregate(response).await?.reader();
260
261 let mut bytes = Vec::new();
262
263 std::io::copy(&mut buf, &mut bytes)?;
264
265 let json = String::from_utf8(bytes)?;
266
267 let listings = serde_json::from_str(&json)?;
270
271 Ok(listings)
272 }
273
274 fn base_request(&self, uri: Uri) -> hyper::http::request::Builder {
275 Request::get(uri)
276 .header(AUTHORIZATION, &self.auth_header_value())
277 .header(ACCEPT, "*/*")
278 .header(USER_AGENT, self.user_agent)
279 }
280
281 pub(crate) async fn check_auth(&mut self) -> Result<(), RedditError> {
282 if self.auth.expires_at <= Utc::now() {
283 self.auth =
284 RedditClient::authorize(&self.client, &self.client_id, &self.client_secret).await?;
285 Ok(())
286 } else {
287 Ok(())
288 }
289 }
290
291 pub async fn get_comments(&mut self) -> Result<Vec<RedditListing>, RedditError> {
297 let uri: Uri = "https://oauth.reddit.com/r/benluelo_testing/comments/qbq1jr/yeet/.json"
298 .parse()
299 .expect("");
300
301 self.request(uri).await
302 }
303}
304
305#[derive(Debug, Error)]
306pub enum RedditError {
307 #[error("error fetching resource")]
308 Request(#[from] hyper::Error),
309
310 #[error("error deserializing resource")]
311 Deserialize(#[from] serde_json::Error),
312
313 #[error("Payload was not valid UTF8")]
314 Utf8(#[from] std::string::FromUtf8Error),
315
316 #[error("IO error")]
317 Io(#[from] std::io::Error),
318}
319
320