use std::sync::Arc;
use actix_http::http::header::{AUTHORIZATION, CONTENT_TYPE, ORIGIN, WWW_AUTHENTICATE};
use actix_http::http::{HeaderValue, Method};
use actix_web::web::{block, Data, Path, Payload};
use actix_web::{HttpRequest, HttpResponse};
use actix_web_actors::ws::start;
use base64::decode as base64_decode;
use serde_json::Value;
use crate::communication::{ProxyRequest, ProxyResponse};
use crate::proxy_error::{ProxyError, ProxyResult};
use crate::server::client_manager::ClientManager;
use crate::server::websocket_handler::WebsocketHandler;
fn error_response(message: impl ToString) -> Value {
json!({
"success": false,
"message": [{
"error": message.to_string()
}]
})
}
fn get_user_credentials(request: &HttpRequest) -> ProxyResult<(String, String)> {
let mut authorization = match request.headers().get(AUTHORIZATION) {
Some(authorization) => String::from(authorization.to_str().unwrap()),
_ => {
return Err(ProxyError::message(String::from(
"Authorization is missing",
)));
}
};
if !authorization.starts_with("Basic ") {
return Err(ProxyError::message(String::from(
"Only Basic auth is supported",
)));
}
authorization = String::from(authorization.replace("Basic", "").trim());
let decoded = match base64_decode(authorization) {
Ok(decoded) => match String::from_utf8(decoded) {
Ok(decoded) => decoded,
_ => {
return Err(ProxyError::message(String::from(
"Basic auth character encoding is invalid",
)));
}
},
_ => {
return Err(ProxyError::message(String::from(
"Basic auth encoding is invalid",
)));
}
};
let parts = decoded
.split(":")
.map(String::from)
.collect::<Vec<String>>();
let empty = String::new();
let username = String::from(parts.get(0).unwrap_or(&empty).trim());
let password = String::from(parts.get(1).unwrap_or(&empty).trim());
Ok((username, password))
}
pub async fn handle_connect(
path: Path<(String,)>,
client_manager: Data<Arc<ClientManager>>,
request: HttpRequest,
stream: Payload,
) -> HttpResponse {
let client_id = path.into_inner().0;
let client_manager = Arc::clone(client_manager.into_inner().as_ref());
if client_manager.is_client_registered(&client_id) {
return HttpResponse::BadRequest().json(error_response(format!(
"Client {} is already registered",
&client_id
)));
}
let socket_handler = WebsocketHandler::new(client_manager.clone(), client_id);
let response = start(socket_handler, &request, stream);
match response {
Ok(response) => response,
_ => HttpResponse::BadRequest()
.json(error_response("Cannot upgrade connection to websocket")),
}
}
pub async fn handle_proxy(
path: Path<(String, String)>,
body: Option<String>,
client_manager: Data<Arc<ClientManager>>,
request: HttpRequest,
) -> HttpResponse {
if request.method().eq(&Method::OPTIONS) {
return HttpResponse::Ok().body("");
}
let (client_id, request_uri) = path.into_inner();
let (username, password) = match get_user_credentials(&request) {
Ok(credentials) => credentials,
Err(error) => {
let mut response =
HttpResponse::Unauthorized().json(error_response(format!("{}", error)));
response.headers_mut().insert(
WWW_AUTHENTICATE,
HeaderValue::from_str(r#"Basic realm="New Home Proxy""#).unwrap(),
);
return response;
}
};
let origin = request
.headers()
.get(ORIGIN)
.unwrap_or(&HeaderValue::from_str("").unwrap())
.to_str()
.unwrap_or("")
.to_string();
let request = ProxyRequest {
method: request.method().to_string(),
url: request_uri,
origin,
request_body: match body {
Some(body) => match serde_json::from_str(body.as_str()) {
Ok(body) => body,
_ => Value::Null,
},
_ => Value::Null,
},
username,
password,
};
match block(move || client_manager.send_request(&client_id, request)).await {
Ok(response) => proxy_to_server_response(response),
Err(error) => HttpResponse::InternalServerError().json(error_response(error)),
}
}
fn proxy_to_server_response(proxy_response: ProxyResponse) -> HttpResponse {
if proxy_response.content_type.eq("application/json") {
return HttpResponse::Ok().json(proxy_response.response);
}
let mut server_response = HttpResponse::Ok();
server_response.set_header(CONTENT_TYPE, proxy_response.content_type);
server_response.body(String::from(proxy_response.response.as_str().unwrap_or("")))
}