use super::{
EntityUID, Expr, ExprConstructionError, ExprKind, Literal, Name, PartialValue, Unknown, Value,
ValueKind,
};
use crate::entities::JsonSerializationError;
use crate::parser::err::ParseErrors;
use crate::parser::{self, Loc};
use miette::Diagnostic;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, ToSmolStr};
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::Arc;
use thiserror::Error;
#[derive(Deserialize, Serialize, Hash, Debug, Clone, PartialEq, Eq)]
#[serde(transparent)]
pub struct RestrictedExpr(Expr);
impl RestrictedExpr {
pub fn new(expr: Expr) -> Result<Self, RestrictedExprError> {
is_restricted(&expr)?;
Ok(Self(expr))
}
pub fn new_unchecked(expr: Expr) -> Self {
if cfg!(debug_assertions) {
#[allow(clippy::unwrap_used)]
Self::new(expr).unwrap()
} else {
Self(expr)
}
}
pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
Self(self.0.with_maybe_source_loc(source_loc))
}
pub fn val(v: impl Into<Literal>) -> Self {
Self::new_unchecked(Expr::val(v))
}
pub fn unknown(u: Unknown) -> Self {
Self::new_unchecked(Expr::unknown(u))
}
pub fn set(exprs: impl IntoIterator<Item = RestrictedExpr>) -> Self {
Self::new_unchecked(Expr::set(exprs.into_iter().map(Into::into)))
}
pub fn record(
pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
) -> Result<Self, ExprConstructionError> {
Ok(Self::new_unchecked(Expr::record(
pairs.into_iter().map(|(k, v)| (k, v.into())),
)?))
}
pub fn call_extension_fn(
function_name: Name,
args: impl IntoIterator<Item = RestrictedExpr>,
) -> Self {
Self::new_unchecked(Expr::call_extension_fn(
function_name,
args.into_iter().map(Into::into).collect(),
))
}
pub fn to_natural_json(&self) -> Result<serde_json::Value, JsonSerializationError> {
self.as_borrowed().to_natural_json()
}
pub fn as_bool(&self) -> Option<bool> {
match self.expr_kind() {
ExprKind::Lit(Literal::Bool(b)) => Some(*b),
_ => None,
}
}
pub fn as_long(&self) -> Option<i64> {
match self.expr_kind() {
ExprKind::Lit(Literal::Long(i)) => Some(*i),
_ => None,
}
}
pub fn as_string(&self) -> Option<&SmolStr> {
match self.expr_kind() {
ExprKind::Lit(Literal::String(s)) => Some(s),
_ => None,
}
}
pub fn as_euid(&self) -> Option<&EntityUID> {
match self.expr_kind() {
ExprKind::Lit(Literal::EntityUID(e)) => Some(e),
_ => None,
}
}
pub fn as_unknown(&self) -> Option<&Unknown> {
match self.expr_kind() {
ExprKind::Unknown(u) => Some(u),
_ => None,
}
}
pub fn as_set_elements(&self) -> Option<impl Iterator<Item = BorrowedRestrictedExpr<'_>>> {
match self.expr_kind() {
ExprKind::Set(set) => Some(set.iter().map(BorrowedRestrictedExpr::new_unchecked)), _ => None,
}
}
pub fn as_record_pairs(
&self,
) -> Option<impl Iterator<Item = (&SmolStr, BorrowedRestrictedExpr<'_>)>> {
match self.expr_kind() {
ExprKind::Record(map) => Some(
map.iter()
.map(|(k, v)| (k, BorrowedRestrictedExpr::new_unchecked(v))),
), _ => None,
}
}
pub fn as_extn_fn_call(
&self,
) -> Option<(&Name, impl Iterator<Item = BorrowedRestrictedExpr<'_>>)> {
match self.expr_kind() {
ExprKind::ExtensionFunctionApp { fn_name, args } => Some((
fn_name,
args.iter().map(BorrowedRestrictedExpr::new_unchecked),
)), _ => None,
}
}
}
impl From<Value> for RestrictedExpr {
fn from(value: Value) -> RestrictedExpr {
RestrictedExpr::from(value.value).with_maybe_source_loc(value.loc)
}
}
impl From<ValueKind> for RestrictedExpr {
fn from(value: ValueKind) -> RestrictedExpr {
match value {
ValueKind::Lit(lit) => RestrictedExpr::val(lit),
ValueKind::Set(set) => {
RestrictedExpr::set(set.iter().map(|val| RestrictedExpr::from(val.clone())))
}
#[allow(clippy::expect_used)]
ValueKind::Record(record) => RestrictedExpr::record(
Arc::unwrap_or_clone(record)
.into_iter()
.map(|(k, v)| (k, RestrictedExpr::from(v))),
)
.expect("can't have duplicate keys, because the input `map` was already a BTreeMap"),
ValueKind::ExtensionValue(ev) => {
let ev = Arc::unwrap_or_clone(ev);
RestrictedExpr::call_extension_fn(ev.constructor, ev.args)
}
}
}
}
impl TryFrom<PartialValue> for RestrictedExpr {
type Error = PartialValueToRestrictedExprError;
fn try_from(pvalue: PartialValue) -> Result<RestrictedExpr, PartialValueToRestrictedExprError> {
match pvalue {
PartialValue::Value(v) => Ok(RestrictedExpr::from(v)),
PartialValue::Residual(expr) => match RestrictedExpr::new(expr) {
Ok(e) => Ok(e),
Err(RestrictedExprError::InvalidRestrictedExpression { expr, .. }) => {
Err(PartialValueToRestrictedExprError::NontrivialResidual {
residual: Box::new(expr),
})
}
},
}
}
}
#[derive(Debug, PartialEq, Diagnostic, Error)]
pub enum PartialValueToRestrictedExprError {
#[error("residual is not a valid restricted expression: `{residual}`")]
NontrivialResidual {
residual: Box<Expr>,
},
}
impl std::str::FromStr for RestrictedExpr {
type Err = RestrictedExprParseError;
fn from_str(s: &str) -> Result<RestrictedExpr, Self::Err> {
parser::parse_restrictedexpr(s)
}
}
#[derive(Serialize, Hash, Debug, Clone, PartialEq, Eq, Copy)]
pub struct BorrowedRestrictedExpr<'a>(&'a Expr);
impl<'a> BorrowedRestrictedExpr<'a> {
pub fn new(expr: &'a Expr) -> Result<Self, RestrictedExprError> {
is_restricted(expr)?;
Ok(Self(expr))
}
pub fn new_unchecked(expr: &'a Expr) -> Self {
if cfg!(debug_assertions) {
#[allow(clippy::unwrap_used)]
Self::new(expr).unwrap()
} else {
Self(expr)
}
}
pub fn to_natural_json(self) -> Result<serde_json::Value, JsonSerializationError> {
Ok(serde_json::to_value(
crate::entities::CedarValueJson::from_expr(self)?,
)?)
}
pub fn to_owned(self) -> RestrictedExpr {
RestrictedExpr::new_unchecked(self.0.clone())
}
pub fn as_bool(&self) -> Option<bool> {
match self.expr_kind() {
ExprKind::Lit(Literal::Bool(b)) => Some(*b),
_ => None,
}
}
pub fn as_long(&self) -> Option<i64> {
match self.expr_kind() {
ExprKind::Lit(Literal::Long(i)) => Some(*i),
_ => None,
}
}
pub fn as_string(&self) -> Option<&SmolStr> {
match self.expr_kind() {
ExprKind::Lit(Literal::String(s)) => Some(s),
_ => None,
}
}
pub fn as_euid(&self) -> Option<&EntityUID> {
match self.expr_kind() {
ExprKind::Lit(Literal::EntityUID(e)) => Some(e),
_ => None,
}
}
pub fn as_unknown(&self) -> Option<&Unknown> {
match self.expr_kind() {
ExprKind::Unknown(u) => Some(u),
_ => None,
}
}
pub fn as_set_elements(&self) -> Option<impl Iterator<Item = BorrowedRestrictedExpr<'_>>> {
match self.expr_kind() {
ExprKind::Set(set) => Some(set.iter().map(BorrowedRestrictedExpr::new_unchecked)), _ => None,
}
}
pub fn as_record_pairs(
&self,
) -> Option<impl Iterator<Item = (&'_ SmolStr, BorrowedRestrictedExpr<'_>)>> {
match self.expr_kind() {
ExprKind::Record(map) => Some(
map.iter()
.map(|(k, v)| (k, BorrowedRestrictedExpr::new_unchecked(v))),
), _ => None,
}
}
pub fn as_extn_fn_call(
&self,
) -> Option<(&Name, impl Iterator<Item = BorrowedRestrictedExpr<'_>>)> {
match self.expr_kind() {
ExprKind::ExtensionFunctionApp { fn_name, args } => Some((
fn_name,
args.iter().map(BorrowedRestrictedExpr::new_unchecked),
)), _ => None,
}
}
}
fn is_restricted(expr: &Expr) -> Result<(), RestrictedExprError> {
match expr.expr_kind() {
ExprKind::Lit(_) => Ok(()),
ExprKind::Unknown(_) => Ok(()),
ExprKind::Var(_) => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: "variables".into(),
expr: expr.clone(),
}),
ExprKind::Slot(_) => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: "template slots".into(),
expr: expr.clone(),
}),
ExprKind::If { .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: "if-then-else".into(),
expr: expr.clone(),
}),
ExprKind::And { .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: "&&".into(),
expr: expr.clone(),
}),
ExprKind::Or { .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: "||".into(),
expr: expr.clone(),
}),
ExprKind::UnaryApp { op, .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: op.to_smolstr(),
expr: expr.clone(),
}),
ExprKind::BinaryApp { op, .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: op.to_smolstr(),
expr: expr.clone(),
}),
ExprKind::GetAttr { .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: "attribute accesses".into(),
expr: expr.clone(),
}),
ExprKind::HasAttr { .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: "'has'".into(),
expr: expr.clone(),
}),
ExprKind::Like { .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: "'like'".into(),
expr: expr.clone(),
}),
ExprKind::Is { .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: "'is'".into(),
expr: expr.clone(),
}),
ExprKind::ExtensionFunctionApp { args, .. } => args.iter().try_for_each(is_restricted),
ExprKind::Set(exprs) => exprs.iter().try_for_each(is_restricted),
ExprKind::Record(map) => map.values().try_for_each(is_restricted),
}
}
impl From<RestrictedExpr> for Expr {
fn from(r: RestrictedExpr) -> Expr {
r.0
}
}
impl AsRef<Expr> for RestrictedExpr {
fn as_ref(&self) -> &Expr {
&self.0
}
}
impl Deref for RestrictedExpr {
type Target = Expr;
fn deref(&self) -> &Expr {
self.as_ref()
}
}
impl std::fmt::Display for RestrictedExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", &self.0)
}
}
impl<'a> From<BorrowedRestrictedExpr<'a>> for &'a Expr {
fn from(r: BorrowedRestrictedExpr<'a>) -> &'a Expr {
r.0
}
}
impl<'a> AsRef<Expr> for BorrowedRestrictedExpr<'a> {
fn as_ref(&self) -> &'a Expr {
self.0
}
}
impl RestrictedExpr {
pub fn as_borrowed(&self) -> BorrowedRestrictedExpr<'_> {
BorrowedRestrictedExpr::new_unchecked(self.as_ref())
}
}
impl<'a> Deref for BorrowedRestrictedExpr<'a> {
type Target = Expr;
fn deref(&self) -> &'a Expr {
self.0
}
}
impl<'a> std::fmt::Display for BorrowedRestrictedExpr<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", &self.0)
}
}
#[derive(Eq, Debug, Clone)]
pub struct RestrictedExprShapeOnly<'a>(BorrowedRestrictedExpr<'a>);
impl<'a> RestrictedExprShapeOnly<'a> {
pub fn new(e: BorrowedRestrictedExpr<'a>) -> RestrictedExprShapeOnly<'a> {
RestrictedExprShapeOnly(e)
}
}
impl<'a> PartialEq for RestrictedExprShapeOnly<'a> {
fn eq(&self, other: &Self) -> bool {
self.0.eq_shape(&other.0)
}
}
impl<'a> Hash for RestrictedExprShapeOnly<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash_shape(state);
}
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum RestrictedExprError {
#[error("not allowed to use {feature} in a restricted expression: `{expr}`")]
InvalidRestrictedExpression {
feature: SmolStr,
expr: Expr,
},
}
impl Diagnostic for RestrictedExprError {
fn labels(&self) -> Option<Box<dyn Iterator<Item = miette::LabeledSpan> + '_>> {
match self {
Self::InvalidRestrictedExpression { expr, .. } => expr.source_loc().map(|loc| {
Box::new(std::iter::once(miette::LabeledSpan::underline(loc.span)))
as Box<dyn Iterator<Item = _>>
}),
}
}
fn source_code(&self) -> Option<&dyn miette::SourceCode> {
match self {
Self::InvalidRestrictedExpression { expr, .. } => expr
.source_loc()
.map(|loc| &loc.src as &dyn miette::SourceCode),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Diagnostic, Error)]
pub enum RestrictedExprParseError {
#[error("failed to parse restricted expression: {0}")]
#[diagnostic(transparent)]
Parse(#[from] ParseErrors),
#[error(transparent)]
#[diagnostic(transparent)]
RestrictedExpr(#[from] RestrictedExprError),
}
#[cfg(test)]
mod test {
use super::*;
use crate::ast::expression_construction_errors;
use crate::parser::err::{ParseError, ToASTError, ToASTErrorKind};
use crate::parser::Loc;
use std::str::FromStr;
use std::sync::Arc;
#[test]
fn duplicate_key() {
assert_eq!(
RestrictedExpr::record([
("foo".into(), RestrictedExpr::val(37),),
("foo".into(), RestrictedExpr::val("hello"),),
]),
Err(expression_construction_errors::DuplicateKeyError {
key: "foo".into(),
context: "in record literal",
}
.into())
);
assert_eq!(
RestrictedExpr::record([
("foo".into(), RestrictedExpr::val(37),),
("foo".into(), RestrictedExpr::val(101),),
]),
Err(expression_construction_errors::DuplicateKeyError {
key: "foo".into(),
context: "in record literal",
}
.into())
);
assert_eq!(
RestrictedExpr::record([
("foo".into(), RestrictedExpr::val(37),),
("foo".into(), RestrictedExpr::val(37),),
]),
Err(expression_construction_errors::DuplicateKeyError {
key: "foo".into(),
context: "in record literal",
}
.into())
);
assert_eq!(
RestrictedExpr::record([
("bar".into(), RestrictedExpr::val(-3),),
("foo".into(), RestrictedExpr::val(37),),
("spam".into(), RestrictedExpr::val("eggs"),),
("foo".into(), RestrictedExpr::val(37),),
("eggs".into(), RestrictedExpr::val("spam"),),
]),
Err(expression_construction_errors::DuplicateKeyError {
key: "foo".into(),
context: "in record literal",
}
.into())
);
let str = r#"{ foo: 37, bar: "hi", foo: 101 }"#;
assert_eq!(
RestrictedExpr::from_str(str),
Err(RestrictedExprParseError::Parse(ParseErrors(vec![
ParseError::ToAST(ToASTError::new(
ToASTErrorKind::ExprConstructionError(
expression_construction_errors::DuplicateKeyError {
key: "foo".into(),
context: "in record literal",
}
.into()
),
Loc::new(0..32, Arc::from(str))
))
]))),
)
}
}