reddit-rs 0.1.1

A wrapper around the Reddit API.
Documentation
use std::{borrow::Cow, env, time::Duration};

use chrono::Utc;
use hyper::{
    body::{aggregate, Buf},
    header::{ACCEPT, AUTHORIZATION, USER_AGENT},
    Body, Request, Uri,
};
use serde::de::DeserializeOwned;
use thiserror::Error;

use crate::models::{
    auth::RedditAuth,
    fullname::FullName,
    link::{RedditListing, Sort},
};

pub struct RedditClient {
    pub(crate) client: hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>>,

    #[allow(dead_code)]
    pub(crate) client_secret: String,

    #[allow(dead_code)]
    pub(crate) client_id: String,

    #[allow(dead_code)]
    pub(crate) base_path: Uri,

    pub(crate) auth: RedditAuth,

    pub(crate) user_agent: &'static str,
}

#[derive(Debug, thiserror::Error)]
pub enum MissingEnvVariableError {
    #[error("REDDIT_CLIENT_ID environment variable missing")]
    RedditClientId,
    #[error("REDDIT_CLIENT_SECRET environment variable missing")]
    RedditClientSecret,
}

/// Configuration struct for [`RedditClient`].
pub struct Config<'a> {
    pub client_id: &'a str,
    pub client_secret: &'a str,
    /// The name of the client, to be used in the `User-Agent` header in all
    /// rqeuests to reddit.
    ///
    /// https://github.com/reddit-archive/reddit/wiki/API#rules
    pub client_name: &'a str,
}

impl RedditClient {
    /// Creates a new client, reading from the `REDDIT_CLIENT_ID` and
    /// `REDDIT_CLIENT_SECRET` for the `client_id` and `client_secret`,
    /// respectively. This function uses dotenv, so the variables can be kept in
    /// a `.env` file at the root of the workspace.
    ///
    /// # Errors
    /// This function will error if the credentials provided in the
    /// aformentioned environment variables are not able to authenticate on
    /// Reddit.
    ///
    /// # Panics
    /// This function will panic if either of the aformentioned environment
    /// variables are not present.
    pub async fn from_env(client_name: &str) -> Result<Self, RedditError> {
        #[allow(clippy::unwrap_used)]
        Self::try_from_env(client_name).await.unwrap()
    }

    /// Creates a new client, reading from the `REDDIT_CLIENT_ID` and
    /// `REDDIT_CLIENT_SECRET` for the `client_id` and `client_secret`,
    /// respectively. This function uses dotenv, so the variables can be kept in
    /// a `.env` file at the root of the workspace.
    ///
    /// # Errors
    /// This function will error if the aformentioned environment variables are
    /// not present or if the credentials provided in them are not able to
    /// authenticate on Reddit.
    pub async fn try_from_env(
        client_name: &str,
    ) -> Result<Result<Self, RedditError>, MissingEnvVariableError> {
        dotenv::dotenv().ok();

        #[allow(clippy::disallowed_method)]
        let client_id = match env::var("REDDIT_CLIENT_ID") {
            Ok(client_id) => client_id,
            Err(_) => return Err(MissingEnvVariableError::RedditClientId),
        };

        #[allow(clippy::disallowed_method)]
        let client_secret = match env::var("REDDIT_CLIENT_SECRET") {
            Ok(client_secret) => client_secret,
            Err(_) => return Err(MissingEnvVariableError::RedditClientSecret),
        };

        Self::from_config(Config {
            client_id: &client_id,
            client_secret: &client_secret,
            client_name,
        })
        .await
    }

    /// Creates a new client with the provided [`Config`]. See the documentation
    /// of [`Config`] for more information.
    ///
    /// # Errors
    /// This function will error if the aformentioned environment variables are
    /// not present or if the credentials provided in them are not able to
    /// authenticate on Reddit.
    pub async fn from_config(
        Config {
            client_id,
            client_secret,
            client_name,
        }: Config<'_>,
    ) -> Result<Result<Self, RedditError>, MissingEnvVariableError> {
        let base_path = "https://www.reddit.com".parse().expect("infallible");

        let https = hyper_rustls::HttpsConnector::with_native_roots();
        let client = hyper::Client::builder().build::<_, hyper::Body>(https);

        let auth = match RedditClient::authorize(&client, client_id, client_secret).await {
            Ok(auth) => auth,
            Err(err) => return Ok(Err(err)),
        };

        // https://github.com/reddit-archive/reddit/wiki/API#rules
        let user_agent: &'static str = {
            let version = env!("CARGO_PKG_VERSION");
            let user_agent = format!(
                "ubuntu:{name}:{version} (by /u/benluelo)",
                name = client_name,
                version = version
            );
            Box::leak(Box::new(user_agent))
        };

