use crate::entities::json::{
err::JsonSerializationError, ContextJsonDeserializationError, ContextJsonParser,
NullContextSchema,
};
use crate::entities::CedarValueJson;
use crate::evaluator::{EvaluationError, RestrictedEvaluator};
use crate::extensions::Extensions;
use crate::parser::Loc;
use miette::Diagnostic;
use smol_str::{SmolStr, ToSmolStr};
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;
use thiserror::Error;
use super::{
BorrowedRestrictedExpr, BoundedDisplay, EntityType, EntityUID, Expr, ExprKind,
ExpressionConstructionError, PartialValue, RestrictedExpr, Unknown, Value, ValueKind, Var,
};
#[derive(Debug, Clone)]
pub struct Request {
pub(crate) principal: EntityUIDEntry,
pub(crate) action: EntityUIDEntry,
pub(crate) resource: EntityUIDEntry,
pub(crate) context: Option<Context>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(
feature = "entity-manifest",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct RequestType {
pub principal: EntityType,
pub action: EntityUID,
pub resource: EntityType,
}
#[derive(Debug, Clone)]
pub enum EntityUIDEntry {
Known {
euid: Arc<EntityUID>,
loc: Option<Loc>,
},
Unknown {
ty: Option<EntityType>,
loc: Option<Loc>,
},
}
impl From<EntityUID> for EntityUIDEntry {
fn from(euid: EntityUID) -> Self {
Self::Known {
euid: Arc::new(euid.clone()),
loc: match &euid {
EntityUID::EntityUID(euid) => euid.loc(),
#[cfg(feature = "tolerant-ast")]
EntityUID::Error => None,
},
}
}
}
impl EntityUIDEntry {
pub fn evaluate(&self, var: Var) -> PartialValue {
match self {
EntityUIDEntry::Known { euid, loc } => {
Value::new(Arc::unwrap_or_clone(Arc::clone(euid)), loc.clone()).into()
}
EntityUIDEntry::Unknown { ty: None, loc } => {
Expr::unknown(Unknown::new_untyped(var.to_smolstr()))
.with_maybe_source_loc(loc.clone())
.into()
}
EntityUIDEntry::Unknown {
ty: Some(known_type),
loc,
} => Expr::unknown(Unknown::new_with_type(
var.to_smolstr(),
super::Type::Entity {
ty: known_type.clone(),
},
))
.with_maybe_source_loc(loc.clone())
.into(),
}
}
pub fn known(euid: EntityUID, loc: Option<Loc>) -> Self {
Self::Known {
euid: Arc::new(euid),
loc,
}
}
pub fn unknown() -> Self {
Self::Unknown {
ty: None,
loc: None,
}
}
pub fn unknown_with_type(ty: EntityType, loc: Option<Loc>) -> Self {
Self::Unknown { ty: Some(ty), loc }
}
pub fn uid(&self) -> Option<&EntityUID> {
match self {
Self::Known { euid, .. } => Some(euid),
Self::Unknown { .. } => None,
}
}
pub fn get_type(&self) -> Option<&EntityType> {
match self {
Self::Known { euid, .. } => Some(euid.entity_type()),
Self::Unknown { ty, .. } => ty.as_ref(),
}
}
}
impl Request {
pub fn new<S: RequestSchema>(
principal: (EntityUID, Option<Loc>),
action: (EntityUID, Option<Loc>),
resource: (EntityUID, Option<Loc>),
context: Context,
schema: Option<&S>,
extensions: &Extensions<'_>,
) -> Result<Self, S::Error> {
let req = Self {
principal: EntityUIDEntry::known(principal.0, principal.1),
action: EntityUIDEntry::known(action.0, action.1),
resource: EntityUIDEntry::known(resource.0, resource.1),
context: Some(context),
};
if let Some(schema) = schema {
schema.validate_request(&req, extensions)?;
}
Ok(req)
}
pub fn new_with_unknowns<S: RequestSchema>(
principal: EntityUIDEntry,
action: EntityUIDEntry,
resource: EntityUIDEntry,
context: Option<Context>,
schema: Option<&S>,
extensions: &Extensions<'_>,
) -> Result<Self, S::Error> {
let req = Self {
principal,
action,
resource,
context,
};
if let Some(schema) = schema {
schema.validate_request(&req, extensions)?;
}
Ok(req)
}
pub fn new_unchecked(
principal: EntityUIDEntry,
action: EntityUIDEntry,
resource: EntityUIDEntry,
context: Option<Context>,
) -> Self {
Self {
principal,
action,
resource,
context,
}
}
pub fn principal(&self) -> &EntityUIDEntry {
&self.principal
}
pub fn action(&self) -> &EntityUIDEntry {
&self.action
}
pub fn resource(&self) -> &EntityUIDEntry {
&self.resource
}
pub fn context(&self) -> Option<&Context> {
self.context.as_ref()
}
pub fn to_request_type(&self) -> Option<RequestType> {
Some(RequestType {
principal: self.principal().uid()?.entity_type().clone(),
action: self.action().uid()?.clone(),
resource: self.resource().uid()?.entity_type().clone(),
})
}
}
impl std::fmt::Display for Request {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let display_euid = |maybe_euid: &EntityUIDEntry| match maybe_euid {
EntityUIDEntry::Known { euid, .. } => format!("{euid}"),
EntityUIDEntry::Unknown { ty: None, .. } => "unknown".to_string(),
EntityUIDEntry::Unknown {
ty: Some(known_type),
..
} => format!("unknown of type {known_type}"),
};
write!(
f,
"request with principal {}, action {}, resource {}, and context {}",
display_euid(&self.principal),
display_euid(&self.action),
display_euid(&self.resource),
match &self.context {
Some(x) => format!("{x}"),
None => "unknown".to_string(),
}
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Context {
Value(Arc<BTreeMap<SmolStr, Value>>),
RestrictedResidual(Arc<BTreeMap<SmolStr, Expr>>),
}
impl Context {
pub fn empty() -> Self {
Self::Value(Arc::new(BTreeMap::new()))
}
fn from_restricted_partial_val_unchecked(
value: PartialValue,
) -> Result<Self, ContextCreationError> {
match value {
PartialValue::Value(v) => {
if let ValueKind::Record(attrs) = v.value {
Ok(Context::Value(attrs))
} else {
Err(ContextCreationError::not_a_record(v.into()))
}
}
PartialValue::Residual(e) => {
if let ExprKind::Record(attrs) = e.expr_kind() {
Ok(Context::RestrictedResidual(attrs.clone()))
} else {
Err(ContextCreationError::not_a_record(e))
}
}
}
}
pub fn from_expr(
expr: BorrowedRestrictedExpr<'_>,
extensions: &Extensions<'_>,
) -> Result<Self, ContextCreationError> {
match expr.expr_kind() {
ExprKind::Record { .. } => {
let evaluator = RestrictedEvaluator::new(extensions);
let pval = evaluator.partial_interpret(expr)?;
#[expect(clippy::expect_used, reason = "See above")]
Ok(Self::from_restricted_partial_val_unchecked(pval).expect(
"`from_restricted_partial_val_unchecked` should succeed when called on a record.",
))
}
_ => Err(ContextCreationError::not_a_record(expr.to_owned().into())),
}
}
pub fn from_pairs(
pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
extensions: &Extensions<'_>,
) -> Result<Self, ContextCreationError> {
match RestrictedExpr::record(pairs) {
Ok(record) => Self::from_expr(record.as_borrowed(), extensions),
Err(ExpressionConstructionError::DuplicateKey(err)) => Err(
ExpressionConstructionError::DuplicateKey(err.with_context("in context")).into(),
),
}
}
pub fn from_json_str(json: &str) -> Result<Self, ContextJsonDeserializationError> {
ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
.from_json_str(json)
}
pub fn from_json_value(
json: serde_json::Value,
) -> Result<Self, ContextJsonDeserializationError> {
ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
.from_json_value(json)
}
pub fn from_json_file(
json: impl std::io::Read,
) -> Result<Self, ContextJsonDeserializationError> {
ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
.from_json_file(json)
}
pub fn to_json_value(&self) -> Result<serde_json::Value, JsonSerializationError> {
match self {
Self::Value(record) => record
.iter()
.map(|(k, v)| {
let cjson = CedarValueJson::from_value(v.clone())?;
Ok((k.to_string(), serde_json::to_value(cjson)?))
})
.collect(),
Self::RestrictedResidual(record) => record
.iter()
.map(|(k, v)| {
let cjson =
CedarValueJson::from_expr(BorrowedRestrictedExpr::new_unchecked(v))?;
Ok((k.to_string(), serde_json::to_value(cjson)?))
})
.collect(),
}
}
pub fn num_keys(&self) -> usize {
match self {
Context::Value(record) => record.len(),
Context::RestrictedResidual(record) => record.len(),
}
}
fn into_pairs(self) -> Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>> {
match self {
Context::Value(record) => Box::new(
Arc::unwrap_or_clone(record)
.into_iter()
.map(|(k, v)| (k, RestrictedExpr::from(v))),
),
Context::RestrictedResidual(record) => Box::new(
Arc::unwrap_or_clone(record)
.into_iter()
.map(|(k, v)| (k, RestrictedExpr::new_unchecked(v))),
),
}
}
pub fn substitute(self, mapping: &HashMap<SmolStr, Value>) -> Result<Self, EvaluationError> {
match self {
Context::RestrictedResidual(residual_context) => {
let expr = Expr::record_arc(residual_context).substitute(mapping);
let expr = BorrowedRestrictedExpr::new_unchecked(&expr);
let extns = Extensions::all_available();
let eval = RestrictedEvaluator::new(extns);
let partial_value = eval.partial_interpret(expr)?;
#[expect(clippy::expect_used, reason = "See above")]
Ok(
Self::from_restricted_partial_val_unchecked(partial_value).expect(
"`from_restricted_partial_val_unchecked` should succeed when called on a record.",
),
)
}
Context::Value(_) => Ok(self),
}
}
}
mod iter {
use super::*;
pub struct IntoIter(pub(super) Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>>);
impl std::fmt::Debug for IntoIter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "IntoIter(<context>)")
}
}
impl Iterator for IntoIter {
type Item = (SmolStr, RestrictedExpr);
fn next(&mut self) -> Option<Self::Item> {
self.0.next()
}
}
}
impl IntoIterator for Context {
type Item = (SmolStr, RestrictedExpr);
type IntoIter = iter::IntoIter;
fn into_iter(self) -> Self::IntoIter {
iter::IntoIter(self.into_pairs())
}
}
impl From<Context> for RestrictedExpr {
fn from(value: Context) -> Self {
match value {
Context::Value(attrs) => Value::record_arc(attrs, None).into(),
Context::RestrictedResidual(attrs) => {
RestrictedExpr::new_unchecked(Expr::record_arc(attrs))
}
}
}
}
impl From<Context> for PartialValue {
fn from(ctx: Context) -> PartialValue {
match ctx {
Context::Value(attrs) => Value::record_arc(attrs, None).into(),
Context::RestrictedResidual(attrs) => {
PartialValue::Residual(Expr::record_arc(attrs))
}
}
}
}
impl std::default::Default for Context {
fn default() -> Context {
Context::empty()
}
}
impl std::fmt::Display for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", PartialValue::from(self.clone()))
}
}
impl BoundedDisplay for Context {
fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
BoundedDisplay::fmt(&PartialValue::from(self.clone()), f, n)
}
}
#[derive(Debug, Diagnostic, Error)]
pub enum ContextCreationError {
#[error(transparent)]
#[diagnostic(transparent)]
NotARecord(#[from] context_creation_errors::NotARecord),
#[error(transparent)]
#[diagnostic(transparent)]
Evaluation(#[from] EvaluationError),
#[error(transparent)]
#[diagnostic(transparent)]
ExpressionConstruction(#[from] ExpressionConstructionError),
}
impl ContextCreationError {
pub(crate) fn not_a_record(expr: Expr) -> Self {
Self::NotARecord(context_creation_errors::NotARecord {
expr: Box::new(expr),
})
}
}
pub mod context_creation_errors {
use super::Expr;
use crate::impl_diagnostic_from_method_on_field;
use miette::Diagnostic;
use thiserror::Error;
#[derive(Debug, Error)]
#[error("expression is not a record: {expr}")]
pub struct NotARecord {
pub(super) expr: Box<Expr>,
}
impl Diagnostic for NotARecord {
impl_diagnostic_from_method_on_field!(expr, source_loc);
}
}
pub trait RequestSchema {
type Error: miette::Diagnostic;
fn validate_request(
&self,
request: &Request,
extensions: &Extensions<'_>,
) -> Result<(), Self::Error>;
fn validate_context<'a>(
&self,
context: &Context,
action: &EntityUID,
extensions: &Extensions<'a>,
) -> std::result::Result<(), Self::Error>;
fn validate_scope_variables(
&self,
principal: Option<&EntityUID>,
action: Option<&EntityUID>,
resource: Option<&EntityUID>,
) -> std::result::Result<(), Self::Error>;
}
#[derive(Debug, Clone)]
pub struct RequestSchemaAllPass;
impl RequestSchema for RequestSchemaAllPass {
type Error = Infallible;
fn validate_request(
&self,
_request: &Request,
_extensions: &Extensions<'_>,
) -> Result<(), Self::Error> {
Ok(())
}
fn validate_context<'a>(
&self,
_context: &Context,
_action: &EntityUID,
_extensions: &Extensions<'a>,
) -> std::result::Result<(), Self::Error> {
Ok(())
}
fn validate_scope_variables(
&self,
_principal: Option<&EntityUID>,
_action: Option<&EntityUID>,
_resource: Option<&EntityUID>,
) -> std::result::Result<(), Self::Error> {
Ok(())
}
}
#[derive(Debug, Diagnostic, Error)]
#[error(transparent)]
pub struct Infallible(pub std::convert::Infallible);
pub trait IsInfallible {
fn never_returns(self) -> std::convert::Infallible;
}
impl IsInfallible for std::convert::Infallible {
fn never_returns(self) -> std::convert::Infallible {
match self {}
}
}
impl IsInfallible for Infallible {
fn never_returns(self) -> std::convert::Infallible {
self.0.never_returns()
}
}
pub trait UnwrapInfallible<A> {
fn unwrap_infallible(self) -> A;
}
impl<A, B: IsInfallible> UnwrapInfallible<A> for Result<A, B> {
fn unwrap_infallible(self) -> A {
match self {
Ok(a) => a,
#[expect(unreachable_code, reason = "error type is uninhabited")]
Err(e) => match e.never_returns() {},
}
}
}
#[cfg(test)]
mod test {
use super::super::Name;
use super::*;
use cool_asserts::assert_matches;
use std::str::FromStr;
#[track_caller]
fn roundtrip_json(context: &Context) -> Context {
Context::from_json_value(context.to_json_value().unwrap()).unwrap()
}
#[test]
fn test_json_from_str_non_record() {
assert_matches!(
Context::from_expr(RestrictedExpr::val("1").as_borrowed(), Extensions::none()),
Err(ContextCreationError::NotARecord { .. })
);
assert_matches!(
Context::from_json_str("1"),
Err(ContextJsonDeserializationError::ContextCreation(
ContextCreationError::NotARecord { .. }
))
);
}
#[test]
fn test_roundtrip_empty() {
let context = Context::empty();
assert_eq!(context, roundtrip_json(&context));
}
#[test]
fn test_roundtrip_complex() {
let context = Context::from_pairs(
[
("b".into(), RestrictedExpr::val(false)),
("i".into(), RestrictedExpr::val(32)),
(
"s".into(),
RestrictedExpr::val("hi I have spaces and \" special ch@ract&rs: !{} \""),
),
(
"uid".into(),
RestrictedExpr::val(EntityUID::from_str("Group::\"admins\"").unwrap()),
),
(
"multi".into(),
RestrictedExpr::set([
RestrictedExpr::val(0),
RestrictedExpr::val(22),
RestrictedExpr::val(-310),
]),
),
(
"record".into(),
RestrictedExpr::record([
("inner".into(), RestrictedExpr::val(-210)),
(
"inner_uid".into(),
RestrictedExpr::val(EntityUID::from_str("Group::\"interns\"").unwrap()),
),
(
"inner_set".into(),
RestrictedExpr::set([
RestrictedExpr::val("my name is"),
RestrictedExpr::val("inigo montoya"),
]),
),
])
.unwrap(),
),
(
"dec".into(),
RestrictedExpr::call_extension_fn(
Name::parse_unqualified_name("decimal").unwrap(),
[RestrictedExpr::val("-1.111")],
),
),
(
"ipv6".into(),
RestrictedExpr::call_extension_fn(
Name::parse_unqualified_name("ip").unwrap(),
[RestrictedExpr::val("ffff::1/16")],
),
),
(
"dt".into(),
RestrictedExpr::call_extension_fn(
Name::parse_unqualified_name("datetime").unwrap(),
[RestrictedExpr::val("2026-01-01T03:04:05Z")],
),
),
],
&Extensions::all_available(),
)
.unwrap();
assert_eq!(context, roundtrip_json(&context));
}
}