use reqwest::{
header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
Method, Response,
};
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::error::Error;
use nordnet_model::auth::Session;
#[derive(Debug, Clone)]
pub struct Client {
http: reqwest::Client,
base_url: String,
session: Option<Session>,
}
impl Client {
pub fn new(base_url: impl Into<String>) -> Result<Self, Error> {
let http = reqwest::Client::builder()
.build()
.map_err(Error::Transport)?;
Ok(Self {
http,
base_url: base_url.into().trim_end_matches('/').to_owned(),
session: None,
})
}
pub fn with_session(mut self, session: Session) -> Self {
self.session = Some(session);
self
}
pub fn session(&self) -> Option<&Session> {
self.session.as_ref()
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T, Error> {
self.send::<T, ()>(Method::GET, path, None).await
}
pub async fn post<T: DeserializeOwned, B: Serialize>(
&self,
path: &str,
body: &B,
) -> Result<T, Error> {
self.send(Method::POST, path, Some(Body::Json(body))).await
}
pub async fn put<T: DeserializeOwned, B: Serialize>(
&self,
path: &str,
body: &B,
) -> Result<T, Error> {
self.send(Method::PUT, path, Some(Body::Json(body))).await
}
pub async fn post_form<T: DeserializeOwned, B: Serialize>(
&self,
path: &str,
body: &B,
) -> Result<T, Error> {
self.send(Method::POST, path, Some(Body::Form(body))).await
}
pub async fn put_form<T: DeserializeOwned, B: Serialize>(
&self,
path: &str,
body: &B,
) -> Result<T, Error> {
self.send(Method::PUT, path, Some(Body::Form(body))).await
}
pub async fn put_empty<T: DeserializeOwned>(&self, path: &str) -> Result<T, Error> {
self.send::<T, ()>(Method::PUT, path, None).await
}
pub async fn delete<T: DeserializeOwned>(&self, path: &str) -> Result<T, Error> {
self.send::<T, ()>(Method::DELETE, path, None).await
}
pub fn url(&self, path: &str) -> String {
if path.starts_with('/') {
format!("{}{}", self.base_url, path)
} else {
format!("{}/{}", self.base_url, path)
}
}
fn auth_headers(&self) -> Result<HeaderMap, Error> {
let mut headers = HeaderMap::new();
if let Some(session) = &self.session {
let value = session.basic_auth_header();
let header =
HeaderValue::from_str(&value).map_err(|e| Error::InvalidHeader(e.to_string()))?;
headers.insert(AUTHORIZATION, header);
}
Ok(headers)
}
async fn send<T: DeserializeOwned, B: Serialize>(
&self,
method: Method,
path: &str,
body: Option<Body<'_, B>>,
) -> Result<T, Error> {
let url = self.url(path);
let headers = self.auth_headers()?;
let response = self.execute_once(method, &url, headers, body).await?;
parse_response::<T>(response).await
}
async fn execute_once<B: Serialize>(
&self,
method: Method,
url: &str,
headers: HeaderMap,
body: Option<Body<'_, B>>,
) -> Result<Response, Error> {
let mut req = self.http.request(method, url).headers(headers);
match body {
Some(Body::Json(b)) => {
req = req.header(CONTENT_TYPE, "application/json").json(b);
}
Some(Body::Form(b)) => {
let encoded =
serde_urlencoded::to_string(b).map_err(|e| Error::EncodeForm(e.to_string()))?;
req = req
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(encoded);
}
None => {}
}
req.send().await.map_err(Error::Transport)
}
}
enum Body<'a, B: Serialize> {
Json(&'a B),
Form(&'a B),
}
async fn parse_response<T: DeserializeOwned>(response: Response) -> Result<T, Error> {
let status = response.status();
let body = response.text().await.map_err(Error::Transport)?;
if status.is_success() {
serde_json::from_str::<T>(&body).map_err(|source| Error::Decode { source, body })
} else {
Err(Error::from_status(status.as_u16(), body))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn url_handles_leading_slash() {
let c = Client::new("http://example.com/api/2").unwrap();
assert_eq!(c.url("/accounts"), "http://example.com/api/2/accounts");
assert_eq!(c.url("accounts"), "http://example.com/api/2/accounts");
}
#[test]
fn url_strips_trailing_slash_on_base() {
let c = Client::new("http://example.com/api/2/").unwrap();
assert_eq!(c.url("/x"), "http://example.com/api/2/x");
}
#[test]
fn no_session_no_auth_header() {
let c = Client::new("http://x").unwrap();
let h = c.auth_headers().unwrap();
assert!(!h.contains_key(AUTHORIZATION));
}
#[test]
fn with_session_sets_basic_auth() {
let c = Client::new("http://x").unwrap().with_session(Session {
session_key: "abc".into(),
expires_in: 60,
});
let h = c.auth_headers().unwrap();
assert_eq!(
h.get(AUTHORIZATION).unwrap().to_str().unwrap(),
"Basic YWJjOmFiYw=="
);
}
}