use actix_web::{http::header::LOCATION, web, HttpResponse};
use futures::future::IntoFuture;
use std::collections::HashMap;
use crate::oauth::OAuthClient;
#[cfg(feature = "authorization")]
use crate::rest_api::auth::authorization::Permission;
use crate::rest_api::{
actix_web_1::{Method, ProtocolVersionRangeGuard, Resource},
ErrorResponse, SPLINTER_PROTOCOL_VERSION,
};
const OAUTH_LOGIN_MIN: u32 = 1;
pub fn make_login_route(client: OAuthClient) -> Resource {
let resource = Resource::build("/oauth/login").add_request_guard(
ProtocolVersionRangeGuard::new(OAUTH_LOGIN_MIN, SPLINTER_PROTOCOL_VERSION),
);
#[cfg(feature = "authorization")]
{
resource.add_method(
Method::Get,
Permission::AllowUnauthenticated,
move |req, _| {
let query: web::Query<HashMap<String, String>> =
if let Ok(q) = web::Query::from_query(req.query_string()) {
q
} else {
return Box::new(
HttpResponse::BadRequest()
.json(ErrorResponse::bad_request("Invalid query"))
.into_future(),
);
};
let client_redirect_url = if let Some(header_value) = query.get("redirect_url") {
header_value
} else {
match req.headers().get("referer") {
Some(url) => match url.to_str() {
Ok(url) => url,
Err(_) => {
return Box::new(
HttpResponse::BadRequest()
.json(ErrorResponse::bad_request(
"Referer header is set, but is not a valid URL",
))
.into_future(),
)
}
},
None => {
return Box::new(
HttpResponse::BadRequest()
.json(ErrorResponse::bad_request(
"No valid redirect URL supplied",
))
.into_future(),
)
}
}
};
Box::new(
match client.get_authorization_url(client_redirect_url.to_string()) {
Ok(auth_url) => HttpResponse::Found().header(LOCATION, auth_url).finish(),
Err(err) => {
error!("{}", err);
HttpResponse::InternalServerError()
.json(ErrorResponse::internal_error())
}
}
.into_future(),
)
},
)
}
#[cfg(not(feature = "authorization"))]
{
resource.add_method(Method::Get, move |req, _| {
let query: web::Query<HashMap<String, String>> =
if let Ok(q) = web::Query::from_query(req.query_string()) {
q
} else {
return Box::new(
HttpResponse::BadRequest()
.json(ErrorResponse::bad_request("Invalid query"))
.into_future(),
);
};
let client_redirect_url = if let Some(header_value) = query.get("redirect_url") {
header_value
} else {
match req.headers().get("referer") {
Some(url) => match url.to_str() {
Ok(url) => url,
Err(_) => {
return Box::new(
HttpResponse::BadRequest()
.json(ErrorResponse::bad_request(
"No valid redirect URL supplied",
))
.into_future(),
)
}
},
None => {
return Box::new(
HttpResponse::BadRequest()
.json(ErrorResponse::bad_request("No valid redirect URL supplied"))
.into_future(),
)
}
}
};
Box::new(
match client.get_authorization_url(client_redirect_url.to_string()) {
Ok(auth_url) => HttpResponse::Found().header(LOCATION, auth_url).finish(),
Err(err) => {
error!("{}", err);
HttpResponse::InternalServerError().json(ErrorResponse::internal_error())
}
}
.into_future(),
)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use reqwest::{blocking::Client, redirect, StatusCode, Url};
use crate::oauth::tests::TestProfileProvider;
use crate::oauth::{
new_basic_client,
store::{
InflightOAuthRequestStore, InflightOAuthRequestStoreError,
MemoryInflightOAuthRequestStore,
},
tests::TestSubjectProvider,
PendingAuthorization,
};
use crate::rest_api::actix_web_1::{RestApiBuilder, RestApiShutdownHandle};
const CLIENT_ID: &str = "client_id";
const CLIENT_SECRET: &str = "client_secret";
const AUTH_URL: &str = "http://oauth/auth";
const REDIRECT_URL: &str = "http://oauth/callback";
const TOKEN_ENDPOINT: &str = "/token";
const CLIENT_REDIRECT_URL: &str = "http://client/redirect";
#[test]
fn get_login_with_redirect_url() {
let client = OAuthClient::new(
new_basic_client(
CLIENT_ID.into(),
CLIENT_SECRET.into(),
AUTH_URL.into(),
REDIRECT_URL.into(),
format!("http://oauth{}", TOKEN_ENDPOINT),
)
.expect("Failed to create basic client"),
vec![],
vec![],
Box::new(TestSubjectProvider),
Box::new(TestInflightOAuthRequestStore),
Box::new(TestProfileProvider),
);
let (shutdown_handle, join_handle, bind_url) =
run_rest_api_on_open_port(vec![make_login_route(client)]);
let url = Url::parse_with_params(
&format!("http://{}/oauth/login", bind_url),
&[("redirect_url", CLIENT_REDIRECT_URL)],
)
.expect("Failed to parse URL");
let resp = Client::builder()
.redirect(redirect::Policy::none())
.build()
.expect("Failed to build client")
.get(url)
.header("SplinterProtocolVersion", SPLINTER_PROTOCOL_VERSION)
.send()
.expect("Failed to perform request");
assert_eq!(resp.status(), StatusCode::FOUND);
assert!(resp
.headers()
.get("Location")
.expect("Location header not set")
.to_str()
.expect("Location header should only contain visible ASCII characters")
.starts_with(AUTH_URL));
shutdown_handle
.shutdown()
.expect("Unable to shutdown rest api");
join_handle.join().expect("Unable to join rest api thread");
}
#[test]
fn get_login_with_referer_header() {
let client = OAuthClient::new(
new_basic_client(
CLIENT_ID.into(),
CLIENT_SECRET.into(),
AUTH_URL.into(),
REDIRECT_URL.into(),
format!("http://oauth{}", TOKEN_ENDPOINT),
)
.expect("Failed to create basic client"),
vec![],
vec![],
Box::new(TestSubjectProvider),
Box::new(TestInflightOAuthRequestStore),
Box::new(TestProfileProvider),
);
let (shutdown_handle, join_handle, bind_url) =
run_rest_api_on_open_port(vec![make_login_route(client)]);
let url =
Url::parse(&format!("http://{}/oauth/login", bind_url)).expect("Failed to parse URL");
let resp = Client::builder()
.redirect(redirect::Policy::none())
.build()
.expect("Failed to build client")
.get(url)
.header("SplinterProtocolVersion", SPLINTER_PROTOCOL_VERSION)
.header("Referer", CLIENT_REDIRECT_URL)
.send()
.expect("Failed to perform request");
assert_eq!(resp.status(), StatusCode::FOUND);
assert!(resp
.headers()
.get("Location")
.expect("Location header not set")
.to_str()
.expect("Location header should only contain visible ASCII characters")
.starts_with(AUTH_URL));
shutdown_handle
.shutdown()
.expect("Unable to shutdown rest api");
join_handle.join().expect("Unable to join rest api thread");
}
#[test]
fn get_login_missing_client_redirect() {
let client = OAuthClient::new(
new_basic_client(
CLIENT_ID.into(),
CLIENT_SECRET.into(),
AUTH_URL.into(),
REDIRECT_URL.into(),
format!("http://oauth{}", TOKEN_ENDPOINT),
)
.expect("Failed to create basic client"),
vec![],
vec![],
Box::new(TestSubjectProvider),
Box::new(MemoryInflightOAuthRequestStore::new()),
Box::new(TestProfileProvider),
);
let (shutdown_handle, join_handle, bind_url) =
run_rest_api_on_open_port(vec![make_login_route(client)]);
let url =
Url::parse(&format!("http://{}/oauth/login", bind_url)).expect("Failed to parse URL");
let resp = Client::builder()
.redirect(redirect::Policy::none())
.build()
.expect("Failed to build client")
.get(url)
.header("SplinterProtocolVersion", SPLINTER_PROTOCOL_VERSION)
.send()
.expect("Failed to perform request");
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
shutdown_handle
.shutdown()
.expect("Unable to shutdown rest api");
join_handle.join().expect("Unable to join rest api thread");
}
#[derive(Clone)]
pub struct TestInflightOAuthRequestStore;
impl InflightOAuthRequestStore for TestInflightOAuthRequestStore {
fn insert_request(
&self,
_request_id: String,
authorization: PendingAuthorization,
) -> Result<(), InflightOAuthRequestStoreError> {
assert_eq!(&authorization.client_redirect_url, CLIENT_REDIRECT_URL);
Ok(())
}
fn remove_request(
&self,
_request_id: &str,
) -> Result<Option<PendingAuthorization>, InflightOAuthRequestStoreError> {
Ok(None)
}
fn clone_box(&self) -> Box<dyn InflightOAuthRequestStore> {
Box::new(self.clone())
}
}
fn run_rest_api_on_open_port(
resources: Vec<Resource>,
) -> (RestApiShutdownHandle, std::thread::JoinHandle<()>, String) {
#[cfg(not(feature = "https-bind"))]
let bind = "127.0.0.1:0";
#[cfg(feature = "https-bind")]
let bind = crate::rest_api::BindConfig::Http("127.0.0.1:0".into());
let result = RestApiBuilder::new()
.with_bind(bind)
.add_resources(resources.clone())
.build_insecure()
.expect("Failed to build REST API")
.run_insecure();
match result {
Ok((shutdown_handle, join_handle)) => {
let port = shutdown_handle.port_numbers()[0];
(shutdown_handle, join_handle, format!("127.0.0.1:{}", port))
}
Err(err) => panic!("Failed to run REST API: {}", err),
}
}
}