shortener 0.1.2

A simple URL shortener.
Documentation
use std::io;
use std::net::{Ipv6Addr, SocketAddr};
use std::sync::Arc;

use axum::Router;
use axum::body::{Body, to_bytes};
use axum::extract::{ConnectInfo, State};
use axum::http::header::{AUTHORIZATION, CONTENT_TYPE, LOCATION};
use axum::http::{HeaderMap, Method, StatusCode, Uri};
use axum::response::Response;
use percent_encoding::percent_decode_str;
use tokio::net::TcpListener;
use url::Url;

use crate::code::{generate_code, is_valid_code};
use crate::config::Config;
use crate::database::{Database, DatabaseError};

const MAX_BODY_SIZE: usize = 8192;

/// The URL shortener HTTP server.
pub struct Shortener {
  state: AppState,
}

#[derive(Clone)]
struct AppState {
  config: Arc<Config>,
  database: Arc<Database>,
}

impl Shortener {
  /// Creates a new `Shortener` from `config`, opening or creating the
  /// SQLite database and running schema initialization.
  pub fn new(config: Config) -> Result<Self, DatabaseError> {
    let database = Database::new(&config.sqlite_db, true)?;
    database.init()?;

    Ok(Self {
      state: AppState {
        config: Arc::new(config),
        database: Arc::new(database),
      },
    })
  }

  /// Binds the configured TCP port and starts serving HTTP requests.
  /// This future runs until the server encounters an I/O error.
  pub async fn listen_and_serve(self) -> io::Result<()> {
    let address = &SocketAddr::new(
      Ipv6Addr::UNSPECIFIED.into(),
      self.state.config.listen_port,
    );
    let app = Router::new()
      .fallback(handle_request)
      .with_state(self.state);
    let listener = TcpListener::bind(address).await?;
    axum::serve(
      listener,
      app.into_make_service_with_connect_info::<SocketAddr>(),
    )
    .await
  }
}

async fn handle_request(
  State(state): State<AppState>,
  ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
  method: Method,
  uri: Uri,
  headers: HeaderMap,
  body: Body,
) -> Response {
  match method {
    Method::GET => {
      if uri.path().is_empty() || uri.path() == "/" {
        homepage_handler(state, remote_addr, method, uri, headers)
      } else {
        redirect_handler(state, remote_addr, method, uri, headers)
      }
    }
    Method::POST => {
      create_code_handler(state, remote_addr, method, uri, headers, body).await
    }
    _ => {
      let client_host = get_client_host(&state.config, remote_addr, &headers);
      log::info!("{} {} {} Method not allowed", client_host, method, uri);
      http_error(StatusCode::METHOD_NOT_ALLOWED, "Method not allowed")
    }
  }
}

fn homepage_handler(
  state: AppState,
  remote_addr: SocketAddr,
  method: Method,
  uri: Uri,
  headers: HeaderMap,
) -> Response {
  log::info!(
    "{} {} {}",
    get_client_host(&state.config, remote_addr, &headers),
    method,
    uri
  );

  match &state.config.main_page {
    Some(main_page) => redirect_response(main_page),
    None => plain_response("hello, world\n"),
  }
}

