rocket-api-base 0.2.0

Base utilities for secure API construction with Rocket.rs
Documentation
#![feature(proc_macro_hygiene, decl_macro)]

use rocket::http::uri::Uri;
use rocket::http::{Header, RawStr};
use rocket::local::{Client, LocalResponse};
use rocket::{get, routes, Rocket};
use rocket_api_base::{AuthError, BaseContent, BaseResponse, NoContent, RocketBasicAuthFairing};
use rocket_contrib::json::Json;
use serde_json::Value;
use std::fs::File;
use std::io::Read;

const MISSING_HEADER_FILE: &str = "tests/jwt_auth_handler_missing_header.json";
const INVALID_HEADER_FILE: &str = "tests/jwt_auth_handler_invalid_header.json";
const INVALID_FORMAT_FILE: &str = "tests/basic_auth_handler_invalid_format.json";
const INVALID_CREDENTIALS_FILE: &str = "tests/basic_auth_handler_invalid_credentials.json";
const OK_FILE: &str = "tests/jwt_auth_handler_ok.json";

fn file_to_value(file_name: &'static str) -> Value {
    let mut file = File::open(file_name).unwrap();
    let mut expected_string = String::new();
    file.read_to_string(&mut expected_string).unwrap();
    serde_json::from_str(&expected_string).unwrap()
}

fn response_to_value(mut res: LocalResponse) -> Value {
    serde_json::from_str(&res.body_string().unwrap()).unwrap()
}

#[get("/test")]
fn test_endpoint() -> Json<BaseResponse<NoContent>> {
    Json(BaseResponse::<NoContent> {
        result: "ok".to_string(),
        error_message: "".to_string(),
        error_code: 0,
        content: BaseContent::None,
    })
}

#[get("/test2")]
fn test2_endpoint() -> Json<BaseResponse<NoContent>> {
    Json(BaseResponse::<NoContent> {
        result: "ok".to_string(),
        error_message: "".to_string(),
        error_code: 0,
        content: BaseContent::None,
    })
}

#[get("/other")]
fn other_endpoint() -> Json<BaseResponse<NoContent>> {
    Json(BaseResponse::<NoContent> {
        result: "ok".to_string(),
        error_message: "".to_string(),
        error_code: 0,
        content: BaseContent::None,
    })
}

#[get("/auth_error?<message>")]
fn auth_error_endpoint(message: &RawStr) -> Json<BaseResponse<NoContent>> {
    Json(BaseResponse::<NoContent> {
        result: "error".to_string(),
        error_message: Uri::percent_decode_lossy(message.as_ref()).to_string(),
        error_code: 403,
        content: BaseContent::None,
    })
}

fn auth_error_handler(error: AuthError) -> String {
    "/auth_error?message=".to_string() + &error.get_message_encoded()
}

const USER: &str = "user";
const PASSWORD: &str = "password";

fn basic_auth_callback(user: String, password: String) -> Result<(), String> {
    if user.eq(USER) && password.eq(PASSWORD) {
        Ok(())
    } else {
        Err("/auth_error?message=".to_string() + &Uri::percent_encode("Invalid credentials"))
    }
}

fn start_test_server() -> Rocket {
    rocket::ignite()
        .attach(RocketBasicAuthFairing::new(
            basic_auth_callback,
            auth_error_handler,
        ))
        .mount(
            "/",
            routes![
                test_endpoint,
                test2_endpoint,
                other_endpoint,
                auth_error_endpoint
            ],
        )
}

fn start_test_server_with_excludes(excludes: Vec<&str>) -> Rocket {
    rocket::ignite()
        .attach(RocketBasicAuthFairing::new_with_excludes(
            basic_auth_callback,
            auth_error_handler,
            excludes,
        ))
        .mount(
            "/",
            routes![
                test_endpoint,
                test2_endpoint,
                other_endpoint,
                auth_error_endpoint
            ],
        )
}

fn start_test_server_with_includes(includes: Vec<&str>) -> Rocket {
    rocket::ignite()
        .attach(RocketBasicAuthFairing::new_with_includes(
            basic_auth_callback,
            auth_error_handler,
            includes,
        ))
        .mount(
            "/",
            routes![
                test_endpoint,
                test2_endpoint,
                other_endpoint,
                auth_error_endpoint
            ],
        )
}

