use super::{AllowedFunctions, AllowedStatements, Command, Error, Guard};
use once_cell::sync::Lazy;
use serde::Deserialize;
use tracing_test::traced_test;
#[derive(Debug, Deserialize, PartialEq)]
struct Function {
name: String,
arguments: Vec<serde_json::Value>,
}
impl Function {
fn to_query(&self, command: &str) -> String {
let argument_list = self
.arguments
.iter()
.map(|argument| argument.to_string())
.collect::<Vec<_>>()
.join(",");
format!("{} {}({})", command, self.name, argument_list)
}
}
static INSERT_CTE: &str = include_str!("../data/insert_cte.sql");
static INSERT_SELECT_VALUES: &str = include_str!("../data/insert_select_values.sql");
const CASE_CARDINALITY_SUBQUERY: &str = include_str!("../data/case_cardinality_subquery.sql");
static COMPLEX_QUERIES: [&str; 3] = [INSERT_CTE, INSERT_SELECT_VALUES, CASE_CARDINALITY_SUBQUERY];
static INVALID_QUERIES: Lazy<Vec<String>> = Lazy::new(|| {
serde_json::from_str(include_str!("../data/invalid_queries.json"))
.expect("error reading query data from file")
});
static SIMPLE_QUERIES: Lazy<Vec<String>> = Lazy::new(|| {
serde_json::from_str(include_str!("../data/simple_queries.json"))
.expect("error reading query data from file")
});
static PRIVILEGED_QUERIES: Lazy<Vec<String>> = Lazy::new(|| {
serde_json::from_str(include_str!("../data/privileged_queries.json"))
.expect("error reading query data from file")
});
static FUNCTIONS: Lazy<Vec<Function>> = Lazy::new(|| {
serde_json::from_str(include_str!("../data/functions.json"))
.expect("error reading query data from file")
});
#[test]
#[traced_test]
fn does_nothing_without_restrictions() {
let guard = Guard::default();
for query in &*INVALID_QUERIES {
guard
.guard(query)
.expect("Rejected query even without restrictions")
}
}
#[test]
#[traced_test]
fn rejects_invalid_queries() {
let guard = Guard::new(AllowedStatements::List(vec![]), AllowedFunctions::All);
for query in &*INVALID_QUERIES {
match guard.guard(query) {
Err(Error::Parse(..)) => (),
Ok(..) => panic!("Failed to reject invalid query"),
Err(..) => panic!("Failed to reject invalid query at the parsing step"),
}
}
}
#[test]
#[traced_test]
fn parses_simple_queries() {
let guard = Guard::new(
AllowedStatements::List(vec![
Command::Select,
Command::Insert,
Command::Update,
Command::Delete,
]),
AllowedFunctions::All,
);
for query in &*SIMPLE_QUERIES {
guard.guard(query).expect("Failed to parse valid query");
}
}
#[test]
fn rejects_privileged_queries() {
let guard = Guard::new(
AllowedStatements::List(vec![
Command::Select,
Command::Insert,
Command::Update,
Command::Delete,
]),
AllowedFunctions::All,
);
for query in &*PRIVILEGED_QUERIES {
guard
.guard(query)
.expect_err("Failed to reject privileged query");
}
}
#[test]
fn parses_complex_queries() {
let guard = Guard::new(
AllowedStatements::List(vec![
Command::Select,
Command::Insert,
Command::Update,
Command::Delete,
]),
AllowedFunctions::All,
);
for query in COMPLEX_QUERIES {
guard.guard(query).expect("Failed to parse valid query");
}
}
#[test]
fn selectively_rejects_disallowed_statements_in_ctes() {
let guard = Guard::new(
AllowedStatements::List(vec![Command::Select]),
AllowedFunctions::All,
);
guard
.guard(INSERT_CTE)
.expect_err("Failed to reject disallowed statement in CTE");
}
#[test]
#[traced_test]
fn rejects_top_level_disallowed_functions() {
let guard = Guard::new(AllowedStatements::All, AllowedFunctions::List(vec![]));
for function in &*FUNCTIONS {
guard
.guard(&function.to_query("select"))
.expect_err("Failed to reject function invocation");
}
}
#[test]
#[traced_test]
fn rejects_nested_disallowed_functions() {
let guard = Guard::new(AllowedStatements::All, AllowedFunctions::List(vec![]));
guard
.guard(CASE_CARDINALITY_SUBQUERY)
.expect_err("Failed to reject function invocation");
}
#[test]
#[traced_test]
fn selectively_rejects_disallowed_functions() {
let allowed_function = &FUNCTIONS[0];
let guard = Guard::new(
AllowedStatements::All,
AllowedFunctions::List(vec![allowed_function.name.clone()]),
);
for function in &*FUNCTIONS {
let result = guard.guard(&function.to_query("select"));
if function == allowed_function {
result.expect("Failed to allow permitted function");
} else {
result.expect_err("Failed to reject function invocation");
}
}
}