use std::io;
use std::net::{Ipv6Addr, SocketAddr};
use std::sync::Arc;
use axum::Router;
use axum::body::{Body, to_bytes};
use axum::extract::{ConnectInfo, State};
use axum::http::header::{AUTHORIZATION, CONTENT_TYPE, LOCATION};
use axum::http::{HeaderMap, Method, StatusCode, Uri};
use axum::response::Response;
use percent_encoding::percent_decode_str;
use tokio::net::TcpListener;
use url::Url;
use crate::code::{generate_code, is_valid_code};
use crate::config::Config;
use crate::database::{Database, DatabaseError};
const MAX_BODY_SIZE: usize = 8192;
pub struct Shortener {
state: AppState,
}
#[derive(Clone)]
struct AppState {
config: Arc<Config>,
database: Arc<Database>,
}
impl Shortener {
pub fn new(config: Config) -> Result<Self, DatabaseError> {
let database = Database::new(&config.sqlite_db, true)?;
database.init()?;
Ok(Self {
state: AppState {
config: Arc::new(config),
database: Arc::new(database),
},
})
}
pub async fn listen_and_serve(self) -> io::Result<()> {
let address = &SocketAddr::new(
Ipv6Addr::UNSPECIFIED.into(),
self.state.config.listen_port,
);
let app = Router::new()
.fallback(handle_request)
.with_state(self.state);
let listener = TcpListener::bind(address).await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
}
}
async fn handle_request(
State(state): State<AppState>,
ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Response {
match method {
Method::GET => {
if uri.path().is_empty() || uri.path() == "/" {
homepage_handler(state, remote_addr, method, uri, headers)
} else {
redirect_handler(state, remote_addr, method, uri, headers)
}
}
Method::POST => {
create_code_handler(state, remote_addr, method, uri, headers, body).await
}
_ => {
let client_host = get_client_host(&state.config, remote_addr, &headers);
log::info!("{} {} {} Method not allowed", client_host, method, uri);
http_error(StatusCode::METHOD_NOT_ALLOWED, "Method not allowed")
}
}
}
fn homepage_handler(
state: AppState,
remote_addr: SocketAddr,
method: Method,
uri: Uri,
headers: HeaderMap,
) -> Response {
log::info!(
"{} {} {}",
get_client_host(&state.config, remote_addr, &headers),
method,
uri
);
match &state.config.main_page {
Some(main_page) => redirect_response(main_page),
None => plain_response("hello, world\n"),
}
}
fn redirect_handler(
state: AppState,
remote_addr: SocketAddr,
method: Method,
uri: Uri,
headers: HeaderMap,
) -> Response {
let code = code_from_path(uri.path());
let client_host = get_client_host(&state.config, remote_addr, &headers);
match state.database.get_url(&code) {
Ok(url) => {
log::info!("{} {} {} => {}", client_host, method, uri, url);
redirect_response(&url)
}
Err(DatabaseError::NotFound) => {
log::info!("{} {} {} [Not found]", client_host, method, uri);
http_error(StatusCode::NOT_FOUND, "Not found")
}
Err(error) => {
log::info!("{} {} {} [{}]", client_host, method, uri, error);
http_error(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error")
}
}
}
async fn create_code_handler(
state: AppState,
remote_addr: SocketAddr,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Response {
let custom_code = if uri.path().is_empty() || uri.path() == "/" {
String::new()
} else {
code_from_path(uri.path())
};
let mut username = String::new();
if state.config.auth {
let auth_header = headers
.get(AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.unwrap_or_default();
if !auth_header.starts_with("Bearer ") {
log::info!(
"{} {} {} [Missing credentials]",
get_client_host(&state.config, remote_addr, &headers),
method,
uri
);
return http_error(StatusCode::UNAUTHORIZED, "Unauthorized");
}
match state
.database
.check_api_key(&auth_header["Bearer ".len()..])
{
Ok(owner) => username = owner,
Err(_) => {
log::info!(
"{} {} {} [Invalid credentials]",
get_client_host(&state.config, remote_addr, &headers),
method,
uri
);
return http_error(StatusCode::UNAUTHORIZED, "Unauthorized");
}
}
}
let client_host = get_client_host(&state.config, remote_addr, &headers);
let created_by = if state.config.auth {
username.as_str()
} else {
client_host.as_str()
};
let body = match to_bytes(body, MAX_BODY_SIZE).await {
Ok(body) => body,
Err(_) => {
log::info!(
"{} {} {} [Request body too large]",
client_host,
method,
uri
);
return http_error(
StatusCode::PAYLOAD_TOO_LARGE,
"Request body too large",
);
}
};
let body = String::from_utf8_lossy(&body);
let target_url = body.trim().to_owned();
if !is_valid_http_url(&target_url) {
log::info!("{} {} {} [Invalid URL]", client_host, method, uri);
return http_error(StatusCode::BAD_REQUEST, "Invalid URL");
}
if !custom_code.is_empty() && !is_valid_code(&custom_code) {
log::info!("{} {} {} [Invalid code]", client_host, method, uri);
return http_error(StatusCode::BAD_REQUEST, "Invalid code");
}
let code = if custom_code.is_empty() {
match create_generated_code(
&state,
&target_url,
created_by,
&client_host,
&method,
&uri,
) {
Some(code) => code,
None => {
return http_error(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
);
}
}
} else {
match state
.database
.create_code(&target_url, &custom_code, created_by)
{
Ok(()) => custom_code,
Err(DatabaseError::CodeAlreadyInUse) => {
log::info!("{} {} {} [Code already in use]", client_host, method, uri);
return http_error(StatusCode::CONFLICT, "Code already in use");
}
Err(error) => {
log::info!("{} {} {} [{}]", client_host, method, uri, error);
return http_error(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
);
}
}
};
let new_url = format!("{}{}", state.config.url_prefix, code);
log::info!(
"{} {} {} ({}) => {}",
client_host,
method,
uri,
target_url,
new_url
);
plain_response(format!("{}\n", new_url))
}
fn create_generated_code(
state: &AppState,
target_url: &str,
created_by: &str,
client_host: &str,
method: &Method,
uri: &Uri,
) -> Option<String> {
let mut code = String::new();
for attempt in 0..3 {
code = generate_code(state.config.code_length);
match state.database.create_code(target_url, &code, created_by) {
Ok(()) => break,
Err(DatabaseError::CodeAlreadyInUse) => {
log::info!(
"{} {} {} [Attempt {}: {}: Code already in use]",
client_host,
method,
uri,
attempt,
code
);
code.clear();
}
Err(error) => {
log::info!(
"{} {} {} [Attempt {}: {}: {}]",
client_host,
method,
uri,
attempt,
code,
error
);
code.clear();
}
}
}
if code.is_empty() {
log::info!(
"{} {} {} [Could not generate code]",
client_host,
method,
uri
);
return None;
}
Some(code)
}
fn is_valid_http_url(input: &str) -> bool {
Url::parse(input)
.map(|url| matches!(url.scheme(), "http" | "https"))
.unwrap_or(false)
}
fn get_client_host(
config: &Config,
remote_addr: SocketAddr,
headers: &HeaderMap,
) -> String {
if !config.trust_proxy {
return remote_addr.ip().to_string();
}
let forwarded_for = headers
.get("x-forwarded-for")
.and_then(|value| value.to_str().ok())
.unwrap_or_default();
let host = forwarded_for.split(',').next().unwrap_or_default().trim();
if host.is_empty() {
remote_addr.ip().to_string()
} else {
host.to_owned()
}
}
fn code_from_path(path: &str) -> String {
let code = path.strip_prefix("/").unwrap_or(path);
percent_decode_str(code).decode_utf8_lossy().into_owned()
}
fn plain_response(body: impl Into<Body>) -> Response {
Response::builder()
.status(StatusCode::OK)
.body(body.into())
.expect("response should build")
}
fn http_error(status: StatusCode, message: &str) -> Response {
Response::builder()
.status(status)
.header(CONTENT_TYPE, "text/plain; charset=utf-8")
.header("X-Content-Type-Options", "nosniff")
.body(Body::from(format!("{}\n", message)))
.expect("response should build")
}
fn redirect_response(location: &str) -> Response {
Response::builder()
.status(StatusCode::FOUND)
.header(LOCATION, location)
.body(Body::from(format!("{}\n", location)))
.expect("response should build")
}