#[test]
fn test_no_header() {
    let rocket = start_test_server();
    let client = Client::new(rocket).expect("valid rocket instance");
    let req = client.get("/test");
    let response = req.dispatch();
    assert_eq!(
        file_to_value(MISSING_HEADER_FILE),
        response_to_value(response)
    );
}

#[test]
fn test_invalid_header() {
    let rocket = start_test_server();
    let client = Client::new(rocket).expect("valid rocket instance");
    let mut req = client.get("/test");
    req.add_header(Header::new("Authorization", "Bearer bad_scheme"));
    let response = req.dispatch();
    assert_eq!(
        file_to_value(INVALID_HEADER_FILE),
        response_to_value(response)
    );
}

#[test]
fn test_invalid_format() {
    let rocket = start_test_server();
    let credentials = base64::encode(format!("{}-{}", USER, PASSWORD));
    let client = Client::new(rocket).expect("valid rocket instance");
    let mut req = client.get("/test");
    req.add_header(Header::new(
        "Authorization",
        format!("Basic {}", credentials),
    ));
    let response = req.dispatch();
    assert_eq!(
        file_to_value(INVALID_FORMAT_FILE),
        response_to_value(response)
    );
}

#[test]
fn test_invalid_credentials() {
    let rocket = start_test_server();
    let credentials = base64::encode("user:bad_password");
    let client = Client::new(rocket).expect("valid rocket instance");
    let mut req = client.get("/test");
    req.add_header(Header::new(
        "Authorization",
        format!("Basic {}", credentials),
    ));
    let response = req.dispatch();
    assert_eq!(
        file_to_value(INVALID_CREDENTIALS_FILE),
        response_to_value(response)
    );
}

#[test]
fn test_ok() {
    let rocket = start_test_server();
    let credentials = base64::encode(format!("{}:{}", USER, PASSWORD));
    let client = Client::new(rocket).expect("valid rocket instance");
    let mut req = client.get("/test");
    req.add_header(Header::new(
        "Authorization",
        format!("Basic {}", credentials),
    ));
    let response = req.dispatch();
    assert_eq!(file_to_value(OK_FILE), response_to_value(response));
}

#[test]
fn test_excluded() {
    let rocket = start_test_server_with_excludes(vec!["/test"]);
    let client = Client::new(rocket).expect("valid rocket instance");
    let req = client.get("/test");
    let response = req.dispatch();
    assert_eq!(file_to_value(OK_FILE), response_to_value(response));
    let req = client.get("/test2");
    let response = req.dispatch();
    assert_eq!(
        file_to_value(MISSING_HEADER_FILE),
        response_to_value(response)
    );
}

#[test]
fn test_included() {
    let rocket = start_test_server_with_includes(vec!["/test"]);
    let client = Client::new(rocket).expect("valid rocket instance");
    let req = client.get("/test2");
    let response = req.dispatch();
    assert_eq!(file_to_value(OK_FILE), response_to_value(response));
    let req = client.get("/test");
    let response = req.dispatch();
    assert_eq!(
        file_to_value(MISSING_HEADER_FILE),
        response_to_value(response)
    );
}

#[test]
fn test_excluded_glob() {
    let rocket = start_test_server_with_excludes(vec!["/test*"]);
    let client = Client::new(rocket).expect("valid rocket instance");
    let req = client.get("/test");
    let response = req.dispatch();
    assert_eq!(file_to_value(OK_FILE), response_to_value(response));
    let req = client.get("/test2");
    let response = req.dispatch();
    assert_eq!(file_to_value(OK_FILE), response_to_value(response));
    let req = client.get("/other");
    let response = req.dispatch();
    assert_eq!(
        file_to_value(MISSING_HEADER_FILE),
        response_to_value(response)
    );
}

#[test]
fn test_included_glob() {
    let rocket = start_test_server_with_includes(vec!["/test*"]);
    let client = Client::new(rocket).expect("valid rocket instance");
    let req = client.get("/test");
    let response = req.dispatch();
    assert_eq!(
        file_to_value(MISSING_HEADER_FILE),
        response_to_value(response)
    );
    let req = client.get("/test2");
    let response = req.dispatch();
    assert_eq!(
        file_to_value(MISSING_HEADER_FILE),
        response_to_value(response)
    );
    let req = client.get("/other");
    let response = req.dispatch();
    assert_eq!(file_to_value(OK_FILE), response_to_value(response));
}