use std::{num::ParseIntError, str::FromStr};
use http::{header::AsHeaderName, HeaderMap, StatusCode};
use serde::Serialize;
use crate::{
error::{ApiError, DeserializeError, FromHttpError, HeaderError, IntoHttpError},
AuthRequirement, Context, Metadata,
};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize)]
pub struct Pagination {
pub page: usize,
pub limit: usize,
}
impl Default for Pagination {
fn default() -> Self {
Self::DEFAULT
}
}
impl Pagination {
const DEFAULT: Self = Self::new(1, 10);
#[inline]
#[must_use]
pub const fn new(page: usize, limit: usize) -> Self {
Self { page, limit }
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct PaginationResponse<T> {
pub items: Vec<T>,
pub current_page: usize,
pub items_per_page: usize,
pub total_pages: usize,
pub total_items: usize,
}
impl<T> PaginationResponse<T> {
pub fn from_headers(items: Vec<T>, map: &HeaderMap) -> Result<Self, DeserializeError> {
let current_page = parse_from_header(map, "X-Pagination-Page")?;
let items_per_page = parse_from_header(map, "X-Pagination-Limit")?;
let total_pages = parse_from_header(map, "X-Pagination-Page-Count")?;
let total_items = parse_from_header(map, "X-Pagination-Item-Count")?;
Ok(Self {
items,
current_page,
items_per_page,
total_pages,
total_items,
})
}
#[inline]
#[must_use]
pub const fn next_page(&self) -> Option<Pagination> {
if self.current_page < self.total_pages {
Some(Pagination::new(self.current_page + 1, self.items_per_page))
} else {
None
}
}
}
pub fn parse_from_header<T, K>(map: &HeaderMap, key: K) -> Result<T, DeserializeError>
where
T: FromStr<Err = ParseIntError>,
K: AsHeaderName,
{
map.get(key)
.ok_or(HeaderError::MissingHeader)?
.to_str()
.map_err(HeaderError::ToStrError)?
.parse()
.map_err(DeserializeError::ParseInt)
}
pub fn handle_response_body<B, T>(
response: &http::Response<B>,
expected: StatusCode,
) -> Result<T, FromHttpError>
where
B: AsRef<[u8]>,
T: serde::de::DeserializeOwned,
{
if response.status() == expected {
Ok(serde_json::from_slice(response.body().as_ref()).map_err(DeserializeError::Json)?)
} else {
Err(FromHttpError::Api(ApiError::from(response.status())))
}
}
pub fn construct_req<B>(
ctx: &Context,
md: &Metadata,
path: &impl Serialize,
query: &impl Serialize,
body: B,
) -> Result<http::Request<B>, IntoHttpError> {
let url = crate::construct_url(ctx.base_url, md.endpoint, path, query)?;
let request = http::Request::builder()
.method(&md.method)
.uri(url)
.header("Content-Type", "application/json")
.header("trakt-api-version", "2")
.header("trakt-api-key", ctx.client_id);
let request = match (md.auth, ctx.oauth_token) {
(AuthRequirement::None, _) | (AuthRequirement::Optional, None) => request,
(AuthRequirement::Optional | AuthRequirement::Required, Some(token)) => {
request.header("Authorization", format!("Bearer {token}"))
}
(AuthRequirement::Required, None) => {
return Err(IntoHttpError::MissingToken);
}
};
Ok(request.body(body)?)
}
#[cfg(test)]
mod tests {
use http::HeaderValue;
use super::*;
#[test]
fn test_parse_from_header() {
let mut map = HeaderMap::new();
map.insert("B", HeaderValue::from_bytes(b"hello\xfa").unwrap());
map.insert("C", HeaderValue::from_static("hello"));
map.insert("D", HeaderValue::from_static("10"));
assert!(matches!(
parse_from_header::<u32, _>(&map, "A"),
Err(DeserializeError::Header(HeaderError::MissingHeader))
));
assert!(matches!(
parse_from_header::<u32, _>(&map, "B"),
Err(DeserializeError::Header(HeaderError::ToStrError(_)))
));
assert!(matches!(
parse_from_header::<u32, _>(&map, "C"),
Err(DeserializeError::ParseInt(_))
));
assert_eq!(parse_from_header::<u32, _>(&map, "D").unwrap(), 10);
}
#[test]
fn test_handle_response_body_ok() {
let response = http::Response::builder()
.status(StatusCode::OK)
.body(b"\"hello\"")
.unwrap();
assert_eq!(
handle_response_body::<_, String>(&response, StatusCode::OK).unwrap(),
"hello"
);
}
#[test]
fn test_handle_response_body_bad_request() {
let response = http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(b"\"hello\"")
.unwrap();
assert!(matches!(
handle_response_body::<_, String>(&response, StatusCode::OK),
Err(FromHttpError::Api(ApiError::BadRequest))
));
}
#[test]
fn test_handle_response_body_deserialize_error() {
let response = http::Response::builder()
.status(StatusCode::OK)
.body(b"\"hello\xfa\"")
.unwrap();
assert!(matches!(
handle_response_body::<_, String>(&response, StatusCode::OK),
Err(FromHttpError::Deserialize(DeserializeError::Json(_)))
));
}
#[allow(clippy::cognitive_complexity)]
#[test]
fn test_construct_req() {
let mut ctx = Context {
base_url: "https://api.trakt.tv",
client_id: "client id",
oauth_token: None,
};
let mut md = Metadata {
endpoint: "/test",
method: http::Method::GET,
auth: AuthRequirement::None,
};
let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
assert_eq!(req.method(), &http::Method::GET);
assert_eq!(req.uri(), "https://api.trakt.tv/test");
assert_eq!(
req.headers().get("Content-Type").unwrap(),
"application/json"
);
assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
assert!(req.headers().get("Authorization").is_none());
assert_eq!(req.into_body(), "body");
md.auth = AuthRequirement::Required;
ctx.oauth_token = Some("token");
let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
assert_eq!(req.method(), &http::Method::GET);
assert_eq!(req.uri(), "https://api.trakt.tv/test");
assert_eq!(
req.headers().get("Content-Type").unwrap(),
"application/json"
);
assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
assert_eq!(req.headers().get("Authorization").unwrap(), "Bearer token");
assert_eq!(req.into_body(), "body");
md.auth = AuthRequirement::Required;
ctx.oauth_token = None;
let result = construct_req(&ctx, &md, &(), &(), "body").unwrap_err();
assert!(matches!(result, IntoHttpError::MissingToken));
md.auth = AuthRequirement::Optional;
ctx.oauth_token = None;
let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
assert_eq!(req.method(), &http::Method::GET);
assert_eq!(req.uri(), "https://api.trakt.tv/test");
assert_eq!(
req.headers().get("Content-Type").unwrap(),
"application/json"
);
assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
assert!(req.headers().get("Authorization").is_none());
assert_eq!(req.into_body(), "body");
md.auth = AuthRequirement::Optional;
ctx.oauth_token = Some("token");
let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
assert_eq!(req.method(), &http::Method::GET);
assert_eq!(req.uri(), "https://api.trakt.tv/test");
assert_eq!(
req.headers().get("Content-Type").unwrap(),
"application/json"
);
assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
assert_eq!(req.headers().get("Authorization").unwrap(), "Bearer token");
assert_eq!(req.into_body(), "body");
}
}