use std::{error::Error, fmt::Display, str::FromStr};
macro_rules! count_idents {
($ident: ident, $($rest_idents: ident),* $(,)?) => {
1 + count_idents!($($rest_idents),*)
};
($ident: ident) => {
1
};
() => {
0
};
}
macro_rules! formatting_string_literal {
($ident: ident, $($rest_idents: ident),* $(,)?) => {
concat!("{}.", formatting_string_literal!($($rest_idents),*))
};
($ident: ident) => {
"{}"
};
() => {
""
};
}
macro_rules! define_access_control_selector_types {
(
$(
$(#[$struct_meta: meta])*
$struct_vis: vis struct $struct_ident: ident {
$(
$(#[$field_meta: meta])*
$field_vis: vis $field_ident: ident: $field_ty: ty
),* $(,)?
}
)*
) => {
paste::paste! {
pub struct AuthorizationResolver {
pub selector_set: AccessControlSelectorSet,
$(
pub [< $struct_ident:snake _default_permissions >]: rusqlite::hooks::Authorization,
)*
}
const _: () = {
impl AuthorizationResolver {
pub fn new_allow_everything() -> Self {
Self {
selector_set: Default::default(),
$(
[< $struct_ident:snake _default_permissions >]: rusqlite::hooks::Authorization::Allow,
)*
}
}
pub fn new_deny_everything() -> Self {
Self {
selector_set: Default::default(),
$(
[< $struct_ident:snake _default_permissions >]: rusqlite::hooks::Authorization::Deny,
)*
}
}
pub fn with_selector(mut self, selector: impl Into<AccessControlSelector>, allow: bool) -> Self {
self.selector_set = self.selector_set.with_selector(selector, allow);
self
}
$(
#[doc = concat!(
"Sets the default authorization returned for [`",
stringify!($struct_ident),
"`] actions when no rule matches."
)]
pub const fn [< with_ $struct_ident:snake _default_permissions >](
mut self,
default_permissions: rusqlite::hooks::Authorization
) -> Self {
self.[< $struct_ident:snake _default_permissions >] = default_permissions;
self
}
)*
pub fn authorization(
&self,
ctx: rusqlite::hooks::AuthContext<'_>
) -> rusqlite::hooks::Authorization {
match ctx.action {
$(
rusqlite::hooks::AuthAction::$struct_ident { $($field_ident,)* .. } => {
self
.selector_set
.[< check_ $struct_ident: snake >]($(&$field_ident),*)
.map(Into::into)
.unwrap_or(self.[< $struct_ident:snake _default_permissions >])
}
)*,
_ => rusqlite::hooks::Authorization::Deny
}
}
}
};
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct AccessControlSelectorSet {
$(
$(#[$struct_meta])*
pub [<$struct_ident: snake _selectors>]: std::collections::BTreeMap<usize, Vec<Permission<$struct_ident>>>,
)*
}
const _: () = {
impl AccessControlSelectorSet {
pub fn new() -> Self {
Default::default()
}
pub fn with_selector(
self,
selector: impl Into<AccessControlSelector>,
allow: bool
) -> Self {
let selector = selector.into();
match selector {
$(
AccessControlSelector::$struct_ident(selector)
=> self.[<with_ $struct_ident: snake _selector>](
selector,
allow
)
),*
}
}
$(
#[doc = concat!(
"Adds a [`", stringify!($struct_ident),
"`] rule, inserting it into the ",
"bucket matching its specificity."
)]
pub fn [<with_ $struct_ident: snake _selector>](
mut self,
selector: $struct_ident,
allow: bool
) -> Self {
self.[<$struct_ident: snake _selectors>]
.entry(selector.specificity())
.or_default()
.push(Permission {
selector,
allow
});
self
}
#[doc = concat!(
"Resolves the configured [`",
stringify!($struct_ident),
"`] rules against the given field values. ",
"Iterates from highest to lowest specificity; ",
"at each level, deny wins if both allow and deny ",
"match. Returns `None` when no rule matches."
)]
#[allow(clippy::ptr_arg)]
pub fn [< check_ $struct_ident: snake >](&self, $($field_ident: &(impl PartialEq<$field_ty> + ?Sized)),*) -> Option<Authorization> {
for permissions in self.[<$struct_ident: snake _selectors>].values().rev() {
let mut allow_encountered = false;
let mut deny_encountered = false;
let authorizations = permissions.iter().filter_map(|permission| permission.check($($field_ident),*));
for authorization in authorizations {
match authorization {
Authorization::Allow => {
allow_encountered = true
}
Authorization::Deny => {
deny_encountered = true
}
}
if allow_encountered == true && deny_encountered == true {
return Some(Authorization::Deny)
}
}
match (allow_encountered, deny_encountered) {
(false, false) => {
continue
},
(true, false) => {
return Some(Authorization::Allow)
},
(_, true) => {
return Some(Authorization::Deny)
},
}
}
None
}
)*
}
};
#[allow(clippy::enum_variant_names)]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum AccessControlSelector {
$(
$(#[$struct_meta])*
$struct_ident($struct_ident)
),*
}
const _: () = {
use core::fmt::{Display, Formatter, Result as FmtResult};
use $crate::access_control::AccessControlSelectorParseError;
use $crate::access_control::AccessControlSelectorParseError::*;
impl AccessControlSelector {
$(
#[doc = concat!(
"Returns `true` if this is the [`",
stringify!($struct_ident),
"`] variant."
)]
pub const fn [<is_ $struct_ident: snake>](&self) -> bool {
matches!(self, Self::$struct_ident(..))
}
#[doc = concat!(
"Returns a reference to the inner [`",
stringify!($struct_ident),
"`] if this is that variant, ",
"or `None` otherwise."
)]
pub const fn [<as_ $struct_ident: snake>](&self) -> Option<&$struct_ident> {
if let Self::$struct_ident(value) = self {
Some(value)
} else {
None
}
}
#[doc = concat!(
"Consumes `self` and returns the ",
"inner [`",
stringify!($struct_ident),
"`] if this is that variant, ",
"or `None` otherwise."
)]
pub fn [<into_ $struct_ident: snake>](self) -> Option<$struct_ident> {
if let Self::$struct_ident(value) = self {
Some(value)
} else {
None
}
}
)*
}
impl Display for AccessControlSelector {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
$(
Self::$struct_ident(selector) => Display::fmt(selector, f),
)*
}
}
}
impl From<AccessControlSelector> for String {
fn from(v: AccessControlSelector) -> Self {
v.to_string()
}
}
impl FromStr for AccessControlSelector {
type Err = AccessControlSelectorParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let selector = s
.split("(")
.next()
.ok_or_else(|| InvalidAccessControlSelectorIdentifier {
found: None,
expected: &[$(stringify!($struct_ident)),*],
access_control_selector_string: s.into(),
})?;
match selector {
$(
stringify!($struct_ident) => s.parse().map(Self::$struct_ident),
)*
other => Err(InvalidAccessControlSelectorIdentifier {
found: Some(other.into()),
expected: &[$(stringify!($struct_ident)),*],
access_control_selector_string: s.into(),
})
}
}
}
impl TryFrom<String> for AccessControlSelector {
type Error = AccessControlSelectorParseError;
fn try_from(v: String) -> Result<Self, Self::Error> {
v.parse()
}
}
$(
#[doc = concat!(
"Converts a [`",
stringify!($struct_ident),
"`] into its corresponding [`",
"AccessControlSelector`] variant."
)]
impl From<$struct_ident> for AccessControlSelector {
fn from(value: $struct_ident) -> Self {
Self::$struct_ident(value)
}
}
)*
};
$(
$(#[$struct_meta])*
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
$struct_vis struct $struct_ident {
$(
$(#[$field_meta])*
$field_vis $field_ident: ValueOrGlob<$field_ty>
),*
}
const _: () = {
use core::fmt::{Display, Formatter, Result as FmtResult};
use core::convert::*;
use core::str::*;
use $crate::access_control::{AccessControlSelectorParseError};
use $crate::access_control::AccessControlSelectorParseError::*;
impl $struct_ident {
#[doc = concat!(
"The number of fields declared on [`",
stringify!($struct_ident),
"`], used for field-count validation during parsing."
)]
const FIELD_COUNT: usize = count_idents!($($field_ident),*);
#[doc = concat!(
"Creates a new [`", stringify!($struct_ident),
"`] with the provided field values."
)]
#[doc = ""]
#[doc = concat!(
"Each argument accepts any type that implements ",
"`Into<ValueOrGlob<T>>`, so callers can pass a raw ",
"value, a [`ValueOrGlob`], or an `Option<T>`."
)]
pub fn new(
$(
$field_ident: impl Into<ValueOrGlob<$field_ty>>
),*
) -> Self {
Self {
$(
$field_ident: $field_ident.into()
),*
}
}
#[doc = concat!(
"Creates a [`", stringify!($struct_ident),
"`] where every field is a [`Glob`](ValueOrGlob::Glob), ",
"matching all possible values."
)]
pub fn empty() -> Self {
Self {
$(
$field_ident: Default::default()
),*
}
}
$(
#[doc = concat!(
"Sets [`", stringify!($struct_ident),
"::", stringify!($field_ident),
"`] and returns `self` for method chaining."
)]
#[doc = ""]
$(#[$field_meta])*
pub fn [< with_ $field_ident >](mut self, value: impl Into<ValueOrGlob<$field_ty>>) -> Self {
self.$field_ident = value.into();
self
}
)*
#[doc = concat!(
"Returns `true` if every field on this [`",
stringify!($struct_ident),
"`] is a [`Glob`](ValueOrGlob::Glob)."
)]
pub const fn is_all_glob(&self) -> bool {
if Self::FIELD_COUNT != 0 {
true $(&& ValueOrGlob::is_glob(&self.$field_ident))*
} else {
false
}
}
#[doc = concat!(
"Returns `true` if every field on this [`",
stringify!($struct_ident),
"`] is a [`Value`](ValueOrGlob::Value)."
)]
pub const fn is_all_value(&self) -> bool {
if Self::FIELD_COUNT != 0 {
true $(&& ValueOrGlob::is_value(&self.$field_ident))*
} else {
false
}
}
#[doc = concat!(
"Returns `true` if at least one field on this [`",
stringify!($struct_ident),
"`] is a [`Glob`](ValueOrGlob::Glob)."
)]
pub const fn is_any_glob(&self) -> bool {
if Self::FIELD_COUNT != 0 {
false $(|| ValueOrGlob::is_glob(&self.$field_ident))*
} else {
false
}
}
#[doc = concat!(
"Returns `true` if at least one field on this [`",
stringify!($struct_ident),
"`] is a [`Value`](ValueOrGlob::Value)."
)]
pub const fn is_any_value(&self) -> bool {
if Self::FIELD_COUNT != 0 {
false $(|| ValueOrGlob::is_value(&self.$field_ident))*
} else {
false
}
}
#[doc = concat!(
"Returns the number of ",
"[`Value`](ValueOrGlob::Value) ",
"fields on this [`",
stringify!($struct_ident),
"`], used to rank competing ",
"policy rules during resolution. ",
"A higher specificity means the ",
"rule targets a narrower set of ",
"operations."
)]
#[allow(unused_mut)]
pub const fn specificity(&self) -> usize {
let mut specificity = 0;
$(
if self.$field_ident.is_value() {
specificity += 1
}
)*
specificity
}
#[doc = concat!(
"Returns `true` if this selector ",
"covers the given field values. ",
"A [`Glob`](ValueOrGlob::Glob) ",
"field matches any value; a ",
"[`Value`](ValueOrGlob::Value) ",
"field matches only when equal."
)]
#[allow(clippy::ptr_arg)]
pub fn matches(&self, $($field_ident: &(impl PartialEq<$field_ty> + ?Sized)),*) -> bool {
true $(&& (self.$field_ident.is_glob() || self.$field_ident.as_value().is_some_and(|field_value| PartialEq::<$field_ty>::eq($field_ident, field_value))))*
}
}
#[allow(clippy::ptr_arg)]
impl Permission<$struct_ident> {
#[doc = concat!(
"If the inner [`",
stringify!($struct_ident),
"`] selector matches the given ",
"fields, returns the ",
"[`Authorization`] implied by ",
"this permission's `allow` flag. ",
"Returns `None` on no match."
)]
pub fn check(&self, $($field_ident: &(impl PartialEq<$field_ty> + ?Sized)),*) -> Option<Authorization> {
self.selector.matches($($field_ident),*).then(|| match self.allow {
true => Authorization::Allow,
false => Authorization::Deny,
})
}
}
#[doc = concat!(
"Defaults to [`", stringify!($struct_ident),
"::empty`], producing a selector where every field ",
"is a glob."
)]
impl Default for $struct_ident {
fn default() -> Self {
Self::empty()
}
}
#[doc = concat!(
"Formats this [`", stringify!($struct_ident),
"`] using the CLI selector syntax. When all fields ",
"are globs, renders the bare identifier `",
stringify!($struct_ident),
"`; otherwise renders the identifier with ",
"parenthesized dot-separated fields."
)]
impl Display for $struct_ident {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
if Self::FIELD_COUNT == 0 || self.is_all_glob() {
core::write!(f, stringify!($struct_ident))
} else {
core::write!(
f,
concat!(
stringify!($struct_ident),
"(",
formatting_string_literal!($($field_ident),*),
")"
),
$(self.$field_ident),*
)
}
}
}
#[doc = concat!(
"Converts a [`", stringify!($struct_ident),
"`] into its string representation via [`Display`]."
)]
impl From<$struct_ident> for String {
fn from(value: $struct_ident) -> Self {
value.to_string()
}
}
#[doc = concat!(
"Parses a selector string into a [`",
stringify!($struct_ident),
"`]. Accepts the syntax `",
stringify!($struct_ident),
"` or `", stringify!($struct_ident),
"(field1.field2)`."
)]
#[allow(unused_mut, unused_variables)]
impl FromStr for $struct_ident {
type Err = AccessControlSelectorParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parsed_access_control_selector = ParsedAccessControlSelector::<'_, {Self::FIELD_COUNT}>::new(s)?;
if parsed_access_control_selector.access_control_selector_ident != stringify!($struct_ident) {
return Err(IncorrectAccessControlSelectorIdentifier {
expected: stringify!($struct_ident).into(),
found: parsed_access_control_selector.access_control_selector_ident.into(),
access_control_selector_string: s.into()
})
}
let mut access_control_selector = Self::default();
let mut fields = parsed_access_control_selector.fields.into_iter().flatten().map(str::trim);
$(
let field = fields.next();
if let Some(field) = field {
let field_value = ValueOrGlob::<$field_ty>::from_str(field).map_err(|err| FailedToParseFieldValue {
field_ident: stringify!($field_ident).into(),
field_value: field.into(),
error: Box::new(err) as _
})?;
access_control_selector = access_control_selector.[< with_ $field_ident >](field_value)
}
)*
Ok(access_control_selector)
}
}
#[doc = concat!(
"Parses an owned `String` into a [`",
stringify!($struct_ident),
"`] by delegating to [`FromStr`]."
)]
impl TryFrom<String> for $struct_ident {
type Error = AccessControlSelectorParseError;
fn try_from(v: String) -> Result<Self, Self::Error> {
v.parse()
}
}
};
)*
}
};
}
define_access_control_selector_types! {
pub struct CreateIndex {
pub table_name: String,
pub index_name: String,
}
pub struct CreateTable {
pub table_name: String,
}
pub struct CreateTempIndex {
pub table_name: String,
pub index_name: String,
}
pub struct CreateTempTable {
pub table_name: String,
}
pub struct CreateTempTrigger {
pub table_name: String,
pub trigger_name: String,
}
pub struct CreateTempView {
pub view_name: String,
}
pub struct CreateTrigger {
pub table_name: String,
pub trigger_name: String,
}
pub struct CreateView {
pub view_name: String,
}
pub struct Delete {
pub table_name: String,
}
pub struct DropIndex {
pub table_name: String,
pub index_name: String,
}
pub struct DropTable {
pub table_name: String,
}
pub struct DropTempIndex {
pub table_name: String,
pub index_name: String,
}
pub struct DropTempTable {
pub table_name: String,
}
pub struct DropTempTrigger {
pub table_name: String,
pub trigger_name: String,
}
pub struct DropTempView {
pub view_name: String,
}
pub struct DropTrigger {
pub table_name: String,
pub trigger_name: String,
}
pub struct DropView {
pub view_name: String,
}
pub struct Insert {
pub table_name: String,
}
pub struct Pragma {
pub pragma_name: String,
}
pub struct Read {
pub table_name: String,
pub column_name: String,
}
pub struct Select {}
pub struct Transaction {
pub operation: TransactionOperation,
}
pub struct Update {
pub table_name: String,
pub column_name: String,
}
pub struct Attach {
pub filename: String,
}
pub struct Detach {
pub database_name: String,
}
pub struct AlterTable {
pub database_name: String,
pub table_name: String,
}
pub struct Reindex {
pub index_name: String,
}
pub struct Analyze {
pub table_name: String,
}
pub struct CreateVtable {
pub table_name: String,
pub module_name: String,
}
pub struct DropVtable {
pub table_name: String,
pub module_name: String,
}
pub struct Function {
pub function_name: String,
}
pub struct Savepoint {
pub operation: TransactionOperation,
pub savepoint_name: String,
}
pub struct Recursive {}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct Glob;
impl Display for Glob {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "*")
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ValueOrGlob<T> {
Value(T),
Glob(Glob),
}
impl<T> Default for ValueOrGlob<T> {
fn default() -> Self {
Self::new_glob()
}
}
impl<T> ValueOrGlob<T> {
pub fn new_value(value: impl Into<T>) -> Self {
Self::Value(value.into())
}
pub const fn new_glob() -> Self {
Self::Glob(Glob)
}
pub const fn is_value(&self) -> bool {
matches!(self, Self::Value(..))
}
pub const fn is_glob(&self) -> bool {
matches!(self, Self::Glob(..))
}
pub const fn as_value(&self) -> Option<&T> {
match self {
ValueOrGlob::Value(value) => Some(value),
ValueOrGlob::Glob(_) => None,
}
}
pub const fn as_glob(&self) -> Option<&Glob> {
match self {
ValueOrGlob::Value(_) => None,
ValueOrGlob::Glob(glob) => Some(glob),
}
}
pub fn into_value(self) -> Option<T> {
match self {
ValueOrGlob::Value(value) => Some(value),
ValueOrGlob::Glob(_) => None,
}
}
pub fn into_glob(self) -> Option<Glob> {
match self {
ValueOrGlob::Value(_) => None,
ValueOrGlob::Glob(glob) => Some(glob),
}
}
}
impl<T> Display for ValueOrGlob<T>
where
T: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ValueOrGlob::Value(value) => Display::fmt(value, f),
ValueOrGlob::Glob(glob) => Display::fmt(glob, f),
}
}
}
impl<T> From<ValueOrGlob<T>> for String
where
T: Display,
{
fn from(value: ValueOrGlob<T>) -> Self {
value.to_string()
}
}
impl<T> FromStr for ValueOrGlob<T>
where
T: FromStr,
{
type Err = T::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s == "*" {
Ok(Self::Glob(Glob))
} else {
T::from_str(s).map(Self::Value)
}
}
}
impl<T> TryFrom<String> for ValueOrGlob<T>
where
T: FromStr,
{
type Error = T::Err;
fn try_from(value: String) -> Result<Self, Self::Error> {
value.parse()
}
}
impl<T> From<Option<T>> for ValueOrGlob<T> {
fn from(value: Option<T>) -> Self {
match value {
Some(value) => Self::Value(value),
None => Self::Glob(Glob),
}
}
}
struct ParsedAccessControlSelector<'a, const FIELD_COUNT: usize> {
pub access_control_selector_ident: &'a str,
pub fields: Option<std::str::Split<'a, char>>,
}
impl<'a, const FIELD_COUNT: usize>
ParsedAccessControlSelector<'a, FIELD_COUNT>
{
pub fn new(
string: &'a str,
) -> Result<Self, AccessControlSelectorParseError> {
use AccessControlSelectorParseError::*;
let mut brace_split = string.splitn(2, '(');
let access_control_selector_ident = brace_split
.next()
.ok_or_else(|| NoAccessControlSelectorIdentifierFound {
expected: stringify!($struct_ident).into(),
access_control_selector_string: string.to_string(),
})?
.trim();
if let Some(remaining_string) = brace_split.next() {
let Some((inner_fields, _)) = remaining_string.split_once(')')
else {
return Err(NoClosingBraceFound {
access_control_selector_string: string.to_string(),
});
};
let number_of_fields =
inner_fields.chars().filter(|c| matches!(c, '.')).count() + 1;
if number_of_fields > FIELD_COUNT {
return Err(InvalidNumberOfFields {
found: number_of_fields,
maximum_expected: FIELD_COUNT,
access_control_selector_string: string.to_string(),
});
}
let fields = inner_fields.split('.');
Ok(Self {
access_control_selector_ident,
fields: Some(fields),
})
} else {
Ok(Self {
access_control_selector_ident,
fields: None,
})
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Permission<T> {
pub selector: T,
pub allow: bool,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Authorization {
Allow,
Deny,
}
impl From<Authorization> for rusqlite::hooks::Authorization {
fn from(value: Authorization) -> Self {
match value {
Authorization::Allow => Self::Allow,
Authorization::Deny => Self::Deny,
}
}
}
#[derive(
Clone,
Copy,
Debug,
Eq,
PartialEq,
PartialOrd,
Ord,
Hash,
strum::Display,
strum::EnumString,
)]
#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
pub enum TransactionOperation {
Unknown,
Begin,
Release,
Rollback,
}
impl From<TransactionOperation> for rusqlite::hooks::TransactionOperation {
fn from(value: TransactionOperation) -> Self {
match value {
TransactionOperation::Unknown => Self::Unknown,
TransactionOperation::Begin => Self::Begin,
TransactionOperation::Release => Self::Release,
TransactionOperation::Rollback => Self::Rollback,
}
}
}
impl PartialEq<rusqlite::hooks::TransactionOperation> for TransactionOperation {
fn eq(&self, other: &rusqlite::hooks::TransactionOperation) -> bool {
&rusqlite::hooks::TransactionOperation::from(*self) == other
}
}
impl PartialEq<TransactionOperation> for rusqlite::hooks::TransactionOperation {
fn eq(&self, other: &TransactionOperation) -> bool {
&rusqlite::hooks::TransactionOperation::from(*other) == self
}
}
#[derive(Debug, thiserror::Error)]
pub enum AccessControlSelectorParseError {
#[error(
"no access control selector identifier found in \"{access_control_selector_string}\", \
expected \"{expected}\""
)]
NoAccessControlSelectorIdentifierFound {
expected: String,
access_control_selector_string: String,
},
#[error(
"incorrect access control selector identifier in \"{access_control_selector_string}\": \
expected \"{expected}\", found \"{found}\""
)]
IncorrectAccessControlSelectorIdentifier {
expected: String,
found: String,
access_control_selector_string: String,
},
#[error(
"invalid access control selector identifier in \"{access_control_selector_string}\": \
found {found:?}, expected one of {expected:?}"
)]
InvalidAccessControlSelectorIdentifier {
found: Option<String>,
expected: &'static [&'static str],
access_control_selector_string: String,
},
#[error(
"no closing parenthesis found in \"{access_control_selector_string}\""
)]
NoClosingBraceFound {
access_control_selector_string: String,
},
#[error(
"invalid number of fields in \"{access_control_selector_string}\": \
found {found}, maximum expected {maximum_expected}"
)]
InvalidNumberOfFields {
found: usize,
maximum_expected: usize,
access_control_selector_string: String,
},
#[error(
"failed to parse field \"{field_ident}\" with value \"{field_value}\": {error}"
)]
FailedToParseFieldValue {
field_ident: String,
field_value: String,
error: Box<dyn Error + Send + Sync>,
},
}
#[derive(
Clone,
Copy,
Debug,
Default,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
strum::Display,
strum::EnumString,
clap::ValueEnum,
)]
#[strum(serialize_all = "kebab-case")]
pub enum Preset {
DenyEverything,
#[default]
ReadOnly,
ReadWrite,
FullDdl,
AllowEverything,
}
impl From<Preset> for AuthorizationResolver {
fn from(preset: Preset) -> Self {
match preset {
Preset::DenyEverything => Self::new_deny_everything(),
Preset::ReadOnly => Self::new_read_only(),
Preset::ReadWrite => Self::new_read_write(),
Preset::FullDdl => Self::new_full_ddl(),
Preset::AllowEverything => Self::new_allow_everything(),
}
}
}
impl AuthorizationResolver {
pub fn new_read_only() -> Self {
use rusqlite::hooks::Authorization::Allow;
Self::new_deny_everything()
.with_read_default_permissions(Allow)
.with_select_default_permissions(Allow)
.with_transaction_default_permissions(Allow)
.with_function_default_permissions(Allow)
.with_recursive_default_permissions(Allow)
.with_pragma_default_permissions(Allow)
}
pub fn new_read_write() -> Self {
use rusqlite::hooks::Authorization::Allow;
Self::new_read_only()
.with_insert_default_permissions(Allow)
.with_update_default_permissions(Allow)
.with_delete_default_permissions(Allow)
.with_savepoint_default_permissions(Allow)
.with_analyze_default_permissions(Allow)
.with_reindex_default_permissions(Allow)
.with_create_temp_table_default_permissions(Allow)
.with_create_temp_index_default_permissions(Allow)
.with_create_temp_view_default_permissions(Allow)
.with_create_temp_trigger_default_permissions(Allow)
.with_drop_temp_table_default_permissions(Allow)
.with_drop_temp_index_default_permissions(Allow)
.with_drop_temp_view_default_permissions(Allow)
.with_drop_temp_trigger_default_permissions(Allow)
}
pub fn new_full_ddl() -> Self {
use rusqlite::hooks::Authorization::Allow;
Self::new_read_write()
.with_create_table_default_permissions(Allow)
.with_drop_table_default_permissions(Allow)
.with_alter_table_default_permissions(Allow)
.with_create_index_default_permissions(Allow)
.with_drop_index_default_permissions(Allow)
.with_create_trigger_default_permissions(Allow)
.with_drop_trigger_default_permissions(Allow)
.with_create_view_default_permissions(Allow)
.with_drop_view_default_permissions(Allow)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn glob_display_shows_asterisk() {
let glob = Glob;
let displayed = glob.to_string();
assert_eq!(displayed, "*");
}
#[test]
fn value_or_glob_new_value_creates_value_variant() {
let input = "hello";
let vog = ValueOrGlob::<String>::new_value(input);
assert_eq!(vog, ValueOrGlob::Value("hello".to_string()));
}
#[test]
fn value_or_glob_new_glob_creates_glob_variant() {
let vog = ValueOrGlob::<String>::new_glob();
assert_eq!(vog, ValueOrGlob::Glob(Glob));
}
#[test]
fn value_or_glob_default_returns_glob() {
let vog = ValueOrGlob::<String>::default();
assert_eq!(vog, ValueOrGlob::Glob(Glob));
}
#[test]
fn value_or_glob_is_value_returns_true_for_value() {
let vog = ValueOrGlob::<String>::new_value("test");
let result = vog.is_value();
assert!(result);
}
#[test]
fn value_or_glob_is_value_returns_false_for_glob() {
let vog = ValueOrGlob::<String>::new_glob();
let result = vog.is_value();
assert!(!result);
}
#[test]
fn value_or_glob_is_glob_returns_true_for_glob() {
let vog = ValueOrGlob::<String>::new_glob();
let result = vog.is_glob();
assert!(result);
}
#[test]
fn value_or_glob_is_glob_returns_false_for_value() {
let vog = ValueOrGlob::<String>::new_value("test");
let result = vog.is_glob();
assert!(!result);
}
#[test]
fn value_or_glob_as_value_returns_some_for_value() {
let vog = ValueOrGlob::<String>::new_value("hello");
let result = vog.as_value();
assert_eq!(result, Some(&"hello".to_string()));
}
#[test]
fn value_or_glob_as_value_returns_none_for_glob() {
let vog = ValueOrGlob::<String>::new_glob();
let result = vog.as_value();
assert_eq!(result, None);
}
#[test]
fn value_or_glob_as_glob_returns_some_for_glob() {
let vog = ValueOrGlob::<String>::new_glob();
let result = vog.as_glob();
assert_eq!(result, Some(&Glob));
}
#[test]
fn value_or_glob_as_glob_returns_none_for_value() {
let vog = ValueOrGlob::<String>::new_value("test");
let result = vog.as_glob();
assert_eq!(result, None);
}
#[test]
fn value_or_glob_into_value_returns_some_for_value() {
let vog = ValueOrGlob::<String>::new_value("hello");
let result = vog.into_value();
assert_eq!(result, Some("hello".to_string()));
}
#[test]
fn value_or_glob_into_value_returns_none_for_glob() {
let vog = ValueOrGlob::<String>::new_glob();
let result = vog.into_value();
assert_eq!(result, None);
}
#[test]
fn value_or_glob_into_glob_returns_some_for_glob() {
let vog = ValueOrGlob::<String>::new_glob();
let result = vog.into_glob();
assert_eq!(result, Some(Glob));
}
#[test]
fn value_or_glob_into_glob_returns_none_for_value() {
let vog = ValueOrGlob::<String>::new_value("test");
let result = vog.into_glob();
assert_eq!(result, None);
}
#[test]
fn value_or_glob_display_shows_inner_value_for_value() {
let vog = ValueOrGlob::<String>::new_value("Students");
let displayed = vog.to_string();
assert_eq!(displayed, "Students");
}
#[test]
fn value_or_glob_display_shows_asterisk_for_glob() {
let vog = ValueOrGlob::<String>::new_glob();
let displayed = vog.to_string();
assert_eq!(displayed, "*");
}
#[test]
fn value_or_glob_from_str_with_asterisk_returns_glob() {
let input = "*";
let result = input.parse::<ValueOrGlob<String>>().unwrap();
assert_eq!(result, ValueOrGlob::Glob(Glob));
}
#[test]
fn value_or_glob_from_str_with_text_returns_value() {
let input = "Students";
let result = input.parse::<ValueOrGlob<String>>().unwrap();
assert_eq!(result, ValueOrGlob::Value("Students".to_string()));
}
#[test]
fn value_or_glob_from_some_option_gives_value() {
let opt = Some("name".to_string());
let vog = ValueOrGlob::<String>::from(opt);
assert_eq!(vog, ValueOrGlob::Value("name".to_string()));
}
#[test]
fn value_or_glob_from_none_option_gives_glob() {
let opt = Option::<String>::None;
let vog = ValueOrGlob::<String>::from(opt);
assert_eq!(vog, ValueOrGlob::Glob(Glob));
}
#[test]
fn value_or_glob_into_string_for_value() {
let vog = ValueOrGlob::<String>::new_value("hello");
let s = String::from(vog);
assert_eq!(s, "hello");
}
#[test]
fn value_or_glob_into_string_for_glob() {
let vog = ValueOrGlob::<String>::new_glob();
let s = String::from(vog);
assert_eq!(s, "*");
}
#[test]
fn read_new_creates_with_given_values() {
let table = ValueOrGlob::<String>::new_value("Students");
let column = ValueOrGlob::<String>::new_value("name");
let read = Read::new(table, column);
assert_eq!(read.table_name, ValueOrGlob::Value("Students".to_string()));
assert_eq!(read.column_name, ValueOrGlob::Value("name".to_string()));
}
#[test]
fn read_empty_creates_with_all_globs() {
let read = Read::empty();
assert_eq!(read.table_name, ValueOrGlob::Glob(Glob));
assert_eq!(read.column_name, ValueOrGlob::Glob(Glob));
}
#[test]
fn read_with_table_name_sets_table_name() {
let read = Read::empty();
let read =
read.with_table_name(ValueOrGlob::<String>::new_value("Grades"));
assert_eq!(read.table_name, ValueOrGlob::Value("Grades".to_string()));
assert_eq!(read.column_name, ValueOrGlob::Glob(Glob));
}
#[test]
fn read_with_column_name_sets_column_name() {
let read = Read::empty();
let read =
read.with_column_name(ValueOrGlob::<String>::new_value("score"));
assert_eq!(read.table_name, ValueOrGlob::Glob(Glob));
assert_eq!(read.column_name, ValueOrGlob::Value("score".to_string()));
}
#[test]
fn read_is_all_glob_true_when_both_glob() {
let read = Read::empty();
let result = read.is_all_glob();
assert!(result);
}
#[test]
fn read_is_all_glob_false_when_table_is_value() {
let read = Read::empty()
.with_table_name(ValueOrGlob::<String>::new_value("Students"));
let result = read.is_all_glob();
assert!(!result);
}
#[test]
fn read_is_all_glob_false_when_column_is_value() {
let read = Read::empty()
.with_column_name(ValueOrGlob::<String>::new_value("name"));
let result = read.is_all_glob();
assert!(!result);
}
#[test]
fn read_is_all_value_true_when_both_value() {
let read = Read::new(
ValueOrGlob::<String>::new_value("Students"),
ValueOrGlob::<String>::new_value("name"),
);
let result = read.is_all_value();
assert!(result);
}
#[test]
fn read_is_all_value_false_when_one_is_glob() {
let read = Read::empty()
.with_table_name(ValueOrGlob::<String>::new_value("Students"));
let result = read.is_all_value();
assert!(!result);
}
#[test]
fn read_is_any_glob_true_when_one_is_glob() {
let read = Read::empty()
.with_table_name(ValueOrGlob::<String>::new_value("Students"));
let result = read.is_any_glob();
assert!(result);
}
#[test]
fn read_is_any_glob_false_when_both_value() {
let read = Read::new(
ValueOrGlob::<String>::new_value("Students"),
ValueOrGlob::<String>::new_value("name"),
);
let result = read.is_any_glob();
assert!(!result);
}
#[test]
fn read_is_any_value_true_when_one_is_value() {
let read = Read::empty()
.with_column_name(ValueOrGlob::<String>::new_value("name"));
let result = read.is_any_value();
assert!(result);
}
#[test]
fn read_is_any_value_false_when_both_glob() {
let read = Read::empty();
let result = read.is_any_value();
assert!(!result);
}
#[test]
fn read_display_all_glob_shows_read_without_parens() {
let read = Read::empty();
let displayed = read.to_string();
assert_eq!(displayed, "Read");
}
#[test]
fn read_display_specific_shows_read_with_table_and_column() {
let read = Read::new(
ValueOrGlob::<String>::new_value("Students"),
ValueOrGlob::<String>::new_value("name"),
);
let displayed = read.to_string();
assert_eq!(displayed, "Read(Students.name)");
}
#[test]
fn read_display_glob_table_value_column() {
let read = Read::empty()
.with_column_name(ValueOrGlob::<String>::new_value("name"));
let displayed = read.to_string();
assert_eq!(displayed, "Read(*.name)");
}
#[test]
fn read_display_value_table_glob_column() {
let read = Read::empty()
.with_table_name(ValueOrGlob::<String>::new_value("Students"));
let displayed = read.to_string();
assert_eq!(displayed, "Read(Students.*)");
}
#[test]
fn read_from_str_bare_read_parses_to_all_glob() {
let input = "Read";
let read = input.parse::<Read>().unwrap();
assert_eq!(read.table_name, ValueOrGlob::Glob(Glob));
assert_eq!(read.column_name, ValueOrGlob::Glob(Glob));
}
#[test]
fn read_from_str_specific_table_and_column_parses_correctly() {
let input = "Read(Students.name)";
let read = input.parse::<Read>().unwrap();
assert_eq!(read.table_name, ValueOrGlob::Value("Students".to_string()));
assert_eq!(read.column_name, ValueOrGlob::Value("name".to_string()));
}
#[test]
fn read_from_str_glob_table_value_column_parses_correctly() {
let input = "Read(*.name)";
let read = input.parse::<Read>().unwrap();
assert_eq!(read.table_name, ValueOrGlob::Glob(Glob));
assert_eq!(read.column_name, ValueOrGlob::Value("name".to_string()));
}
#[test]
fn read_from_str_value_table_glob_column_parses_correctly() {
let input = "Read(Students.*)";
let read = input.parse::<Read>().unwrap();
assert_eq!(read.table_name, ValueOrGlob::Value("Students".to_string()));
assert_eq!(read.column_name, ValueOrGlob::Glob(Glob));
}
#[test]
fn read_from_str_round_trip_produces_same_result() {
let input = "Read(Students.name)";
let read = input.parse::<Read>().unwrap();
let displayed = read.to_string();
let reparsed = displayed.parse::<Read>().unwrap();
assert_eq!(
reparsed.table_name,
ValueOrGlob::Value("Students".to_string())
);
assert_eq!(
reparsed.column_name,
ValueOrGlob::Value("name".to_string())
);
}
#[test]
fn read_from_str_round_trip_bare_read_produces_same_result() {
let input = "Read";
let read = input.parse::<Read>().unwrap();
let displayed = read.to_string();
let reparsed = displayed.parse::<Read>().unwrap();
assert_eq!(reparsed.table_name, ValueOrGlob::Glob(Glob));
assert_eq!(reparsed.column_name, ValueOrGlob::Glob(Glob));
assert_eq!(displayed, "Read");
}
#[test]
fn read_into_string_works() {
let read = Read::new(
ValueOrGlob::<String>::new_value("Students"),
ValueOrGlob::<String>::new_value("name"),
);
let s = String::from(read);
assert_eq!(s, "Read(Students.name)");
}
#[test]
fn read_try_from_string_works() {
let input = "Read(Students.name)".to_string();
let read = Read::try_from(input).unwrap();
assert_eq!(read.table_name, ValueOrGlob::Value("Students".to_string()));
assert_eq!(read.column_name, ValueOrGlob::Value("name".to_string()));
}
#[test]
fn access_control_selector_from_str_bare_read_parses_to_read_selector() {
let input = "Read";
let selector = input.parse::<AccessControlSelector>().unwrap();
match selector {
AccessControlSelector::Read(read) => {
assert_eq!(read.table_name, ValueOrGlob::Glob(Glob));
assert_eq!(read.column_name, ValueOrGlob::Glob(Glob));
}
other => panic!("expected ReadSelector, got {other}"),
}
}
#[test]
fn access_control_selector_from_str_specific_parses_correctly() {
let input = "Read(Students.name)";
let selector = input.parse::<AccessControlSelector>().unwrap();
match selector {
AccessControlSelector::Read(read) => {
assert_eq!(
read.table_name,
ValueOrGlob::Value("Students".to_string())
);
assert_eq!(
read.column_name,
ValueOrGlob::Value("name".to_string())
);
}
other => panic!("expected ReadSelector, got {other}"),
}
}
#[test]
fn access_control_selector_display_round_trip_works() {
let input = "Read(Students.name)";
let selector = input.parse::<AccessControlSelector>().unwrap();
let displayed = selector.to_string();
let reparsed = displayed.parse::<AccessControlSelector>().unwrap();
match reparsed {
AccessControlSelector::Read(read) => {
assert_eq!(
read.table_name,
ValueOrGlob::Value("Students".to_string())
);
assert_eq!(
read.column_name,
ValueOrGlob::Value("name".to_string())
);
}
other => panic!("expected ReadSelector, got {other}"),
}
}
#[test]
fn access_control_selector_try_from_string_works() {
let input = "Read(Students.name)".to_string();
let selector = AccessControlSelector::try_from(input).unwrap();
match selector {
AccessControlSelector::Read(read) => {
assert_eq!(
read.table_name,
ValueOrGlob::Value("Students".to_string())
);
assert_eq!(
read.column_name,
ValueOrGlob::Value("name".to_string())
);
}
other => panic!("expected ReadSelector, got {other}"),
}
}
#[test]
fn parsing_with_missing_closing_paren_returns_error() {
let input = "Read(Students.name";
let result = input.parse::<Read>();
assert!(matches!(
result,
Err(AccessControlSelectorParseError::NoClosingBraceFound { .. })
));
}
#[test]
fn parsing_with_too_many_fields_returns_error() {
let input = "Read(Students.name.extra)";
let result = input.parse::<Read>();
assert!(matches!(
result,
Err(AccessControlSelectorParseError::InvalidNumberOfFields { .. })
));
}
#[test]
fn parsing_with_wrong_identifier_returns_error_for_struct() {
let input = "Write(Students.name)";
let result = input.parse::<Read>();
assert!(matches!(
result,
Err(
AccessControlSelectorParseError::IncorrectAccessControlSelectorIdentifier {
..
}
)
));
}
#[test]
fn parsing_with_wrong_identifier_returns_error_for_enum() {
let input = "Write(Students.name)";
let result = input.parse::<AccessControlSelector>();
assert!(matches!(
result,
Err(
AccessControlSelectorParseError::InvalidAccessControlSelectorIdentifier {
..
}
)
));
}
#[test]
fn parsing_empty_string_returns_error_for_enum() {
let input = "";
let result = input.parse::<AccessControlSelector>();
assert!(result.is_err());
}
#[test]
fn parsing_empty_string_returns_error_for_struct() {
let input = "";
let result = input.parse::<Read>();
assert!(matches!(
result,
Err(
AccessControlSelectorParseError::IncorrectAccessControlSelectorIdentifier {
..
}
)
));
}
#[test]
fn empty_selector_set_returns_none_for_read() {
let set = AccessControlSelectorSet::new();
let result = set.check_read("Students", "name");
assert_eq!(result, None);
}
#[test]
fn single_allow_rule_matches_and_returns_allow() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), true);
let result = set.check_read("Students", "name");
assert_eq!(result, Some(Authorization::Allow));
}
#[test]
fn single_deny_rule_matches_and_returns_deny() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), false);
let result = set.check_read("Students", "name");
assert_eq!(result, Some(Authorization::Deny));
}
#[test]
fn single_allow_rule_that_does_not_match_returns_none() {
let set = AccessControlSelectorSet::new().with_read_selector(
Read::new(
ValueOrGlob::new_value("Secrets"),
ValueOrGlob::new_glob(),
),
true,
);
let result = set.check_read("Students", "name");
assert_eq!(result, None);
}
#[test]
fn single_deny_rule_that_does_not_match_returns_none() {
let set = AccessControlSelectorSet::new().with_read_selector(
Read::new(
ValueOrGlob::new_value("Secrets"),
ValueOrGlob::new_glob(),
),
false,
);
let result = set.check_read("Students", "name");
assert_eq!(result, None);
}
#[test]
fn higher_specificity_allow_beats_lower_specificity_deny() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), false)
.with_read_selector(
Read::new(
ValueOrGlob::new_value("Students"),
ValueOrGlob::new_glob(),
),
true,
);
let result = set.check_read("Students", "name");
assert_eq!(result, Some(Authorization::Allow));
}
#[test]
fn higher_specificity_deny_beats_lower_specificity_allow() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), true)
.with_read_selector(
Read::new(
ValueOrGlob::new_value("Secrets"),
ValueOrGlob::new_glob(),
),
false,
);
let result = set.check_read("Secrets", "ssn");
assert_eq!(result, Some(Authorization::Deny));
}
#[test]
fn specificity_2_beats_specificity_1_beats_specificity_0() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), true)
.with_read_selector(
Read::new(
ValueOrGlob::new_value("Secrets"),
ValueOrGlob::new_glob(),
),
false,
)
.with_read_selector(
Read::new(
ValueOrGlob::new_value("Secrets"),
ValueOrGlob::new_value("id"),
),
true,
);
let allow_all = set.check_read("Public", "data");
let deny_table = set.check_read("Secrets", "ssn");
let allow_carveout = set.check_read("Secrets", "id");
assert_eq!(allow_all, Some(Authorization::Allow));
assert_eq!(deny_table, Some(Authorization::Deny));
assert_eq!(allow_carveout, Some(Authorization::Allow));
}
#[test]
fn same_specificity_allow_and_deny_resolves_to_deny() {
let set = AccessControlSelectorSet::new()
.with_read_selector(
Read::new(
ValueOrGlob::new_value("Students"),
ValueOrGlob::new_glob(),
),
true,
)
.with_read_selector(
Read::new(
ValueOrGlob::new_glob(),
ValueOrGlob::new_value("name"),
),
false,
);
let result = set.check_read("Students", "name");
assert_eq!(result, Some(Authorization::Deny));
}
#[test]
fn same_specificity_multiple_allows_no_deny_returns_allow() {
let set = AccessControlSelectorSet::new()
.with_read_selector(
Read::new(ValueOrGlob::new_value("A"), ValueOrGlob::new_glob()),
true,
)
.with_read_selector(
Read::new(ValueOrGlob::new_glob(), ValueOrGlob::new_value("x")),
true,
);
let result = set.check_read("A", "x");
assert_eq!(result, Some(Authorization::Allow));
}
#[test]
fn same_specificity_multiple_denys_no_allow_returns_deny() {
let set = AccessControlSelectorSet::new()
.with_read_selector(
Read::new(ValueOrGlob::new_value("A"), ValueOrGlob::new_glob()),
false,
)
.with_read_selector(
Read::new(ValueOrGlob::new_glob(), ValueOrGlob::new_value("x")),
false,
);
let result = set.check_read("A", "x");
assert_eq!(result, Some(Authorization::Deny));
}
#[test]
fn non_matching_rules_are_skipped_to_lower_specificity() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), true)
.with_read_selector(
Read::new(
ValueOrGlob::new_value("Other"),
ValueOrGlob::new_glob(),
),
false,
);
let result = set.check_read("Students", "name");
assert_eq!(result, Some(Authorization::Allow));
}
#[test]
fn rules_with_specificity_gap_still_resolve_correctly() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), false)
.with_read_selector(
Read::new(
ValueOrGlob::new_value("Students"),
ValueOrGlob::new_value("name"),
),
true,
);
let carveout = set.check_read("Students", "name");
let blocked = set.check_read("Students", "ssn");
assert_eq!(carveout, Some(Authorization::Allow));
assert_eq!(blocked, Some(Authorization::Deny));
}
#[test]
fn glob_table_pinned_column_matches_any_table() {
let set = AccessControlSelectorSet::new().with_read_selector(
Read::new(ValueOrGlob::new_glob(), ValueOrGlob::new_value("ssn")),
false,
);
let result_a = set.check_read("Students", "ssn");
let result_b = set.check_read("Employees", "ssn");
let result_c = set.check_read("Students", "name");
assert_eq!(result_a, Some(Authorization::Deny));
assert_eq!(result_b, Some(Authorization::Deny));
assert_eq!(result_c, None);
}
#[test]
fn pinned_table_glob_column_matches_any_column() {
let set = AccessControlSelectorSet::new().with_read_selector(
Read::new(
ValueOrGlob::new_value("Secrets"),
ValueOrGlob::new_glob(),
),
false,
);
let result_a = set.check_read("Secrets", "ssn");
let result_b = set.check_read("Secrets", "id");
let result_c = set.check_read("Public", "data");
assert_eq!(result_a, Some(Authorization::Deny));
assert_eq!(result_b, Some(Authorization::Deny));
assert_eq!(result_c, None);
}
#[test]
fn fully_pinned_selector_matches_only_exact_pair() {
let set = AccessControlSelectorSet::new().with_read_selector(
Read::new(
ValueOrGlob::new_value("Students"),
ValueOrGlob::new_value("name"),
),
false,
);
let exact = set.check_read("Students", "name");
let wrong_col = set.check_read("Students", "id");
let wrong_table = set.check_read("Other", "name");
assert_eq!(exact, Some(Authorization::Deny));
assert_eq!(wrong_col, None);
assert_eq!(wrong_table, None);
}
#[test]
fn read_rules_do_not_affect_delete_checks() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), false);
let read_result = set.check_read("Students", "name");
let delete_result = set.check_delete("Students");
assert_eq!(read_result, Some(Authorization::Deny));
assert_eq!(delete_result, None);
}
#[test]
fn delete_single_field_selector_works() {
let set = AccessControlSelectorSet::new()
.with_delete_selector(Delete::empty(), true)
.with_delete_selector(
Delete::new(ValueOrGlob::new_value("AuditLog")),
false,
);
let allowed = set.check_delete("Students");
let denied = set.check_delete("AuditLog");
assert_eq!(allowed, Some(Authorization::Allow));
assert_eq!(denied, Some(Authorization::Deny));
}
#[test]
fn select_zero_field_selector_allow_works() {
let set = AccessControlSelectorSet::new()
.with_select_selector(Select {}, true);
let result = set.check_select();
assert_eq!(result, Some(Authorization::Allow));
}
#[test]
fn select_zero_field_allow_and_deny_resolves_to_deny() {
let set = AccessControlSelectorSet::new()
.with_select_selector(Select {}, true)
.with_select_selector(Select {}, false);
let result = set.check_select();
assert_eq!(result, Some(Authorization::Deny));
}
#[test]
fn with_selector_dispatches_to_correct_type() {
let read = AccessControlSelector::from(Read::empty());
let set = AccessControlSelectorSet::new().with_selector(read, false);
let result = set.check_read("anything", "anything");
assert_eq!(result, Some(Authorization::Deny));
}
#[test]
fn resolver_allow_everything_defaults_to_allow() {
let resolver = AuthorizationResolver::new_allow_everything();
let result = resolver.selector_set.check_read("X", "Y");
assert_eq!(result, None);
assert_eq!(
resolver.read_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
}
#[test]
fn resolver_deny_everything_defaults_to_deny() {
let resolver = AuthorizationResolver::new_deny_everything();
let result = resolver.selector_set.check_read("X", "Y");
assert_eq!(result, None);
assert_eq!(
resolver.read_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
}
#[test]
fn resolver_per_action_default_override_works() {
let resolver = AuthorizationResolver::new_allow_everything()
.with_read_default_permissions(
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
resolver.read_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
resolver.insert_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
}
#[test]
fn doc_example_read_only_with_table_blocked() {
let set = AccessControlSelectorSet::new()
.with_selector("Read".parse::<Read>().unwrap(), true)
.with_selector("Read(Secrets)".parse::<Read>().unwrap(), false);
let public_read = set.check_read("Public", "data");
let secrets_read = set.check_read("Secrets", "ssn");
assert_eq!(public_read, Some(Authorization::Allow));
assert_eq!(secrets_read, Some(Authorization::Deny));
}
#[test]
fn doc_example_read_only_with_carveout() {
let set = AccessControlSelectorSet::new()
.with_selector("Read".parse::<Read>().unwrap(), true)
.with_selector("Read(Secrets)".parse::<Read>().unwrap(), false)
.with_selector("Read(Secrets.id)".parse::<Read>().unwrap(), true);
let public = set.check_read("Public", "data");
let secrets_ssn = set.check_read("Secrets", "ssn");
let secrets_id = set.check_read("Secrets", "id");
assert_eq!(public, Some(Authorization::Allow));
assert_eq!(secrets_ssn, Some(Authorization::Deny));
assert_eq!(secrets_id, Some(Authorization::Allow));
}
#[test]
fn doc_example_conflicting_same_specificity() {
let set = AccessControlSelectorSet::new()
.with_selector("Read(Students)".parse::<Read>().unwrap(), false)
.with_selector("Read(*.name)".parse::<Read>().unwrap(), true);
let result = set.check_read("Students", "name");
assert_eq!(result, Some(Authorization::Deny));
}
#[test]
fn doc_example_deny_functions_allow_specific() {
let set = AccessControlSelectorSet::new()
.with_selector("Function".parse::<Function>().unwrap(), false)
.with_selector("Function(count)".parse::<Function>().unwrap(), true)
.with_selector("Function(sum)".parse::<Function>().unwrap(), true);
let count = set.check_function("count");
let sum = set.check_function("sum");
let evil = set.check_function("load_extension");
assert_eq!(count, Some(Authorization::Allow));
assert_eq!(sum, Some(Authorization::Allow));
assert_eq!(evil, Some(Authorization::Deny));
}
#[test]
fn multiple_non_matching_rules_returns_none() {
let set = AccessControlSelectorSet::new()
.with_read_selector(
Read::new(
ValueOrGlob::new_value("A"),
ValueOrGlob::new_value("x"),
),
true,
)
.with_read_selector(
Read::new(
ValueOrGlob::new_value("B"),
ValueOrGlob::new_value("y"),
),
false,
);
let result = set.check_read("C", "z");
assert_eq!(result, None);
}
#[test]
fn deny_at_higher_specificity_is_not_overridden_by_lower_allow() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), true)
.with_read_selector(
Read::new(
ValueOrGlob::new_value("Secrets"),
ValueOrGlob::new_value("ssn"),
),
false,
);
let result = set.check_read("Secrets", "ssn");
assert_eq!(result, Some(Authorization::Deny));
}
#[test]
fn only_matching_rules_at_a_level_contribute_to_verdict() {
let set = AccessControlSelectorSet::new()
.with_read_selector(
Read::new(ValueOrGlob::new_value("A"), ValueOrGlob::new_glob()),
false,
)
.with_read_selector(
Read::new(ValueOrGlob::new_value("B"), ValueOrGlob::new_glob()),
true,
);
let result_a = set.check_read("A", "x");
let result_b = set.check_read("B", "x");
assert_eq!(result_a, Some(Authorization::Deny));
assert_eq!(result_b, Some(Authorization::Allow));
}
#[test]
fn insertion_order_does_not_affect_specificity_ranking() {
let set = AccessControlSelectorSet::new()
.with_read_selector(
Read::new(
ValueOrGlob::new_value("T"),
ValueOrGlob::new_value("c"),
),
true,
)
.with_read_selector(Read::empty(), false);
let result = set.check_read("T", "c");
assert_eq!(result, Some(Authorization::Allow));
}
#[test]
fn three_specificity_levels_middle_deny_overridden_by_top() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), true)
.with_read_selector(
Read::new(ValueOrGlob::new_value("T"), ValueOrGlob::new_glob()),
false,
)
.with_read_selector(
Read::new(
ValueOrGlob::new_value("T"),
ValueOrGlob::new_value("ok"),
),
true,
);
let other_table = set.check_read("X", "y");
let denied_col = set.check_read("T", "secret");
let carveout = set.check_read("T", "ok");
assert_eq!(other_table, Some(Authorization::Allow));
assert_eq!(denied_col, Some(Authorization::Deny));
assert_eq!(carveout, Some(Authorization::Allow));
}
#[test]
fn duplicate_allow_rules_still_return_allow() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), true)
.with_read_selector(Read::empty(), true);
let result = set.check_read("T", "c");
assert_eq!(result, Some(Authorization::Allow));
}
#[test]
fn duplicate_deny_rules_still_return_deny() {
let set = AccessControlSelectorSet::new()
.with_read_selector(Read::empty(), false)
.with_read_selector(Read::empty(), false);
let result = set.check_read("T", "c");
assert_eq!(result, Some(Authorization::Deny));
}
#[test]
fn read_only_allows_reads_and_selects() {
let r = AuthorizationResolver::new_read_only();
assert_eq!(
r.read_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.select_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.transaction_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.function_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.recursive_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.pragma_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
}
#[test]
fn read_only_denies_writes_and_ddl() {
let r = AuthorizationResolver::new_read_only();
assert_eq!(
r.insert_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.update_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.delete_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.create_table_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.drop_table_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.attach_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
}
#[test]
fn read_write_allows_data_modification() {
let r = AuthorizationResolver::new_read_write();
assert_eq!(
r.insert_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.update_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.delete_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.savepoint_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.analyze_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
}
#[test]
fn read_write_allows_temp_objects_but_denies_permanent_ddl() {
let r = AuthorizationResolver::new_read_write();
assert_eq!(
r.create_temp_table_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.drop_temp_table_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.create_table_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.drop_table_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.alter_table_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.attach_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
}
#[test]
fn full_ddl_allows_permanent_schema_changes() {
let r = AuthorizationResolver::new_full_ddl();
assert_eq!(
r.create_table_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.drop_table_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.alter_table_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.create_index_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.create_view_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.create_trigger_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
}
#[test]
fn full_ddl_still_denies_attach_detach_and_vtables() {
let r = AuthorizationResolver::new_full_ddl();
assert_eq!(
r.attach_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.detach_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.create_vtable_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.drop_vtable_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
}
#[test]
fn presets_are_composable_with_per_action_overrides() {
let r = AuthorizationResolver::new_read_only()
.with_insert_default_permissions(
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.insert_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
assert_eq!(
r.update_default_permissions,
rusqlite::hooks::Authorization::Deny,
);
assert_eq!(
r.read_default_permissions,
rusqlite::hooks::Authorization::Allow,
);
}
}