        Ok(Ok(Self {
            client,
            client_secret: client_secret.into(),
            client_id: client_id.into(),
            base_path,
            auth,
            user_agent,
        }))
    }

    async fn authorize(
        client: &hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>>,
        client_id: &str,
        client_secret: &str,
    ) -> Result<RedditAuth, RedditError> {
        let auth_path = format!(
            "{}/{}",
            std::env::temp_dir().display(),
            "reddit_client_authorization.json"
        );

        if let Ok(Ok(auth)) = std::fs::read_to_string(&auth_path)
            .as_deref()
            .map(serde_json::from_str::<RedditAuth>)
        {
            if auth.expires_at > Utc::now() {
                println!("cache hit for auth");
                return Ok(auth);
            }
            println!("fetching new auth");
        }

        let response = client
            .request(
                Request::post("https://www.reddit.com/api/v1/access_token")
                    .header(
                        AUTHORIZATION,
                        dbg!(format!(
                            "Basic {}",
                            base64::encode_config(
                                format!("{}:{}", client_id, client_secret),
                                base64::URL_SAFE
                            )
                        )),
                    )
                    .body(hyper::Body::from("grant_type=client_credentials"))
                    .expect("infallible"),
            )
            .await?;

        let buf = aggregate(response).await?;

        let auth = serde_json::from_reader::<_, RedditAuth>(buf.reader())?;

        if let Err(why) = std::fs::write(
            &auth_path,
            serde_json::to_string_pretty(&auth).expect("infallible"),
        ) {
            log::warn!(target: "reddit_client", "Unable to cache auth file: {}", why);
        }

        Ok(auth)
    }

    pub(crate) fn auth_header_value(&self) -> String {
        format!("Bearer {}", &self.auth.access_token)
    }

    /// Gets posts from a subreddit, using the specified `sorting` and `after`
    /// options.
    ///
    /// # Errors
    /// This function will error for many reasons. See the documentation for
    /// [`RedditError`] for more information.
    ///
    /// # Panics
    /// This function will panic if either of the aformentioned environment
    /// variables are not present.
    pub async fn get_subreddit_posts(
        &mut self,
        subreddit: &str,
        sorting: Sort,
        after: Option<&FullName>,
    ) -> Result<RedditListing, RedditError> {
        let after: Cow<_> = if let Some(fullname) = after {
            format!("after={}", fullname).into()
        } else {
            "".into()
        };

        let uri = Uri::builder()
            .scheme("https")
            .authority("oauth.reddit.com")
            .path_and_query(format!(
                "/r/{subreddit}/{sorting}.json?{after}",
                subreddit = subreddit,
                sorting = sorting,
                after = &after
            ))
            .build()
            .expect("Uri builder shouldn't fail"); //oauth.reddit.com/r/PKMNTCGDeals/new.json");

        self.request(uri).await
    }

    async fn request<T: DeserializeOwned>(&mut self, uri: Uri) -> Result<T, RedditError> {
        self.check_auth().await?;

        tokio::time::sleep(Duration::from_secs(1)).await;

        let request = self
            .base_request(uri)
            .body(Body::empty())
            .expect("infallible");

        let response = self.client.request(request).await?;

        let mut buf = aggregate(response).await?.reader();

        let mut bytes = Vec::new();

        std::io::copy(&mut buf, &mut bytes)?;

        let json = String::from_utf8(bytes)?;

        // println!("{}", &json);

        let listings = serde_json::from_str(&json)?;

        Ok(listings)
    }

    fn base_request(&self, uri: Uri) -> hyper::http::request::Builder {
        Request::get(uri)
            .header(AUTHORIZATION, &self.auth_header_value())
            .header(ACCEPT, "*/*")
            .header(USER_AGENT, self.user_agent)
    }

    pub(crate) async fn check_auth(&mut self) -> Result<(), RedditError> {
        if self.auth.expires_at <= Utc::now() {
            self.auth =
                RedditClient::authorize(&self.client, &self.client_id, &self.client_secret).await?;
            Ok(())
        } else {
            Ok(())
        }
    }

    /// Gets the comments for the provided post.
    ///
    /// # Errors
    /// This function will error for many reasons. See the documentation for
    /// [`RedditError`] for more information.
    pub async fn get_comments(&mut self) -> Result<Vec<RedditListing>, RedditError> {
        let uri: Uri = "https://oauth.reddit.com/r/benluelo_testing/comments/qbq1jr/yeet/.json"
            .parse()
            .expect("");

        self.request(uri).await
    }
}

#[derive(Debug, Error)]
pub enum RedditError {
    #[error("error fetching resource")]
    Request(#[from] hyper::Error),

    #[error("error deserializing resource")]
    Deserialize(#[from] serde_json::Error),

    #[error("Payload was not valid UTF8")]
    Utf8(#[from] std::string::FromUtf8Error),

    #[error("IO error")]
    Io(#[from] std::io::Error),
}

// TODO: Move to integration test
// #[cfg(test)]
// mod test_reddit_client {
//     use super::*;

//     #[tokio::test]
//     async fn test_get_subreddit_posts() {
//         let mut client = RedditClient::new("TCG Collector Discord Bot")
//             .await
//             .unwrap();

//         let mut after = None;
//         loop {
//             let listing = client
//                 .get_subreddit_posts("PKMNTCGDeals", Sort::New,
// after.as_ref())                 .await
//                 .unwrap();
//             println!("got posts");
//             after = match listing {
//                 RedditListing::Listing {
//                     after, children, ..
//                 } => {
//                     for child in children {
//                         println!("link: {:#?}", &child);
//                     }
//                     after
//                 }
//                 RedditListing::Link(_) => unreachable!(),
//             };
//         }
//     }
// }