use super::{id::Id, PrincipalOrResource, UnreservedId};
use educe::Educe;
use itertools::Itertools;
use miette::Diagnostic;
use ref_cast::RefCast;
use regex::Regex;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use smol_str::ToSmolStr;
use std::collections::HashSet;
use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;
use thiserror::Error;
use crate::parser::err::{ParseError, ParseErrors, ToASTError, ToASTErrorKind};
use crate::parser::Loc;
use crate::FromNormalizedStr;
#[derive(Educe, Clone)]
#[educe(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct InternalName {
pub(crate) id: Id,
pub(crate) path: Arc<Vec<Id>>,
#[educe(PartialEq(ignore))]
#[educe(Hash(ignore))]
#[educe(PartialOrd(ignore))]
#[educe(Debug(ignore))]
pub(crate) loc: Option<Loc>,
}
impl From<Id> for InternalName {
fn from(value: Id) -> Self {
Self::unqualified_name(value, None)
}
}
impl TryFrom<InternalName> for Id {
type Error = ();
fn try_from(value: InternalName) -> Result<Self, Self::Error> {
if value.is_unqualified() {
Ok(value.id)
} else {
Err(())
}
}
}
impl InternalName {
pub fn new(basename: Id, path: impl IntoIterator<Item = Id>, loc: Option<Loc>) -> Self {
Self {
id: basename,
path: Arc::new(path.into_iter().collect()),
loc,
}
}
pub fn unqualified_name(id: Id, loc: Option<Loc>) -> Self {
Self {
id,
path: Arc::new(vec![]),
loc,
}
}
pub fn __cedar() -> Self {
Self::unqualified_name(Id::new_unchecked_const("__cedar"), None)
}
pub fn parse_unqualified_name(s: &str) -> Result<Self, ParseErrors> {
Ok(Self {
id: s.parse()?,
path: Arc::new(vec![]),
loc: None,
})
}
pub fn type_in_namespace(
basename: Id,
namespace: InternalName,
loc: Option<Loc>,
) -> InternalName {
let mut path = Arc::unwrap_or_clone(namespace.path);
path.push(namespace.id);
InternalName::new(basename, path, loc)
}
pub fn loc(&self) -> Option<&Loc> {
self.loc.as_ref()
}
pub fn basename(&self) -> &Id {
&self.id
}
pub fn namespace_components(&self) -> impl Iterator<Item = &Id> {
self.path.iter()
}
pub fn namespace(&self) -> String {
self.path.iter().join("::")
}
pub fn qualify_with(&self, namespace: Option<&InternalName>) -> InternalName {
if self.is_unqualified() {
match namespace {
Some(namespace) => Self::new(
self.basename().clone(),
namespace
.namespace_components()
.chain(std::iter::once(namespace.basename()))
.cloned(),
self.loc.clone(),
),
None => self.clone(),
}
} else {
self.clone()
}
}
pub fn qualify_with_name(&self, namespace: Option<&Name>) -> InternalName {
let ns = namespace.map(AsRef::as_ref);
self.qualify_with(ns)
}
pub fn is_unqualified(&self) -> bool {
self.path.is_empty()
}
pub fn is_reserved(&self) -> bool {
self.path
.iter()
.chain(std::iter::once(&self.id))
.any(|id| id.is_reserved())
}
}
impl std::fmt::Display for InternalName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for elem in self.path.as_ref() {
write!(f, "{elem}::")?;
}
write!(f, "{}", self.id)?;
Ok(())
}
}
impl Serialize for InternalName {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.to_smolstr().serialize(serializer)
}
}
impl std::str::FromStr for InternalName {
type Err = ParseErrors;
fn from_str(s: &str) -> Result<Self, Self::Err> {
crate::parser::parse_internal_name(s)
}
}
impl FromNormalizedStr for InternalName {
fn describe_self() -> &'static str {
"internal name"
}
}
#[cfg(feature = "arbitrary")]
impl<'a> arbitrary::Arbitrary<'a> for InternalName {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let path_size = u.int_in_range(0..=8)?;
Ok(Self {
id: u.arbitrary()?,
path: Arc::new(
(0..path_size)
.map(|_| u.arbitrary())
.collect::<Result<Vec<Id>, _>>()?,
),
loc: None,
})
}
}
struct NameVisitor;
impl serde::de::Visitor<'_> for NameVisitor {
type Value = InternalName;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a name consisting of an optional namespace and id")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
InternalName::from_normalized_str(value)
.map_err(|err| serde::de::Error::custom(format!("invalid name `{value}`: {err}")))
}
}
impl<'de> Deserialize<'de> for InternalName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(NameVisitor)
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct SlotId(pub(crate) ValidSlotId);
impl SlotId {
pub fn principal() -> Self {
Self(ValidSlotId::Principal)
}
pub fn resource() -> Self {
Self(ValidSlotId::Resource)
}
pub fn is_principal(&self) -> bool {
matches!(self, Self(ValidSlotId::Principal))
}
pub fn is_resource(&self) -> bool {
matches!(self, Self(ValidSlotId::Resource))
}
}
impl From<PrincipalOrResource> for SlotId {
fn from(v: PrincipalOrResource) -> Self {
match v {
PrincipalOrResource::Principal => SlotId::principal(),
PrincipalOrResource::Resource => SlotId::resource(),
}
}
}
impl std::fmt::Display for SlotId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub(crate) enum ValidSlotId {
#[serde(rename = "?principal")]
Principal,
#[serde(rename = "?resource")]
Resource,
}
impl std::fmt::Display for ValidSlotId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
ValidSlotId::Principal => "principal",
ValidSlotId::Resource => "resource",
};
write!(f, "?{s}")
}
}
#[derive(Educe, Debug, Clone)]
#[educe(PartialEq, Eq, Hash)]
pub struct Slot {
pub id: SlotId,
#[educe(PartialEq(ignore))]
#[educe(Hash(ignore))]
pub loc: Option<Loc>,
}
#[cfg(test)]
mod vars_test {
use super::*;
#[test]
fn vars_correct() {
SlotId::principal();
SlotId::resource();
}
#[test]
fn display() {
assert_eq!(format!("{}", SlotId::principal()), "?principal")
}
}
#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash, Serialize, RefCast)]
#[repr(transparent)]
#[serde(transparent)]
pub struct Name(pub(crate) InternalName);
impl From<UnreservedId> for Name {
fn from(value: UnreservedId) -> Self {
Self::unqualified_name(value)
}
}
impl TryFrom<Name> for UnreservedId {
type Error = ();
fn try_from(value: Name) -> Result<Self, Self::Error> {
if value.0.is_unqualified() {
Ok(value.basename())
} else {
Err(())
}
}
}
impl Display for Name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl FromStr for Name {
type Err = ParseErrors;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let n: InternalName = s.parse()?;
n.try_into().map_err(ParseErrors::singleton)
}
}
#[expect(clippy::unwrap_used, reason = "this is a valid Regex pattern")]
static VALID_ANY_IDENT_REGEX: std::sync::LazyLock<Regex> =
std::sync::LazyLock::new(|| Regex::new("^[_a-zA-Z][_a-zA-Z0-9]*$").unwrap());
#[expect(clippy::unwrap_used, reason = "this is a valid Regex pattern")]
static VALID_NAME_REGEX: std::sync::LazyLock<Regex> = std::sync::LazyLock::new(|| {
Regex::new("^[_a-zA-Z][_a-zA-Z0-9]*(?:::[_a-zA-Z][_a-zA-Z0-9]*)*$").unwrap()
});
static RESERVED_IDS: std::sync::LazyLock<HashSet<&'static str>> = std::sync::LazyLock::new(|| {
vec![
"true", "false", "if", "then", "else", "in", "is", "like", "has",
"__cedar",
]
.into_iter()
.collect()
});
pub fn is_normalized_ident(s: &str) -> bool {
VALID_ANY_IDENT_REGEX.is_match(s) && !RESERVED_IDS.contains(s)
}
impl FromNormalizedStr for Name {
fn from_normalized_str(s: &str) -> Result<Self, ParseErrors> {
if !VALID_NAME_REGEX.is_match(s) {
return Err(Self::parse_err_from_str(s));
}
let path_parts: Vec<&str> = s.split("::").collect();
if path_parts.iter().any(|s| RESERVED_IDS.contains(s)) {
return Err(Self::parse_err_from_str(s));
}
if let Some((last, prefix)) = path_parts.split_last() {
Ok(Self(InternalName::new(
Id::new_unchecked(*last),
prefix.iter().map(|part| Id::new_unchecked(*part)),
Some(Loc::new(0..(s.len()), s.into())),
)))
} else {
Err(Self::parse_err_from_str(s))
}
}
fn describe_self() -> &'static str {
"Name"
}
}
impl<'de> Deserialize<'de> for Name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer
.deserialize_str(NameVisitor)
.and_then(|n| n.try_into().map_err(serde::de::Error::custom))
}
}
impl Name {
pub fn parse_unqualified_name(s: &str) -> Result<Self, ParseErrors> {
InternalName::parse_unqualified_name(s)
.and_then(|n| n.try_into().map_err(ParseErrors::singleton))
}
pub fn unqualified_name(id: UnreservedId) -> Self {
Self(InternalName::unqualified_name(id.0, None))
}
pub fn basename_as_ref(&self) -> &Id {
self.0.basename()
}
pub fn basename(&self) -> UnreservedId {
#![allow(
clippy::unwrap_used,
reason = "Any component of a `Name` is a `UnreservedId`"
)]
self.0.basename().clone().try_into().unwrap()
}
pub fn is_unqualified(&self) -> bool {
self.0.is_unqualified()
}
pub fn qualify_with(&self, namespace: Option<&InternalName>) -> InternalName {
self.0.qualify_with(namespace)
}
pub fn qualify_with_name(&self, namespace: Option<&Self>) -> Self {
Self(self.as_ref().qualify_with(namespace.map(|n| n.as_ref())))
}
pub fn loc(&self) -> Option<&Loc> {
self.0.loc()
}
fn parse_err_from_str(s: &str) -> ParseErrors {
match Self::from_str(s) {
Err(parse_err) => parse_err,
Ok(parsed) => {
let normalized_src = parsed.to_string();
let diff_byte = s
.bytes()
.zip(normalized_src.bytes())
.enumerate()
.find(|(_, (b0, b1))| b0 != b1)
.map(|(idx, _)| idx)
.unwrap_or_else(|| s.len().min(normalized_src.len()));
ParseErrors::singleton(ParseError::ToAST(ToASTError::new(
ToASTErrorKind::NonNormalizedString {
kind: Self::describe_self(),
src: s.to_string(),
normalized_src,
},
Some(Loc::new(diff_byte, s.into())),
)))
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Error, Diagnostic, Hash)]
#[error("The name `{0}` contains `__cedar`, which is reserved")]
pub struct ReservedNameError(pub(crate) InternalName);
impl ReservedNameError {
pub fn name(&self) -> &InternalName {
&self.0
}
}
impl From<ReservedNameError> for ParseError {
fn from(value: ReservedNameError) -> Self {
ParseError::ToAST(ToASTError::new(
value.clone().into(),
value.0.loc.clone().or_else(|| {
let name_str = value.0.to_string();
Some(Loc::new(0..(name_str.len()), name_str.into()))
}),
))
}
}
impl TryFrom<InternalName> for Name {
type Error = ReservedNameError;
fn try_from(value: InternalName) -> Result<Self, Self::Error> {
if value.is_reserved() {
Err(ReservedNameError(value))
} else {
Ok(Self(value))
}
}
}
impl<'a> TryFrom<&'a InternalName> for &'a Name {
type Error = ReservedNameError;
fn try_from(value: &'a InternalName) -> Result<&'a Name, ReservedNameError> {
if value.is_reserved() {
Err(ReservedNameError(value.clone()))
} else {
Ok(<Name as RefCast>::ref_cast(value))
}
}
}
impl From<Name> for InternalName {
fn from(value: Name) -> Self {
value.0
}
}
impl AsRef<InternalName> for Name {
fn as_ref(&self) -> &InternalName {
&self.0
}
}
#[cfg(feature = "arbitrary")]
impl<'a> arbitrary::Arbitrary<'a> for Name {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let path_size = u.int_in_range(0..=8)?;
let basename: UnreservedId = u.arbitrary()?;
let path: Vec<UnreservedId> = (0..path_size)
.map(|_| u.arbitrary())
.collect::<Result<Vec<_>, _>>()?;
let name = InternalName::new(basename.into(), path.into_iter().map(|id| id.into()), None);
#[expect(
clippy::unwrap_used,
reason = "`name` is made of `UnreservedId`s and thus should be a valid `Name`"
)]
Ok(name.try_into().unwrap())
}
fn size_hint(depth: usize) -> (usize, Option<usize>) {
<InternalName as arbitrary::Arbitrary>::size_hint(depth)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn normalized_name() {
InternalName::from_normalized_str("foo").expect("should be OK");
InternalName::from_normalized_str("foo::bar").expect("should be OK");
InternalName::from_normalized_str(r#"foo::"bar""#).expect_err("shouldn't be OK");
InternalName::from_normalized_str(" foo").expect_err("shouldn't be OK");
InternalName::from_normalized_str("foo ").expect_err("shouldn't be OK");
InternalName::from_normalized_str("foo\n").expect_err("shouldn't be OK");
InternalName::from_normalized_str("foo//comment").expect_err("shouldn't be OK");
}
#[test]
fn qualify_with() {
assert_eq!(
"foo::bar::baz",
InternalName::from_normalized_str("baz")
.unwrap()
.qualify_with(Some(&"foo::bar".parse().unwrap()))
.to_smolstr()
);
assert_eq!(
"C::D",
InternalName::from_normalized_str("C::D")
.unwrap()
.qualify_with(Some(&"A::B".parse().unwrap()))
.to_smolstr()
);
assert_eq!(
"A::B::C::D",
InternalName::from_normalized_str("D")
.unwrap()
.qualify_with(Some(&"A::B::C".parse().unwrap()))
.to_smolstr()
);
assert_eq!(
"B::C::D",
InternalName::from_normalized_str("B::C::D")
.unwrap()
.qualify_with(Some(&"A".parse().unwrap()))
.to_smolstr()
);
assert_eq!(
"A",
InternalName::from_normalized_str("A")
.unwrap()
.qualify_with(None)
.to_smolstr()
)
}
#[test]
fn test_reserved() {
for n in [
"__cedar",
"__cedar::A",
"__cedar::A::B",
"A::__cedar",
"A::__cedar::B",
] {
assert!(InternalName::from_normalized_str(n).unwrap().is_reserved());
}
for n in ["__cedarr", "A::_cedar", "A::___cedar::B"] {
assert!(!InternalName::from_normalized_str(n).unwrap().is_reserved());
}
}
#[test]
fn test_name_identifier_intersection() {
let not_reserved_for_ids = [
"permit",
"forbid",
"when",
"unless",
"principal",
"action",
"resource",
"context",
];
for id in not_reserved_for_ids {
Name::from_normalized_str(&format!("A::{id}")).unwrap();
}
for id in RESERVED_IDS.iter() {
Name::from_normalized_str(&format!("A::{id}")).unwrap_err();
}
}
}