1use async_trait::async_trait;
6use http_body_util::BodyExt;
7use hyper::{Method, Request, http::uri::InvalidUri};
8use hyper_util::client::legacy::Client;
9use hyper_util::rt::TokioExecutor;
10use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
11use regex::Regex;
12use serde_json::{Value, json};
13use unftp_core::auth::{AuthenticationError, Authenticator, Credentials, Principal};
14
15#[derive(Clone, Debug)]
20pub struct RestAuthenticator {
21 username_placeholder: String,
22 password_placeholder: String,
23 source_ip_placeholder: String,
24
25 method: Method,
26 url: String,
27 body: String,
28 selector: String,
29 regex: Regex,
30}
31
32#[derive(Clone, Debug, Default)]
34pub struct Builder {
35 username_placeholder: String,
36 password_placeholder: String,
37 source_ip_placeholder: String,
38
39 method: Method,
40 url: String,
41 body: String,
42 selector: String,
43 regex: String,
44}
45
46impl Builder {
47 pub fn new() -> Builder {
63 Builder { ..Default::default() }
64 }
65
66 pub fn with_username_placeholder(mut self, s: String) -> Self {
88 self.username_placeholder = s;
89 self
90 }
91
92 pub fn with_password_placeholder(mut self, s: String) -> Self {
114 self.password_placeholder = s;
115 self
116 }
117
118 pub fn with_source_ip_placeholder(mut self, s: String) -> Self {
141 self.source_ip_placeholder = s;
142 self
143 }
144
145 pub fn with_method(mut self, s: Method) -> Self {
147 self.method = s;
148 self
149 }
150
151 pub fn with_url(mut self, s: String) -> Self {
153 self.url = s;
154 self
155 }
156
157 pub fn with_body(mut self, s: String) -> Self {
159 self.body = s;
160 self
161 }
162
163 pub fn with_selector(mut self, s: String) -> Self {
166 self.selector = s;
167 self
168 }
169
170 pub fn with_regex(mut self, s: String) -> Self {
172 self.regex = s;
173 self
174 }
175
176 pub fn build(self) -> Result<RestAuthenticator, Box<dyn std::error::Error>> {
178 Ok(RestAuthenticator {
179 username_placeholder: self.username_placeholder,
180 password_placeholder: self.password_placeholder,
181 source_ip_placeholder: self.source_ip_placeholder,
182 method: self.method,
183 url: self.url,
184 body: self.body,
185 selector: self.selector,
186 regex: Regex::new(&self.regex)?,
187 })
188 }
189}
190
191impl RestAuthenticator {
192 fn fill_encoded_placeholders(&self, string: &str, username: &str, password: &str, source_ip: &str) -> String {
193 let mut result = string.to_owned();
194
195 if !self.username_placeholder.is_empty() {
196 result = result.replace(&self.username_placeholder, username);
197 }
198 if !self.password_placeholder.is_empty() {
199 result = result.replace(&self.password_placeholder, password);
200 }
201 if !self.source_ip_placeholder.is_empty() {
202 result = result.replace(&self.source_ip_placeholder, source_ip);
203 }
204
205 result
206 }
207}
208
209trait TrimQuotes {
210 fn trim_quotes(&self) -> &str;
211}
212
213impl TrimQuotes for String {
214 fn trim_quotes(&self) -> &str {
216 if self.starts_with('"') && self.ends_with('"') && self.len() > 1 {
217 &self[1..self.len() - 1]
218 } else {
219 self
220 }
221 }
222}
223
224#[async_trait]
225impl Authenticator for RestAuthenticator {
226 #[tracing_attributes::instrument]
227 async fn authenticate(&self, username: &str, creds: &Credentials) -> Result<Principal, AuthenticationError> {
228 let username_url = utf8_percent_encode(username, NON_ALPHANUMERIC).collect::<String>();
229 let password = creds.password.as_ref().ok_or(AuthenticationError::BadPassword)?.as_ref();
230 let password_url = utf8_percent_encode(password, NON_ALPHANUMERIC).collect::<String>();
231 let source_ip = creds.source_ip.to_string();
232 let source_ip_url = utf8_percent_encode(&source_ip, NON_ALPHANUMERIC).collect::<String>();
233
234 let url = self.fill_encoded_placeholders(&self.url, &username_url, &password_url, &source_ip_url);
235
236 let username = serde_json::to_string(username)
237 .map_err(|e| AuthenticationError::ImplPropagated(e.to_string(), None))?
238 .trim_quotes()
239 .to_string();
240 let password = serde_json::to_string(password)
241 .map_err(|e| AuthenticationError::ImplPropagated(e.to_string(), None))?
242 .trim_quotes()
243 .to_string();
244 let source_ip = serde_json::to_string(&source_ip)
245 .map_err(|e| AuthenticationError::ImplPropagated(e.to_string(), None))?
246 .trim_quotes()
247 .to_string();
248
249 let body = self.fill_encoded_placeholders(&self.body, &username, &password, &source_ip);
250
251 let req = Request::builder()
252 .method(&self.method)
253 .header("Content-type", "application/json")
254 .uri(url)
255 .body(body)
256 .map_err(|e| AuthenticationError::with_source("rest authenticator http client error", e))?;
257
258 let https = hyper_rustls::HttpsConnectorBuilder::new()
259 .with_native_roots()
260 .expect("no native root CA certificates found")
261 .https_or_http()
262 .enable_http1()
263 .build();
264
265 let client = Client::builder(TokioExecutor::new()).build(https);
266
267 let resp = client
268 .request(req)
269 .await
270 .map_err(|e| AuthenticationError::with_source("rest authenticator http client error", e))?;
271
272 let (parts, body) = resp.into_parts();
273 let status_context = format!("http status={}", parts.status.as_str());
274 let body = BodyExt::collect(body)
275 .await
276 .map_err(|e| AuthenticationError::with_source(format!("error while receiving http response ({})", status_context), e))?
277 .to_bytes();
278 let body: Value = serde_json::from_slice(&body)
279 .map_err(|e| AuthenticationError::with_source(format!("rest authenticator unmarshalling error ({})", status_context), e))?;
280
281 let parsed = match body.pointer(&self.selector) {
282 Some(parsed) => parsed.to_string(),
283 None => json!(null).to_string(),
284 };
285
286 if self.regex.is_match(&parsed) {
287 Ok(Principal {
288 username: username.to_string(),
289 })
290 } else {
291 Err(AuthenticationError::BadPassword)
292 }
293 }
294}
295
296#[allow(missing_docs)] #[derive(Debug)]
299pub enum RestError {
300 InvalidUri(InvalidUri),
301 HttpStatusError(u16),
302 HyperError(hyper::Error),
303 HttpError(String),
304 JsonDeserializationError(serde_json::Error),
305 JsonSerializationError(serde_json::Error),
306}
307
308impl From<hyper::Error> for RestError {
309 fn from(e: hyper::Error) -> Self {
310 Self::HttpError(e.to_string())
311 }
312}
313
314impl From<serde_json::error::Error> for RestError {
315 fn from(e: serde_json::error::Error) -> Self {
316 Self::JsonDeserializationError(e)
317 }
318}