use crate::entities::{ContextJsonDeserializationError, ContextJsonParser, NullContextSchema};
use crate::evaluator::{EvaluationError, RestrictedEvaluator};
use crate::extensions::Extensions;
use crate::parser::Loc;
use miette::Diagnostic;
use serde::Serialize;
use smol_str::SmolStr;
use std::sync::Arc;
use thiserror::Error;
use super::{
BorrowedRestrictedExpr, EntityUID, Expr, ExprConstructionError, ExprKind, PartialValue,
PartialValueSerializedAsExpr, RestrictedExpr, Unknown, Value, ValueKind, Var,
};
#[derive(Debug, Clone, Serialize)]
pub struct Request {
pub(crate) principal: EntityUIDEntry,
pub(crate) action: EntityUIDEntry,
pub(crate) resource: EntityUIDEntry,
pub(crate) context: Option<Context>,
}
#[derive(Debug, Clone, Serialize)]
pub enum EntityUIDEntry {
Known {
euid: Arc<EntityUID>,
loc: Option<Loc>,
},
Unknown {
loc: Option<Loc>,
},
}
impl EntityUIDEntry {
pub fn evaluate(&self, var: Var) -> PartialValue {
match self {
EntityUIDEntry::Known { euid, loc } => {
Value::new(Arc::unwrap_or_clone(Arc::clone(euid)), loc.clone()).into()
}
EntityUIDEntry::Unknown { loc } => Expr::unknown(Unknown::new_untyped(var.to_string()))
.with_maybe_source_loc(loc.clone())
.into(),
}
}
pub fn concrete(euid: EntityUID, loc: Option<Loc>) -> Self {
Self::Known {
euid: Arc::new(euid),
loc,
}
}
pub fn uid(&self) -> Option<&EntityUID> {
match self {
Self::Known { euid, .. } => Some(euid),
Self::Unknown { .. } => None,
}
}
}
impl Request {
pub fn new<S: RequestSchema>(
principal: (EntityUID, Option<Loc>),
action: (EntityUID, Option<Loc>),
resource: (EntityUID, Option<Loc>),
context: Context,
schema: Option<&S>,
extensions: Extensions<'_>,
) -> Result<Self, S::Error> {
let req = Self {
principal: EntityUIDEntry::concrete(principal.0, principal.1),
action: EntityUIDEntry::concrete(action.0, action.1),
resource: EntityUIDEntry::concrete(resource.0, resource.1),
context: Some(context),
};
if let Some(schema) = schema {
schema.validate_request(&req, extensions)?;
}
Ok(req)
}
pub fn new_with_unknowns<S: RequestSchema>(
principal: EntityUIDEntry,
action: EntityUIDEntry,
resource: EntityUIDEntry,
context: Option<Context>,
schema: Option<&S>,
extensions: Extensions<'_>,
) -> Result<Self, S::Error> {
let req = Self {
principal,
action,
resource,
context,
};
if let Some(schema) = schema {
schema.validate_request(&req, extensions)?;
}
Ok(req)
}
pub fn new_unchecked(
principal: EntityUIDEntry,
action: EntityUIDEntry,
resource: EntityUIDEntry,
context: Option<Context>,
) -> Self {
Self {
principal,
action,
resource,
context,
}
}
pub fn principal(&self) -> &EntityUIDEntry {
&self.principal
}
pub fn action(&self) -> &EntityUIDEntry {
&self.action
}
pub fn resource(&self) -> &EntityUIDEntry {
&self.resource
}
pub fn context(&self) -> Option<&Context> {
self.context.as_ref()
}
}
impl std::fmt::Display for Request {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let display_euid = |maybe_euid: &EntityUIDEntry| match maybe_euid {
EntityUIDEntry::Known { euid, .. } => format!("{euid}"),
EntityUIDEntry::Unknown { .. } => "unknown".to_string(),
};
write!(
f,
"request with principal {}, action {}, resource {}, and context {}",
display_euid(&self.principal),
display_euid(&self.action),
display_euid(&self.resource),
match &self.context {
Some(x) => format!("{x}"),
None => "unknown".to_string(),
}
)
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct Context {
#[serde(flatten)]
context: PartialValueSerializedAsExpr,
}
impl Context {
pub fn empty() -> Self {
Self {
context: PartialValue::Value(Value::empty_record(None)).into(),
}
}
pub fn from_expr(
expr: BorrowedRestrictedExpr<'_>,
extensions: Extensions<'_>,
) -> Result<Self, ContextCreationError> {
match expr.expr_kind() {
ExprKind::Record { .. } => {
let evaluator = RestrictedEvaluator::new(&extensions);
let pval = evaluator.partial_interpret(expr)?;
Ok(Self {
context: pval.into(),
})
}
_ => Err(ContextCreationError::NotARecord {
expr: Box::new(expr.to_owned()),
}),
}
}
pub fn from_pairs(
pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
extensions: Extensions<'_>,
) -> Result<Self, ContextCreationError> {
match RestrictedExpr::record(pairs) {
Ok(record) => Self::from_expr(record.as_borrowed(), extensions),
Err(ExprConstructionError::DuplicateKey(err)) => {
Err(ExprConstructionError::DuplicateKey(err.with_context("in context")).into())
}
}
}
pub fn from_json_str(json: &str) -> Result<Self, ContextJsonDeserializationError> {
ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
.from_json_str(json)
}
pub fn from_json_value(
json: serde_json::Value,
) -> Result<Self, ContextJsonDeserializationError> {
ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
.from_json_value(json)
}
pub fn from_json_file(
json: impl std::io::Read,
) -> Result<Self, ContextJsonDeserializationError> {
ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
.from_json_file(json)
}
fn into_values(self) -> Box<dyn Iterator<Item = (SmolStr, PartialValue)>> {
#[allow(clippy::panic)]
match self.context.into() {
PartialValue::Value(Value {
value: ValueKind::Record(record),
..
}) => Box::new(
Arc::unwrap_or_clone(record)
.into_iter()
.map(|(k, v)| (k, PartialValue::Value(v))),
),
PartialValue::Residual(expr) => match expr.into_expr_kind() {
ExprKind::Record(map) => Box::new(
Arc::unwrap_or_clone(map)
.into_iter()
.map(|(k, v)| (k, PartialValue::Residual(v))),
),
kind => panic!("internal invariant violation: expected a record, got {kind:?}"),
},
v => panic!("internal invariant violation: expected a record, got {v:?}"),
}
}
}
mod iter {
use super::*;
pub struct IntoIter(pub(super) Box<dyn Iterator<Item = (SmolStr, PartialValue)>>);
impl std::fmt::Debug for IntoIter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "IntoIter(<context>)")
}
}
impl Iterator for IntoIter {
type Item = (SmolStr, PartialValue);
fn next(&mut self) -> Option<Self::Item> {
self.0.next()
}
}
}
impl IntoIterator for Context {
type Item = (SmolStr, PartialValue);
type IntoIter = iter::IntoIter;
fn into_iter(self) -> Self::IntoIter {
iter::IntoIter(self.into_values())
}
}
impl AsRef<PartialValue> for Context {
fn as_ref(&self) -> &PartialValue {
&self.context
}
}
impl From<Context> for PartialValue {
fn from(ctx: Context) -> PartialValue {
ctx.context.into()
}
}
impl std::default::Default for Context {
fn default() -> Context {
Context::empty()
}
}
impl std::fmt::Display for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.context)
}
}
#[derive(Debug, Diagnostic, Error)]
pub enum ContextCreationError {
#[error("expression is not a record: `{expr}`")]
NotARecord {
expr: Box<RestrictedExpr>,
},
#[error(transparent)]
#[diagnostic(transparent)]
Evaluation(#[from] EvaluationError),
#[error(transparent)]
#[diagnostic(transparent)]
ExprConstruction(#[from] ExprConstructionError),
}
pub trait RequestSchema {
type Error: miette::Diagnostic;
fn validate_request(
&self,
request: &Request,
extensions: Extensions<'_>,
) -> Result<(), Self::Error>;
}
#[derive(Debug, Clone)]
pub struct RequestSchemaAllPass;
impl RequestSchema for RequestSchemaAllPass {
type Error = Infallible;
fn validate_request(
&self,
_request: &Request,
_extensions: Extensions<'_>,
) -> Result<(), Self::Error> {
Ok(())
}
}
#[derive(Debug, Diagnostic, Error)]
#[error(transparent)]
pub struct Infallible(pub std::convert::Infallible);
#[cfg(test)]
mod test {
use super::*;
use cool_asserts::assert_matches;
#[test]
fn test_json_from_str_non_record() {
assert_matches!(
Context::from_expr(RestrictedExpr::val("1").as_borrowed(), Extensions::none()),
Err(ContextCreationError::NotARecord { .. })
);
assert_matches!(
Context::from_json_str("1"),
Err(ContextJsonDeserializationError::ContextCreation(
ContextCreationError::NotARecord { .. }
))
);
}
}