#![deny(clippy::all)]
#![deny(missing_docs)]
#![forbid(unsafe_code)]
use async_trait::async_trait;
use hyper::{http::uri::InvalidUri, Body, Client, Method, Request};
use libunftp::auth::{AuthenticationError, Authenticator, Credentials, DefaultUser};
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
use regex::Regex;
use serde_json::{json, Value};
use std::string::String;
#[derive(Clone, Debug)]
pub struct RestAuthenticator {
username_placeholder: String,
password_placeholder: String,
method: Method,
url: String,
body: String,
selector: String,
regex: Regex,
}
#[derive(Clone, Debug, Default)]
pub struct Builder {
username_placeholder: String,
password_placeholder: String,
method: Method,
url: String,
body: String,
selector: String,
regex: String,
}
impl Builder {
pub fn new() -> Builder {
Builder { ..Default::default() }
}
pub fn with_username_placeholder(mut self, s: String) -> Self {
self.username_placeholder = s;
self
}
pub fn with_password_placeholder(mut self, s: String) -> Self {
self.password_placeholder = s;
self
}
pub fn with_method(mut self, s: Method) -> Self {
self.method = s;
self
}
pub fn with_url(mut self, s: String) -> Self {
self.url = s;
self
}
pub fn with_body(mut self, s: String) -> Self {
self.body = s;
self
}
pub fn with_selector(mut self, s: String) -> Self {
self.selector = s;
self
}
pub fn with_regex(mut self, s: String) -> Self {
self.regex = s;
self
}
pub fn build(self) -> Result<RestAuthenticator, Box<dyn std::error::Error>> {
Ok(RestAuthenticator {
username_placeholder: self.username_placeholder,
password_placeholder: self.password_placeholder,
method: self.method,
url: self.url,
body: self.body,
selector: self.selector,
regex: Regex::new(&self.regex)?,
})
}
}
impl RestAuthenticator {
fn fill_encoded_placeholders(&self, string: &str, username: &str, password: &str) -> String {
string
.replace(&self.username_placeholder, username)
.replace(&self.password_placeholder, password)
}
}
#[async_trait]
impl Authenticator<DefaultUser> for RestAuthenticator {
#[allow(clippy::type_complexity)]
#[tracing_attributes::instrument]
async fn authenticate(&self, username: &str, creds: &Credentials) -> Result<DefaultUser, AuthenticationError> {
let username_url = utf8_percent_encode(username, NON_ALPHANUMERIC).collect::<String>();
let password = creds.password.as_ref().ok_or(AuthenticationError::BadPassword)?.as_ref();
let password_url = utf8_percent_encode(password, NON_ALPHANUMERIC).collect::<String>();
let url = self.fill_encoded_placeholders(&self.url, &username_url, &password_url);
let username_json = encode_string_json(username);
let password_json = encode_string_json(password);
let body = self.fill_encoded_placeholders(&self.body, &username_json, &password_json);
let method = self.method.clone();
let selector = self.selector.clone();
let regex = self.regex.clone();
let req = Request::builder()
.method(method)
.header("Content-type", "application/json")
.uri(url)
.body(Body::from(body))
.map_err(|e| AuthenticationError::with_source("rest authenticator http client error", e))?;
let client = Client::new();
let resp = client
.request(req)
.await
.map_err(|e| AuthenticationError::with_source("rest authenticator http client error", e))?;
let body_bytes = hyper::body::to_bytes(resp.into_body())
.await
.map_err(|e| AuthenticationError::with_source("rest authenticator http client error", e))?;
let body: Value = serde_json::from_slice(&body_bytes).map_err(|e| AuthenticationError::with_source("rest authenticator unmarshalling error", e))?;
let parsed = match body.pointer(&selector) {
Some(parsed) => parsed.to_string(),
None => json!(null).to_string(),
};
if regex.is_match(&parsed) {
Ok(DefaultUser {})
} else {
Err(AuthenticationError::BadPassword)
}
}
}
fn encode_string_json(string: &str) -> String {
let mut res = String::with_capacity(string.len() * 2);
for i in string.chars() {
match i {
'\\' => res.push_str("\\\\"),
'"' => res.push_str("\\\""),
' '..='~' => res.push(i),
_ => {
}
}
}
res
}
#[derive(Debug)]
pub enum RestError {
InvalidUri(InvalidUri),
HttpStatusError(u16),
HyperError(hyper::Error),
HttpError(String),
JsonDeserializationError(serde_json::Error),
JsonSerializationError(serde_json::Error),
}
impl From<hyper::Error> for RestError {
fn from(e: hyper::Error) -> Self {
Self::HttpError(e.to_string())
}
}
impl From<serde_json::error::Error> for RestError {
fn from(e: serde_json::error::Error) -> Self {
Self::JsonDeserializationError(e)
}
}