fn redirect_handler(
  state: AppState,
  remote_addr: SocketAddr,
  method: Method,
  uri: Uri,
  headers: HeaderMap,
) -> Response {
  let code = code_from_path(uri.path());
  let client_host = get_client_host(&state.config, remote_addr, &headers);

  match state.database.get_url(&code) {
    Ok(url) => {
      log::info!("{} {} {} => {}", client_host, method, uri, url);
      redirect_response(&url)
    }
    Err(DatabaseError::NotFound) => {
      log::info!("{} {} {} [Not found]", client_host, method, uri);
      http_error(StatusCode::NOT_FOUND, "Not found")
    }
    Err(error) => {
      log::info!("{} {} {} [{}]", client_host, method, uri, error);
      http_error(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error")
    }
  }
}

async fn create_code_handler(
  state: AppState,
  remote_addr: SocketAddr,
  method: Method,
  uri: Uri,
  headers: HeaderMap,
  body: Body,
) -> Response {
  let custom_code = if uri.path().is_empty() || uri.path() == "/" {
    String::new()
  } else {
    code_from_path(uri.path())
  };

  let mut username = String::new();
  if state.config.auth {
    let auth_header = headers
      .get(AUTHORIZATION)
      .and_then(|value| value.to_str().ok())
      .unwrap_or_default();
    if !auth_header.starts_with("Bearer ") {
      log::info!(
        "{} {} {} [Missing credentials]",
        get_client_host(&state.config, remote_addr, &headers),
        method,
        uri
      );
      return http_error(StatusCode::UNAUTHORIZED, "Unauthorized");
    }

    match state
      .database
      .check_api_key(&auth_header["Bearer ".len()..])
    {
      Ok(owner) => username = owner,
      Err(_) => {
        log::info!(
          "{} {} {} [Invalid credentials]",
          get_client_host(&state.config, remote_addr, &headers),
          method,
          uri
        );
        return http_error(StatusCode::UNAUTHORIZED, "Unauthorized");
      }
    }
  }

  let client_host = get_client_host(&state.config, remote_addr, &headers);
  let created_by = if state.config.auth {
    username.as_str()
  } else {
    client_host.as_str()
  };

  let body = match to_bytes(body, MAX_BODY_SIZE).await {
    Ok(body) => body,
    Err(_) => {
      log::info!(
        "{} {} {} [Request body too large]",
        client_host,
        method,
        uri
      );
      return http_error(
        StatusCode::PAYLOAD_TOO_LARGE,
        "Request body too large",
      );
    }
  };
  let body = String::from_utf8_lossy(&body);
  let target_url = body.trim().to_owned();

  if !is_valid_http_url(&target_url) {
    log::info!("{} {} {} [Invalid URL]", client_host, method, uri);
    return http_error(StatusCode::BAD_REQUEST, "Invalid URL");
  }

  if !custom_code.is_empty() && !is_valid_code(&custom_code) {
    log::info!("{} {} {} [Invalid code]", client_host, method, uri);
    return http_error(StatusCode::BAD_REQUEST, "Invalid code");
  }

  let code = if custom_code.is_empty() {
    match create_generated_code(
      &state,
      &target_url,
      created_by,
      &client_host,
      &method,
      &uri,
    ) {
      Some(code) => code,
      None => {
        return http_error(
          StatusCode::INTERNAL_SERVER_ERROR,
          "Internal server error",
        );
      }
    }
  } else {
    match state
      .database
      .create_code(&target_url, &custom_code, created_by)
    {
      Ok(()) => custom_code,
      Err(DatabaseError::CodeAlreadyInUse) => {
        log::info!("{} {} {} [Code already in use]", client_host, method, uri);
        return http_error(StatusCode::CONFLICT, "Code already in use");
      }
      Err(error) => {
        log::info!("{} {} {} [{}]", client_host, method, uri, error);
        return http_error(
          StatusCode::INTERNAL_SERVER_ERROR,
          "Internal server error",
        );
      }
    }
  };

  let new_url = format!("{}{}", state.config.url_prefix, code);
  log::info!(
    "{} {} {} ({}) => {}",
    client_host,
    method,
    uri,
    target_url,
    new_url
  );
  plain_response(format!("{}\n", new_url))
}

fn create_generated_code(
  state: &AppState,
  target_url: &str,
  created_by: &str,
  client_host: &str,
  method: &Method,
  uri: &Uri,
) -> Option<String> {
  let mut code = String::new();

  for attempt in 0..3 {
    code = generate_code(state.config.code_length);
    match state.database.create_code(target_url, &code, created_by) {
      Ok(()) => break,
      Err(DatabaseError::CodeAlreadyInUse) => {
        log::info!(
          "{} {} {} [Attempt {}: {}: Code already in use]",
          client_host,
          method,
          uri,
          attempt,
          code
        );
        code.clear();
      }
      Err(error) => {
        log::info!(
          "{} {} {} [Attempt {}: {}: {}]",
          client_host,
          method,
          uri,
          attempt,
          code,
          error
        );
        code.clear();
      }
    }
  }

  if code.is_empty() {
    log::info!(
      "{} {} {} [Could not generate code]",
      client_host,
      method,
      uri
    );
    return None;
  }

  Some(code)
}

fn is_valid_http_url(input: &str) -> bool {
  Url::parse(input)
    .map(|url| matches!(url.scheme(), "http" | "https"))
    .unwrap_or(false)
}

fn get_client_host(
  config: &Config,
  remote_addr: SocketAddr,
  headers: &HeaderMap,
) -> String {
  if !config.trust_proxy {
    return remote_addr.ip().to_string();
  }

  let forwarded_for = headers
    .get("x-forwarded-for")
    .and_then(|value| value.to_str().ok())
    .unwrap_or_default();
  let host = forwarded_for.split(',').next().unwrap_or_default().trim();
  if host.is_empty() {
    remote_addr.ip().to_string()
  } else {
    host.to_owned()
  }
}

fn code_from_path(path: &str) -> String {
  let code = path.strip_prefix("/").unwrap_or(path);
  percent_decode_str(code).decode_utf8_lossy().into_owned()
}

fn plain_response(body: impl Into<Body>) -> Response {
  Response::builder()
    .status(StatusCode::OK)
    .body(body.into())
    .expect("response should build")
}

fn http_error(status: StatusCode, message: &str) -> Response {
  Response::builder()
    .status(status)
    .header(CONTENT_TYPE, "text/plain; charset=utf-8")
    .header("X-Content-Type-Options", "nosniff")
    .body(Body::from(format!("{}\n", message)))
    .expect("response should build")
}

fn redirect_response(location: &str) -> Response {
  Response::builder()
    .status(StatusCode::FOUND)
    .header(LOCATION, location)
    .body(Body::from(format!("{}\n", location)))
    .expect("response should build")
}