use super::{
EntityUID, Expr, ExprKind, ExpressionConstructionError, Literal, Name, PartialValue, Type,
Unknown, Value, ValueKind,
};
use crate::entities::json::err::JsonSerializationError;
use crate::extensions::Extensions;
use crate::parser::err::ParseErrors;
use crate::parser::{self, Loc};
use miette::Diagnostic;
use smol_str::{SmolStr, ToSmolStr};
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::Arc;
use thiserror::Error;
#[derive(Hash, Debug, Clone, PartialEq, Eq)]
pub struct RestrictedExpr(Expr);
impl RestrictedExpr {
pub fn new(expr: Expr) -> Result<Self, RestrictedExpressionError> {
is_restricted(&expr)?;
Ok(Self(expr))
}
pub fn new_unchecked(expr: Expr) -> Self {
if cfg!(debug_assertions) {
#[expect(
clippy::unwrap_used,
reason = "We're in debug mode and panicking intentionally"
)]
Self::new(expr).unwrap()
} else {
Self(expr)
}
}
pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
Self(self.0.with_maybe_source_loc(source_loc))
}
pub fn val(v: impl Into<Literal>) -> Self {
Self::new_unchecked(Expr::val(v))
}
pub fn unknown(u: Unknown) -> Self {
Self::new_unchecked(Expr::unknown(u))
}
pub fn set(exprs: impl IntoIterator<Item = RestrictedExpr>) -> Self {
Self::new_unchecked(Expr::set(exprs.into_iter().map(Into::into)))
}
pub fn record(
pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
) -> Result<Self, ExpressionConstructionError> {
Ok(Self::new_unchecked(Expr::record(
pairs.into_iter().map(|(k, v)| (k, v.into())),
)?))
}
pub fn call_extension_fn(
function_name: Name,
args: impl IntoIterator<Item = RestrictedExpr>,
) -> Self {
Self::new_unchecked(Expr::call_extension_fn(
function_name,
args.into_iter().map(Into::into).collect(),
))
}
pub fn to_natural_json(&self) -> Result<serde_json::Value, JsonSerializationError> {
self.as_borrowed().to_natural_json()
}
pub fn as_bool(&self) -> Option<bool> {
match self.expr_kind() {
ExprKind::Lit(Literal::Bool(b)) => Some(*b),
_ => None,
}
}
pub fn as_long(&self) -> Option<i64> {
match self.expr_kind() {
ExprKind::Lit(Literal::Long(i)) => Some(*i),
_ => None,
}
}
pub fn as_string(&self) -> Option<&SmolStr> {
match self.expr_kind() {
ExprKind::Lit(Literal::String(s)) => Some(s),
_ => None,
}
}
pub fn as_euid(&self) -> Option<&EntityUID> {
match self.expr_kind() {
ExprKind::Lit(Literal::EntityUID(e)) => Some(e),
_ => None,
}
}
pub fn as_unknown(&self) -> Option<&Unknown> {
match self.expr_kind() {
ExprKind::Unknown(u) => Some(u),
_ => None,
}
}
pub fn as_set_elements(&self) -> Option<impl Iterator<Item = BorrowedRestrictedExpr<'_>>> {
match self.expr_kind() {
ExprKind::Set(set) => Some(set.iter().map(BorrowedRestrictedExpr::new_unchecked)), _ => None,
}
}
pub fn as_record_pairs(
&self,
) -> Option<impl Iterator<Item = (&SmolStr, BorrowedRestrictedExpr<'_>)>> {
match self.expr_kind() {
ExprKind::Record(map) => Some(
map.iter()
.map(|(k, v)| (k, BorrowedRestrictedExpr::new_unchecked(v))),
), _ => None,
}
}
pub fn as_extn_fn_call(
&self,
) -> Option<(&Name, impl Iterator<Item = BorrowedRestrictedExpr<'_>>)> {
match self.expr_kind() {
ExprKind::ExtensionFunctionApp { fn_name, args } => Some((
fn_name,
args.iter().map(BorrowedRestrictedExpr::new_unchecked),
)), _ => None,
}
}
}
impl From<Value> for RestrictedExpr {
fn from(value: Value) -> RestrictedExpr {
RestrictedExpr::from(value.value).with_maybe_source_loc(value.loc)
}
}
impl From<ValueKind> for RestrictedExpr {
fn from(value: ValueKind) -> RestrictedExpr {
match value {
ValueKind::Lit(lit) => RestrictedExpr::val(lit),
ValueKind::Set(set) => {
RestrictedExpr::set(set.iter().map(|val| RestrictedExpr::from(val.clone())))
}
#[expect(
clippy::expect_used,
reason = "cannot have duplicate key because the input was already a BTreeMap"
)]
ValueKind::Record(record) => RestrictedExpr::record(
Arc::unwrap_or_clone(record)
.into_iter()
.map(|(k, v)| (k, RestrictedExpr::from(v))),
)
.expect("can't have duplicate keys, because the input `map` was already a BTreeMap"),
ValueKind::ExtensionValue(ev) => {
let ev = Arc::unwrap_or_clone(ev);
ev.into()
}
}
}
}
impl TryFrom<PartialValue> for RestrictedExpr {
type Error = PartialValueToRestrictedExprError;
fn try_from(pvalue: PartialValue) -> Result<RestrictedExpr, PartialValueToRestrictedExprError> {
match pvalue {
PartialValue::Value(v) => Ok(RestrictedExpr::from(v)),
PartialValue::Residual(expr) => match RestrictedExpr::new(expr) {
Ok(e) => Ok(e),
Err(RestrictedExpressionError::InvalidRestrictedExpression(
restricted_expr_errors::InvalidRestrictedExpressionError { expr, .. },
)) => Err(PartialValueToRestrictedExprError::NontrivialResidual {
residual: Box::new(expr),
}),
},
}
}
}
#[derive(Debug, PartialEq, Eq, Diagnostic, Error)]
pub enum PartialValueToRestrictedExprError {
#[error("residual is not a valid restricted expression: `{residual}`")]
NontrivialResidual {
residual: Box<Expr>,
},
}
impl std::str::FromStr for RestrictedExpr {
type Err = RestrictedExpressionParseError;
fn from_str(s: &str) -> Result<RestrictedExpr, Self::Err> {
parser::parse_restrictedexpr(s)
}
}
#[derive(Hash, Debug, Clone, PartialEq, Eq, Copy)]
pub struct BorrowedRestrictedExpr<'a>(&'a Expr);
impl<'a> BorrowedRestrictedExpr<'a> {
pub fn new(expr: &'a Expr) -> Result<Self, RestrictedExpressionError> {
is_restricted(expr)?;
Ok(Self(expr))
}
pub fn new_unchecked(expr: &'a Expr) -> Self {
if cfg!(debug_assertions) {
#[expect(
clippy::unwrap_used,
reason = "We're in debug mode and panicking intentionally"
)]
Self::new(expr).unwrap()
} else {
Self(expr)
}
}
pub fn to_natural_json(self) -> Result<serde_json::Value, JsonSerializationError> {
Ok(serde_json::to_value(
crate::entities::json::CedarValueJson::from_expr(self)?,
)?)
}
pub fn to_owned(self) -> RestrictedExpr {
RestrictedExpr::new_unchecked(self.0.clone())
}
pub fn as_bool(&self) -> Option<bool> {
match self.expr_kind() {
ExprKind::Lit(Literal::Bool(b)) => Some(*b),
_ => None,
}
}
pub fn as_long(&self) -> Option<i64> {
match self.expr_kind() {
ExprKind::Lit(Literal::Long(i)) => Some(*i),
_ => None,
}
}
pub fn as_string(&self) -> Option<&SmolStr> {
match self.expr_kind() {
ExprKind::Lit(Literal::String(s)) => Some(s),
_ => None,
}
}
pub fn as_euid(&self) -> Option<&EntityUID> {
match self.expr_kind() {
ExprKind::Lit(Literal::EntityUID(e)) => Some(e),
_ => None,
}
}
pub fn as_unknown(&self) -> Option<&Unknown> {
match self.expr_kind() {
ExprKind::Unknown(u) => Some(u),
_ => None,
}
}
pub fn as_set_elements(&self) -> Option<impl Iterator<Item = BorrowedRestrictedExpr<'_>>> {
match self.expr_kind() {
ExprKind::Set(set) => Some(set.iter().map(BorrowedRestrictedExpr::new_unchecked)), _ => None,
}
}
pub fn as_record_pairs(
&self,
) -> Option<impl Iterator<Item = (&'_ SmolStr, BorrowedRestrictedExpr<'_>)>> {
match self.expr_kind() {
ExprKind::Record(map) => Some(
map.iter()
.map(|(k, v)| (k, BorrowedRestrictedExpr::new_unchecked(v))),
), _ => None,
}
}
pub fn as_extn_fn_call(
&self,
) -> Option<(&Name, impl Iterator<Item = BorrowedRestrictedExpr<'_>>)> {
match self.expr_kind() {
ExprKind::ExtensionFunctionApp { fn_name, args } => Some((
fn_name,
args.iter().map(BorrowedRestrictedExpr::new_unchecked),
)), _ => None,
}
}
pub fn try_type_of(&self, extensions: &Extensions<'_>) -> Option<Type> {
self.0.try_type_of(extensions)
}
}
fn is_restricted(expr: &Expr) -> Result<(), RestrictedExpressionError> {
match expr.expr_kind() {
ExprKind::Lit(_) => Ok(()),
ExprKind::Unknown(_) => Ok(()),
ExprKind::Var(_) => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
feature: "variables".into(),
expr: expr.clone(),
}
.into()),
ExprKind::Slot(_) => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
feature: "template slots".into(),
expr: expr.clone(),
}
.into()),
ExprKind::If { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
feature: "if-then-else".into(),
expr: expr.clone(),
}
.into()),
ExprKind::And { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
feature: "&&".into(),
expr: expr.clone(),
}
.into()),
ExprKind::Or { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
feature: "||".into(),
expr: expr.clone(),
}
.into()),
ExprKind::UnaryApp { op, .. } => {
Err(restricted_expr_errors::InvalidRestrictedExpressionError {
feature: op.to_smolstr(),
expr: expr.clone(),
}
.into())
}
ExprKind::BinaryApp { op, .. } => {
Err(restricted_expr_errors::InvalidRestrictedExpressionError {
feature: op.to_smolstr(),
expr: expr.clone(),
}
.into())
}
ExprKind::GetAttr { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
feature: "attribute accesses".into(),
expr: expr.clone(),
}
.into()),
ExprKind::HasAttr { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
feature: "'has'".into(),
expr: expr.clone(),
}
.into()),
ExprKind::Like { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
feature: "'like'".into(),
expr: expr.clone(),
}
.into()),
ExprKind::Is { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
feature: "'is'".into(),
expr: expr.clone(),
}
.into()),
ExprKind::ExtensionFunctionApp { args, .. } => args.iter().try_for_each(is_restricted),
ExprKind::Set(exprs) => exprs.iter().try_for_each(is_restricted),
ExprKind::Record(map) => map.values().try_for_each(is_restricted),
#[cfg(feature = "tolerant-ast")]
ExprKind::Error { .. } => Ok(()),
}
}
impl From<RestrictedExpr> for Expr {
fn from(r: RestrictedExpr) -> Expr {
r.0
}
}
impl AsRef<Expr> for RestrictedExpr {
fn as_ref(&self) -> &Expr {
&self.0
}
}
impl Deref for RestrictedExpr {
type Target = Expr;
fn deref(&self) -> &Expr {
self.as_ref()
}
}
impl std::fmt::Display for RestrictedExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", &self.0)
}
}
impl<'a> From<BorrowedRestrictedExpr<'a>> for &'a Expr {
fn from(r: BorrowedRestrictedExpr<'a>) -> &'a Expr {
r.0
}
}
impl<'a> AsRef<Expr> for BorrowedRestrictedExpr<'a> {
fn as_ref(&self) -> &'a Expr {
self.0
}
}
impl RestrictedExpr {
pub fn as_borrowed(&self) -> BorrowedRestrictedExpr<'_> {
BorrowedRestrictedExpr::new_unchecked(self.as_ref())
}
}
impl<'a> Deref for BorrowedRestrictedExpr<'a> {
type Target = Expr;
fn deref(&self) -> &'a Expr {
self.0
}
}
impl std::fmt::Display for BorrowedRestrictedExpr<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", &self.0)
}
}
#[derive(Eq, Debug, Clone)]
pub struct RestrictedExprShapeOnly<'a>(BorrowedRestrictedExpr<'a>);
impl<'a> RestrictedExprShapeOnly<'a> {
pub fn new(e: BorrowedRestrictedExpr<'a>) -> RestrictedExprShapeOnly<'a> {
RestrictedExprShapeOnly(e)
}
}
impl PartialEq for RestrictedExprShapeOnly<'_> {
fn eq(&self, other: &Self) -> bool {
self.0.eq_shape(&other.0)
}
}
impl Hash for RestrictedExprShapeOnly<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash_shape(state);
}
}
#[derive(Debug, Clone, PartialEq, Eq, Error, Diagnostic)]
pub enum RestrictedExpressionError {
#[error(transparent)]
#[diagnostic(transparent)]
InvalidRestrictedExpression(#[from] restricted_expr_errors::InvalidRestrictedExpressionError),
}
pub mod restricted_expr_errors {
use super::Expr;
use crate::impl_diagnostic_from_method_on_field;
use miette::Diagnostic;
use smol_str::SmolStr;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[error("not allowed to use {feature} in a restricted expression: `{expr}`")]
pub struct InvalidRestrictedExpressionError {
pub(crate) feature: SmolStr,
pub(crate) expr: Expr,
}
impl Diagnostic for InvalidRestrictedExpressionError {
impl_diagnostic_from_method_on_field!(expr, source_loc);
}
}
#[derive(Debug, Clone, PartialEq, Eq, Diagnostic, Error)]
pub enum RestrictedExpressionParseError {
#[error(transparent)]
#[diagnostic(transparent)]
Parse(#[from] ParseErrors),
#[error(transparent)]
#[diagnostic(transparent)]
InvalidRestrictedExpression(#[from] RestrictedExpressionError),
}
#[cfg(test)]
mod test {
use super::*;
use crate::ast::expression_construction_errors;
use crate::parser::err::{ParseError, ToASTError, ToASTErrorKind};
use crate::parser::Loc;
use std::str::FromStr;
use std::sync::Arc;
#[test]
fn duplicate_key() {
assert_eq!(
RestrictedExpr::record([
("foo".into(), RestrictedExpr::val(37),),
("foo".into(), RestrictedExpr::val("hello"),),
]),
Err(expression_construction_errors::DuplicateKeyError {
key: "foo".into(),
context: "in record literal",
}
.into())
);
assert_eq!(
RestrictedExpr::record([
("foo".into(), RestrictedExpr::val(37),),
("foo".into(), RestrictedExpr::val(101),),
]),
Err(expression_construction_errors::DuplicateKeyError {
key: "foo".into(),
context: "in record literal",
}
.into())
);
assert_eq!(
RestrictedExpr::record([
("foo".into(), RestrictedExpr::val(37),),
("foo".into(), RestrictedExpr::val(37),),
]),
Err(expression_construction_errors::DuplicateKeyError {
key: "foo".into(),
context: "in record literal",
}
.into())
);
assert_eq!(
RestrictedExpr::record([
("bar".into(), RestrictedExpr::val(-3),),
("foo".into(), RestrictedExpr::val(37),),
("spam".into(), RestrictedExpr::val("eggs"),),
("foo".into(), RestrictedExpr::val(37),),
("eggs".into(), RestrictedExpr::val("spam"),),
]),
Err(expression_construction_errors::DuplicateKeyError {
key: "foo".into(),
context: "in record literal",
}
.into())
);
let str = r#"{ foo: 37, bar: "hi", foo: 101 }"#;
assert_eq!(
RestrictedExpr::from_str(str),
Err(RestrictedExpressionParseError::Parse(
ParseErrors::singleton(ParseError::ToAST(ToASTError::new(
ToASTErrorKind::ExpressionConstructionError(
expression_construction_errors::DuplicateKeyError {
key: "foo".into(),
context: "in record literal",
}
.into()
),
Some(Loc::new(0..32, Arc::from(str)))
)))
)),
)
}
}