1use crate::entities::{ContextJsonDeserializationError, ContextJsonParser, NullContextSchema};
18use crate::evaluator::{EvaluationError, RestrictedEvaluator};
19use crate::extensions::Extensions;
20use crate::parser::Loc;
21use miette::Diagnostic;
22use serde::Serialize;
23use smol_str::SmolStr;
24use std::sync::Arc;
25use thiserror::Error;
26
27use super::{
28 BorrowedRestrictedExpr, EntityUID, Expr, ExprConstructionError, ExprKind, PartialValue,
29 PartialValueSerializedAsExpr, RestrictedExpr, Unknown, Value, ValueKind, Var,
30};
31
32#[derive(Debug, Clone, Serialize)]
34pub struct Request {
35 pub(crate) principal: EntityUIDEntry,
37
38 pub(crate) action: EntityUIDEntry,
40
41 pub(crate) resource: EntityUIDEntry,
43
44 pub(crate) context: Option<Context>,
47}
48
49#[derive(Debug, Clone, Serialize)]
53pub enum EntityUIDEntry {
54 Known {
56 euid: Arc<EntityUID>,
58 loc: Option<Loc>,
60 },
61 Unknown {
63 loc: Option<Loc>,
65 },
66}
67
68impl EntityUIDEntry {
69 pub fn evaluate(&self, var: Var) -> PartialValue {
73 match self {
74 EntityUIDEntry::Known { euid, loc } => {
75 Value::new(Arc::unwrap_or_clone(Arc::clone(euid)), loc.clone()).into()
76 }
77 EntityUIDEntry::Unknown { loc } => Expr::unknown(Unknown::new_untyped(var.to_string()))
78 .with_maybe_source_loc(loc.clone())
79 .into(),
80 }
81 }
82
83 pub fn concrete(euid: EntityUID, loc: Option<Loc>) -> Self {
85 Self::Known {
86 euid: Arc::new(euid),
87 loc,
88 }
89 }
90
91 pub fn uid(&self) -> Option<&EntityUID> {
93 match self {
94 Self::Known { euid, .. } => Some(euid),
95 Self::Unknown { .. } => None,
96 }
97 }
98}
99
100impl Request {
101 pub fn new<S: RequestSchema>(
106 principal: (EntityUID, Option<Loc>),
107 action: (EntityUID, Option<Loc>),
108 resource: (EntityUID, Option<Loc>),
109 context: Context,
110 schema: Option<&S>,
111 extensions: Extensions<'_>,
112 ) -> Result<Self, S::Error> {
113 let req = Self {
114 principal: EntityUIDEntry::concrete(principal.0, principal.1),
115 action: EntityUIDEntry::concrete(action.0, action.1),
116 resource: EntityUIDEntry::concrete(resource.0, resource.1),
117 context: Some(context),
118 };
119 if let Some(schema) = schema {
120 schema.validate_request(&req, extensions)?;
121 }
122 Ok(req)
123 }
124
125 pub fn new_with_unknowns<S: RequestSchema>(
131 principal: EntityUIDEntry,
132 action: EntityUIDEntry,
133 resource: EntityUIDEntry,
134 context: Option<Context>,
135 schema: Option<&S>,
136 extensions: Extensions<'_>,
137 ) -> Result<Self, S::Error> {
138 let req = Self {
139 principal,
140 action,
141 resource,
142 context,
143 };
144 if let Some(schema) = schema {
145 schema.validate_request(&req, extensions)?;
146 }
147 Ok(req)
148 }
149
150 pub fn new_unchecked(
153 principal: EntityUIDEntry,
154 action: EntityUIDEntry,
155 resource: EntityUIDEntry,
156 context: Option<Context>,
157 ) -> Self {
158 Self {
159 principal,
160 action,
161 resource,
162 context,
163 }
164 }
165
166 pub fn principal(&self) -> &EntityUIDEntry {
168 &self.principal
169 }
170
171 pub fn action(&self) -> &EntityUIDEntry {
173 &self.action
174 }
175
176 pub fn resource(&self) -> &EntityUIDEntry {
178 &self.resource
179 }
180
181 pub fn context(&self) -> Option<&Context> {
184 self.context.as_ref()
185 }
186}
187
188impl std::fmt::Display for Request {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 let display_euid = |maybe_euid: &EntityUIDEntry| match maybe_euid {
191 EntityUIDEntry::Known { euid, .. } => format!("{euid}"),
192 EntityUIDEntry::Unknown { .. } => "unknown".to_string(),
193 };
194 write!(
195 f,
196 "request with principal {}, action {}, resource {}, and context {}",
197 display_euid(&self.principal),
198 display_euid(&self.action),
199 display_euid(&self.resource),
200 match &self.context {
201 Some(x) => format!("{x}"),
202 None => "unknown".to_string(),
203 }
204 )
205 }
206}
207
208#[derive(Debug, Clone, PartialEq, Serialize)]
210pub struct Context {
211 #[serde(flatten)]
218 context: PartialValueSerializedAsExpr,
219}
220
221impl Context {
222 pub fn empty() -> Self {
226 Self {
227 context: PartialValue::Value(Value::empty_record(None)).into(),
228 }
229 }
230
231 pub fn from_expr(
238 expr: BorrowedRestrictedExpr<'_>,
239 extensions: Extensions<'_>,
240 ) -> Result<Self, ContextCreationError> {
241 match expr.expr_kind() {
242 ExprKind::Record { .. } => {
245 let evaluator = RestrictedEvaluator::new(&extensions);
246 let pval = evaluator.partial_interpret(expr)?;
247 Ok(Self {
248 context: pval.into(),
249 })
250 }
251 _ => Err(ContextCreationError::NotARecord {
252 expr: Box::new(expr.to_owned()),
253 }),
254 }
255 }
256
257 pub fn from_pairs(
263 pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
264 extensions: Extensions<'_>,
265 ) -> Result<Self, ContextCreationError> {
266 match RestrictedExpr::record(pairs) {
268 Ok(record) => Self::from_expr(record.as_borrowed(), extensions),
269 Err(ExprConstructionError::DuplicateKey(err)) => {
270 Err(ExprConstructionError::DuplicateKey(err.with_context("in context")).into())
271 }
272 }
273 }
274
275 pub fn from_json_str(json: &str) -> Result<Self, ContextJsonDeserializationError> {
282 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
284 .from_json_str(json)
285 }
286
287 pub fn from_json_value(
294 json: serde_json::Value,
295 ) -> Result<Self, ContextJsonDeserializationError> {
296 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
298 .from_json_value(json)
299 }
300
301 pub fn from_json_file(
308 json: impl std::io::Read,
309 ) -> Result<Self, ContextJsonDeserializationError> {
310 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
312 .from_json_file(json)
313 }
314
315 fn into_values(self) -> Box<dyn Iterator<Item = (SmolStr, PartialValue)>> {
321 #[allow(clippy::panic)]
323 match self.context.into() {
324 PartialValue::Value(Value {
325 value: ValueKind::Record(record),
326 ..
327 }) => Box::new(
328 Arc::unwrap_or_clone(record)
329 .into_iter()
330 .map(|(k, v)| (k, PartialValue::Value(v))),
331 ),
332 PartialValue::Residual(expr) => match expr.into_expr_kind() {
333 ExprKind::Record(map) => Box::new(
334 Arc::unwrap_or_clone(map)
335 .into_iter()
336 .map(|(k, v)| (k, PartialValue::Residual(v))),
337 ),
338 kind => panic!("internal invariant violation: expected a record, got {kind:?}"),
339 },
340 v => panic!("internal invariant violation: expected a record, got {v:?}"),
341 }
342 }
343}
344
345mod iter {
347 use super::*;
348
349 pub struct IntoIter(pub(super) Box<dyn Iterator<Item = (SmolStr, PartialValue)>>);
351
352 impl std::fmt::Debug for IntoIter {
353 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
354 write!(f, "IntoIter(<context>)")
355 }
356 }
357
358 impl Iterator for IntoIter {
359 type Item = (SmolStr, PartialValue);
360
361 fn next(&mut self) -> Option<Self::Item> {
362 self.0.next()
363 }
364 }
365}
366
367impl IntoIterator for Context {
368 type Item = (SmolStr, PartialValue);
369
370 type IntoIter = iter::IntoIter;
371
372 fn into_iter(self) -> Self::IntoIter {
373 iter::IntoIter(self.into_values())
374 }
375}
376
377impl AsRef<PartialValue> for Context {
378 fn as_ref(&self) -> &PartialValue {
379 &self.context
380 }
381}
382
383impl From<Context> for PartialValue {
384 fn from(ctx: Context) -> PartialValue {
385 ctx.context.into()
386 }
387}
388
389impl std::default::Default for Context {
390 fn default() -> Context {
391 Context::empty()
392 }
393}
394
395impl std::fmt::Display for Context {
396 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397 write!(f, "{}", self.context)
398 }
399}
400
401#[derive(Debug, Diagnostic, Error)]
403pub enum ContextCreationError {
404 #[error("expression is not a record: `{expr}`")]
406 NotARecord {
407 expr: Box<RestrictedExpr>,
409 },
410 #[error(transparent)]
412 #[diagnostic(transparent)]
413 Evaluation(#[from] EvaluationError),
414 #[error(transparent)]
417 #[diagnostic(transparent)]
418 ExprConstruction(#[from] ExprConstructionError),
419}
420
421pub trait RequestSchema {
423 type Error: miette::Diagnostic;
425 fn validate_request(
427 &self,
428 request: &Request,
429 extensions: Extensions<'_>,
430 ) -> Result<(), Self::Error>;
431}
432
433#[derive(Debug, Clone)]
435pub struct RequestSchemaAllPass;
436impl RequestSchema for RequestSchemaAllPass {
437 type Error = Infallible;
438 fn validate_request(
439 &self,
440 _request: &Request,
441 _extensions: Extensions<'_>,
442 ) -> Result<(), Self::Error> {
443 Ok(())
444 }
445}
446
447#[derive(Debug, Diagnostic, Error)]
450#[error(transparent)]
451pub struct Infallible(pub std::convert::Infallible);
452
453#[cfg(test)]
454mod test {
455 use super::*;
456 use cool_asserts::assert_matches;
457
458 #[test]
459 fn test_json_from_str_non_record() {
460 assert_matches!(
461 Context::from_expr(RestrictedExpr::val("1").as_borrowed(), Extensions::none()),
462 Err(ContextCreationError::NotARecord { .. })
463 );
464 assert_matches!(
465 Context::from_json_str("1"),
466 Err(ContextJsonDeserializationError::ContextCreation(
467 ContextCreationError::NotARecord { .. }
468 ))
469 );
470 }
471}