use crate::v2::errors::SmugMugError;
use base64::prelude::*;
use bytes::Bytes;
use chrono::{DateTime, Duration, TimeZone, Utc};
use hmac::{Hmac, Mac};
use num_enum::TryFromPrimitive;
use rand::Rng;
use rand::distr::Alphanumeric;
use reqwest::Response as ReqwestResponse;
use reqwest::header::HeaderMap;
use serde::Deserialize;
use serde::de::DeserializeOwned;
use sha1::Sha1;
use std::collections::BTreeMap;
use std::sync::{Arc, RwLock};
use urlencoding::encode as url_encode;
type HmacSha1 = Hmac<Sha1>;
pub(crate) const API_ORIGIN: &str = "https://api.smugmug.com";
#[derive(Default, Clone)]
pub struct Client {
inner: Arc<ClientRef>,
}
impl Client {
pub fn new(creds: Creds) -> Self {
Self {
inner: Arc::new(ClientRef::new(creds)),
}
}
pub async fn get<T: DeserializeOwned>(
&self,
url: &str,
params: Option<&ApiParams<'_>>,
) -> Result<Response<T>, SmugMugError> {
self.inner.get::<T>(url, params).await
}
pub async fn get_binary_data(
&self,
url: &str,
params: Option<&ApiParams<'_>>,
) -> Result<Response<Bytes>, SmugMugError> {
self.inner.get_binary_data(url, params).await
}
pub async fn patch<T: DeserializeOwned>(
&self,
url: &str,
data: Vec<u8>,
params: Option<&ApiParams<'_>>,
) -> Result<Response<T>, SmugMugError> {
self.inner.patch::<T>(url, data, params).await
}
pub async fn post<T: DeserializeOwned>(
&self,
url: &str,
data: Vec<u8>,
params: Option<&ApiParams<'_>>,
) -> Result<Response<T>, SmugMugError> {
self.inner.post::<T>(url, data, params).await
}
pub fn get_last_rate_limit_window_update(&self) -> Option<Arc<RateLimitWindow>> {
let rate_window = self
.inner
.last_rate_window
.read()
.expect("Failed read locking for last rate window update")
.clone();
if rate_window.is_valid() {
Some(rate_window)
} else {
None
}
}
}
#[derive(Default)]
struct ClientRef {
creds: Creds,
https_client: reqwest::Client,
last_rate_window: RwLock<Arc<RateLimitWindow>>,
}
impl ClientRef {
fn new(creds: Creds) -> Self {
Self {
creds,
https_client: reqwest::Client::new(),
last_rate_window: RwLock::new(Arc::new(RateLimitWindow {
..Default::default()
})),
}
}
async fn get<T: DeserializeOwned>(
&self,
url: &str,
params: Option<&ApiParams<'_>>,
) -> Result<Response<T>, SmugMugError> {
let req_url = self.create_req(url, params)?;
let resp = if self.creds.are_all_tokens_available() {
let auth_header = self.creds.create_oauth1_header("GET", &req_url)?;
self.https_client
.clone()
.get(req_url)
.header("Accept", "application/json")
.header("Authorization", auth_header)
.send()
.await?
} else {
self.https_client
.clone()
.get(req_url)
.header("Accept", "application/json")
.send()
.await?
};
self.handle_json_response(resp).await
}
async fn get_binary_data(
&self,
url: &str,
params: Option<&ApiParams<'_>>,
) -> Result<Response<Bytes>, SmugMugError> {
let req_url = self.create_req(url, params)?;
let resp = if self.creds.are_all_tokens_available() {
let auth_header = self.creds.create_oauth1_header("GET", &req_url)?;
self.https_client
.clone()
.get(req_url)
.header("Authorization", auth_header)
.send()
.await?
} else {
self.https_client.clone().get(req_url).send().await?
};
self.error_on_http_status(&resp, None)?;
match resp.bytes().await {
Ok(body) => Ok(Response {
payload: Some(body),
rate_limit: None,
}),
Err(err) => Err(err.into()),
}
}
async fn patch<T: DeserializeOwned>(
&self,
url: &str,
data: Vec<u8>,
params: Option<&ApiParams<'_>>,
) -> Result<Response<T>, SmugMugError> {
let req_url = self.create_req(url, params)?;
let auth_header = self.creds.create_oauth1_header("PATCH", &req_url)?;
let resp = self
.https_client
.clone()
.patch(req_url)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("Authorization", auth_header)
.body(data)
.send()
.await?;
self.handle_json_response(resp).await
}
async fn post<T: DeserializeOwned>(
&self,
url: &str,
data: Vec<u8>,
params: Option<&ApiParams<'_>>,
) -> Result<Response<T>, SmugMugError> {
let req_url = self.create_req(url, params)?;
let auth_header = self.creds.create_oauth1_header("POST", &req_url)?;
let resp = self
.https_client
.clone()
.post(req_url)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("Authorization", auth_header)
.body(data)
.send()
.await?;
self.handle_json_response(resp).await
}
fn extract_rate_limits_from_response(&self, resp: &ReqwestResponse) -> Arc<RateLimitWindow> {
let rate_limit = Arc::new(RateLimitWindow::from_reqwest_headers(
Utc::now(),
resp.headers(),
));
self.last_rate_window
.write()
.map(|mut v| {
*v = rate_limit.clone();
})
.unwrap();
rate_limit
}
fn error_on_http_status(
&self,
resp: &ReqwestResponse,
rate_limit: Option<&RateLimitWindow>,
) -> Result<(), SmugMugError> {
let _ = resp.error_for_status_ref().map_err(|v| {
match rate_limit.and_then(|v| v.retry_after_seconds()) {
Some(retry_after) if resp.status().as_u16() == 429 => {
SmugMugError::ApiResponseTooManyRequests(retry_after)
}
_ => SmugMugError::from(v),
}
})?;
Ok(())
}
async fn handle_json_response<T: DeserializeOwned>(
&self,
resp: ReqwestResponse,
) -> Result<Response<T>, SmugMugError> {
let rate_limit = self.extract_rate_limits_from_response(&resp);
self.error_on_http_status(&resp, Some(&rate_limit))?;
let payload_bytes = resp.bytes().await?;
if log::log_enabled!(log::Level::Debug) {
if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&payload_bytes) {
log::debug!("JSON Raw Resp: {}", serde_json::to_string_pretty(&val)?);
}
}
match serde_json::from_slice::<ResponseBody<T>>(payload_bytes.as_ref()) {
Ok(body) => {
if !body.is_code_an_error()? {
return Err(SmugMugError::ApiResponse(body.code, body.message));
}
Ok(Response {
payload: body.response,
rate_limit: Some(rate_limit),
})
}
Err(err) => {
if log::log_enabled!(log::Level::Debug) {
log::debug!(
"Payload parse error: {}",
String::from_utf8(payload_bytes.to_vec()).unwrap()
);
}
Err(SmugMugError::ApiResponseMalformed(err))
}
}
}
fn create_req(
&self,
url: &str,
params: Option<&ApiParams<'_>>,
) -> Result<reqwest::Url, SmugMugError> {
let mut req_url = params.map_or(reqwest::Url::parse(url), |v| {
reqwest::Url::parse_with_params(url, v)
})?;
if self.creds.access_token.is_none() || self.creds.token_secret.is_none() {
req_url = reqwest::Url::parse_with_params(
req_url.as_str(),
[("APIKey", &self.creds.consumer_api_key)],
)?;
}
if log::log_enabled!(log::Level::Trace) {
log::trace!("Outgoing request url: {}", req_url.as_str());
}
Ok(req_url)
}
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ApiClient").finish()
}
}
pub type ApiParams<'a> = [(&'a str, &'a str)];
#[derive(Debug, TryFromPrimitive)]
#[repr(u32)]
pub enum ApiErrorCodes {
Ok = 200,
CreatedSuccessfully = 201,
Accepted = 202,
MovedPermanently = 301,
MovedTemporarily = 302,
BadRequest = 400,
Unauthorized = 401,
PaymentRequired = 402,
Forbidden = 403,
NotFound = 404,
MethodNotAllowed = 405,
BadAccept = 406,
Conflict = 407,
TooManyRequests = 429,
InternalServerError = 500,
ServiceUnavailable = 503,
}
#[derive(Default, Clone)]
pub struct RateLimitWindow {
num_remaining_requests: Option<u64>,
current_window_reset_datetime: Option<DateTime<Utc>>,
retry_after_seconds: Option<u64>,
timestamp: DateTime<Utc>,
}
impl RateLimitWindow {
fn from_reqwest_headers(timestamp: DateTime<Utc>, headers: &HeaderMap) -> Self {
let retry_after_seconds = headers
.get("retry-after")
.and_then(|v| v.to_str().map_or(None, |v| v.parse().ok()));
let num_remaining_requests = headers
.get("x-ratelimit-remaining")
.and_then(|v| v.to_str().map_or(None, |v| v.parse().ok()));
let current_window_reset_datetime = headers.get("x-ratelimit-reset").and_then(|v| {
v.to_str().map_or(None, |v| {
v.parse::<i64>()
.ok()
.and_then(|v| Utc.timestamp_opt(v, 0).latest())
})
});
RateLimitWindow {
num_remaining_requests,
current_window_reset_datetime,
retry_after_seconds,
timestamp,
}
}
pub fn num_remaining_requests(&self) -> Option<u64> {
self.num_remaining_requests
}
pub fn window_reset_datetime(&self) -> Option<DateTime<Utc>> {
self.current_window_reset_datetime
}
pub fn retry_after_seconds(&self) -> Option<u64> {
self.retry_after_seconds
}
pub fn is_valid(&self) -> bool {
self.num_remaining_requests.is_some() || self.retry_after_seconds.is_some()
}
pub fn resume_after(&self) -> Option<DateTime<Utc>> {
self.retry_after_seconds
.map(|v| self.timestamp + Duration::seconds(v as i64))
}
pub fn timestamp(&self) -> DateTime<Utc> {
self.timestamp
}
}
pub struct Response<T> {
pub payload: Option<T>,
pub rate_limit: Option<Arc<RateLimitWindow>>,
}
#[derive(Default, Clone)]
pub struct Creds {
consumer_api_key: String,
consumer_api_secret: Option<String>,
access_token: Option<String>,
token_secret: Option<String>,
}
impl Creds {
pub fn from_tokens(
consumer_api_key: &str,
consumer_api_secret: Option<&str>,
access_token: Option<&str>,
token_secret: Option<&str>,
) -> Self {
Self {
consumer_api_key: consumer_api_key.into(),
consumer_api_secret: consumer_api_secret.map(|v| v.into()),
access_token: access_token.map(|v| v.into()),
token_secret: token_secret.map(|v| v.into()),
}
}
fn are_all_tokens_available(&self) -> bool {
!self.consumer_api_key.is_empty()
&& self.consumer_api_secret.is_some()
&& self.access_token.is_some()
&& self.token_secret.is_some()
}
fn create_oauth1_header(
&self,
method: &str,
url: &reqwest::Url,
) -> Result<String, SmugMugError> {
let access_token = self
.access_token
.as_ref()
.ok_or(SmugMugError::Auth("Access token not found".to_string()))?;
let consumer_api_secret = self
.consumer_api_secret
.as_ref()
.ok_or(SmugMugError::Auth("Consumer secret not found".to_string()))?;
let token_secret = self
.token_secret
.as_ref()
.ok_or(SmugMugError::Auth("Token secret not found".to_string()))?;
let timestamp = Utc::now().timestamp().to_string();
let nonce: String = rand::rng()
.sample_iter(Alphanumeric)
.take(32) .map(char::from)
.collect();
let mut oauth_params: BTreeMap<&str, &str> = BTreeMap::new();
oauth_params.insert("oauth_consumer_key", &self.consumer_api_key);
oauth_params.insert("oauth_nonce", &nonce);
oauth_params.insert("oauth_signature_method", "HMAC-SHA1");
oauth_params.insert("oauth_timestamp", ×tamp);
oauth_params.insert("oauth_token", access_token);
oauth_params.insert("oauth_version", "1.0");
let mut all_params = oauth_params.clone();
let extra_params = url.query_pairs().into_owned().collect::<BTreeMap<_, _>>();
for (key, value) in &extra_params {
all_params.insert(key, value);
}
let parameter_string = all_params
.iter()
.map(|(key, value)| format!("{}={}", url_encode(key), url_encode(value)))
.collect::<Vec<String>>()
.join("&");
let url_to_sign = {
let mut signing_url = url.clone();
signing_url.set_query(None);
signing_url.set_fragment(None);
signing_url.to_string()
};
let base_string = format!(
"{}&{}&{}",
method.to_uppercase(),
url_encode(&url_to_sign),
url_encode(¶meter_string)
);
let signing_key = format!(
"{}&{}",
url_encode(consumer_api_secret),
url_encode(token_secret)
);
let mut mac = HmacSha1::new_from_slice(signing_key.as_bytes())
.expect("HMAC can be initialized with key");
mac.update(base_string.as_bytes());
let signature = mac.finalize().into_bytes();
let signature_base64 = BASE64_STANDARD.encode(signature);
oauth_params.insert("oauth_signature", &signature_base64);
let auth_header_value = oauth_params
.iter()
.map(|(key, value)| format!("{}=\"{}\"", key, url_encode(value)))
.collect::<Vec<String>>()
.join(", ");
let header = format!("OAuth {}", auth_header_value);
Ok(header)
}
}
impl std::fmt::Debug for Creds {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Creds")
.field("consumer_api_key", &"xxx")
.field(
"consumer_api_secret",
&self.consumer_api_secret.as_ref().map_or("", |_| "xxx"),
)
.field(
"access_token",
&self.access_token.as_ref().map_or("", |_| "xxx"),
)
.field(
"token_secret",
&self.access_token.as_ref().map_or("", |_| "xxx"),
)
.finish()
}
}
#[derive(Deserialize, Debug)]
struct ResponseBody<ResponseType> {
#[serde(rename = "Code")]
code: u32,
#[serde(rename = "Message")]
message: String,
#[serde(rename = "Response")]
response: Option<ResponseType>,
}
impl<ResponseType> ResponseBody<ResponseType> {
fn is_code_an_error(&self) -> Result<bool, SmugMugError> {
use ApiErrorCodes as E;
match ApiErrorCodes::try_from(self.code)? {
E::Accepted
| E::Ok
| E::CreatedSuccessfully
| E::MovedPermanently
| E::MovedTemporarily => Ok(true),
_ => Ok(false),
}
}
}
#[derive(Deserialize, Debug)]
pub(crate) struct Pages {
#[serde(rename = "NextPage")]
pub(crate) next_page: Option<String>,
}