use std::{collections::BTreeSet, fmt::Display};
use cedar_policy_core::{
ast::{CallStyle, Expr},
parser::SourceInfo,
};
use super::types::Type;
use itertools::Itertools;
use thiserror::Error;
#[derive(Debug, Hash, PartialEq, Eq)]
pub struct TypeError {
pub(crate) on_expr: Option<Expr>,
pub(crate) source_location: Option<SourceInfo>,
pub(crate) kind: TypeErrorKind,
}
impl TypeError {
pub fn type_error_kind(self) -> TypeErrorKind {
self.kind_and_location().0
}
pub fn source_location(self) -> Option<SourceInfo> {
self.kind_and_location().1
}
pub fn kind_and_location(self) -> (TypeErrorKind, Option<SourceInfo>) {
(
self.kind,
match self.source_location {
Some(_) => self.source_location,
None => self.on_expr.and_then(|e| e.into_source_info()),
},
)
}
pub(crate) fn expected_one_of_types(
on_expr: Expr,
expected: impl IntoIterator<Item = Type>,
actual: Type,
) -> Self {
Self {
on_expr: Some(on_expr),
source_location: None,
kind: TypeErrorKind::UnexpectedType(UnexpectedType {
expected: expected.into_iter().collect::<BTreeSet<_>>(),
actual,
}),
}
}
pub(crate) fn incompatible_types(on_expr: Expr, types: impl IntoIterator<Item = Type>) -> Self {
Self {
on_expr: Some(on_expr),
source_location: None,
kind: TypeErrorKind::IncompatibleTypes(IncompatibleTypes {
types: types.into_iter().collect::<BTreeSet<_>>(),
}),
}
}
pub(crate) fn types_must_match<T>(
on_expr: Expr<T>,
types: impl IntoIterator<Item = Type>,
) -> Self {
Self {
on_expr: None,
source_location: on_expr.into_source_info(),
kind: TypeErrorKind::TypesMustMatch(TypesMustMatch {
types: types.into_iter().collect::<BTreeSet<_>>(),
}),
}
}
pub(crate) fn missing_attribute(
on_expr: Expr,
missing: String,
suggestion: Option<String>,
) -> Self {
Self {
on_expr: Some(on_expr),
source_location: None,
kind: TypeErrorKind::MissingAttribute(MissingAttribute {
missing,
suggestion,
}),
}
}
pub(crate) fn unsafe_optional_attribute_access(on_expr: Expr, optional: String) -> Self {
Self {
on_expr: Some(on_expr),
source_location: None,
kind: TypeErrorKind::UnsafeOptionalAttributeAccess(UnsafeOptionalAttributeAccess {
optional,
}),
}
}
pub(crate) fn impossible_policy(on_expr: Expr) -> Self {
Self {
on_expr: Some(on_expr),
source_location: None,
kind: TypeErrorKind::ImpossiblePolicy,
}
}
pub(crate) fn undefined_extension(on_expr: Expr, name: String) -> Self {
Self {
on_expr: Some(on_expr),
source_location: None,
kind: TypeErrorKind::UndefinedFunction(UndefinedFunction { name }),
}
}
pub(crate) fn multiply_defined_extension(on_expr: Expr, name: String) -> Self {
Self {
on_expr: Some(on_expr),
source_location: None,
kind: TypeErrorKind::MultiplyDefinedFunction(MultiplyDefinedFunction { name }),
}
}
pub(crate) fn wrong_number_args(on_expr: Expr, expected: usize, actual: usize) -> Self {
Self {
on_expr: Some(on_expr),
source_location: None,
kind: TypeErrorKind::WrongNumberArguments(WrongNumberArguments { expected, actual }),
}
}
pub(crate) fn arg_validation_error(on_expr: Expr, msg: String) -> Self {
Self {
on_expr: Some(on_expr),
source_location: None,
kind: TypeErrorKind::FunctionArgumentValidationError(FunctionArgumentValidationError {
msg,
}),
}
}
pub(crate) fn empty_set_forbidden<T>(on_expr: Expr<T>) -> Self {
Self {
on_expr: None,
source_location: on_expr.into_source_info(),
kind: TypeErrorKind::EmptySetForbidden,
}
}
pub(crate) fn non_lit_ext_constructor<T>(on_expr: Expr<T>) -> Self {
Self {
on_expr: None,
source_location: on_expr.into_source_info(),
kind: TypeErrorKind::NonLitExtConstructor,
}
}
}
impl Display for TypeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.kind.fmt(f)
}
}
impl std::error::Error for TypeError {}
#[derive(Debug, Error, Hash, Eq, PartialEq)]
#[non_exhaustive]
pub enum TypeErrorKind {
#[error("Unexpected type. Expected one of [{}] but saw {}", .0.expected.iter().join(","), .0.actual)]
UnexpectedType(UnexpectedType),
#[error("Unable to find upper bound for types [{}]", .0.types.iter().join(","))]
IncompatibleTypes(IncompatibleTypes),
#[error("Attribute not found in record or entity {}", .0.missing)]
MissingAttribute(MissingAttribute),
#[error("Unable to guarantee safety of access to optional attribute {}", .0.optional)]
UnsafeOptionalAttributeAccess(UnsafeOptionalAttributeAccess),
#[error(
"Policy is impossible. The policy expression evaluates to false for all valid requests"
)]
ImpossiblePolicy,
#[error("Undefined extension function {}", .0.name)]
UndefinedFunction(UndefinedFunction),
#[error("Undefined extension function {}", .0.name)]
MultiplyDefinedFunction(MultiplyDefinedFunction),
#[error("Wrong number of arguments in extension function application. Expected {}, got {}", .0.expected, .0.actual)]
WrongNumberArguments(WrongNumberArguments),
#[error("Wrong call style in extension function application. Expected {}, got {}", .0.expected, .0.actual)]
WrongCallStyle(WrongCallStyle),
#[error("Error during extension function argument validation: {}", .0.msg)]
FunctionArgumentValidationError(FunctionArgumentValidationError),
#[error("Types of operands in this expression are not equal: [{}]", .0.types.iter().join(","))]
TypesMustMatch(TypesMustMatch),
#[error("empty set literals are forbidden in policies")]
EmptySetForbidden,
#[error("extension constructors may not be called with non-literal expressions")]
NonLitExtConstructor,
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct UnexpectedType {
expected: BTreeSet<Type>,
actual: Type,
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct IncompatibleTypes {
types: BTreeSet<Type>,
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct TypesMustMatch {
pub(crate) types: BTreeSet<Type>,
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct MissingAttribute {
missing: String,
suggestion: Option<String>,
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct UnsafeOptionalAttributeAccess {
optional: String,
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct UndefinedFunction {
name: String,
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct MultiplyDefinedFunction {
name: String,
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct WrongNumberArguments {
expected: usize,
actual: usize,
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct WrongCallStyle {
expected: CallStyle,
actual: CallStyle,
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct FunctionArgumentValidationError {
msg: String,
}